""" Scene LoRA Training Script - Transformers + Safetensors Production-grade training with proper security and performance optimizations """ import os import torch import logging from pathlib import Path from typing import List, Dict, Optional from dataclasses import dataclass # Transformers and PEFT imports from transformers import ( Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM ) from peft import ( LoraConfig, get_peft_model, TaskType, PeftModel, PeftConfig ) from safetensors import safe_open from safetensors.torch import save_file import json # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class TrainingConfig: """Configuration for LoRA training.""" base_model: str = "google/mt5-small" output_dir: str = "./memo-scene-lora" rank: int = 32 alpha: int = 64 dropout: float = 0.1 target_modules: List[str] = None epochs: int = 3 batch_size: int = 4 learning_rate: float = 1e-4 warmup_steps: int = 100 save_steps: int = 500 logging_steps: int = 50 fp16: bool = True use_8bit: bool = False save_safetensors: bool = True # MANDATORY def __post_init__(self): if self.target_modules is None: # Default target modules for different model types if "t5" in self.base_model.lower(): self.target_modules = ["q", "k", "v", "o"] elif "mt5" in self.base_model.lower(): self.target_modules = ["q", "k", "v", "o"] else: self.target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] class SceneLoRATrainer: """ Production-grade LoRA trainer with transformers integration. Ensures safetensors-only output and proper security measures. """ def __init__(self, config: TrainingConfig): """ Initialize the trainer with configuration. Args: config: Training configuration """ self.config = config self.model = None self.tokenizer = None self.peft_model = None logger.info("SceneLoRATrainer initialized") logger.info(f"Base model: {config.base_model}") logger.info(f"Output directory: {config.output_dir}") logger.info(f"Safetensors enabled: {config.save_safetensors}") # Setup output directory os.makedirs(config.output_dir, exist_ok=True) # Save configuration self._save_config() def _save_config(self): """Save training configuration.""" config_dict = { "base_model": self.config.base_model, "rank": self.config.rank, "alpha": self.config.alpha, "dropout": self.config.dropout, "target_modules": self.config.target_modules, "epochs": self.config.epochs, "batch_size": self.config.batch_size, "learning_rate": self.config.learning_rate, "fp16": self.config.fp16, "use_8bit": self.config.use_8bit, "save_safetensors": self.config.save_safetensors, "timestamp": torch.datetime.now().isoformat() } config_path = os.path.join(self.config.output_dir, "training_config.json") with open(config_path, 'w') as f: json.dump(config_dict, f, indent=2) logger.info(f"Training configuration saved to {config_path}") def load_model(self): """Load base model and tokenizer.""" try: logger.info("Loading base model and tokenizer...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.config.base_model, use_fast=True ) # Configure model loading model_kwargs = { "torch_dtype": torch.float16 if self.config.fp16 else torch.float32, "device_map": "auto" if torch.cuda.is_available() else None } if self.config.use_8bit: model_kwargs["load_in_8bit"] = True # Load model self.model = AutoModelForSeq2SeqLM.from_pretrained( self.config.base_model, **model_kwargs ) if not torch.cuda.is_available(): self.model = self.model.to("cpu") logger.info(f"Base model loaded successfully") logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}") except Exception as e: logger.error(f"Failed to load model: {e}") raise def setup_lora(self): """Setup LoRA configuration and model.""" try: logger.info("Setting up LoRA configuration...") # Create LoRA configuration lora_config = LoraConfig( task_type=TaskType.SEQ2SEQ_LM, r=self.config.rank, lora_alpha=self.config.alpha, lora_dropout=self.config.dropout, target_modules=self.config.target_modules, bias="none", fan_in_fan_out=False ) # Apply LoRA to model self.peft_model = get_peft_model(self.model, lora_config) # Print trainable parameters self._print_trainable_parameters() logger.info("LoRA configuration applied successfully") except Exception as e: logger.error(f"Failed to setup LoRA: {e}") raise def _print_trainable_parameters(self): """Print information about trainable parameters.""" trainable_params = 0 all_param = 0 for _, param in self.peft_model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() logger.info( f"Trainable params: {trainable_params:,} || " f"All params: {all_param:,} || " f"Trainable%: {100 * trainable_params / all_param:.2f}%" ) def prepare_training_data(self, training_data: List[Dict]) -> List[Dict]: """ Prepare training data for the model. Args: training_data: List of training examples Returns: Processed training data """ logger.info(f"Preparing {len(training_data)} training examples...") processed_data = [] for example in training_data: try: # Tokenize input text input_text = example.get("input", "") target_text = example.get("output", "") if not input_text or not target_text: continue # Add task-specific formatting formatted_input = f"Extract scenes from text: {input_text}" # Tokenize tokenized = self.tokenizer( formatted_input, text_target=target_text, padding=True, truncation=True, max_length=512, return_tensors="pt" ) processed_data.append({ "input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"], "labels": tokenized["labels"] }) except Exception as e: logger.warning(f"Failed to process example: {e}") continue logger.info(f"Successfully processed {len(processed_data)} training examples") return processed_data def train(self, training_data: List[Dict]): """ Train the LoRA model. Args: training_data: Training examples """ try: # Prepare training data processed_data = self.prepare_training_data(training_data) if not processed_data: raise ValueError("No valid training data available") # Setup training arguments with security features training_args = TrainingArguments( output_dir=self.config.output_dir, per_device_train_batch_size=self.config.batch_size, gradient_accumulation_steps=1, num_train_epochs=self.config.epochs, learning_rate=self.config.learning_rate, lr_scheduler_type="cosine", warmup_steps=self.config.warmup_steps, logging_steps=self.config.logging_steps, save_steps=self.config.save_steps, save_total_limit=3, evaluation_strategy="no", # Disable evaluation for faster training load_best_model_at_end=False, metric_for_best_model="eval_loss", greater_is_better=False, # Security and performance settings fp16=self.config.fp16, dataloader_pin_memory=False, remove_unused_columns=False, # MANDATORY safetensors settings save_safetensors=self.config.save_safetensors, # Optimizer settings optim="adamw_torch", weight_decay=0.01, max_grad_norm=1.0, # Memory optimization gradient_checkpointing=True ) # Create trainer trainer = Trainer( model=self.peft_model, args=training_args, train_dataset=processed_data, tokenizer=self.tokenizer, data_collator=self._data_collator ) logger.info("Starting training...") # Start training trainer.train() # Save final model with safetensors self._save_final_model() logger.info("Training completed successfully") except Exception as e: logger.error(f"Training failed: {e}") raise def _data_collator(self, features): """Custom data collator for the trainer.""" batch = {} # Stack tensors batch["input_ids"] = torch.stack([f["input_ids"] for f in features]) batch["attention_mask"] = torch.stack([f["attention_mask"] for f in features]) batch["labels"] = torch.stack([f["labels"] for f in features]) return batch def _save_final_model(self): """Save the final model with safetensors.""" try: logger.info("Saving final model with safetensors...") # Save LoRA adapter with safetensors self.peft_model.save_pretrained( self.config.output_dir, save_safetensors=self.config.save_safetensors ) # Save tokenizer self.tokenizer.save_pretrained(self.config.output_dir) # Verify safetensors file exists safetensors_path = os.path.join(self.config.output_dir, "adapter_model.safetensors") if os.path.exists(safetensors_path): logger.info(f"LoRA weights saved to {safetensors_path}") # Verify file integrity self._verify_safetensors_file(safetensors_path) else: logger.warning("Safetensors file not found!") # Save model info self._save_model_info() except Exception as e: logger.error(f"Failed to save model: {e}") raise def _verify_safetensors_file(self, filepath: str): """Verify safetensors file integrity.""" try: with safe_open(filepath, framework="pt") as f: tensor_names = list(f.keys()) logger.info(f"Safetensors file contains {len(tensor_names)} tensors") logger.info(f"Sample tensors: {tensor_names[:5]}") except Exception as e: logger.error(f"Safetensors verification failed: {e}") raise def _save_model_info(self): """Save model information and metadata.""" model_info = { "model_type": "LoRA", "base_model": self.config.base_model, "lora_rank": self.config.rank, "lora_alpha": self.config.alpha, "lora_dropout": self.config.dropout, "target_modules": self.config.target_modules, "training_epochs": self.config.epochs, "save_safetensors": self.config.save_safetensors, "total_parameters": sum(p.numel() for p in self.peft_model.parameters()), "trainable_parameters": sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad), "timestamp": torch.datetime.now().isoformat() } info_path = os.path.join(self.config.output_dir, "model_info.json") with open(info_path, 'w') as f: json.dump(model_info, f, indent=2) logger.info(f"Model info saved to {info_path}") def create_sample_training_data() -> List[Dict]: """Create sample training data for demonstration.""" sample_data = [ { "input": "আজকের দিনটি ছিল খুব সুন্দর। রোদ উজ্জ্বল ছিল এবং হাওয়া মৃদুমন্দ।", "output": "দৃশ্য ১: উজ্জ্বল সূর্যের আলোয় একটি সুন্দর দিন\nদৃশ্য ২: মৃদুমন্দ বাতাসে গাছের পাতা দুলছে" }, { "input": "শহরের ব্যস্ত রাস্তায় মানুষের চলাচল চলছে। গাড়ি আর মানুষের একটা কর্মব্যস্ততা দেখা যাচ্ছে।", "output": "দৃশ্য ১: শহরের ব্যস্ত রাস্তায় মানুষের চলাচল\nদৃশ্য ২: যানবাহন আর পথচারীর গতিশীল দৃশ্য" } ] return sample_data def main(): """Main training function.""" # Configuration config = TrainingConfig( base_model="google/mt5-small", output_dir="./memo-scene-lora", rank=32, alpha=64, epochs=3, batch_size=2, save_safetensors=True # MANDATORY ) # Initialize trainer trainer = SceneLoRATrainer(config) # Load model and setup LoRA trainer.load_model() trainer.setup_lora() # Create sample training data training_data = create_sample_training_data() # Train model trainer.train(training_data) print(f"\n✅ Training completed successfully!") print(f"📁 Model saved to: {config.output_dir}") print(f"🔒 Using safetensors: {config.save_safetensors}") if __name__ == "__main__": main()