File size: 2,416 Bytes
495526b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""train.py — top-level entry-point, delegates to QLoRA production pipeline.

For full CLI options use:
    python scripts/train_production.py --help
"""

from __future__ import annotations

import argparse
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train WorldDisasterLM-8B (QLoRA)")
    parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl")
    parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--learning-rate", type=float, default=2e-4)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--grad-accum", type=int, default=8)
    parser.add_argument("--lora-r", type=int, default=16)
    parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none")
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    try:
        from worlddisasterlm.training.train_qlora import QLoRAConfig, train
    except ImportError:
        # Graceful fallback if GPU stack (torch/bitsandbytes) not installed
        logger.warning(
            "QLoRA dependencies not available. Using lightweight stub training. "
            "Install with: pip install torch bitsandbytes peft trl"
        )
        from worlddisasterlm.training.fine_tune import TrainingConfig, run_training  # type: ignore[import]
        run_training(TrainingConfig(
            base_model=args.base_model,
            dataset_path=args.dataset,
            output_dir=args.output,
            epochs=args.epochs,
            learning_rate=args.learning_rate,
            batch_size=args.batch_size,
        ))
        return

    config = QLoRAConfig(
        base_model=args.base_model,
        dataset_path=args.dataset,
        output_dir=args.output,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        lora_r=args.lora_r,
        report_to=args.report_to,
    )
    train(config)


if __name__ == "__main__":
    main()