opensoc-env / train /sft_warmstart.py
shivam2k3's picture
sft_warmstart: import unsloth first; batched formatting_func
99d0d29
"""SFT warm-start trainer for the defender role.
Designed to run in Colab or Kaggle with a single T4/L4 GPU. Locally on
CPU it would take hours; we therefore make this a *script* (rather than
a notebook cell) so it can be `python -m train.sft_warmstart`'d from a
GPU machine, and the README points reviewers at the matching Colab.
The model: Qwen2.5-3B-Instruct (good float16 instruction follower under
8GB VRAM with Unsloth's 4-bit loader). The adapter weights are saved to
``checkpoints/defender_sft_adapter`` and re-loaded at the top of the
GRPO notebook.
Why we do this *before* GRPO: with a cold model, P(format-compliant
response) ~= 0 on the first GRPO batch, which means the verifier reward
sees only the format-violation floor. The ~600 SFT examples push that
to >=80%, which is the bare minimum for GRPO to find gradient.
"""
from __future__ import annotations
import argparse
import os
import sys
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
# Unsloth must be imported BEFORE trl/transformers/peft so its
# monkey-patches (incl. the eos_token resolution and the SFTTrainer
# formatting_func validator) actually apply. A late import silently
# leaves SFTConfig's eos_token as the literal placeholder "<EOS_TOKEN>".
try:
import unsloth # noqa: F401
except ImportError:
pass
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--data", default="data/sft_train.jsonl")
parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--grad-accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--max-seq-len", type=int, default=2048)
parser.add_argument("--out", default="checkpoints/defender_sft_adapter")
args = parser.parse_args()
try:
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
except ImportError as exc:
sys.exit(
"This script requires unsloth, transformers, trl, and datasets, "
"which are GPU-only deps. Install with:\n"
" pip install --upgrade pip\n"
" pip install 'unsloth @ git+https://github.com/unslothai/unsloth.git' unsloth_zoo\n"
" pip install --no-deps trl peft accelerate bitsandbytes\n"
f"(import failed with: {exc})"
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_len,
dtype=None,
load_in_4bit=True,
)
# unsloth occasionally hands back a tokenizer whose eos_token is the
# literal placeholder "<EOS_TOKEN>" (not in the Qwen2.5 vocab); fix it
# before trl's SFTConfig vocab check rejects it.
if tokenizer.eos_token in (None, "<EOS_TOKEN>"):
tokenizer.eos_token = "<|im_end|>"
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=32,
lora_dropout=0.0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)
data_path = os.path.join(os.path.dirname(_HERE), args.data)
ds = load_dataset("json", data_files=data_path, split="train")
def formatting_func(batch):
# trl 0.24's SFTTrainer (with unsloth's wrapper) calls
# formatting_func with a *batched* dict-of-lists and expects a
# list[str] back; older trl called it per-example.
if isinstance(batch.get("messages"), list) and batch["messages"] and isinstance(batch["messages"][0], list):
return [
tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=False)
for m in batch["messages"]
]
return [tokenizer.apply_chat_template(batch["messages"], tokenize=False, add_generation_prompt=False)]
out_dir = os.path.join(os.path.dirname(_HERE), args.out)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=ds,
formatting_func=formatting_func,
args=SFTConfig(
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
num_train_epochs=args.epochs,
learning_rate=args.lr,
logging_steps=10,
warmup_steps=20,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="cosine",
seed=3407,
output_dir=out_dir,
save_strategy="epoch",
report_to="none",
bf16=True,
max_length=args.max_seq_len,
packing=False,
eos_token="<|im_end|>",
),
)
trainer.train()
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
print(f"Saved adapter to {out_dir}")
if __name__ == "__main__":
main()