| """SFT warm-start trainer for the defender role. |
| |
| Designed to run in Colab or Kaggle with a single T4/L4 GPU. Locally on |
| CPU it would take hours; we therefore make this a *script* (rather than |
| a notebook cell) so it can be `python -m train.sft_warmstart`'d from a |
| GPU machine, and the README points reviewers at the matching Colab. |
| |
| The model: Qwen2.5-3B-Instruct (good float16 instruction follower under |
| 8GB VRAM with Unsloth's 4-bit loader). The adapter weights are saved to |
| ``checkpoints/defender_sft_adapter`` and re-loaded at the top of the |
| GRPO notebook. |
| |
| Why we do this *before* GRPO: with a cold model, P(format-compliant |
| response) ~= 0 on the first GRPO batch, which means the verifier reward |
| sees only the format-violation floor. The ~600 SFT examples push that |
| to >=80%, which is the bare minimum for GRPO to find gradient. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import sys |
|
|
| _HERE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.dirname(_HERE)) |
|
|
| |
| |
| |
| |
| try: |
| import unsloth |
| except ImportError: |
| pass |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data", default="data/sft_train.jsonl") |
| parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct") |
| parser.add_argument("--epochs", type=int, default=1) |
| parser.add_argument("--batch-size", type=int, default=4) |
| parser.add_argument("--grad-accum", type=int, default=4) |
| parser.add_argument("--lr", type=float, default=2e-4) |
| parser.add_argument("--max-seq-len", type=int, default=2048) |
| parser.add_argument("--out", default="checkpoints/defender_sft_adapter") |
| args = parser.parse_args() |
|
|
| try: |
| from datasets import load_dataset |
| from trl import SFTConfig, SFTTrainer |
| from unsloth import FastLanguageModel |
| except ImportError as exc: |
| sys.exit( |
| "This script requires unsloth, transformers, trl, and datasets, " |
| "which are GPU-only deps. Install with:\n" |
| " pip install --upgrade pip\n" |
| " pip install 'unsloth @ git+https://github.com/unslothai/unsloth.git' unsloth_zoo\n" |
| " pip install --no-deps trl peft accelerate bitsandbytes\n" |
| f"(import failed with: {exc})" |
| ) |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=args.model, |
| max_seq_length=args.max_seq_len, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
| |
| |
| |
| if tokenizer.eos_token in (None, "<EOS_TOKEN>"): |
| tokenizer.eos_token = "<|im_end|>" |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=16, |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| lora_alpha=32, |
| lora_dropout=0.0, |
| bias="none", |
| use_gradient_checkpointing="unsloth", |
| random_state=3407, |
| ) |
|
|
| data_path = os.path.join(os.path.dirname(_HERE), args.data) |
| ds = load_dataset("json", data_files=data_path, split="train") |
|
|
| def formatting_func(batch): |
| |
| |
| |
| if isinstance(batch.get("messages"), list) and batch["messages"] and isinstance(batch["messages"][0], list): |
| return [ |
| tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=False) |
| for m in batch["messages"] |
| ] |
| return [tokenizer.apply_chat_template(batch["messages"], tokenize=False, add_generation_prompt=False)] |
|
|
| out_dir = os.path.join(os.path.dirname(_HERE), args.out) |
| trainer = SFTTrainer( |
| model=model, |
| processing_class=tokenizer, |
| train_dataset=ds, |
| formatting_func=formatting_func, |
| args=SFTConfig( |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| num_train_epochs=args.epochs, |
| learning_rate=args.lr, |
| logging_steps=10, |
| warmup_steps=20, |
| optim="adamw_8bit", |
| weight_decay=0.01, |
| lr_scheduler_type="cosine", |
| seed=3407, |
| output_dir=out_dir, |
| save_strategy="epoch", |
| report_to="none", |
| bf16=True, |
| max_length=args.max_seq_len, |
| packing=False, |
| eos_token="<|im_end|>", |
| ), |
| ) |
| trainer.train() |
| model.save_pretrained(out_dir) |
| tokenizer.save_pretrained(out_dir) |
| print(f"Saved adapter to {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|