Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Train Model CLI | |
| Unified CLI for fine-tuning the CEO voice model using QLoRA or LoRA. | |
| Designed to run on Hugging Face infrastructure. | |
| Usage: | |
| python scripts/train_model.py --dataset data/training/train.jsonl --output-repo username/model | |
| Environment variables: | |
| HF_TOKEN - Hugging Face token for pushing to Hub | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add src to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from rich.console import Console | |
| from rich.table import Table | |
| from rich.prompt import Confirm | |
| console = Console() | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Fine-tune the CEO voice model", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Train with QLoRA (recommended for A10G/T4) | |
| python scripts/train_model.py \\ | |
| --dataset data/training/train.jsonl \\ | |
| --val-dataset data/training/validation.jsonl \\ | |
| --output-repo username/ceo-voice-model \\ | |
| --epochs 3 | |
| # Train with standard LoRA (for larger GPUs) | |
| python scripts/train_model.py \\ | |
| --dataset data/training/train.jsonl \\ | |
| --method lora \\ | |
| --output-repo username/ceo-voice-model | |
| # Custom hyperparameters | |
| python scripts/train_model.py \\ | |
| --dataset data/training/train.jsonl \\ | |
| --lora-r 32 \\ | |
| --learning-rate 1e-4 \\ | |
| --batch-size 8 | |
| # Local training without Hub push | |
| python scripts/train_model.py \\ | |
| --dataset data/training/train.jsonl \\ | |
| --output ./local_outputs \\ | |
| --no-push | |
| Environment: | |
| HF_TOKEN - Hugging Face token (required for --push-to-hub) | |
| """, | |
| ) | |
| # Required arguments | |
| parser.add_argument( | |
| "--dataset", | |
| required=True, | |
| help="Path to training dataset (JSONL)", | |
| ) | |
| # Optional data arguments | |
| parser.add_argument( | |
| "--val-dataset", | |
| help="Path to validation dataset (JSONL)", | |
| ) | |
| # Method selection | |
| parser.add_argument( | |
| "--method", | |
| choices=["qlora", "lora"], | |
| default="qlora", | |
| help="Training method (default: qlora)", | |
| ) | |
| # Model arguments | |
| parser.add_argument( | |
| "--base-model", | |
| default="Qwen/Qwen3-4B-Instruct", | |
| help="Base model (default: Qwen/Qwen3-4B-Instruct)", | |
| ) | |
| parser.add_argument( | |
| "--max-seq-length", | |
| type=int, | |
| default=2048, | |
| help="Maximum sequence length (default: 2048)", | |
| ) | |
| # LoRA arguments | |
| parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank (default: 64)") | |
| parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha (default: 128)") | |
| parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout (default: 0.05)") | |
| # Training arguments | |
| parser.add_argument("--epochs", type=int, default=3, help="Training epochs (default: 3)") | |
| parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)") | |
| parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation (default: 4)") | |
| parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate (default: 2e-4)") | |
| parser.add_argument("--warmup-ratio", type=float, default=0.03, help="Warmup ratio (default: 0.03)") | |
| # Output arguments | |
| parser.add_argument( | |
| "--output", | |
| default="./outputs", | |
| help="Local output directory (default: ./outputs)", | |
| ) | |
| parser.add_argument( | |
| "--output-repo", | |
| help="Hugging Face Hub repo ID for model", | |
| ) | |
| parser.add_argument( | |
| "--no-push", | |
| action="store_true", | |
| help="Don't push to Hub (local only)", | |
| ) | |
| # Other arguments | |
| parser.add_argument("--resume", help="Resume from checkpoint") | |
| parser.add_argument("--config", help="JSON config file (overrides CLI args)") | |
| parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation") | |
| args = parser.parse_args() | |
| console.print("\n[bold blue]AI Executive - Model Training[/bold blue]") | |
| console.print("=" * 50) | |
| # Validate inputs | |
| dataset_path = Path(args.dataset) | |
| if not dataset_path.exists(): | |
| console.print(f"[red]Error:[/red] Dataset not found: {dataset_path}") | |
| return 1 | |
| val_path = None | |
| if args.val_dataset: | |
| val_path = Path(args.val_dataset) | |
| if not val_path.exists(): | |
| console.print(f"[red]Error:[/red] Validation dataset not found: {val_path}") | |
| return 1 | |
| # Check Hub token if pushing | |
| push_to_hub = args.output_repo and not args.no_push | |
| if push_to_hub: | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| console.print("[red]Error:[/red] HF_TOKEN not found in environment") | |
| console.print("\nSet it with:") | |
| console.print(" export HF_TOKEN=your_token_here") | |
| return 1 | |
| # Display configuration | |
| console.print("\n[yellow]Training Configuration[/yellow]") | |
| table = Table(show_header=False, box=None) | |
| table.add_column(style="dim", width=25) | |
| table.add_column(style="white") | |
| table.add_row("Method:", args.method.upper()) | |
| table.add_row("Base model:", args.base_model) | |
| table.add_row("Training data:", str(dataset_path)) | |
| if val_path: | |
| table.add_row("Validation data:", str(val_path)) | |
| table.add_row("", "") | |
| table.add_row("LoRA rank:", str(args.lora_r)) | |
| table.add_row("LoRA alpha:", str(args.lora_alpha)) | |
| table.add_row("Epochs:", str(args.epochs)) | |
| table.add_row("Batch size:", str(args.batch_size)) | |
| table.add_row("Gradient accumulation:", str(args.grad_accum)) | |
| table.add_row("Effective batch:", str(args.batch_size * args.grad_accum)) | |
| table.add_row("Learning rate:", str(args.learning_rate)) | |
| table.add_row("Max sequence length:", str(args.max_seq_length)) | |
| table.add_row("", "") | |
| table.add_row("Local output:", args.output) | |
| if push_to_hub: | |
| table.add_row("Hub repo:", args.output_repo) | |
| else: | |
| table.add_row("Hub push:", "Disabled") | |
| console.print(table) | |
| # Count training examples | |
| with open(dataset_path, "r") as f: | |
| train_count = sum(1 for _ in f) | |
| console.print(f"\n[dim]Training examples: {train_count}[/dim]") | |
| if val_path: | |
| with open(val_path, "r") as f: | |
| val_count = sum(1 for _ in f) | |
| console.print(f"[dim]Validation examples: {val_count}[/dim]") | |
| # Confirm | |
| if not args.yes: | |
| console.print() | |
| if not Confirm.ask("Start training?"): | |
| console.print("[dim]Cancelled.[/dim]") | |
| return 0 | |
| # Import training module based on method | |
| console.print("\n[yellow]Initializing training...[/yellow]") | |
| try: | |
| if args.method == "qlora": | |
| from src.training.train_qlora import train_qlora, QLoRAConfig | |
| config = QLoRAConfig( | |
| base_model=args.base_model, | |
| max_seq_length=args.max_seq_length, | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| warmup_ratio=args.warmup_ratio, | |
| ) | |
| # Override with JSON config if provided | |
| if args.config: | |
| with open(args.config, "r") as f: | |
| config_data = json.load(f) | |
| for key, value in config_data.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| adapter_path = train_qlora( | |
| train_dataset_path=str(dataset_path), | |
| val_dataset_path=str(val_path) if val_path else None, | |
| output_dir=args.output, | |
| config=config, | |
| push_to_hub=push_to_hub, | |
| hub_model_id=args.output_repo, | |
| resume_from_checkpoint=args.resume, | |
| ) | |
| else: | |
| from src.training.train_lora import train_lora, LoRAConfig | |
| config = LoRAConfig( | |
| base_model=args.base_model, | |
| max_seq_length=args.max_seq_length, | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| warmup_ratio=args.warmup_ratio, | |
| ) | |
| if args.config: | |
| with open(args.config, "r") as f: | |
| config_data = json.load(f) | |
| for key, value in config_data.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| adapter_path = train_lora( | |
| train_dataset_path=str(dataset_path), | |
| val_dataset_path=str(val_path) if val_path else None, | |
| output_dir=args.output, | |
| config=config, | |
| push_to_hub=push_to_hub, | |
| hub_model_id=args.output_repo, | |
| resume_from_checkpoint=args.resume, | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error:[/red] Missing dependencies: {e}") | |
| console.print("\nInstall with:") | |
| console.print(" pip install torch transformers peft trl bitsandbytes datasets") | |
| return 1 | |
| except Exception as e: | |
| console.print(f"[red]Training failed:[/red] {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| # Success | |
| console.print("\n" + "=" * 50) | |
| console.print("[bold green]Training complete![/bold green]") | |
| console.print(f"\nAdapter saved to: {adapter_path}") | |
| if push_to_hub: | |
| console.print(f"Model pushed to: https://huggingface.co/{args.output_repo}") | |
| console.print("\n[dim]Next steps:[/dim]") | |
| if push_to_hub: | |
| console.print(f"[dim] - Load model: AutoModel.from_pretrained('{args.output_repo}')[/dim]") | |
| console.print(f"[dim] - Run inference: python app/app.py[/dim]") | |
| console.print(f"[dim] - Evaluate: python scripts/evaluate_model.py --model {adapter_path}[/dim]") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |