ai_exec / src /training /train_lora.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Standard LoRA Training Module
Fine-tune Qwen3-4B using standard LoRA (full precision) with PEFT + TRL.
Use this for training on larger GPUs without quantization.
Example usage:
from src.training.train_lora import train_lora
train_lora(
train_dataset_path="data/training/train.jsonl",
val_dataset_path="data/training/validation.jsonl",
output_dir="./outputs",
push_to_hub=True,
hub_model_id="username/ceo-voice-model",
)
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from loguru import logger
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
TRAINING_AVAILABLE = True
except ImportError as e:
TRAINING_AVAILABLE = False
logger.warning(f"Training dependencies not available: {e}")
@dataclass
class LoRAConfig:
"""Configuration for standard LoRA training."""
# Model configuration
base_model: str = "Qwen/Qwen3-4B-Instruct"
max_seq_length: int = 2048
torch_dtype: str = "bfloat16" # or "float16", "float32"
# LoRA configuration
lora_r: int = 64
lora_alpha: int = 128
lora_dropout: float = 0.05
target_modules: list = field(default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
])
# Training configuration
num_train_epochs: int = 3
per_device_train_batch_size: int = 2
per_device_eval_batch_size: int = 2
gradient_accumulation_steps: int = 8
learning_rate: float = 2e-4
weight_decay: float = 0.01
warmup_ratio: float = 0.03
lr_scheduler_type: str = "cosine"
# Optimization
fp16: bool = False
bf16: bool = True
gradient_checkpointing: bool = True
optim: str = "adamw_torch"
# Logging and saving
logging_steps: int = 10
save_steps: int = 100
eval_steps: int = 100
save_total_limit: int = 3
# Hub configuration
push_to_hub: bool = False
hub_model_id: Optional[str] = None
hub_token: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"base_model": self.base_model,
"max_seq_length": self.max_seq_length,
"torch_dtype": self.torch_dtype,
"lora_r": self.lora_r,
"lora_alpha": self.lora_alpha,
"lora_dropout": self.lora_dropout,
"target_modules": self.target_modules,
"num_train_epochs": self.num_train_epochs,
"per_device_train_batch_size": self.per_device_train_batch_size,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"learning_rate": self.learning_rate,
}
def get_lora_config(config: LoRAConfig) -> "LoraConfig":
"""Get LoRA configuration."""
return LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=config.target_modules,
bias="none",
task_type="CAUSAL_LM",
)
def get_torch_dtype(dtype_str: str):
"""Convert string to torch dtype."""
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
return dtype_map.get(dtype_str, torch.bfloat16)
def load_model_and_tokenizer(config: LoRAConfig):
"""
Load the base model and tokenizer without quantization.
Args:
config: LoRA configuration
Returns:
Tuple of (model, tokenizer)
"""
logger.info(f"Loading model: {config.base_model}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.base_model,
trust_remote_code=True,
padding_side="right",
)
# Ensure special tokens
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Get torch dtype
torch_dtype = get_torch_dtype(config.torch_dtype)
# Load model without quantization
model = AutoModelForCausalLM.from_pretrained(
config.base_model,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch_dtype,
)
# Enable gradient checkpointing
if config.gradient_checkpointing:
model.gradient_checkpointing_enable()
logger.info(f"Model loaded: {model.dtype}")
return model, tokenizer
def train_lora(
train_dataset_path: str | Path,
val_dataset_path: Optional[str | Path] = None,
output_dir: str = "./outputs",
config: Optional[LoRAConfig] = None,
push_to_hub: bool = False,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
resume_from_checkpoint: Optional[str] = None,
) -> str:
"""
Run standard LoRA fine-tuning on the voice model.
Args:
train_dataset_path: Path to training JSONL
val_dataset_path: Path to validation JSONL
output_dir: Directory for outputs
config: LoRA configuration (uses defaults if None)
push_to_hub: Whether to push to HF Hub
hub_model_id: Hub repository ID
hub_token: HF token
resume_from_checkpoint: Checkpoint path to resume from
Returns:
Path to saved adapter
"""
if not TRAINING_AVAILABLE:
raise ImportError(
"Training dependencies not available. Install with:\n"
"pip install torch transformers peft trl datasets"
)
# Use default config if not provided
if config is None:
config = LoRAConfig()
# Override hub settings
if push_to_hub:
config.push_to_hub = True
if hub_model_id:
config.hub_model_id = hub_model_id
# Use provided token or fall back to environment variable
config.hub_token = hub_token or os.environ.get("HF_TOKEN")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("Starting LoRA training (full precision)")
logger.info(f"Config: {config.to_dict()}")
# Load datasets
from .prepare_dataset import load_jsonl, format_chat_template
logger.info(f"Loading training data: {train_dataset_path}")
train_data = load_jsonl(train_dataset_path)
val_data = None
if val_dataset_path:
logger.info(f"Loading validation data: {val_dataset_path}")
val_data = load_jsonl(val_dataset_path)
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(config)
# Format datasets
def format_example(example):
text = format_chat_template(example["messages"], tokenizer)
return {"text": text}
train_formatted = [format_example(ex) for ex in train_data]
train_dataset = Dataset.from_list(train_formatted)
eval_dataset = None
if val_data:
val_formatted = [format_example(ex) for ex in val_data]
eval_dataset = Dataset.from_list(val_formatted)
# Get LoRA config
lora_config = get_lora_config(config)
# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = SFTConfig(
output_dir=str(output_dir),
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
lr_scheduler_type=config.lr_scheduler_type,
fp16=config.fp16,
bf16=config.bf16,
gradient_checkpointing=config.gradient_checkpointing,
optim=config.optim,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_steps=config.eval_steps if eval_dataset else None,
eval_strategy="steps" if eval_dataset else "no",
save_total_limit=config.save_total_limit,
load_best_model_at_end=True if eval_dataset else False,
metric_for_best_model="eval_loss" if eval_dataset else None,
greater_is_better=False,
push_to_hub=config.push_to_hub,
hub_model_id=config.hub_model_id,
hub_token=config.hub_token,
report_to=["tensorboard"],
max_seq_length=config.max_seq_length,
dataset_text_field="text",
packing=False,
)
# Create trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
# Train
logger.info("Starting training...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save final model
final_path = output_dir / "final_adapter"
logger.info(f"Saving adapter to: {final_path}")
trainer.save_model(str(final_path))
tokenizer.save_pretrained(str(final_path))
# Push to hub if configured
if config.push_to_hub and config.hub_model_id:
logger.info(f"Pushing to Hub: {config.hub_model_id}")
trainer.push_to_hub()
logger.info("Training complete!")
return str(final_path)
def main():
"""CLI entry point for LoRA training."""
import argparse
import json
parser = argparse.ArgumentParser(
description="Fine-tune Qwen3-4B with standard LoRA (full precision)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic training
python train_lora.py --train data/training/train.jsonl --output ./outputs
# With validation and hub push
python train_lora.py \\
--train data/training/train.jsonl \\
--val data/training/validation.jsonl \\
--output ./outputs \\
--push-to-hub \\
--hub-model-id username/ceo-voice-model
Note: Standard LoRA requires more VRAM than QLoRA. Use QLoRA for
constrained GPU environments.
""",
)
# Data arguments
parser.add_argument("--train", required=True, help="Training JSONL file")
parser.add_argument("--val", help="Validation JSONL file")
parser.add_argument("--output", default="./outputs", help="Output directory")
# Model arguments
parser.add_argument(
"--base-model",
default="Qwen/Qwen3-4B-Instruct",
help="Base model name",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=2048,
help="Maximum sequence length",
)
parser.add_argument(
"--dtype",
choices=["float16", "bfloat16", "float32"],
default="bfloat16",
help="Torch dtype for model",
)
# LoRA arguments
parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank")
parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha")
parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
# Training arguments
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=2, help="Batch size")
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
# Hub arguments
parser.add_argument("--push-to-hub", action="store_true", help="Push to HF Hub")
parser.add_argument("--hub-model-id", help="Hub model ID")
# Other arguments
parser.add_argument("--resume", help="Resume from checkpoint")
parser.add_argument("--config", help="JSON config file")
args = parser.parse_args()
# Build config
config = LoRAConfig(
base_model=args.base_model,
max_seq_length=args.max_seq_length,
torch_dtype=args.dtype,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
)
# Override with JSON config if provided
if args.config:
with open(args.config, "r") as f:
config_data = json.load(f)
for key, value in config_data.items():
if hasattr(config, key):
setattr(config, key, value)
# Run training
adapter_path = train_lora(
train_dataset_path=args.train,
val_dataset_path=args.val,
output_dir=args.output,
config=config,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
resume_from_checkpoint=args.resume,
)
print(f"\nTraining complete!")
print(f"Adapter saved to: {adapter_path}")
if __name__ == "__main__":
main()