#!/usr/bin/env python3 """ Train Model CLI Unified CLI for fine-tuning the CEO voice model using QLoRA or LoRA. Designed to run on Hugging Face infrastructure. Usage: python scripts/train_model.py --dataset data/training/train.jsonl --output-repo username/model Environment variables: HF_TOKEN - Hugging Face token for pushing to Hub """ import argparse import json import os import sys from pathlib import Path # Add src to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from rich.console import Console from rich.table import Table from rich.prompt import Confirm console = Console() def main(): parser = argparse.ArgumentParser( description="Fine-tune the CEO voice model", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Train with QLoRA (recommended for A10G/T4) python scripts/train_model.py \\ --dataset data/training/train.jsonl \\ --val-dataset data/training/validation.jsonl \\ --output-repo username/ceo-voice-model \\ --epochs 3 # Train with standard LoRA (for larger GPUs) python scripts/train_model.py \\ --dataset data/training/train.jsonl \\ --method lora \\ --output-repo username/ceo-voice-model # Custom hyperparameters python scripts/train_model.py \\ --dataset data/training/train.jsonl \\ --lora-r 32 \\ --learning-rate 1e-4 \\ --batch-size 8 # Local training without Hub push python scripts/train_model.py \\ --dataset data/training/train.jsonl \\ --output ./local_outputs \\ --no-push Environment: HF_TOKEN - Hugging Face token (required for --push-to-hub) """, ) # Required arguments parser.add_argument( "--dataset", required=True, help="Path to training dataset (JSONL)", ) # Optional data arguments parser.add_argument( "--val-dataset", help="Path to validation dataset (JSONL)", ) # Method selection parser.add_argument( "--method", choices=["qlora", "lora"], default="qlora", help="Training method (default: qlora)", ) # Model arguments parser.add_argument( "--base-model", default="Qwen/Qwen3-4B-Instruct", help="Base model (default: Qwen/Qwen3-4B-Instruct)", ) parser.add_argument( "--max-seq-length", type=int, default=2048, help="Maximum sequence length (default: 2048)", ) # LoRA arguments parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank (default: 64)") parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha (default: 128)") parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout (default: 0.05)") # Training arguments parser.add_argument("--epochs", type=int, default=3, help="Training epochs (default: 3)") parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)") parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation (default: 4)") parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate (default: 2e-4)") parser.add_argument("--warmup-ratio", type=float, default=0.03, help="Warmup ratio (default: 0.03)") # Output arguments parser.add_argument( "--output", default="./outputs", help="Local output directory (default: ./outputs)", ) parser.add_argument( "--output-repo", help="Hugging Face Hub repo ID for model", ) parser.add_argument( "--no-push", action="store_true", help="Don't push to Hub (local only)", ) # Other arguments parser.add_argument("--resume", help="Resume from checkpoint") parser.add_argument("--config", help="JSON config file (overrides CLI args)") parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation") args = parser.parse_args() console.print("\n[bold blue]AI Executive - Model Training[/bold blue]") console.print("=" * 50) # Validate inputs dataset_path = Path(args.dataset) if not dataset_path.exists(): console.print(f"[red]Error:[/red] Dataset not found: {dataset_path}") return 1 val_path = None if args.val_dataset: val_path = Path(args.val_dataset) if not val_path.exists(): console.print(f"[red]Error:[/red] Validation dataset not found: {val_path}") return 1 # Check Hub token if pushing push_to_hub = args.output_repo and not args.no_push if push_to_hub: hf_token = os.environ.get("HF_TOKEN") if not hf_token: console.print("[red]Error:[/red] HF_TOKEN not found in environment") console.print("\nSet it with:") console.print(" export HF_TOKEN=your_token_here") return 1 # Display configuration console.print("\n[yellow]Training Configuration[/yellow]") table = Table(show_header=False, box=None) table.add_column(style="dim", width=25) table.add_column(style="white") table.add_row("Method:", args.method.upper()) table.add_row("Base model:", args.base_model) table.add_row("Training data:", str(dataset_path)) if val_path: table.add_row("Validation data:", str(val_path)) table.add_row("", "") table.add_row("LoRA rank:", str(args.lora_r)) table.add_row("LoRA alpha:", str(args.lora_alpha)) table.add_row("Epochs:", str(args.epochs)) table.add_row("Batch size:", str(args.batch_size)) table.add_row("Gradient accumulation:", str(args.grad_accum)) table.add_row("Effective batch:", str(args.batch_size * args.grad_accum)) table.add_row("Learning rate:", str(args.learning_rate)) table.add_row("Max sequence length:", str(args.max_seq_length)) table.add_row("", "") table.add_row("Local output:", args.output) if push_to_hub: table.add_row("Hub repo:", args.output_repo) else: table.add_row("Hub push:", "Disabled") console.print(table) # Count training examples with open(dataset_path, "r") as f: train_count = sum(1 for _ in f) console.print(f"\n[dim]Training examples: {train_count}[/dim]") if val_path: with open(val_path, "r") as f: val_count = sum(1 for _ in f) console.print(f"[dim]Validation examples: {val_count}[/dim]") # Confirm if not args.yes: console.print() if not Confirm.ask("Start training?"): console.print("[dim]Cancelled.[/dim]") return 0 # Import training module based on method console.print("\n[yellow]Initializing training...[/yellow]") try: if args.method == "qlora": from src.training.train_qlora import train_qlora, QLoRAConfig config = QLoRAConfig( base_model=args.base_model, max_seq_length=args.max_seq_length, lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, ) # Override with JSON config if provided if args.config: with open(args.config, "r") as f: config_data = json.load(f) for key, value in config_data.items(): if hasattr(config, key): setattr(config, key, value) adapter_path = train_qlora( train_dataset_path=str(dataset_path), val_dataset_path=str(val_path) if val_path else None, output_dir=args.output, config=config, push_to_hub=push_to_hub, hub_model_id=args.output_repo, resume_from_checkpoint=args.resume, ) else: from src.training.train_lora import train_lora, LoRAConfig config = LoRAConfig( base_model=args.base_model, max_seq_length=args.max_seq_length, lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, ) if args.config: with open(args.config, "r") as f: config_data = json.load(f) for key, value in config_data.items(): if hasattr(config, key): setattr(config, key, value) adapter_path = train_lora( train_dataset_path=str(dataset_path), val_dataset_path=str(val_path) if val_path else None, output_dir=args.output, config=config, push_to_hub=push_to_hub, hub_model_id=args.output_repo, resume_from_checkpoint=args.resume, ) except ImportError as e: console.print(f"[red]Error:[/red] Missing dependencies: {e}") console.print("\nInstall with:") console.print(" pip install torch transformers peft trl bitsandbytes datasets") return 1 except Exception as e: console.print(f"[red]Training failed:[/red] {e}") import traceback traceback.print_exc() return 1 # Success console.print("\n" + "=" * 50) console.print("[bold green]Training complete![/bold green]") console.print(f"\nAdapter saved to: {adapter_path}") if push_to_hub: console.print(f"Model pushed to: https://huggingface.co/{args.output_repo}") console.print("\n[dim]Next steps:[/dim]") if push_to_hub: console.print(f"[dim] - Load model: AutoModel.from_pretrained('{args.output_repo}')[/dim]") console.print(f"[dim] - Run inference: python app/app.py[/dim]") console.print(f"[dim] - Evaluate: python scripts/evaluate_model.py --model {adapter_path}[/dim]") return 0 if __name__ == "__main__": exit(main())