Spaces:
Paused
Paused
| """ | |
| Standard LoRA Training Module | |
| Fine-tune Qwen3-4B using standard LoRA (full precision) with PEFT + TRL. | |
| Use this for training on larger GPUs without quantization. | |
| Example usage: | |
| from src.training.train_lora import train_lora | |
| train_lora( | |
| train_dataset_path="data/training/train.jsonl", | |
| val_dataset_path="data/training/validation.jsonl", | |
| output_dir="./outputs", | |
| push_to_hub=True, | |
| hub_model_id="username/ceo-voice-model", | |
| ) | |
| """ | |
| import os | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| from loguru import logger | |
| try: | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import Dataset | |
| TRAINING_AVAILABLE = True | |
| except ImportError as e: | |
| TRAINING_AVAILABLE = False | |
| logger.warning(f"Training dependencies not available: {e}") | |
| class LoRAConfig: | |
| """Configuration for standard LoRA training.""" | |
| # Model configuration | |
| base_model: str = "Qwen/Qwen3-4B-Instruct" | |
| max_seq_length: int = 2048 | |
| torch_dtype: str = "bfloat16" # or "float16", "float32" | |
| # LoRA configuration | |
| lora_r: int = 64 | |
| lora_alpha: int = 128 | |
| lora_dropout: float = 0.05 | |
| target_modules: list = field(default_factory=lambda: [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ]) | |
| # Training configuration | |
| num_train_epochs: int = 3 | |
| per_device_train_batch_size: int = 2 | |
| per_device_eval_batch_size: int = 2 | |
| gradient_accumulation_steps: int = 8 | |
| learning_rate: float = 2e-4 | |
| weight_decay: float = 0.01 | |
| warmup_ratio: float = 0.03 | |
| lr_scheduler_type: str = "cosine" | |
| # Optimization | |
| fp16: bool = False | |
| bf16: bool = True | |
| gradient_checkpointing: bool = True | |
| optim: str = "adamw_torch" | |
| # Logging and saving | |
| logging_steps: int = 10 | |
| save_steps: int = 100 | |
| eval_steps: int = 100 | |
| save_total_limit: int = 3 | |
| # Hub configuration | |
| push_to_hub: bool = False | |
| hub_model_id: Optional[str] = None | |
| hub_token: Optional[str] = None | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary.""" | |
| return { | |
| "base_model": self.base_model, | |
| "max_seq_length": self.max_seq_length, | |
| "torch_dtype": self.torch_dtype, | |
| "lora_r": self.lora_r, | |
| "lora_alpha": self.lora_alpha, | |
| "lora_dropout": self.lora_dropout, | |
| "target_modules": self.target_modules, | |
| "num_train_epochs": self.num_train_epochs, | |
| "per_device_train_batch_size": self.per_device_train_batch_size, | |
| "gradient_accumulation_steps": self.gradient_accumulation_steps, | |
| "learning_rate": self.learning_rate, | |
| } | |
| def get_lora_config(config: LoRAConfig) -> "LoraConfig": | |
| """Get LoRA configuration.""" | |
| return LoraConfig( | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| lora_dropout=config.lora_dropout, | |
| target_modules=config.target_modules, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| def get_torch_dtype(dtype_str: str): | |
| """Convert string to torch dtype.""" | |
| dtype_map = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| return dtype_map.get(dtype_str, torch.bfloat16) | |
| def load_model_and_tokenizer(config: LoRAConfig): | |
| """ | |
| Load the base model and tokenizer without quantization. | |
| Args: | |
| config: LoRA configuration | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| logger.info(f"Loading model: {config.base_model}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| config.base_model, | |
| trust_remote_code=True, | |
| padding_side="right", | |
| ) | |
| # Ensure special tokens | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Get torch dtype | |
| torch_dtype = get_torch_dtype(config.torch_dtype) | |
| # Load model without quantization | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch_dtype, | |
| ) | |
| # Enable gradient checkpointing | |
| if config.gradient_checkpointing: | |
| model.gradient_checkpointing_enable() | |
| logger.info(f"Model loaded: {model.dtype}") | |
| return model, tokenizer | |
| def train_lora( | |
| train_dataset_path: str | Path, | |
| val_dataset_path: Optional[str | Path] = None, | |
| output_dir: str = "./outputs", | |
| config: Optional[LoRAConfig] = None, | |
| push_to_hub: bool = False, | |
| hub_model_id: Optional[str] = None, | |
| hub_token: Optional[str] = None, | |
| resume_from_checkpoint: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Run standard LoRA fine-tuning on the voice model. | |
| Args: | |
| train_dataset_path: Path to training JSONL | |
| val_dataset_path: Path to validation JSONL | |
| output_dir: Directory for outputs | |
| config: LoRA configuration (uses defaults if None) | |
| push_to_hub: Whether to push to HF Hub | |
| hub_model_id: Hub repository ID | |
| hub_token: HF token | |
| resume_from_checkpoint: Checkpoint path to resume from | |
| Returns: | |
| Path to saved adapter | |
| """ | |
| if not TRAINING_AVAILABLE: | |
| raise ImportError( | |
| "Training dependencies not available. Install with:\n" | |
| "pip install torch transformers peft trl datasets" | |
| ) | |
| # Use default config if not provided | |
| if config is None: | |
| config = LoRAConfig() | |
| # Override hub settings | |
| if push_to_hub: | |
| config.push_to_hub = True | |
| if hub_model_id: | |
| config.hub_model_id = hub_model_id | |
| # Use provided token or fall back to environment variable | |
| config.hub_token = hub_token or os.environ.get("HF_TOKEN") | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info("Starting LoRA training (full precision)") | |
| logger.info(f"Config: {config.to_dict()}") | |
| # Load datasets | |
| from .prepare_dataset import load_jsonl, format_chat_template | |
| logger.info(f"Loading training data: {train_dataset_path}") | |
| train_data = load_jsonl(train_dataset_path) | |
| val_data = None | |
| if val_dataset_path: | |
| logger.info(f"Loading validation data: {val_dataset_path}") | |
| val_data = load_jsonl(val_dataset_path) | |
| # Load model and tokenizer | |
| model, tokenizer = load_model_and_tokenizer(config) | |
| # Format datasets | |
| def format_example(example): | |
| text = format_chat_template(example["messages"], tokenizer) | |
| return {"text": text} | |
| train_formatted = [format_example(ex) for ex in train_data] | |
| train_dataset = Dataset.from_list(train_formatted) | |
| eval_dataset = None | |
| if val_data: | |
| val_formatted = [format_example(ex) for ex in val_data] | |
| eval_dataset = Dataset.from_list(val_formatted) | |
| # Get LoRA config | |
| lora_config = get_lora_config(config) | |
| # Apply LoRA to model | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # Training arguments | |
| training_args = SFTConfig( | |
| output_dir=str(output_dir), | |
| num_train_epochs=config.num_train_epochs, | |
| per_device_train_batch_size=config.per_device_train_batch_size, | |
| per_device_eval_batch_size=config.per_device_eval_batch_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| learning_rate=config.learning_rate, | |
| weight_decay=config.weight_decay, | |
| warmup_ratio=config.warmup_ratio, | |
| lr_scheduler_type=config.lr_scheduler_type, | |
| fp16=config.fp16, | |
| bf16=config.bf16, | |
| gradient_checkpointing=config.gradient_checkpointing, | |
| optim=config.optim, | |
| logging_steps=config.logging_steps, | |
| save_steps=config.save_steps, | |
| eval_steps=config.eval_steps if eval_dataset else None, | |
| eval_strategy="steps" if eval_dataset else "no", | |
| save_total_limit=config.save_total_limit, | |
| load_best_model_at_end=True if eval_dataset else False, | |
| metric_for_best_model="eval_loss" if eval_dataset else None, | |
| greater_is_better=False, | |
| push_to_hub=config.push_to_hub, | |
| hub_model_id=config.hub_model_id, | |
| hub_token=config.hub_token, | |
| report_to=["tensorboard"], | |
| max_seq_length=config.max_seq_length, | |
| dataset_text_field="text", | |
| packing=False, | |
| ) | |
| # Create trainer | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| # Train | |
| logger.info("Starting training...") | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| # Save final model | |
| final_path = output_dir / "final_adapter" | |
| logger.info(f"Saving adapter to: {final_path}") | |
| trainer.save_model(str(final_path)) | |
| tokenizer.save_pretrained(str(final_path)) | |
| # Push to hub if configured | |
| if config.push_to_hub and config.hub_model_id: | |
| logger.info(f"Pushing to Hub: {config.hub_model_id}") | |
| trainer.push_to_hub() | |
| logger.info("Training complete!") | |
| return str(final_path) | |
| def main(): | |
| """CLI entry point for LoRA training.""" | |
| import argparse | |
| import json | |
| parser = argparse.ArgumentParser( | |
| description="Fine-tune Qwen3-4B with standard LoRA (full precision)", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Basic training | |
| python train_lora.py --train data/training/train.jsonl --output ./outputs | |
| # With validation and hub push | |
| python train_lora.py \\ | |
| --train data/training/train.jsonl \\ | |
| --val data/training/validation.jsonl \\ | |
| --output ./outputs \\ | |
| --push-to-hub \\ | |
| --hub-model-id username/ceo-voice-model | |
| Note: Standard LoRA requires more VRAM than QLoRA. Use QLoRA for | |
| constrained GPU environments. | |
| """, | |
| ) | |
| # Data arguments | |
| parser.add_argument("--train", required=True, help="Training JSONL file") | |
| parser.add_argument("--val", help="Validation JSONL file") | |
| parser.add_argument("--output", default="./outputs", help="Output directory") | |
| # Model arguments | |
| parser.add_argument( | |
| "--base-model", | |
| default="Qwen/Qwen3-4B-Instruct", | |
| help="Base model name", | |
| ) | |
| parser.add_argument( | |
| "--max-seq-length", | |
| type=int, | |
| default=2048, | |
| help="Maximum sequence length", | |
| ) | |
| parser.add_argument( | |
| "--dtype", | |
| choices=["float16", "bfloat16", "float32"], | |
| default="bfloat16", | |
| help="Torch dtype for model", | |
| ) | |
| # LoRA arguments | |
| parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank") | |
| parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha") | |
| parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout") | |
| # Training arguments | |
| parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") | |
| parser.add_argument("--batch-size", type=int, default=2, help="Batch size") | |
| parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation") | |
| parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate") | |
| # Hub arguments | |
| parser.add_argument("--push-to-hub", action="store_true", help="Push to HF Hub") | |
| parser.add_argument("--hub-model-id", help="Hub model ID") | |
| # Other arguments | |
| parser.add_argument("--resume", help="Resume from checkpoint") | |
| parser.add_argument("--config", help="JSON config file") | |
| args = parser.parse_args() | |
| # Build config | |
| config = LoRAConfig( | |
| base_model=args.base_model, | |
| max_seq_length=args.max_seq_length, | |
| torch_dtype=args.dtype, | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| ) | |
| # Override with JSON config if provided | |
| if args.config: | |
| with open(args.config, "r") as f: | |
| config_data = json.load(f) | |
| for key, value in config_data.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| # Run training | |
| adapter_path = train_lora( | |
| train_dataset_path=args.train, | |
| val_dataset_path=args.val, | |
| output_dir=args.output, | |
| config=config, | |
| push_to_hub=args.push_to_hub, | |
| hub_model_id=args.hub_model_id, | |
| resume_from_checkpoint=args.resume, | |
| ) | |
| print(f"\nTraining complete!") | |
| print(f"Adapter saved to: {adapter_path}") | |
| if __name__ == "__main__": | |
| main() | |