""" Production training launcher with full CLI. Usage ----- # Minimal (uses all defaults) python scripts/train_production.py # Full options python scripts/train_production.py \\ --dataset data/processed/instruction_dataset.jsonl \\ --base-model meta-llama/Llama-3.1-8B-Instruct \\ --output checkpoints/worlddisasterlm-qlora \\ --epochs 3 \\ --lora-r 16 \\ --batch-size 2 \\ --grad-accum 8 \\ --report-to wandb # Consumer GPU (RTX 4090 24 GB) python scripts/train_production.py \\ --batch-size 1 --grad-accum 16 --max-seq-length 1024 """ from __future__ import annotations import argparse import logging logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") logger = logging.getLogger(__name__) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train WorldDisasterLM with QLoRA") # Model / data parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct") parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl") parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora") parser.add_argument("--max-seq-length", type=int, default=2048) # Training parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--learning-rate", type=float, default=2e-4) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--grad-accum", type=int, default=8) parser.add_argument("--warmup-ratio", type=float, default=0.03) # LoRA parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--lora-alpha", type=int, default=32) parser.add_argument("--lora-dropout", type=float, default=0.05) # Hardware parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit quantization (needs more VRAM)") parser.add_argument("--fp16", action="store_true", help="Use fp16 instead of bf16") # Tracking parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none") parser.add_argument("--seed", type=int, default=42) return parser.parse_args() def main() -> None: args = parse_args() from worlddisasterlm.training.train_qlora import QLoRAConfig, train config = QLoRAConfig( base_model=args.base_model, output_dir=args.output, dataset_path=args.dataset, max_seq_length=args.max_seq_length, use_4bit=not args.no_4bit, epochs=args.epochs, learning_rate=args.learning_rate, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, warmup_ratio=args.warmup_ratio, lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bf16=not args.fp16, fp16=args.fp16, report_to=args.report_to, seed=args.seed, ) logger.info("Effective training config: %s", config) train(config) if __name__ == "__main__": main()