SimMart / sft.py
Viani's picture
HF Space: 4-dept SimMart env + 1.5B SFT+GRPO training (hackathon submission)
5c35138
"""Supervised fine-tuning on Oracle CEO rollouts.
The SFT target is the exact `<action>{…}</action>` string parse_response()
expects, so the checkpoint trivially parses at rollout time. The model
inherits Oracle-grade economic decisions + perfect rogue flagging.
Run via accelerate launch for multi-GPU DDP:
accelerate launch --num_processes 8 sft.py \
--data /scratch/simmart-data/sft_oracle.jsonl \
--model Qwen/Qwen2.5-1.5B-Instruct \
--out /scratch/simmart-runs/sft-1p5b-oracle \
--epochs 2 --per-device-batch 2 --grad-accum 2 --lr 2e-4
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
HERE = os.path.dirname(os.path.abspath(__file__))
if HERE not in sys.path:
sys.path.insert(0, HERE)
# Unsloth must be imported before transformers/peft to install its patches
from unsloth import FastLanguageModel # noqa: E402
import torch # noqa: E402
from datasets import load_dataset # noqa: E402
from trl import SFTConfig, SFTTrainer # noqa: E402
def main():
p = argparse.ArgumentParser()
p.add_argument("--data", required=True, help="JSONL with {messages: [...]} records")
p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct")
p.add_argument("--out", required=True)
p.add_argument("--epochs", type=float, default=2.0)
p.add_argument("--per-device-batch", type=int, default=2)
p.add_argument("--grad-accum", type=int, default=2)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--max-seq-len", type=int, default=4096)
p.add_argument("--lora-r", type=int, default=16)
p.add_argument("--lora-alpha", type=int, default=32)
p.add_argument("--warmup-ratio", type=float, default=0.05)
p.add_argument("--save-steps", type=int, default=200)
p.add_argument("--logging-steps", type=int, default=10)
args = p.parse_args()
Path(args.out).mkdir(parents=True, exist_ok=True)
with open(Path(args.out) / "sft_config.json", "w") as f:
json.dump(vars(args), f, indent=2)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# Pin this rank to its local GPU before loading any CUDA tensors,
# otherwise Unsloth / bitsandbytes puts everything on cuda:0 and DDP
# blows up with "loaded on different device".
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device_str = f"cuda:{local_rank}"
if local_rank == 0:
print(f"[sft] model={args.model} world_size={world_size}")
print(f"[sft] data={args.data} out={args.out}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_len,
dtype=torch.bfloat16,
load_in_4bit=True,
device_map={"": local_rank},
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0.0,
bias="none",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
random_state=42,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset("json", data_files=args.data, split="train")
def fmt(ex):
text = tokenizer.apply_chat_template(
ex["messages"], tokenize=False, add_generation_prompt=False,
)
return {"text": text}
ds = ds.map(fmt, remove_columns=[c for c in ds.column_names if c != "text"])
if local_rank == 0:
print(f"[sft] dataset size: {len(ds)} records")
print(f"[sft] sample tokenized length:")
lengths = [len(tokenizer(x["text"]).input_ids) for x in ds.select(range(min(32, len(ds))))]
print(f" min={min(lengths)} mean={sum(lengths)/len(lengths):.0f} max={max(lengths)}")
total_steps_per_epoch = len(ds) // (args.per_device_batch * args.grad_accum * world_size)
if local_rank == 0:
print(f"[sft] ~{total_steps_per_epoch} steps/epoch (total {int(total_steps_per_epoch * args.epochs)})")
cfg = SFTConfig(
output_dir=args.out,
per_device_train_batch_size=args.per_device_batch,
gradient_accumulation_steps=args.grad_accum,
num_train_epochs=args.epochs,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="cosine",
logging_steps=args.logging_steps,
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=4,
bf16=True,
max_seq_length=args.max_seq_len,
dataset_text_field="text",
packing=False,
report_to=[],
ddp_find_unused_parameters=False,
optim="adamw_torch",
seed=42,
gradient_checkpointing=False, # Unsloth handles this via use_gradient_checkpointing
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
args=cfg,
)
trainer.train()
if local_rank == 0:
final = Path(args.out) / "final-adapter"
model.save_pretrained(str(final))
tokenizer.save_pretrained(str(final))
print(f"[sft] saved final adapter to {final}")
if __name__ == "__main__":
main()