| """Supervised fine-tuning on Oracle CEO rollouts. |
| |
| The SFT target is the exact `<action>{…}</action>` string parse_response() |
| expects, so the checkpoint trivially parses at rollout time. The model |
| inherits Oracle-grade economic decisions + perfect rogue flagging. |
| |
| Run via accelerate launch for multi-GPU DDP: |
| accelerate launch --num_processes 8 sft.py \ |
| --data /scratch/simmart-data/sft_oracle.jsonl \ |
| --model Qwen/Qwen2.5-1.5B-Instruct \ |
| --out /scratch/simmart-runs/sft-1p5b-oracle \ |
| --epochs 2 --per-device-batch 2 --grad-accum 2 --lr 2e-4 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| if HERE not in sys.path: |
| sys.path.insert(0, HERE) |
|
|
| |
| from unsloth import FastLanguageModel |
|
|
| import torch |
| from datasets import load_dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--data", required=True, help="JSONL with {messages: [...]} records") |
| p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") |
| p.add_argument("--out", required=True) |
| p.add_argument("--epochs", type=float, default=2.0) |
| p.add_argument("--per-device-batch", type=int, default=2) |
| p.add_argument("--grad-accum", type=int, default=2) |
| p.add_argument("--lr", type=float, default=2e-4) |
| p.add_argument("--max-seq-len", type=int, default=4096) |
| p.add_argument("--lora-r", type=int, default=16) |
| p.add_argument("--lora-alpha", type=int, default=32) |
| p.add_argument("--warmup-ratio", type=float, default=0.05) |
| p.add_argument("--save-steps", type=int, default=200) |
| p.add_argument("--logging-steps", type=int, default=10) |
| args = p.parse_args() |
|
|
| Path(args.out).mkdir(parents=True, exist_ok=True) |
| with open(Path(args.out) / "sft_config.json", "w") as f: |
| json.dump(vars(args), f, indent=2) |
|
|
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
| |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| device_str = f"cuda:{local_rank}" |
|
|
| if local_rank == 0: |
| print(f"[sft] model={args.model} world_size={world_size}") |
| print(f"[sft] data={args.data} out={args.out}") |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=args.model, |
| max_seq_length=args.max_seq_len, |
| dtype=torch.bfloat16, |
| load_in_4bit=True, |
| device_map={"": local_rank}, |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=0.0, |
| bias="none", |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| use_gradient_checkpointing="unsloth", |
| random_state=42, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| ds = load_dataset("json", data_files=args.data, split="train") |
|
|
| def fmt(ex): |
| text = tokenizer.apply_chat_template( |
| ex["messages"], tokenize=False, add_generation_prompt=False, |
| ) |
| return {"text": text} |
|
|
| ds = ds.map(fmt, remove_columns=[c for c in ds.column_names if c != "text"]) |
|
|
| if local_rank == 0: |
| print(f"[sft] dataset size: {len(ds)} records") |
| print(f"[sft] sample tokenized length:") |
| lengths = [len(tokenizer(x["text"]).input_ids) for x in ds.select(range(min(32, len(ds))))] |
| print(f" min={min(lengths)} mean={sum(lengths)/len(lengths):.0f} max={max(lengths)}") |
|
|
| total_steps_per_epoch = len(ds) // (args.per_device_batch * args.grad_accum * world_size) |
| if local_rank == 0: |
| print(f"[sft] ~{total_steps_per_epoch} steps/epoch (total {int(total_steps_per_epoch * args.epochs)})") |
|
|
| cfg = SFTConfig( |
| output_dir=args.out, |
| per_device_train_batch_size=args.per_device_batch, |
| gradient_accumulation_steps=args.grad_accum, |
| num_train_epochs=args.epochs, |
| learning_rate=args.lr, |
| warmup_ratio=args.warmup_ratio, |
| lr_scheduler_type="cosine", |
| logging_steps=args.logging_steps, |
| save_strategy="steps", |
| save_steps=args.save_steps, |
| save_total_limit=4, |
| bf16=True, |
| max_seq_length=args.max_seq_len, |
| dataset_text_field="text", |
| packing=False, |
| report_to=[], |
| ddp_find_unused_parameters=False, |
| optim="adamw_torch", |
| seed=42, |
| gradient_checkpointing=False, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=ds, |
| args=cfg, |
| ) |
|
|
| trainer.train() |
|
|
| if local_rank == 0: |
| final = Path(args.out) / "final-adapter" |
| model.save_pretrained(str(final)) |
| tokenizer.save_pretrained(str(final)) |
| print(f"[sft] saved final adapter to {final}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|