ai_exec / scripts /train_model.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
#!/usr/bin/env python3
"""
Train Model CLI
Unified CLI for fine-tuning the CEO voice model using QLoRA or LoRA.
Designed to run on Hugging Face infrastructure.
Usage:
python scripts/train_model.py --dataset data/training/train.jsonl --output-repo username/model
Environment variables:
HF_TOKEN - Hugging Face token for pushing to Hub
"""
import argparse
import json
import os
import sys
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from rich.console import Console
from rich.table import Table
from rich.prompt import Confirm
console = Console()
def main():
parser = argparse.ArgumentParser(
description="Fine-tune the CEO voice model",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Train with QLoRA (recommended for A10G/T4)
python scripts/train_model.py \\
--dataset data/training/train.jsonl \\
--val-dataset data/training/validation.jsonl \\
--output-repo username/ceo-voice-model \\
--epochs 3
# Train with standard LoRA (for larger GPUs)
python scripts/train_model.py \\
--dataset data/training/train.jsonl \\
--method lora \\
--output-repo username/ceo-voice-model
# Custom hyperparameters
python scripts/train_model.py \\
--dataset data/training/train.jsonl \\
--lora-r 32 \\
--learning-rate 1e-4 \\
--batch-size 8
# Local training without Hub push
python scripts/train_model.py \\
--dataset data/training/train.jsonl \\
--output ./local_outputs \\
--no-push
Environment:
HF_TOKEN - Hugging Face token (required for --push-to-hub)
""",
)
# Required arguments
parser.add_argument(
"--dataset",
required=True,
help="Path to training dataset (JSONL)",
)
# Optional data arguments
parser.add_argument(
"--val-dataset",
help="Path to validation dataset (JSONL)",
)
# Method selection
parser.add_argument(
"--method",
choices=["qlora", "lora"],
default="qlora",
help="Training method (default: qlora)",
)
# Model arguments
parser.add_argument(
"--base-model",
default="Qwen/Qwen3-4B-Instruct",
help="Base model (default: Qwen/Qwen3-4B-Instruct)",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=2048,
help="Maximum sequence length (default: 2048)",
)
# LoRA arguments
parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank (default: 64)")
parser.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha (default: 128)")
parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout (default: 0.05)")
# Training arguments
parser.add_argument("--epochs", type=int, default=3, help="Training epochs (default: 3)")
parser.add_argument("--batch-size", type=int, default=4, help="Batch size (default: 4)")
parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation (default: 4)")
parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate (default: 2e-4)")
parser.add_argument("--warmup-ratio", type=float, default=0.03, help="Warmup ratio (default: 0.03)")
# Output arguments
parser.add_argument(
"--output",
default="./outputs",
help="Local output directory (default: ./outputs)",
)
parser.add_argument(
"--output-repo",
help="Hugging Face Hub repo ID for model",
)
parser.add_argument(
"--no-push",
action="store_true",
help="Don't push to Hub (local only)",
)
# Other arguments
parser.add_argument("--resume", help="Resume from checkpoint")
parser.add_argument("--config", help="JSON config file (overrides CLI args)")
parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation")
args = parser.parse_args()
console.print("\n[bold blue]AI Executive - Model Training[/bold blue]")
console.print("=" * 50)
# Validate inputs
dataset_path = Path(args.dataset)
if not dataset_path.exists():
console.print(f"[red]Error:[/red] Dataset not found: {dataset_path}")
return 1
val_path = None
if args.val_dataset:
val_path = Path(args.val_dataset)
if not val_path.exists():
console.print(f"[red]Error:[/red] Validation dataset not found: {val_path}")
return 1
# Check Hub token if pushing
push_to_hub = args.output_repo and not args.no_push
if push_to_hub:
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
console.print("[red]Error:[/red] HF_TOKEN not found in environment")
console.print("\nSet it with:")
console.print(" export HF_TOKEN=your_token_here")
return 1
# Display configuration
console.print("\n[yellow]Training Configuration[/yellow]")
table = Table(show_header=False, box=None)
table.add_column(style="dim", width=25)
table.add_column(style="white")
table.add_row("Method:", args.method.upper())
table.add_row("Base model:", args.base_model)
table.add_row("Training data:", str(dataset_path))
if val_path:
table.add_row("Validation data:", str(val_path))
table.add_row("", "")
table.add_row("LoRA rank:", str(args.lora_r))
table.add_row("LoRA alpha:", str(args.lora_alpha))
table.add_row("Epochs:", str(args.epochs))
table.add_row("Batch size:", str(args.batch_size))
table.add_row("Gradient accumulation:", str(args.grad_accum))
table.add_row("Effective batch:", str(args.batch_size * args.grad_accum))
table.add_row("Learning rate:", str(args.learning_rate))
table.add_row("Max sequence length:", str(args.max_seq_length))
table.add_row("", "")
table.add_row("Local output:", args.output)
if push_to_hub:
table.add_row("Hub repo:", args.output_repo)
else:
table.add_row("Hub push:", "Disabled")
console.print(table)
# Count training examples
with open(dataset_path, "r") as f:
train_count = sum(1 for _ in f)
console.print(f"\n[dim]Training examples: {train_count}[/dim]")
if val_path:
with open(val_path, "r") as f:
val_count = sum(1 for _ in f)
console.print(f"[dim]Validation examples: {val_count}[/dim]")
# Confirm
if not args.yes:
console.print()
if not Confirm.ask("Start training?"):
console.print("[dim]Cancelled.[/dim]")
return 0
# Import training module based on method
console.print("\n[yellow]Initializing training...[/yellow]")
try:
if args.method == "qlora":
from src.training.train_qlora import train_qlora, QLoRAConfig
config = QLoRAConfig(
base_model=args.base_model,
max_seq_length=args.max_seq_length,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
)
# Override with JSON config if provided
if args.config:
with open(args.config, "r") as f:
config_data = json.load(f)
for key, value in config_data.items():
if hasattr(config, key):
setattr(config, key, value)
adapter_path = train_qlora(
train_dataset_path=str(dataset_path),
val_dataset_path=str(val_path) if val_path else None,
output_dir=args.output,
config=config,
push_to_hub=push_to_hub,
hub_model_id=args.output_repo,
resume_from_checkpoint=args.resume,
)
else:
from src.training.train_lora import train_lora, LoRAConfig
config = LoRAConfig(
base_model=args.base_model,
max_seq_length=args.max_seq_length,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
)
if args.config:
with open(args.config, "r") as f:
config_data = json.load(f)
for key, value in config_data.items():
if hasattr(config, key):
setattr(config, key, value)
adapter_path = train_lora(
train_dataset_path=str(dataset_path),
val_dataset_path=str(val_path) if val_path else None,
output_dir=args.output,
config=config,
push_to_hub=push_to_hub,
hub_model_id=args.output_repo,
resume_from_checkpoint=args.resume,
)
except ImportError as e:
console.print(f"[red]Error:[/red] Missing dependencies: {e}")
console.print("\nInstall with:")
console.print(" pip install torch transformers peft trl bitsandbytes datasets")
return 1
except Exception as e:
console.print(f"[red]Training failed:[/red] {e}")
import traceback
traceback.print_exc()
return 1
# Success
console.print("\n" + "=" * 50)
console.print("[bold green]Training complete![/bold green]")
console.print(f"\nAdapter saved to: {adapter_path}")
if push_to_hub:
console.print(f"Model pushed to: https://huggingface.co/{args.output_repo}")
console.print("\n[dim]Next steps:[/dim]")
if push_to_hub:
console.print(f"[dim] - Load model: AutoModel.from_pretrained('{args.output_repo}')[/dim]")
console.print(f"[dim] - Run inference: python app/app.py[/dim]")
console.print(f"[dim] - Evaluate: python scripts/evaluate_model.py --model {adapter_path}[/dim]")
return 0
if __name__ == "__main__":
exit(main())