WorldDisasterLM-8B / scripts /train_production.py
drdeveloper88's picture
Upload WorldDisasterLM-8B source code: FastAPI backend, training pipeline, 11-language support
495526b
Raw
History Blame Contribute Delete
3.1 kB
"""
Production training launcher with full CLI.
Usage
-----
# Minimal (uses all defaults)
python scripts/train_production.py
# Full options
python scripts/train_production.py \\
--dataset data/processed/instruction_dataset.jsonl \\
--base-model meta-llama/Llama-3.1-8B-Instruct \\
--output checkpoints/worlddisasterlm-qlora \\
--epochs 3 \\
--lora-r 16 \\
--batch-size 2 \\
--grad-accum 8 \\
--report-to wandb
# Consumer GPU (RTX 4090 24 GB)
python scripts/train_production.py \\
--batch-size 1 --grad-accum 16 --max-seq-length 1024
"""
from __future__ import annotations
import argparse
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train WorldDisasterLM with QLoRA")
# Model / data
parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl")
parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora")
parser.add_argument("--max-seq-length", type=int, default=2048)
# Training
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--grad-accum", type=int, default=8)
parser.add_argument("--warmup-ratio", type=float, default=0.03)
# LoRA
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.05)
# Hardware
parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit quantization (needs more VRAM)")
parser.add_argument("--fp16", action="store_true", help="Use fp16 instead of bf16")
# Tracking
parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none")
parser.add_argument("--seed", type=int, default=42)
return parser.parse_args()
def main() -> None:
args = parse_args()
from worlddisasterlm.training.train_qlora import QLoRAConfig, train
config = QLoRAConfig(
base_model=args.base_model,
output_dir=args.output,
dataset_path=args.dataset,
max_seq_length=args.max_seq_length,
use_4bit=not args.no_4bit,
epochs=args.epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
warmup_ratio=args.warmup_ratio,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bf16=not args.fp16,
fp16=args.fp16,
report_to=args.report_to,
seed=args.seed,
)
logger.info("Effective training config: %s", config)
train(config)
if __name__ == "__main__":
main()