|
|
""" |
|
|
Scene LoRA Training Script - Transformers + Safetensors |
|
|
Production-grade training with proper security and performance optimizations |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
AutoTokenizer, |
|
|
AutoModelForSeq2SeqLM |
|
|
) |
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model, |
|
|
TaskType, |
|
|
PeftModel, |
|
|
PeftConfig |
|
|
) |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
import json |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration for LoRA training.""" |
|
|
base_model: str = "google/mt5-small" |
|
|
output_dir: str = "./memo-scene-lora" |
|
|
rank: int = 32 |
|
|
alpha: int = 64 |
|
|
dropout: float = 0.1 |
|
|
target_modules: List[str] = None |
|
|
epochs: int = 3 |
|
|
batch_size: int = 4 |
|
|
learning_rate: float = 1e-4 |
|
|
warmup_steps: int = 100 |
|
|
save_steps: int = 500 |
|
|
logging_steps: int = 50 |
|
|
fp16: bool = True |
|
|
use_8bit: bool = False |
|
|
save_safetensors: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.target_modules is None: |
|
|
|
|
|
if "t5" in self.base_model.lower(): |
|
|
self.target_modules = ["q", "k", "v", "o"] |
|
|
elif "mt5" in self.base_model.lower(): |
|
|
self.target_modules = ["q", "k", "v", "o"] |
|
|
else: |
|
|
self.target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] |
|
|
|
|
|
class SceneLoRATrainer: |
|
|
""" |
|
|
Production-grade LoRA trainer with transformers integration. |
|
|
Ensures safetensors-only output and proper security measures. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: TrainingConfig): |
|
|
""" |
|
|
Initialize the trainer with configuration. |
|
|
|
|
|
Args: |
|
|
config: Training configuration |
|
|
""" |
|
|
self.config = config |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.peft_model = None |
|
|
|
|
|
logger.info("SceneLoRATrainer initialized") |
|
|
logger.info(f"Base model: {config.base_model}") |
|
|
logger.info(f"Output directory: {config.output_dir}") |
|
|
logger.info(f"Safetensors enabled: {config.save_safetensors}") |
|
|
|
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self._save_config() |
|
|
|
|
|
def _save_config(self): |
|
|
"""Save training configuration.""" |
|
|
config_dict = { |
|
|
"base_model": self.config.base_model, |
|
|
"rank": self.config.rank, |
|
|
"alpha": self.config.alpha, |
|
|
"dropout": self.config.dropout, |
|
|
"target_modules": self.config.target_modules, |
|
|
"epochs": self.config.epochs, |
|
|
"batch_size": self.config.batch_size, |
|
|
"learning_rate": self.config.learning_rate, |
|
|
"fp16": self.config.fp16, |
|
|
"use_8bit": self.config.use_8bit, |
|
|
"save_safetensors": self.config.save_safetensors, |
|
|
"timestamp": torch.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
config_path = os.path.join(self.config.output_dir, "training_config.json") |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
|
|
|
logger.info(f"Training configuration saved to {config_path}") |
|
|
|
|
|
def load_model(self): |
|
|
"""Load base model and tokenizer.""" |
|
|
try: |
|
|
logger.info("Loading base model and tokenizer...") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.config.base_model, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"torch_dtype": torch.float16 if self.config.fp16 else torch.float32, |
|
|
"device_map": "auto" if torch.cuda.is_available() else None |
|
|
} |
|
|
|
|
|
if self.config.use_8bit: |
|
|
model_kwargs["load_in_8bit"] = True |
|
|
|
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
self.config.base_model, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
self.model = self.model.to("cpu") |
|
|
|
|
|
logger.info(f"Base model loaded successfully") |
|
|
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
def setup_lora(self): |
|
|
"""Setup LoRA configuration and model.""" |
|
|
try: |
|
|
logger.info("Setting up LoRA configuration...") |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
task_type=TaskType.SEQ2SEQ_LM, |
|
|
r=self.config.rank, |
|
|
lora_alpha=self.config.alpha, |
|
|
lora_dropout=self.config.dropout, |
|
|
target_modules=self.config.target_modules, |
|
|
bias="none", |
|
|
fan_in_fan_out=False |
|
|
) |
|
|
|
|
|
|
|
|
self.peft_model = get_peft_model(self.model, lora_config) |
|
|
|
|
|
|
|
|
self._print_trainable_parameters() |
|
|
|
|
|
logger.info("LoRA configuration applied successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to setup LoRA: {e}") |
|
|
raise |
|
|
|
|
|
def _print_trainable_parameters(self): |
|
|
"""Print information about trainable parameters.""" |
|
|
trainable_params = 0 |
|
|
all_param = 0 |
|
|
|
|
|
for _, param in self.peft_model.named_parameters(): |
|
|
all_param += param.numel() |
|
|
if param.requires_grad: |
|
|
trainable_params += param.numel() |
|
|
|
|
|
logger.info( |
|
|
f"Trainable params: {trainable_params:,} || " |
|
|
f"All params: {all_param:,} || " |
|
|
f"Trainable%: {100 * trainable_params / all_param:.2f}%" |
|
|
) |
|
|
|
|
|
def prepare_training_data(self, training_data: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
Prepare training data for the model. |
|
|
|
|
|
Args: |
|
|
training_data: List of training examples |
|
|
|
|
|
Returns: |
|
|
Processed training data |
|
|
""" |
|
|
logger.info(f"Preparing {len(training_data)} training examples...") |
|
|
|
|
|
processed_data = [] |
|
|
for example in training_data: |
|
|
try: |
|
|
|
|
|
input_text = example.get("input", "") |
|
|
target_text = example.get("output", "") |
|
|
|
|
|
if not input_text or not target_text: |
|
|
continue |
|
|
|
|
|
|
|
|
formatted_input = f"Extract scenes from text: {input_text}" |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer( |
|
|
formatted_input, |
|
|
text_target=target_text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
processed_data.append({ |
|
|
"input_ids": tokenized["input_ids"], |
|
|
"attention_mask": tokenized["attention_mask"], |
|
|
"labels": tokenized["labels"] |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to process example: {e}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Successfully processed {len(processed_data)} training examples") |
|
|
return processed_data |
|
|
|
|
|
def train(self, training_data: List[Dict]): |
|
|
""" |
|
|
Train the LoRA model. |
|
|
|
|
|
Args: |
|
|
training_data: Training examples |
|
|
""" |
|
|
try: |
|
|
|
|
|
processed_data = self.prepare_training_data(training_data) |
|
|
|
|
|
if not processed_data: |
|
|
raise ValueError("No valid training data available") |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=self.config.output_dir, |
|
|
per_device_train_batch_size=self.config.batch_size, |
|
|
gradient_accumulation_steps=1, |
|
|
num_train_epochs=self.config.epochs, |
|
|
learning_rate=self.config.learning_rate, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_steps=self.config.warmup_steps, |
|
|
logging_steps=self.config.logging_steps, |
|
|
save_steps=self.config.save_steps, |
|
|
save_total_limit=3, |
|
|
evaluation_strategy="no", |
|
|
load_best_model_at_end=False, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
|
|
|
fp16=self.config.fp16, |
|
|
dataloader_pin_memory=False, |
|
|
remove_unused_columns=False, |
|
|
|
|
|
save_safetensors=self.config.save_safetensors, |
|
|
|
|
|
optim="adamw_torch", |
|
|
weight_decay=0.01, |
|
|
max_grad_norm=1.0, |
|
|
|
|
|
gradient_checkpointing=True |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.peft_model, |
|
|
args=training_args, |
|
|
train_dataset=processed_data, |
|
|
tokenizer=self.tokenizer, |
|
|
data_collator=self._data_collator |
|
|
) |
|
|
|
|
|
logger.info("Starting training...") |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
self._save_final_model() |
|
|
|
|
|
logger.info("Training completed successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Training failed: {e}") |
|
|
raise |
|
|
|
|
|
def _data_collator(self, features): |
|
|
"""Custom data collator for the trainer.""" |
|
|
batch = {} |
|
|
|
|
|
|
|
|
batch["input_ids"] = torch.stack([f["input_ids"] for f in features]) |
|
|
batch["attention_mask"] = torch.stack([f["attention_mask"] for f in features]) |
|
|
batch["labels"] = torch.stack([f["labels"] for f in features]) |
|
|
|
|
|
return batch |
|
|
|
|
|
def _save_final_model(self): |
|
|
"""Save the final model with safetensors.""" |
|
|
try: |
|
|
logger.info("Saving final model with safetensors...") |
|
|
|
|
|
|
|
|
self.peft_model.save_pretrained( |
|
|
self.config.output_dir, |
|
|
save_safetensors=self.config.save_safetensors |
|
|
) |
|
|
|
|
|
|
|
|
self.tokenizer.save_pretrained(self.config.output_dir) |
|
|
|
|
|
|
|
|
safetensors_path = os.path.join(self.config.output_dir, "adapter_model.safetensors") |
|
|
if os.path.exists(safetensors_path): |
|
|
logger.info(f"LoRA weights saved to {safetensors_path}") |
|
|
|
|
|
|
|
|
self._verify_safetensors_file(safetensors_path) |
|
|
else: |
|
|
logger.warning("Safetensors file not found!") |
|
|
|
|
|
|
|
|
self._save_model_info() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save model: {e}") |
|
|
raise |
|
|
|
|
|
def _verify_safetensors_file(self, filepath: str): |
|
|
"""Verify safetensors file integrity.""" |
|
|
try: |
|
|
with safe_open(filepath, framework="pt") as f: |
|
|
tensor_names = list(f.keys()) |
|
|
logger.info(f"Safetensors file contains {len(tensor_names)} tensors") |
|
|
logger.info(f"Sample tensors: {tensor_names[:5]}") |
|
|
except Exception as e: |
|
|
logger.error(f"Safetensors verification failed: {e}") |
|
|
raise |
|
|
|
|
|
def _save_model_info(self): |
|
|
"""Save model information and metadata.""" |
|
|
model_info = { |
|
|
"model_type": "LoRA", |
|
|
"base_model": self.config.base_model, |
|
|
"lora_rank": self.config.rank, |
|
|
"lora_alpha": self.config.alpha, |
|
|
"lora_dropout": self.config.dropout, |
|
|
"target_modules": self.config.target_modules, |
|
|
"training_epochs": self.config.epochs, |
|
|
"save_safetensors": self.config.save_safetensors, |
|
|
"total_parameters": sum(p.numel() for p in self.peft_model.parameters()), |
|
|
"trainable_parameters": sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad), |
|
|
"timestamp": torch.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
info_path = os.path.join(self.config.output_dir, "model_info.json") |
|
|
with open(info_path, 'w') as f: |
|
|
json.dump(model_info, f, indent=2) |
|
|
|
|
|
logger.info(f"Model info saved to {info_path}") |
|
|
|
|
|
def create_sample_training_data() -> List[Dict]: |
|
|
"""Create sample training data for demonstration.""" |
|
|
sample_data = [ |
|
|
{ |
|
|
"input": "আজকের দিনটি ছিল খুব সুন্দর। রোদ উজ্জ্বল ছিল এবং হাওয়া মৃদুমন্দ।", |
|
|
"output": "দৃশ্য ১: উজ্জ্বল সূর্যের আলোয় একটি সুন্দর দিন\nদৃশ্য ২: মৃদুমন্দ বাতাসে গাছের পাতা দুলছে" |
|
|
}, |
|
|
{ |
|
|
"input": "শহরের ব্যস্ত রাস্তায় মানুষের চলাচল চলছে। গাড়ি আর মানুষের একটা কর্মব্যস্ততা দেখা যাচ্ছে।", |
|
|
"output": "দৃশ্য ১: শহরের ব্যস্ত রাস্তায় মানুষের চলাচল\nদৃশ্য ২: যানবাহন আর পথচারীর গতিশীল দৃশ্য" |
|
|
} |
|
|
] |
|
|
return sample_data |
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
|
|
|
config = TrainingConfig( |
|
|
base_model="google/mt5-small", |
|
|
output_dir="./memo-scene-lora", |
|
|
rank=32, |
|
|
alpha=64, |
|
|
epochs=3, |
|
|
batch_size=2, |
|
|
save_safetensors=True |
|
|
) |
|
|
|
|
|
|
|
|
trainer = SceneLoRATrainer(config) |
|
|
|
|
|
|
|
|
trainer.load_model() |
|
|
trainer.setup_lora() |
|
|
|
|
|
|
|
|
training_data = create_sample_training_data() |
|
|
|
|
|
|
|
|
trainer.train(training_data) |
|
|
|
|
|
print(f"\n✅ Training completed successfully!") |
|
|
print(f"📁 Model saved to: {config.output_dir}") |
|
|
print(f"🔒 Using safetensors: {config.save_safetensors}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |