File size: 3,096 Bytes
495526b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Production training launcher with full CLI.

Usage
-----
  # Minimal (uses all defaults)
  python scripts/train_production.py

  # Full options
  python scripts/train_production.py \\
      --dataset data/processed/instruction_dataset.jsonl \\
      --base-model meta-llama/Llama-3.1-8B-Instruct \\
      --output checkpoints/worlddisasterlm-qlora \\
      --epochs 3 \\
      --lora-r 16 \\
      --batch-size 2 \\
      --grad-accum 8 \\
      --report-to wandb

  # Consumer GPU (RTX 4090 24 GB)
  python scripts/train_production.py \\
      --batch-size 1 --grad-accum 16 --max-seq-length 1024
"""

from __future__ import annotations

import argparse
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train WorldDisasterLM with QLoRA")

    # Model / data
    parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl")
    parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora")
    parser.add_argument("--max-seq-length", type=int, default=2048)

    # Training
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--learning-rate", type=float, default=2e-4)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--grad-accum", type=int, default=8)
    parser.add_argument("--warmup-ratio", type=float, default=0.03)

    # LoRA
    parser.add_argument("--lora-r", type=int, default=16)
    parser.add_argument("--lora-alpha", type=int, default=32)
    parser.add_argument("--lora-dropout", type=float, default=0.05)

    # Hardware
    parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit quantization (needs more VRAM)")
    parser.add_argument("--fp16", action="store_true", help="Use fp16 instead of bf16")

    # Tracking
    parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none")
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def main() -> None:
    args = parse_args()

    from worlddisasterlm.training.train_qlora import QLoRAConfig, train

    config = QLoRAConfig(
        base_model=args.base_model,
        output_dir=args.output,
        dataset_path=args.dataset,
        max_seq_length=args.max_seq_length,
        use_4bit=not args.no_4bit,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        warmup_ratio=args.warmup_ratio,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bf16=not args.fp16,
        fp16=args.fp16,
        report_to=args.report_to,
        seed=args.seed,
    )

    logger.info("Effective training config: %s", config)
    train(config)


if __name__ == "__main__":
    main()