""" Standard LoRA Training Module Fine-tune Qwen3-4B using standard LoRA (full precision) with PEFT + TRL. Use this for training on larger GPUs without quantization. Example usage: from src.training.train_lora import train_lora train_lora( train_dataset_path="data/training/train.jsonl", val_dataset_path="data/training/validation.jsonl", output_dir="./outputs", push_to_hub=True, hub_model_id="username/ceo-voice-model", ) """ import os from dataclasses import dataclass, field from pathlib import Path from typing import Optional from loguru import logger try: import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, ) from peft import LoraConfig, get_peft_model from trl import SFTTrainer, SFTConfig from datasets import Dataset TRAINING_AVAILABLE = True except ImportError as e: TRAINING_AVAILABLE = False logger.warning(f"Training dependencies not available: {e}") @dataclass class LoRAConfig: """Configuration for standard LoRA training.""" # Model configuration base_model: str = "Qwen/Qwen3-4B-Instruct" max_seq_length: int = 2048 torch_dtype: str = "bfloat16" # or "float16", "float32" # LoRA configuration lora_r: int = 64 lora_alpha: int = 128 lora_dropout: float = 0.05 target_modules: list = field(default_factory=lambda: [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ]) # Training configuration num_train_epochs: int = 3 per_device_train_batch_size: int = 2 per_device_eval_batch_size: int = 2 gradient_accumulation_steps: int = 8 learning_rate: float = 2e-4 weight_decay: float = 0.01 warmup_ratio: float = 0.03 lr_scheduler_type: str = "cosine" # Optimization fp16: bool = False bf16: bool = True gradient_checkpointing: bool = True optim: str = "adamw_torch" # Logging and saving logging_steps: int = 10 save_steps: int = 100 eval_steps: int = 100 save_total_limit: int = 3 # Hub configuration push_to_hub: bool = False hub_model_id: Optional[str] = None hub_token: Optional[str] = None def to_dict(self) -> dict: """Convert to dictionary.""" return { "base_model": self.base_model, "max_seq_length": self.max_seq_length, "torch_dtype": self.torch_dtype, "lora_r": self.lora_r, "lora_alpha": self.lora_alpha, "lora_dropout": self.lora_dropout, "target_modules": self.target_modules, "num_train_epochs": self.num_train_epochs, "per_device_train_batch_size": self.per_device_train_batch_size, "gradient_accumulation_steps": self.gradient_accumulation_steps, "learning_rate": self.learning_rate, } def get_lora_config(config: LoRAConfig) -> "LoraConfig": """Get LoRA configuration.""" return LoraConfig( r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, target_modules=config.target_modules, bias="none", task_type="CAUSAL_LM", ) def get_torch_dtype(dtype_str: str): """Convert string to torch dtype.""" dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } return dtype_map.get(dtype_str, torch.bfloat16) def load_model_and_tokenizer(config: LoRAConfig): """ Load the base model and tokenizer without quantization. Args: config: LoRA configuration Returns: Tuple of (model, tokenizer) """ logger.info(f"Loading model: {config.base_model}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( config.base_model, trust_remote_code=True, padding_side="right", ) # Ensure special tokens if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Get torch dtype torch_dtype = get_torch_dtype(config.torch_dtype) # Load model without quantization model = AutoModelForCausalLM.from_pretrained( config.base_model, device_map="auto", trust_remote_code=True, torch_dtype=torch_dtype, ) # Enable gradient checkpointing if config.gradient_checkpointing: model.gradient_checkpointing_enable() logger.info(f"Model loaded: {model.dtype}") return model, tokenizer def train_lora( train_dataset_path: str | Path, val_dataset_path: Optional[str | Path] = None, output_dir: str = "./outputs", config: Optional[LoRAConfig] = None, push_to_hub: bool = False, hub_model_id: Optional[str] = None, hub_token: Optional[str] = None, resume_from_checkpoint: Optional[str] = None, ) -> str: """ Run standard LoRA fine-tuning on the voice model. Args: train_dataset_path: Path to training JSONL val_dataset_path: Path to validation JSONL output_dir: Directory for outputs config: LoRA configuration (uses defaults if None) push_to_hub: Whether to push to HF Hub hub_model_id: Hub repository ID hub_token: HF token resume_from_checkpoint: Checkpoint path to resume from Returns: Path to saved adapter """ if not TRAINING_AVAILABLE: raise ImportError( "Training dependencies not available. Install with:\n" "pip install torch transformers peft trl datasets" ) # Use default config if not provided if config is None: config = LoRAConfig() # Override hub settings if push_to_hub: config.push_to_hub = True if hub_model_id: config.hub_model_id = hub_model_id # Use provided token or fall back to environment variable config.hub_token = hub_token or os.environ.get("HF_TOKEN") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info("Starting LoRA training (full precision)") logger.info(f"Config: {config.to_dict()}") # Load datasets from .prepare_dataset import load_jsonl, format_chat_template logger.info(f"Loading training data: {train_dataset_path}") train_data = load_jsonl(train_dataset_path) val_data = None if val_dataset_path: logger.info(f"Loading validation data: {val_dataset_path}") val_data = load_jsonl(val_dataset_path) # Load model and tokenizer model, tokenizer = load_model_and_tokenizer(config) # Format datasets def format_example(example): text = format_chat_template(example["messages"], tokenizer) return {"text": text} train_formatted = [format_example(ex) for ex in train_data] train_dataset = Dataset.from_list(train_formatted) eval_dataset = None if val_data: val_formatted = [format_example(ex) for ex in val_data] eval_dataset = Dataset.from_list(val_formatted) # Get LoRA config lora_config = get_lora_config(config) # Apply LoRA to model model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Training arguments training_args = SFTConfig( output_dir=str(output_dir), num_train_epochs=config.num_train_epochs, per_device_train_batch_size=config.per_device_train_batch_size, per_device_eval_batch_size=config.per_device_eval_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, learning_rate=config.learning_rate, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, lr_scheduler_type=config.lr_scheduler_type, fp16=config.fp16, bf16=config.bf16, gradient_checkpointing=config.gradient_checkpointing, optim=config.optim, logging_steps=config.logging_steps, save_steps=config.save_steps, eval_steps=config.eval_steps if eval_dataset else None, eval_strategy="steps" if eval_dataset else "no", save_total_limit=config.save_total_limit, load_best_model_at_end=True if eval_dataset else False, metric_for_best_model="eval_loss" if eval_dataset else None, greater_is_better=False, push_to_hub=config.push_to_hub, hub_model_id=config.hub_model_id, hub_token=config.hub_token, report_to=["tensorboard"], max_seq_length=config.max_seq_length, dataset_text_field="text", packing=False, ) # Create trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, ) # Train logger.info("Starting training...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) # Save final model final_path = output_dir / "final_adapter" logger.info(f"Saving adapter to: {final_path}") trainer.save_model(str(final_path)) tokenizer.save_pretrained(str(final_path)) # Push to hub if configured if config.push_to_hub and config.hub_model_id: logger.info(f"Pushing to Hub: {config.hub_model_id}") trainer.push_to_hub() logger.info("Training complete!") return str(final_path) def main(): """CLI entry point for LoRA training.""" import argparse import json parser = argparse.ArgumentParser( description="Fine-tune Qwen3-4B with standard LoRA (full precision)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Basic training python train_lora.py --train data/training/train.jsonl --output ./outputs # With validation and hub push python train_lora.py \\ --train data/training/train.jsonl \\ --val data/training/validation.jsonl \\ --output ./outputs \\ --push-to-hub \\ --hub-model-id username/ceo-voice-model Note: Standard LoRA requires more VRAM than QLoRA. Use QLoRA for constrained GPU environments. """, ) # Data arguments parser.add_argument("--train", required=True, help="Training JSONL file") parser.add_argument("--val", help="Validation JSONL file") parser.add_argument("--output", default="./outputs", help="Output directory") # Model arguments parser.add_argument( "--base-model", default="Qwen/Qwen3-4B-Instruct", help="Base model name", ) parser.add_argument( "--max-seq-length", type=int, default=2048, help="Maximum sequence length", ) parser.add_argument( "--dtype", choices=["float16", "bfloat16", "float32"], default="bfloat16", help="Torch dtype for model", ) # LoRA arguments parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank") parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha") parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout") # Training arguments parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=2, help="Batch size") parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation") parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate") # Hub arguments parser.add_argument("--push-to-hub", action="store_true", help="Push to HF Hub") parser.add_argument("--hub-model-id", help="Hub model ID") # Other arguments parser.add_argument("--resume", help="Resume from checkpoint") parser.add_argument("--config", help="JSON config file") args = parser.parse_args() # Build config config = LoRAConfig( base_model=args.base_model, max_seq_length=args.max_seq_length, torch_dtype=args.dtype, lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.learning_rate, ) # Override with JSON config if provided if args.config: with open(args.config, "r") as f: config_data = json.load(f) for key, value in config_data.items(): if hasattr(config, key): setattr(config, key, value) # Run training adapter_path = train_lora( train_dataset_path=args.train, val_dataset_path=args.val, output_dir=args.output, config=config, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, resume_from_checkpoint=args.resume, ) print(f"\nTraining complete!") print(f"Adapter saved to: {adapter_path}") if __name__ == "__main__": main()