"""train.py — top-level entry-point, delegates to QLoRA production pipeline. For full CLI options use: python scripts/train_production.py --help """ from __future__ import annotations import argparse import logging logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") logger = logging.getLogger(__name__) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train WorldDisasterLM-8B (QLoRA)") parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct") parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl") parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora") parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--learning-rate", type=float, default=2e-4) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--grad-accum", type=int, default=8) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none") return parser.parse_args() def main() -> None: args = parse_args() try: from worlddisasterlm.training.train_qlora import QLoRAConfig, train except ImportError: # Graceful fallback if GPU stack (torch/bitsandbytes) not installed logger.warning( "QLoRA dependencies not available. Using lightweight stub training. " "Install with: pip install torch bitsandbytes peft trl" ) from worlddisasterlm.training.fine_tune import TrainingConfig, run_training # type: ignore[import] run_training(TrainingConfig( base_model=args.base_model, dataset_path=args.dataset, output_dir=args.output, epochs=args.epochs, learning_rate=args.learning_rate, batch_size=args.batch_size, )) return config = QLoRAConfig( base_model=args.base_model, dataset_path=args.dataset, output_dir=args.output, epochs=args.epochs, learning_rate=args.learning_rate, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, lora_r=args.lora_r, report_to=args.report_to, ) train(config) if __name__ == "__main__": main()