"""SFT warm-start trainer for the defender role. Designed to run in Colab or Kaggle with a single T4/L4 GPU. Locally on CPU it would take hours; we therefore make this a *script* (rather than a notebook cell) so it can be `python -m train.sft_warmstart`'d from a GPU machine, and the README points reviewers at the matching Colab. The model: Qwen2.5-3B-Instruct (good float16 instruction follower under 8GB VRAM with Unsloth's 4-bit loader). The adapter weights are saved to ``checkpoints/defender_sft_adapter`` and re-loaded at the top of the GRPO notebook. Why we do this *before* GRPO: with a cold model, P(format-compliant response) ~= 0 on the first GRPO batch, which means the verifier reward sees only the format-violation floor. The ~600 SFT examples push that to >=80%, which is the bare minimum for GRPO to find gradient. """ from __future__ import annotations import argparse import os import sys _HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(_HERE)) # Unsloth must be imported BEFORE trl/transformers/peft so its # monkey-patches (incl. the eos_token resolution and the SFTTrainer # formatting_func validator) actually apply. A late import silently # leaves SFTConfig's eos_token as the literal placeholder "". try: import unsloth # noqa: F401 except ImportError: pass def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--data", default="data/sft_train.jsonl") parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct") parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--max-seq-len", type=int, default=2048) parser.add_argument("--out", default="checkpoints/defender_sft_adapter") args = parser.parse_args() try: from datasets import load_dataset from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel except ImportError as exc: sys.exit( "This script requires unsloth, transformers, trl, and datasets, " "which are GPU-only deps. Install with:\n" " pip install --upgrade pip\n" " pip install 'unsloth @ git+https://github.com/unslothai/unsloth.git' unsloth_zoo\n" " pip install --no-deps trl peft accelerate bitsandbytes\n" f"(import failed with: {exc})" ) model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_len, dtype=None, load_in_4bit=True, ) # unsloth occasionally hands back a tokenizer whose eos_token is the # literal placeholder "" (not in the Qwen2.5 vocab); fix it # before trl's SFTConfig vocab check rejects it. if tokenizer.eos_token in (None, ""): tokenizer.eos_token = "<|im_end|>" model = FastLanguageModel.get_peft_model( model, r=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=32, lora_dropout=0.0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, ) data_path = os.path.join(os.path.dirname(_HERE), args.data) ds = load_dataset("json", data_files=data_path, split="train") def formatting_func(batch): # trl 0.24's SFTTrainer (with unsloth's wrapper) calls # formatting_func with a *batched* dict-of-lists and expects a # list[str] back; older trl called it per-example. if isinstance(batch.get("messages"), list) and batch["messages"] and isinstance(batch["messages"][0], list): return [ tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=False) for m in batch["messages"] ] return [tokenizer.apply_chat_template(batch["messages"], tokenize=False, add_generation_prompt=False)] out_dir = os.path.join(os.path.dirname(_HERE), args.out) trainer = SFTTrainer( model=model, processing_class=tokenizer, train_dataset=ds, formatting_func=formatting_func, args=SFTConfig( per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, num_train_epochs=args.epochs, learning_rate=args.lr, logging_steps=10, warmup_steps=20, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="cosine", seed=3407, output_dir=out_dir, save_strategy="epoch", report_to="none", bf16=True, max_length=args.max_seq_len, packing=False, eos_token="<|im_end|>", ), ) trainer.train() model.save_pretrained(out_dir) tokenizer.save_pretrained(out_dir) print(f"Saved adapter to {out_dir}") if __name__ == "__main__": main()