memo / scripts /train_scene_lora.py
likhonsheikh's picture
Upload Memo: Production-grade Transformers + Safetensors implementation
a8fc815 verified
"""
Scene LoRA Training Script - Transformers + Safetensors
Production-grade training with proper security and performance optimizations
"""
import os
import torch
import logging
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
# Transformers and PEFT imports
from transformers import (
Trainer,
TrainingArguments,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from peft import (
LoraConfig,
get_peft_model,
TaskType,
PeftModel,
PeftConfig
)
from safetensors import safe_open
from safetensors.torch import save_file
import json
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class TrainingConfig:
"""Configuration for LoRA training."""
base_model: str = "google/mt5-small"
output_dir: str = "./memo-scene-lora"
rank: int = 32
alpha: int = 64
dropout: float = 0.1
target_modules: List[str] = None
epochs: int = 3
batch_size: int = 4
learning_rate: float = 1e-4
warmup_steps: int = 100
save_steps: int = 500
logging_steps: int = 50
fp16: bool = True
use_8bit: bool = False
save_safetensors: bool = True # MANDATORY
def __post_init__(self):
if self.target_modules is None:
# Default target modules for different model types
if "t5" in self.base_model.lower():
self.target_modules = ["q", "k", "v", "o"]
elif "mt5" in self.base_model.lower():
self.target_modules = ["q", "k", "v", "o"]
else:
self.target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
class SceneLoRATrainer:
"""
Production-grade LoRA trainer with transformers integration.
Ensures safetensors-only output and proper security measures.
"""
def __init__(self, config: TrainingConfig):
"""
Initialize the trainer with configuration.
Args:
config: Training configuration
"""
self.config = config
self.model = None
self.tokenizer = None
self.peft_model = None
logger.info("SceneLoRATrainer initialized")
logger.info(f"Base model: {config.base_model}")
logger.info(f"Output directory: {config.output_dir}")
logger.info(f"Safetensors enabled: {config.save_safetensors}")
# Setup output directory
os.makedirs(config.output_dir, exist_ok=True)
# Save configuration
self._save_config()
def _save_config(self):
"""Save training configuration."""
config_dict = {
"base_model": self.config.base_model,
"rank": self.config.rank,
"alpha": self.config.alpha,
"dropout": self.config.dropout,
"target_modules": self.config.target_modules,
"epochs": self.config.epochs,
"batch_size": self.config.batch_size,
"learning_rate": self.config.learning_rate,
"fp16": self.config.fp16,
"use_8bit": self.config.use_8bit,
"save_safetensors": self.config.save_safetensors,
"timestamp": torch.datetime.now().isoformat()
}
config_path = os.path.join(self.config.output_dir, "training_config.json")
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=2)
logger.info(f"Training configuration saved to {config_path}")
def load_model(self):
"""Load base model and tokenizer."""
try:
logger.info("Loading base model and tokenizer...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.base_model,
use_fast=True
)
# Configure model loading
model_kwargs = {
"torch_dtype": torch.float16 if self.config.fp16 else torch.float32,
"device_map": "auto" if torch.cuda.is_available() else None
}
if self.config.use_8bit:
model_kwargs["load_in_8bit"] = True
# Load model
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.config.base_model,
**model_kwargs
)
if not torch.cuda.is_available():
self.model = self.model.to("cpu")
logger.info(f"Base model loaded successfully")
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def setup_lora(self):
"""Setup LoRA configuration and model."""
try:
logger.info("Setting up LoRA configuration...")
# Create LoRA configuration
lora_config = LoraConfig(
task_type=TaskType.SEQ2SEQ_LM,
r=self.config.rank,
lora_alpha=self.config.alpha,
lora_dropout=self.config.dropout,
target_modules=self.config.target_modules,
bias="none",
fan_in_fan_out=False
)
# Apply LoRA to model
self.peft_model = get_peft_model(self.model, lora_config)
# Print trainable parameters
self._print_trainable_parameters()
logger.info("LoRA configuration applied successfully")
except Exception as e:
logger.error(f"Failed to setup LoRA: {e}")
raise
def _print_trainable_parameters(self):
"""Print information about trainable parameters."""
trainable_params = 0
all_param = 0
for _, param in self.peft_model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
logger.info(
f"Trainable params: {trainable_params:,} || "
f"All params: {all_param:,} || "
f"Trainable%: {100 * trainable_params / all_param:.2f}%"
)
def prepare_training_data(self, training_data: List[Dict]) -> List[Dict]:
"""
Prepare training data for the model.
Args:
training_data: List of training examples
Returns:
Processed training data
"""
logger.info(f"Preparing {len(training_data)} training examples...")
processed_data = []
for example in training_data:
try:
# Tokenize input text
input_text = example.get("input", "")
target_text = example.get("output", "")
if not input_text or not target_text:
continue
# Add task-specific formatting
formatted_input = f"Extract scenes from text: {input_text}"
# Tokenize
tokenized = self.tokenizer(
formatted_input,
text_target=target_text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
processed_data.append({
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
"labels": tokenized["labels"]
})
except Exception as e:
logger.warning(f"Failed to process example: {e}")
continue
logger.info(f"Successfully processed {len(processed_data)} training examples")
return processed_data
def train(self, training_data: List[Dict]):
"""
Train the LoRA model.
Args:
training_data: Training examples
"""
try:
# Prepare training data
processed_data = self.prepare_training_data(training_data)
if not processed_data:
raise ValueError("No valid training data available")
# Setup training arguments with security features
training_args = TrainingArguments(
output_dir=self.config.output_dir,
per_device_train_batch_size=self.config.batch_size,
gradient_accumulation_steps=1,
num_train_epochs=self.config.epochs,
learning_rate=self.config.learning_rate,
lr_scheduler_type="cosine",
warmup_steps=self.config.warmup_steps,
logging_steps=self.config.logging_steps,
save_steps=self.config.save_steps,
save_total_limit=3,
evaluation_strategy="no", # Disable evaluation for faster training
load_best_model_at_end=False,
metric_for_best_model="eval_loss",
greater_is_better=False,
# Security and performance settings
fp16=self.config.fp16,
dataloader_pin_memory=False,
remove_unused_columns=False,
# MANDATORY safetensors settings
save_safetensors=self.config.save_safetensors,
# Optimizer settings
optim="adamw_torch",
weight_decay=0.01,
max_grad_norm=1.0,
# Memory optimization
gradient_checkpointing=True
)
# Create trainer
trainer = Trainer(
model=self.peft_model,
args=training_args,
train_dataset=processed_data,
tokenizer=self.tokenizer,
data_collator=self._data_collator
)
logger.info("Starting training...")
# Start training
trainer.train()
# Save final model with safetensors
self._save_final_model()
logger.info("Training completed successfully")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def _data_collator(self, features):
"""Custom data collator for the trainer."""
batch = {}
# Stack tensors
batch["input_ids"] = torch.stack([f["input_ids"] for f in features])
batch["attention_mask"] = torch.stack([f["attention_mask"] for f in features])
batch["labels"] = torch.stack([f["labels"] for f in features])
return batch
def _save_final_model(self):
"""Save the final model with safetensors."""
try:
logger.info("Saving final model with safetensors...")
# Save LoRA adapter with safetensors
self.peft_model.save_pretrained(
self.config.output_dir,
save_safetensors=self.config.save_safetensors
)
# Save tokenizer
self.tokenizer.save_pretrained(self.config.output_dir)
# Verify safetensors file exists
safetensors_path = os.path.join(self.config.output_dir, "adapter_model.safetensors")
if os.path.exists(safetensors_path):
logger.info(f"LoRA weights saved to {safetensors_path}")
# Verify file integrity
self._verify_safetensors_file(safetensors_path)
else:
logger.warning("Safetensors file not found!")
# Save model info
self._save_model_info()
except Exception as e:
logger.error(f"Failed to save model: {e}")
raise
def _verify_safetensors_file(self, filepath: str):
"""Verify safetensors file integrity."""
try:
with safe_open(filepath, framework="pt") as f:
tensor_names = list(f.keys())
logger.info(f"Safetensors file contains {len(tensor_names)} tensors")
logger.info(f"Sample tensors: {tensor_names[:5]}")
except Exception as e:
logger.error(f"Safetensors verification failed: {e}")
raise
def _save_model_info(self):
"""Save model information and metadata."""
model_info = {
"model_type": "LoRA",
"base_model": self.config.base_model,
"lora_rank": self.config.rank,
"lora_alpha": self.config.alpha,
"lora_dropout": self.config.dropout,
"target_modules": self.config.target_modules,
"training_epochs": self.config.epochs,
"save_safetensors": self.config.save_safetensors,
"total_parameters": sum(p.numel() for p in self.peft_model.parameters()),
"trainable_parameters": sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad),
"timestamp": torch.datetime.now().isoformat()
}
info_path = os.path.join(self.config.output_dir, "model_info.json")
with open(info_path, 'w') as f:
json.dump(model_info, f, indent=2)
logger.info(f"Model info saved to {info_path}")
def create_sample_training_data() -> List[Dict]:
"""Create sample training data for demonstration."""
sample_data = [
{
"input": "আজকের দিনটি ছিল খুব সুন্দর। রোদ উজ্জ্বল ছিল এবং হাওয়া মৃদুমন্দ।",
"output": "দৃশ্য ১: উজ্জ্বল সূর্যের আলোয় একটি সুন্দর দিন\nদৃশ্য ২: মৃদুমন্দ বাতাসে গাছের পাতা দুলছে"
},
{
"input": "শহরের ব্যস্ত রাস্তায় মানুষের চলাচল চলছে। গাড়ি আর মানুষের একটা কর্মব্যস্ততা দেখা যাচ্ছে।",
"output": "দৃশ্য ১: শহরের ব্যস্ত রাস্তায় মানুষের চলাচল\nদৃশ্য ২: যানবাহন আর পথচারীর গতিশীল দৃশ্য"
}
]
return sample_data
def main():
"""Main training function."""
# Configuration
config = TrainingConfig(
base_model="google/mt5-small",
output_dir="./memo-scene-lora",
rank=32,
alpha=64,
epochs=3,
batch_size=2,
save_safetensors=True # MANDATORY
)
# Initialize trainer
trainer = SceneLoRATrainer(config)
# Load model and setup LoRA
trainer.load_model()
trainer.setup_lora()
# Create sample training data
training_data = create_sample_training_data()
# Train model
trainer.train(training_data)
print(f"\n✅ Training completed successfully!")
print(f"📁 Model saved to: {config.output_dir}")
print(f"🔒 Using safetensors: {config.save_safetensors}")
if __name__ == "__main__":
main()