"""Supervised fine-tuning on Oracle CEO rollouts. The SFT target is the exact `{…}` string parse_response() expects, so the checkpoint trivially parses at rollout time. The model inherits Oracle-grade economic decisions + perfect rogue flagging. Run via accelerate launch for multi-GPU DDP: accelerate launch --num_processes 8 sft.py \ --data /scratch/simmart-data/sft_oracle.jsonl \ --model Qwen/Qwen2.5-1.5B-Instruct \ --out /scratch/simmart-runs/sft-1p5b-oracle \ --epochs 2 --per-device-batch 2 --grad-accum 2 --lr 2e-4 """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path HERE = os.path.dirname(os.path.abspath(__file__)) if HERE not in sys.path: sys.path.insert(0, HERE) # Unsloth must be imported before transformers/peft to install its patches from unsloth import FastLanguageModel # noqa: E402 import torch # noqa: E402 from datasets import load_dataset # noqa: E402 from trl import SFTConfig, SFTTrainer # noqa: E402 def main(): p = argparse.ArgumentParser() p.add_argument("--data", required=True, help="JSONL with {messages: [...]} records") p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") p.add_argument("--out", required=True) p.add_argument("--epochs", type=float, default=2.0) p.add_argument("--per-device-batch", type=int, default=2) p.add_argument("--grad-accum", type=int, default=2) p.add_argument("--lr", type=float, default=2e-4) p.add_argument("--max-seq-len", type=int, default=4096) p.add_argument("--lora-r", type=int, default=16) p.add_argument("--lora-alpha", type=int, default=32) p.add_argument("--warmup-ratio", type=float, default=0.05) p.add_argument("--save-steps", type=int, default=200) p.add_argument("--logging-steps", type=int, default=10) args = p.parse_args() Path(args.out).mkdir(parents=True, exist_ok=True) with open(Path(args.out) / "sft_config.json", "w") as f: json.dump(vars(args), f, indent=2) local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) # Pin this rank to its local GPU before loading any CUDA tensors, # otherwise Unsloth / bitsandbytes puts everything on cuda:0 and DDP # blows up with "loaded on different device". if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device_str = f"cuda:{local_rank}" if local_rank == 0: print(f"[sft] model={args.model} world_size={world_size}") print(f"[sft] data={args.data} out={args.out}") model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_len, dtype=torch.bfloat16, load_in_4bit=True, device_map={"": local_rank}, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.0, bias="none", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], use_gradient_checkpointing="unsloth", random_state=42, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token ds = load_dataset("json", data_files=args.data, split="train") def fmt(ex): text = tokenizer.apply_chat_template( ex["messages"], tokenize=False, add_generation_prompt=False, ) return {"text": text} ds = ds.map(fmt, remove_columns=[c for c in ds.column_names if c != "text"]) if local_rank == 0: print(f"[sft] dataset size: {len(ds)} records") print(f"[sft] sample tokenized length:") lengths = [len(tokenizer(x["text"]).input_ids) for x in ds.select(range(min(32, len(ds))))] print(f" min={min(lengths)} mean={sum(lengths)/len(lengths):.0f} max={max(lengths)}") total_steps_per_epoch = len(ds) // (args.per_device_batch * args.grad_accum * world_size) if local_rank == 0: print(f"[sft] ~{total_steps_per_epoch} steps/epoch (total {int(total_steps_per_epoch * args.epochs)})") cfg = SFTConfig( output_dir=args.out, per_device_train_batch_size=args.per_device_batch, gradient_accumulation_steps=args.grad_accum, num_train_epochs=args.epochs, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, lr_scheduler_type="cosine", logging_steps=args.logging_steps, save_strategy="steps", save_steps=args.save_steps, save_total_limit=4, bf16=True, max_seq_length=args.max_seq_len, dataset_text_field="text", packing=False, report_to=[], ddp_find_unused_parameters=False, optim="adamw_torch", seed=42, gradient_checkpointing=False, # Unsloth handles this via use_gradient_checkpointing ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=ds, args=cfg, ) trainer.train() if local_rank == 0: final = Path(args.out) / "final-adapter" model.save_pretrained(str(final)) tokenizer.save_pretrained(str(final)) print(f"[sft] saved final adapter to {final}") if __name__ == "__main__": main()