EHRGym / scripts /train_sft.py
adtserapio's picture
Upload folder using huggingface_hub
b4806b0 verified
"""
EHRGym SFT Training Script (Unsloth)
=======================================
Fine-tunes any HF chat model on the exported EHRGym trajectory dataset
using Unsloth + QLoRA for 2x faster training and 70% less VRAM.
Usage (on H100):
python scripts/train_sft.py \
--dataset runs/datasets/ehrgym_sft.jsonl \
--model unsloth/Qwen2.5-7B-Instruct \
--output runs/checkpoints/ehrgym-sft \
--epochs 3 --lr 2e-4 --batch-size 4 --grad-accum 4
Quick smoke test (tiny model):
python scripts/train_sft.py \
--dataset runs/datasets/ehrgym_sft.jsonl \
--model unsloth/Qwen2.5-0.5B-Instruct \
--output runs/checkpoints/ehrgym-sft-tiny \
--epochs 1 --batch-size 2 --lora-r 16
Full 16-bit LoRA (no 4-bit quant):
python scripts/train_sft.py ... --no-4bit
Save merged model for vLLM:
python scripts/train_sft.py ... --save-method merged_16bit
"""
from __future__ import annotations
import argparse
import json
import logging
import os
from pathlib import Path
from typing import Any
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="EHRGym SFT fine-tuning (Unsloth)")
p.add_argument("--dataset", required=True, help="Path to ehrgym_sft.jsonl")
p.add_argument("--model", default="unsloth/Qwen2.5-7B-Instruct",
help="HuggingFace model ID (use unsloth/ prefix for optimised 4-bit quants)")
p.add_argument("--output", default="runs/checkpoints/ehrgym-sft",
help="Output directory for checkpoints")
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--batch-size", type=int, default=4, help="Per-device train batch size")
p.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps")
p.add_argument("--max-seq-len", type=int, default=2048, help="Max sequence length")
p.add_argument("--lora-r", type=int, default=64, help="LoRA rank")
p.add_argument("--lora-alpha", type=int, default=128, help="LoRA alpha (default: 2 * lora_r)")
p.add_argument("--lora-dropout", type=float, default=0.0,
help="LoRA dropout (Unsloth optimised kernels work best with 0)")
p.add_argument("--no-4bit", action="store_true",
help="Disable 4-bit quantisation; use 16-bit LoRA instead")
p.add_argument("--seed", type=int, default=3407)
p.add_argument("--logging-steps", type=int, default=1)
p.add_argument("--save-steps", type=int, default=50)
p.add_argument("--max-steps", type=int, default=-1,
help="Override epoch count with max_steps (useful for smoke tests)")
p.add_argument("--save-method", default="lora",
choices=["lora", "merged_16bit", "merged_4bit", "mxfp4"],
help="How to save the final model")
p.add_argument("--wandb-project", default=None,
help="Weights & Biases project name (optional)")
return p.parse_args()
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_sft_jsonl(path: str | Path) -> list[dict[str, Any]]:
"""Load JSONL and return list of dicts with a 'messages' column."""
rows = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
# Keep only the chat messages
rows.append({"messages": row["messages"]})
log.info("Loaded %d examples from %s", len(rows), path)
return rows
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
# ---- Imports (Unsloth patches HF on import) ----
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
# ---- W&B ----
if args.wandb_project:
os.environ["WANDB_PROJECT"] = args.wandb_project
report_to = "wandb"
else:
report_to = "none"
# ---- Load model + tokenizer via Unsloth ----
load_in_4bit = not args.no_4bit
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_len,
load_in_4bit=load_in_4bit,
dtype=None, # auto-detect (bf16 on H100)
)
log.info("Loaded %s (4-bit=%s)", args.model, load_in_4bit)
# ---- Apply LoRA adapter via Unsloth ----
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth", # Unsloth's optimised checkpointing
random_state=args.seed,
)
model.print_trainable_parameters()
# ---- Chat template ----
# Our JSONL uses {"role": "system/user/assistant", "content": "..."} (ChatML-compatible)
# Detect the right template from the model name
model_lower = args.model.lower()
if "qwen" in model_lower:
chat_template = "qwen-2.5"
elif "llama-3" in model_lower or "llama3" in model_lower:
chat_template = "llama-3.1"
elif "phi" in model_lower:
chat_template = "phi-4"
elif "gemma" in model_lower:
chat_template = "gemma-3"
elif "mistral" in model_lower:
chat_template = "mistral"
else:
chat_template = "chatml"
tokenizer = get_chat_template(tokenizer, chat_template=chat_template)
log.info("Using chat template: %s", chat_template)
# ---- Dataset ----
raw = load_sft_jsonl(args.dataset)
dataset = Dataset.from_list(raw)
def formatting_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
)
for convo in convos
]
return {"text": texts}
dataset = dataset.map(formatting_func, batched=True)
log.info("Dataset ready: %d examples", len(dataset))
# ---- Training config ----
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
training_args = SFTConfig(
output_dir=str(output_dir),
num_train_epochs=args.epochs,
max_steps=args.max_steps,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
bf16=True,
fp16=False,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=3,
seed=args.seed,
report_to=report_to,
max_seq_length=args.max_seq_len,
dataset_text_field="text",
optim="adamw_8bit",
weight_decay=0.01,
packing=False, # disable packing—our examples are short enough
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
)
# ---- Train ----
log.info("Starting training …")
trainer_stats = trainer.train()
log.info("Training complete. Loss: %.4f", trainer_stats.training_loss)
# ---- Save ----
if args.save_method == "lora":
save_dir = str(output_dir / "lora_adapter")
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
log.info("Saved LoRA adapter → %s", save_dir)
else:
save_dir = str(output_dir / "merged")
model.save_pretrained_merged(save_dir, tokenizer, save_method=args.save_method)
log.info("Saved merged model (%s) → %s", args.save_method, save_dir)
log.info("Done ✓")
if __name__ == "__main__":
main()