WorldDisasterLM-8B / train.py
drdeveloper88's picture
Upload WorldDisasterLM-8B source code: FastAPI backend, training pipeline, 11-language support
495526b
Raw
History Blame Contribute Delete
2.42 kB
"""train.py — top-level entry-point, delegates to QLoRA production pipeline.
For full CLI options use:
python scripts/train_production.py --help
"""
from __future__ import annotations
import argparse
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train WorldDisasterLM-8B (QLoRA)")
parser.add_argument("--base-model", default="meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument("--dataset", default="data/processed/instruction_dataset.jsonl")
parser.add_argument("--output", default="checkpoints/worlddisasterlm-qlora")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--grad-accum", type=int, default=8)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--report-to", choices=["mlflow", "wandb", "none"], default="none")
return parser.parse_args()
def main() -> None:
args = parse_args()
try:
from worlddisasterlm.training.train_qlora import QLoRAConfig, train
except ImportError:
# Graceful fallback if GPU stack (torch/bitsandbytes) not installed
logger.warning(
"QLoRA dependencies not available. Using lightweight stub training. "
"Install with: pip install torch bitsandbytes peft trl"
)
from worlddisasterlm.training.fine_tune import TrainingConfig, run_training # type: ignore[import]
run_training(TrainingConfig(
base_model=args.base_model,
dataset_path=args.dataset,
output_dir=args.output,
epochs=args.epochs,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
))
return
config = QLoRAConfig(
base_model=args.base_model,
dataset_path=args.dataset,
output_dir=args.output,
epochs=args.epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
lora_r=args.lora_r,
report_to=args.report_to,
)
train(config)
if __name__ == "__main__":
main()