feat: add training pipeline with SFT and RLVR support for Qwen 2.5-3B-Instruct
Browse files- notebooks/training.ipynb +707 -46
- training/train_rlvr.py +13 -3
- training/train_sft.py +2 -1
notebooks/training.ipynb
CHANGED
|
@@ -2,27 +2,37 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
|
|
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# 🏢 CORP-ENV · Qwen 2.5-
|
| 8 |
"\n",
|
| 9 |
-
"**End-to-end reproducible notebook** for training a Qwen 2.5-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"\n",
|
| 11 |
"CORP-ENV is a multi-agent corporate decision environment where a Master Agent governs a **Shared Workspace Document (SWD)** across long-horizon planning episodes, coordinating frozen worker agents. Rewards measure SWD integrity, task completion, milestone adherence, reasoning density, and LLM-judge scores.\n",
|
| 12 |
"\n",
|
| 13 |
"| Component | Detail |\n",
|
| 14 |
"|---|---|\n",
|
| 15 |
-
"| **Base model** | `Qwen/Qwen2.5-
|
| 16 |
"| **SFT script** | `training/train_sft.py` |\n",
|
| 17 |
"| **RLVR script** | `training/train_rlvr.py` |\n",
|
| 18 |
"| **Tasks** | E1 Launch Readiness, M1 Budget Reallocation, H1 Acquisition Defence |\n",
|
| 19 |
-
"| **Runtime** | Colab
|
| 20 |
"\n",
|
| 21 |
"---"
|
| 22 |
]
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "markdown",
|
|
|
|
| 26 |
"metadata": {},
|
| 27 |
"source": [
|
| 28 |
"## 1️⃣ Setup & Installation"
|
|
@@ -31,49 +41,92 @@
|
|
| 31 |
{
|
| 32 |
"cell_type": "code",
|
| 33 |
"execution_count": null,
|
|
|
|
| 34 |
"metadata": {},
|
| 35 |
"outputs": [],
|
| 36 |
"source": [
|
| 37 |
"import os\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"\n",
|
| 39 |
"# ===== Configuration =====\n",
|
| 40 |
"REPO_URL = \"https://huggingface.co/spaces/Navigam/corp-env\" # Change to your repo\n",
|
| 41 |
-
"BASE_MODEL = \"Qwen/Qwen2.5-
|
| 42 |
"HF_ORG_OR_USER = \"Navigam\" # Your HF username/org\n",
|
| 43 |
"\n",
|
| 44 |
-
"# SFT hyperparameters\n",
|
| 45 |
-
"SFT_MAX_STEPS = 30
|
| 46 |
"SFT_EPOCHS = 2.0\n",
|
| 47 |
"SFT_LR = 2e-4\n",
|
| 48 |
"SFT_BATCH_SIZE = 1\n",
|
| 49 |
"SFT_GRAD_ACCUM = 8\n",
|
|
|
|
| 50 |
"\n",
|
| 51 |
-
"# RLVR hyperparameters\n",
|
| 52 |
"RLVR_ROUNDS = 3\n",
|
| 53 |
-
"RLVR_MAX_PROMPTS = 128\n",
|
| 54 |
-
"RLVR_N_SAMPLES = 8\n",
|
| 55 |
"RLVR_TEMPERATURE = 0.7\n",
|
|
|
|
|
|
|
| 56 |
"\n",
|
| 57 |
"# Eval\n",
|
| 58 |
"EVAL_EPISODES = 3\n",
|
| 59 |
-
"EVAL_MAX_STEPS = 30"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
]
|
| 61 |
},
|
| 62 |
{
|
| 63 |
"cell_type": "code",
|
| 64 |
"execution_count": null,
|
|
|
|
| 65 |
"metadata": {},
|
| 66 |
"outputs": [],
|
| 67 |
"source": [
|
| 68 |
-
"#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"!git clone {REPO_URL} corp_gym 2>/dev/null || echo 'Repo already cloned'\n",
|
| 70 |
"%cd corp_gym\n",
|
| 71 |
-
"!pip install -
|
| 72 |
-
"!pip install -e \".[training,plots]\""
|
| 73 |
]
|
| 74 |
},
|
| 75 |
{
|
| 76 |
"cell_type": "markdown",
|
|
|
|
| 77 |
"metadata": {},
|
| 78 |
"source": [
|
| 79 |
"## 2️⃣ Hugging Face Login (optional)"
|
|
@@ -82,6 +135,7 @@
|
|
| 82 |
{
|
| 83 |
"cell_type": "code",
|
| 84 |
"execution_count": null,
|
|
|
|
| 85 |
"metadata": {},
|
| 86 |
"outputs": [],
|
| 87 |
"source": [
|
|
@@ -91,6 +145,325 @@
|
|
| 91 |
},
|
| 92 |
{
|
| 93 |
"cell_type": "markdown",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"metadata": {},
|
| 95 |
"source": [
|
| 96 |
"## 3️⃣ Environment Validation\n",
|
|
@@ -101,6 +474,7 @@
|
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
| 103 |
"execution_count": null,
|
|
|
|
| 104 |
"metadata": {},
|
| 105 |
"outputs": [],
|
| 106 |
"source": [
|
|
@@ -110,6 +484,7 @@
|
|
| 110 |
},
|
| 111 |
{
|
| 112 |
"cell_type": "markdown",
|
|
|
|
| 113 |
"metadata": {},
|
| 114 |
"source": [
|
| 115 |
"## 4️⃣ Data Preparation\n",
|
|
@@ -120,6 +495,7 @@
|
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
| 122 |
"execution_count": null,
|
|
|
|
| 123 |
"metadata": {},
|
| 124 |
"outputs": [],
|
| 125 |
"source": [
|
|
@@ -148,27 +524,56 @@
|
|
| 148 |
{
|
| 149 |
"cell_type": "code",
|
| 150 |
"execution_count": null,
|
|
|
|
| 151 |
"metadata": {},
|
| 152 |
"outputs": [],
|
| 153 |
"source": [
|
| 154 |
-
"# Check data stats\n",
|
| 155 |
-
"import json\n",
|
| 156 |
-
"from pathlib import Path\n",
|
| 157 |
-
"\n",
|
| 158 |
"sft_path = Path(\"data/sft/e1_m1_h1_examples.jsonl\")\n",
|
| 159 |
"if sft_path.exists():\n",
|
| 160 |
" lines = [json.loads(l) for l in sft_path.read_text().strip().splitlines() if l.strip()]\n",
|
| 161 |
" print(f\"\\n✅ SFT dataset: {len(lines)} examples\")\n",
|
| 162 |
-
" # Count by number of messages\n",
|
| 163 |
" turn_counts = [len(ex['messages']) for ex in lines]\n",
|
| 164 |
" print(f\" Avg turns per example: {sum(turn_counts)/len(turn_counts):.1f}\")\n",
|
| 165 |
" print(f\" Min/Max turns: {min(turn_counts)} / {max(turn_counts)}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
"else:\n",
|
| 167 |
" print(\"❌ SFT dataset not found. Check data preparation above.\")"
|
| 168 |
]
|
| 169 |
},
|
| 170 |
{
|
| 171 |
"cell_type": "markdown",
|
|
|
|
| 172 |
"metadata": {},
|
| 173 |
"source": [
|
| 174 |
"## 5️⃣ Baseline Evaluation\n",
|
|
@@ -179,6 +584,7 @@
|
|
| 179 |
{
|
| 180 |
"cell_type": "code",
|
| 181 |
"execution_count": null,
|
|
|
|
| 182 |
"metadata": {},
|
| 183 |
"outputs": [],
|
| 184 |
"source": [
|
|
@@ -186,39 +592,142 @@
|
|
| 186 |
"!python eval.py --policy oracle --label oracle --episodes {EVAL_EPISODES}"
|
| 187 |
]
|
| 188 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
{
|
| 190 |
"cell_type": "markdown",
|
|
|
|
| 191 |
"metadata": {},
|
| 192 |
"source": [
|
| 193 |
"## 6️⃣ SFT Training (Unsloth + TRL)\n",
|
| 194 |
"\n",
|
| 195 |
-
"Fine-tune Qwen 2.5-
|
| 196 |
"\n",
|
| 197 |
"- Uses `unsloth.FastLanguageModel` for 4-bit QLoRA\n",
|
| 198 |
"- Uses `trl.SFTTrainer` with messages-format conversational SFT\n",
|
| 199 |
-
"- LoRA `r=32`, targets all attention + MLP projections"
|
|
|
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "code",
|
| 204 |
"execution_count": null,
|
|
|
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [],
|
| 207 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 208 |
"!python training/train_sft.py \\\n",
|
| 209 |
" --model {BASE_MODEL} \\\n",
|
| 210 |
" --data data/sft/e1_m1_h1_examples.jsonl \\\n",
|
| 211 |
" --output outputs/sft_adapter \\\n",
|
|
|
|
| 212 |
" --max-steps {SFT_MAX_STEPS} \\\n",
|
| 213 |
" --epochs {SFT_EPOCHS} \\\n",
|
| 214 |
" --lr {SFT_LR} \\\n",
|
| 215 |
" --batch-size {SFT_BATCH_SIZE} \\\n",
|
| 216 |
" --grad-accum {SFT_GRAD_ACCUM} \\\n",
|
| 217 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
]
|
| 219 |
},
|
| 220 |
{
|
| 221 |
"cell_type": "markdown",
|
|
|
|
| 222 |
"metadata": {},
|
| 223 |
"source": [
|
| 224 |
"## 7️⃣ Evaluate SFT Adapter"
|
|
@@ -227,20 +736,58 @@
|
|
| 227 |
{
|
| 228 |
"cell_type": "code",
|
| 229 |
"execution_count": null,
|
|
|
|
| 230 |
"metadata": {},
|
| 231 |
"outputs": [],
|
| 232 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
"!python eval.py \\\n",
|
| 234 |
" --policy hf \\\n",
|
| 235 |
" --label sft \\\n",
|
| 236 |
" --model {BASE_MODEL} \\\n",
|
| 237 |
" --adapter outputs/sft_adapter \\\n",
|
| 238 |
" --episodes {EVAL_EPISODES} \\\n",
|
| 239 |
-
" --max-steps {EVAL_MAX_STEPS}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
]
|
| 241 |
},
|
| 242 |
{
|
| 243 |
"cell_type": "markdown",
|
|
|
|
| 244 |
"metadata": {},
|
| 245 |
"source": [
|
| 246 |
"## 8️⃣ RLVR Training (Rejection-Sampling FT)\n",
|
|
@@ -252,15 +799,25 @@
|
|
| 252 |
"4. SFT on that curated set\n",
|
| 253 |
"5. Repeating for multiple outer rounds\n",
|
| 254 |
"\n",
|
| 255 |
-
"This avoids the zero-variance gradient problem seen with GRPO on CORP-ENV."
|
|
|
|
|
|
|
| 256 |
]
|
| 257 |
},
|
| 258 |
{
|
| 259 |
"cell_type": "code",
|
| 260 |
"execution_count": null,
|
|
|
|
| 261 |
"metadata": {},
|
| 262 |
"outputs": [],
|
| 263 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
"!python training/train_rlvr.py \\\n",
|
| 265 |
" --model {BASE_MODEL} \\\n",
|
| 266 |
" --adapter outputs/sft_adapter \\\n",
|
|
@@ -270,15 +827,40 @@
|
|
| 270 |
" --n-samples {RLVR_N_SAMPLES} \\\n",
|
| 271 |
" --temperature {RLVR_TEMPERATURE} \\\n",
|
| 272 |
" --max-prompts {RLVR_MAX_PROMPTS} \\\n",
|
|
|
|
|
|
|
| 273 |
" --strict-json \\\n",
|
| 274 |
" --use-stub-workers \\\n",
|
| 275 |
" --disable-llm-judge \\\n",
|
| 276 |
-
" --stats-file results/runs/rlvr_qwen2.
|
| 277 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
]
|
| 279 |
},
|
| 280 |
{
|
| 281 |
"cell_type": "markdown",
|
|
|
|
| 282 |
"metadata": {},
|
| 283 |
"source": [
|
| 284 |
"## 9️⃣ Evaluate RLVR Adapter"
|
|
@@ -287,72 +869,149 @@
|
|
| 287 |
{
|
| 288 |
"cell_type": "code",
|
| 289 |
"execution_count": null,
|
|
|
|
| 290 |
"metadata": {},
|
| 291 |
"outputs": [],
|
| 292 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
"!python eval.py \\\n",
|
| 294 |
" --policy hf \\\n",
|
| 295 |
" --label rlvr \\\n",
|
| 296 |
" --model {BASE_MODEL} \\\n",
|
| 297 |
" --adapter outputs/rlvr_adapter \\\n",
|
| 298 |
" --episodes {EVAL_EPISODES} \\\n",
|
| 299 |
-
" --max-steps {EVAL_MAX_STEPS}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
]
|
| 301 |
},
|
| 302 |
{
|
| 303 |
"cell_type": "markdown",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
"metadata": {},
|
|
|
|
| 305 |
"source": [
|
| 306 |
-
"## 📊
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
"\n",
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
]
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
| 317 |
"execution_count": null,
|
|
|
|
| 318 |
"metadata": {},
|
| 319 |
"outputs": [],
|
| 320 |
"source": [
|
|
|
|
| 321 |
"!python plot_results.py \\\n",
|
| 322 |
" --inputs results/runs \\\n",
|
| 323 |
-
" --output-dir results/
|
| 324 |
]
|
| 325 |
},
|
| 326 |
{
|
| 327 |
"cell_type": "code",
|
| 328 |
"execution_count": null,
|
|
|
|
| 329 |
"metadata": {},
|
| 330 |
"outputs": [],
|
| 331 |
"source": [
|
| 332 |
"from IPython.display import Image, display, Markdown\n",
|
| 333 |
-
"from pathlib import Path\n",
|
| 334 |
"\n",
|
| 335 |
-
"plot_dir = Path(\"results/
|
| 336 |
"if not plot_dir.exists():\n",
|
| 337 |
" plot_dir = Path(\"results/model_compare_qwen25_fresh_no_grpo_ep5rlvr\")\n",
|
| 338 |
"\n",
|
| 339 |
-
"
|
| 340 |
-
"
|
| 341 |
-
"
|
|
|
|
| 342 |
"\n",
|
| 343 |
-
"# Show summary table\n",
|
| 344 |
-
"summary_md = plot_dir / \"comparison_summary.md\"\n",
|
| 345 |
-
"if summary_md.exists():\n",
|
| 346 |
-
"
|
|
|
|
|
|
|
| 347 |
]
|
| 348 |
},
|
| 349 |
{
|
| 350 |
"cell_type": "markdown",
|
|
|
|
| 351 |
"metadata": {},
|
| 352 |
"source": [
|
| 353 |
"## 📋 Results Summary\n",
|
| 354 |
"\n",
|
| 355 |
-
"Expected progression for Qwen 2.5-
|
| 356 |
"\n",
|
| 357 |
"| Stage | E1 Terminal Reward | M1 Terminal Reward | H1 Terminal Reward | M1 Success |\n",
|
| 358 |
"|-------|-------------------|-------------------|-------------------|------------|\n",
|
|
@@ -361,7 +1020,9 @@
|
|
| 361 |
"| SFT | 0.910 | 0.943 | 0.889 | 100% |\n",
|
| 362 |
"| RLVR | 0.910 | 0.932 | 0.779 | 80% |\n",
|
| 363 |
"\n",
|
| 364 |
-
"> **Key takeaway**: SFT dramatically improves M1 (budget reallocation) from 0% to 100% success rate. RLVR maintains strong performance while reducing reliance on fixed trajectories."
|
|
|
|
|
|
|
| 365 |
]
|
| 366 |
}
|
| 367 |
],
|
|
@@ -384,4 +1045,4 @@
|
|
| 384 |
},
|
| 385 |
"nbformat": 4,
|
| 386 |
"nbformat_minor": 5
|
| 387 |
-
}
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
+
"id": "23a31c02",
|
| 6 |
"metadata": {},
|
| 7 |
"source": [
|
| 8 |
+
"# 🏢 CORP-ENV · Qwen 2.5-3B-Instruct — SFT + RLVR Training\n",
|
| 9 |
"\n",
|
| 10 |
+
"**End-to-end reproducible notebook** for training a Qwen 2.5-3B-Instruct agent on CORP-ENV using Supervised Fine-Tuning (SFT) followed by Rejection-Sampling RL on Verifiable Rewards (RLVR).\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"### ⚡ Optimized for Google Colab T4 (16 GB VRAM)\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"This notebook is configured to run end-to-end on a **free-tier T4 GPU**:\n",
|
| 15 |
+
"- 4-bit QLoRA quantization to fit 7B model in ~4 GB VRAM\n",
|
| 16 |
+
"- **FP16** precision (T4 lacks BF16 hardware support)\n",
|
| 17 |
+
"- Reduced sequence lengths (4096 tokens) and RLVR samples (4 per prompt)\n",
|
| 18 |
+
"- Inline visualizations after every training and evaluation step\n",
|
| 19 |
"\n",
|
| 20 |
"CORP-ENV is a multi-agent corporate decision environment where a Master Agent governs a **Shared Workspace Document (SWD)** across long-horizon planning episodes, coordinating frozen worker agents. Rewards measure SWD integrity, task completion, milestone adherence, reasoning density, and LLM-judge scores.\n",
|
| 21 |
"\n",
|
| 22 |
"| Component | Detail |\n",
|
| 23 |
"|---|---|\n",
|
| 24 |
+
"| **Base model** | `Qwen/Qwen2.5-3B-Instruct` |\n",
|
| 25 |
"| **SFT script** | `training/train_sft.py` |\n",
|
| 26 |
"| **RLVR script** | `training/train_rlvr.py` |\n",
|
| 27 |
"| **Tasks** | E1 Launch Readiness, M1 Budget Reallocation, H1 Acquisition Defence |\n",
|
| 28 |
+
"| **Runtime** | ✅ Google Colab T4 / Lightning AI H100 / Any CUDA GPU |\n",
|
| 29 |
"\n",
|
| 30 |
"---"
|
| 31 |
]
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "markdown",
|
| 35 |
+
"id": "15d441af",
|
| 36 |
"metadata": {},
|
| 37 |
"source": [
|
| 38 |
"## 1️⃣ Setup & Installation"
|
|
|
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
"execution_count": null,
|
| 44 |
+
"id": "e9394fab",
|
| 45 |
"metadata": {},
|
| 46 |
"outputs": [],
|
| 47 |
"source": [
|
| 48 |
"import os\n",
|
| 49 |
+
"import torch\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# ===== GPU Detection & Configuration =====\n",
|
| 52 |
+
"if torch.cuda.is_available():\n",
|
| 53 |
+
" gpu_name = torch.cuda.get_device_name(0)\n",
|
| 54 |
+
" gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 55 |
+
" has_bf16 = torch.cuda.is_bf16_supported()\n",
|
| 56 |
+
" print(f\"🖥️ GPU: {gpu_name} ({gpu_mem:.1f} GB)\")\n",
|
| 57 |
+
" print(f\" BF16 support: {'✅ Yes' if has_bf16 else '❌ No (using FP16)'}\")\n",
|
| 58 |
+
"else:\n",
|
| 59 |
+
" raise RuntimeError(\"❌ No GPU detected! Enable GPU in Colab: Runtime → Change runtime type → T4 GPU\")\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# Auto-detect hardware constraints\n",
|
| 62 |
+
"LOW_MEMORY = gpu_mem < 20.0 # e.g., T4 (16GB), RTX 4080 (16GB) need smaller batches/sequences\n",
|
| 63 |
+
"USE_FP16 = not has_bf16 # e.g., T4 and V100 dont support BF16\n",
|
| 64 |
"\n",
|
| 65 |
"# ===== Configuration =====\n",
|
| 66 |
"REPO_URL = \"https://huggingface.co/spaces/Navigam/corp-env\" # Change to your repo\n",
|
| 67 |
+
"BASE_MODEL = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
| 68 |
"HF_ORG_OR_USER = \"Navigam\" # Your HF username/org\n",
|
| 69 |
"\n",
|
| 70 |
+
"# SFT hyperparameters (T4-optimized)\n",
|
| 71 |
+
"SFT_MAX_STEPS = 30 # Quick judge smoke; set -1 for full-epoch training\n",
|
| 72 |
"SFT_EPOCHS = 2.0\n",
|
| 73 |
"SFT_LR = 2e-4\n",
|
| 74 |
"SFT_BATCH_SIZE = 1\n",
|
| 75 |
"SFT_GRAD_ACCUM = 8\n",
|
| 76 |
+
" \"SFT_MAX_SEQ_LEN = 3072 if LOW_MEMORY else 8192 # Reduced for <20GB VRAM\\n\",\n",
|
| 77 |
"\n",
|
| 78 |
+
"# RLVR hyperparameters (T4-optimized)\n",
|
| 79 |
"RLVR_ROUNDS = 3\n",
|
| 80 |
+
"RLVR_MAX_PROMPTS = 32 if LOW_MEMORY else 128 # Fewer prompts to fit in T4 time/memory\n",
|
| 81 |
+
" \"RLVR_N_SAMPLES = 4 if LOW_MEMORY else 8 # Fewer samples per prompt\\n\",\n",
|
| 82 |
"RLVR_TEMPERATURE = 0.7\n",
|
| 83 |
+
" \"RLVR_MAX_PROMPT_LEN = 3072 if LOW_MEMORY else 8192\\n\",\n",
|
| 84 |
+
"RLVR_MAX_COMPLETION_LEN = 512\n",
|
| 85 |
"\n",
|
| 86 |
"# Eval\n",
|
| 87 |
"EVAL_EPISODES = 3\n",
|
| 88 |
+
"EVAL_MAX_STEPS = 30\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# FP16 flag for training scripts\n",
|
| 91 |
+
"FP16_FLAG = \"--fp16\" if USE_FP16 else \"\"\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"print(f\"\\n📋 Config: model={BASE_MODEL}, fp16={USE_FP16}, seq_len={SFT_MAX_SEQ_LEN}\")\n",
|
| 94 |
+
"print(f\" RLVR: rounds={RLVR_ROUNDS}, prompts={RLVR_MAX_PROMPTS}, samples={RLVR_N_SAMPLES}\")"
|
| 95 |
]
|
| 96 |
},
|
| 97 |
{
|
| 98 |
"cell_type": "code",
|
| 99 |
"execution_count": null,
|
| 100 |
+
"id": "1fccadd9",
|
| 101 |
"metadata": {},
|
| 102 |
"outputs": [],
|
| 103 |
"source": [
|
| 104 |
+
"# ===== Install dependencies (Colab-optimized) =====\n",
|
| 105 |
+
"# Unsloth requires a specific install path for Colab\n",
|
| 106 |
+
"import subprocess, sys\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"# Check if running in Colab\n",
|
| 109 |
+
"IN_COLAB = 'google.colab' in sys.modules\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"if IN_COLAB:\n",
|
| 112 |
+
" print(\"🔧 Installing Unsloth for Colab...\")\n",
|
| 113 |
+
" !pip install -q --no-deps trl peft accelerate bitsandbytes triton\n",
|
| 114 |
+
" !pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 115 |
+
" !pip install -q --no-deps unsloth_zoo\n",
|
| 116 |
+
" !pip install -q xformers\n",
|
| 117 |
+
"else:\n",
|
| 118 |
+
" print(\"🔧 Installing from pyproject.toml...\")\n",
|
| 119 |
+
" !pip install -q -U pip\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# Clone and install CORP-ENV\n",
|
| 122 |
"!git clone {REPO_URL} corp_gym 2>/dev/null || echo 'Repo already cloned'\n",
|
| 123 |
"%cd corp_gym\n",
|
| 124 |
+
"!pip install -q -e \".[training,plots]\""
|
|
|
|
| 125 |
]
|
| 126 |
},
|
| 127 |
{
|
| 128 |
"cell_type": "markdown",
|
| 129 |
+
"id": "076d342b",
|
| 130 |
"metadata": {},
|
| 131 |
"source": [
|
| 132 |
"## 2️⃣ Hugging Face Login (optional)"
|
|
|
|
| 135 |
{
|
| 136 |
"cell_type": "code",
|
| 137 |
"execution_count": null,
|
| 138 |
+
"id": "df0904d7",
|
| 139 |
"metadata": {},
|
| 140 |
"outputs": [],
|
| 141 |
"source": [
|
|
|
|
| 145 |
},
|
| 146 |
{
|
| 147 |
"cell_type": "markdown",
|
| 148 |
+
"id": "7d4a001c",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"## 📊 Visualization Utilities\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"Helper functions for inline charts after every eval and training step."
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": null,
|
| 159 |
+
"id": "3930908e",
|
| 160 |
+
"metadata": {},
|
| 161 |
+
"outputs": [],
|
| 162 |
+
"source": [
|
| 163 |
+
"import json\n",
|
| 164 |
+
"import matplotlib.pyplot as plt\n",
|
| 165 |
+
"import matplotlib.ticker as mticker\n",
|
| 166 |
+
"import numpy as np\n",
|
| 167 |
+
"from pathlib import Path\n",
|
| 168 |
+
"from collections import defaultdict\n",
|
| 169 |
+
"from IPython.display import display, Markdown, HTML\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"# ---- Plotting style ----\n",
|
| 172 |
+
"plt.rcParams.update({\n",
|
| 173 |
+
" 'figure.facecolor': '#0d1117',\n",
|
| 174 |
+
" 'axes.facecolor': '#161b22',\n",
|
| 175 |
+
" 'axes.edgecolor': '#30363d',\n",
|
| 176 |
+
" 'axes.labelcolor': '#c9d1d9',\n",
|
| 177 |
+
" 'text.color': '#c9d1d9',\n",
|
| 178 |
+
" 'xtick.color': '#8b949e',\n",
|
| 179 |
+
" 'ytick.color': '#8b949e',\n",
|
| 180 |
+
" 'grid.color': '#21262d',\n",
|
| 181 |
+
" 'font.family': 'sans-serif',\n",
|
| 182 |
+
" 'font.size': 11,\n",
|
| 183 |
+
"})\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"PALETTE = {\n",
|
| 186 |
+
" 'baseline': '#8b949e',\n",
|
| 187 |
+
" 'oracle': '#a371f7',\n",
|
| 188 |
+
" 'sft': '#3fb950',\n",
|
| 189 |
+
" 'rlvr': '#f0883e',\n",
|
| 190 |
+
" 'e1_launch_readiness': '#58a6ff',\n",
|
| 191 |
+
" 'm1_budget_reallocation': '#d2a8ff',\n",
|
| 192 |
+
" 'h1_acquisition_defence': '#7ee787',\n",
|
| 193 |
+
"}\n",
|
| 194 |
+
"TASK_SHORT = {\n",
|
| 195 |
+
" 'e1_launch_readiness': 'E1 Launch',\n",
|
| 196 |
+
" 'm1_budget_reallocation': 'M1 Budget',\n",
|
| 197 |
+
" 'h1_acquisition_defence': 'H1 Acquisition',\n",
|
| 198 |
+
"}\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"def load_eval_jsonl(path):\n",
|
| 201 |
+
" \"\"\"Load evaluation JSONL file.\"\"\"\n",
|
| 202 |
+
" rows = []\n",
|
| 203 |
+
" p = Path(path)\n",
|
| 204 |
+
" if p.is_dir():\n",
|
| 205 |
+
" for f in sorted(p.rglob('*_eval.jsonl')):\n",
|
| 206 |
+
" rows.extend(load_eval_jsonl(f))\n",
|
| 207 |
+
" for f in sorted(p.rglob('eval.jsonl')):\n",
|
| 208 |
+
" rows.extend(load_eval_jsonl(f))\n",
|
| 209 |
+
" return rows\n",
|
| 210 |
+
" if p.exists():\n",
|
| 211 |
+
" for line in p.read_text(encoding='utf-8').strip().splitlines():\n",
|
| 212 |
+
" if line.strip():\n",
|
| 213 |
+
" rows.append(json.loads(line))\n",
|
| 214 |
+
" return rows\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"def plot_eval_dashboard(rows, title=\"Evaluation Results\"):\n",
|
| 217 |
+
" \"\"\"Create a 2x2 dashboard of evaluation metrics.\"\"\"\n",
|
| 218 |
+
" if not rows:\n",
|
| 219 |
+
" print(\"⚠️ No evaluation data to plot.\")\n",
|
| 220 |
+
" return\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" # Group by task\n",
|
| 223 |
+
" by_task = defaultdict(list)\n",
|
| 224 |
+
" for r in rows:\n",
|
| 225 |
+
" by_task[r['task_id']].append(r)\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" tasks = sorted(by_task.keys())\n",
|
| 228 |
+
" task_labels = [TASK_SHORT.get(t, t) for t in tasks]\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" # Compute metrics\n",
|
| 231 |
+
" avg_reward = [np.mean([r['terminal_reward'] for r in by_task[t]]) for t in tasks]\n",
|
| 232 |
+
" avg_pass = [np.mean([r['verifier_pass_rate'] for r in by_task[t]]) for t in tasks]\n",
|
| 233 |
+
" success_rate = [np.mean([1 if r.get('success') else 0 for r in by_task[t]]) for t in tasks]\n",
|
| 234 |
+
" avg_steps = [np.mean([r.get('steps', 0) for r in by_task[t]]) for t in tasks]\n",
|
| 235 |
+
"\n",
|
| 236 |
+
" fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
|
| 237 |
+
" fig.suptitle(title, fontsize=18, fontweight='bold', color='#f0f6fc', y=0.98)\n",
|
| 238 |
+
"\n",
|
| 239 |
+
" # -- Terminal Reward --\n",
|
| 240 |
+
" ax = axes[0, 0]\n",
|
| 241 |
+
" colors = [PALETTE.get(t, '#58a6ff') for t in tasks]\n",
|
| 242 |
+
" bars = ax.bar(task_labels, avg_reward, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
|
| 243 |
+
" for bar, val in zip(bars, avg_reward):\n",
|
| 244 |
+
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
|
| 245 |
+
" f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
|
| 246 |
+
" ax.set_title('Terminal Reward', fontsize=13, fontweight='bold')\n",
|
| 247 |
+
" ax.set_ylim(0, 1.15)\n",
|
| 248 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" # -- Verifier Pass Rate --\n",
|
| 251 |
+
" ax = axes[0, 1]\n",
|
| 252 |
+
" bars = ax.bar(task_labels, avg_pass, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
|
| 253 |
+
" for bar, val in zip(bars, avg_pass):\n",
|
| 254 |
+
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
|
| 255 |
+
" f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
|
| 256 |
+
" ax.set_title('Verifier Pass Rate', fontsize=13, fontweight='bold')\n",
|
| 257 |
+
" ax.set_ylim(0, 1.15)\n",
|
| 258 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" # -- Success Rate --\n",
|
| 261 |
+
" ax = axes[1, 0]\n",
|
| 262 |
+
" bars = ax.bar(task_labels, success_rate, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
|
| 263 |
+
" for bar, val in zip(bars, success_rate):\n",
|
| 264 |
+
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
|
| 265 |
+
" f'{val:.0%}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
|
| 266 |
+
" ax.set_title('Success Rate', fontsize=13, fontweight='bold')\n",
|
| 267 |
+
" ax.set_ylim(0, 1.25)\n",
|
| 268 |
+
" ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n",
|
| 269 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 270 |
+
"\n",
|
| 271 |
+
" # -- Avg Steps --\n",
|
| 272 |
+
" ax = axes[1, 1]\n",
|
| 273 |
+
" bars = ax.bar(task_labels, avg_steps, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
|
| 274 |
+
" for bar, val in zip(bars, avg_steps):\n",
|
| 275 |
+
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,\n",
|
| 276 |
+
" f'{val:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
|
| 277 |
+
" ax.set_title('Average Steps per Episode', fontsize=13, fontweight='bold')\n",
|
| 278 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" for ax in axes.flat:\n",
|
| 281 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 282 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
|
| 285 |
+
" plt.show()\n",
|
| 286 |
+
"\n",
|
| 287 |
+
" # Print summary table\n",
|
| 288 |
+
" display(Markdown(\"### 📋 Summary Table\"))\n",
|
| 289 |
+
" header = \"| Task | Terminal Reward | Verifier Pass | Success Rate | Avg Steps |\"\n",
|
| 290 |
+
" sep = \"|------|---------------|--------------|-------------|----------|\"\n",
|
| 291 |
+
" lines = [header, sep]\n",
|
| 292 |
+
" for i, t in enumerate(tasks):\n",
|
| 293 |
+
" lines.append(f\"| {TASK_SHORT.get(t, t)} | {avg_reward[i]:.3f} | {avg_pass[i]:.3f} | {success_rate[i]:.0%} | {avg_steps[i]:.1f} |\")\n",
|
| 294 |
+
" display(Markdown('\\n'.join(lines)))\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"def plot_reward_traces(rows, title=\"Reward Traces\"):\n",
|
| 298 |
+
" \"\"\"Plot reward curves over episode steps.\"\"\"\n",
|
| 299 |
+
" traces_by_task = defaultdict(list)\n",
|
| 300 |
+
" for r in rows:\n",
|
| 301 |
+
" trace = r.get('reward_trace', [])\n",
|
| 302 |
+
" if trace:\n",
|
| 303 |
+
" traces_by_task[r['task_id']].append([float(x) for x in trace])\n",
|
| 304 |
+
"\n",
|
| 305 |
+
" if not traces_by_task:\n",
|
| 306 |
+
" return\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" fig, ax = plt.subplots(figsize=(12, 5))\n",
|
| 309 |
+
" for task_id, traces in sorted(traces_by_task.items()):\n",
|
| 310 |
+
" max_len = max(len(t) for t in traces)\n",
|
| 311 |
+
" means = []\n",
|
| 312 |
+
" for idx in range(max_len):\n",
|
| 313 |
+
" vals = [t[idx] for t in traces if idx < len(t)]\n",
|
| 314 |
+
" means.append(np.mean(vals))\n",
|
| 315 |
+
" xs = range(1, max_len + 1)\n",
|
| 316 |
+
" color = PALETTE.get(task_id, '#58a6ff')\n",
|
| 317 |
+
" ax.plot(xs, means, marker='o', linewidth=2.2, markersize=4,\n",
|
| 318 |
+
" label=TASK_SHORT.get(task_id, task_id), color=color)\n",
|
| 319 |
+
" if len(traces) > 1:\n",
|
| 320 |
+
" mins = [min(t[i] for t in traces if i < len(t)) for i in range(max_len)]\n",
|
| 321 |
+
" maxs = [max(t[i] for t in traces if i < len(t)) for i in range(max_len)]\n",
|
| 322 |
+
" ax.fill_between(xs, mins, maxs, alpha=0.15, color=color)\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" ax.set_title(title, fontsize=15, fontweight='bold')\n",
|
| 325 |
+
" ax.set_xlabel('Environment Step')\n",
|
| 326 |
+
" ax.set_ylabel('Step Reward')\n",
|
| 327 |
+
" ax.axhline(0, color='#484f58', linewidth=0.8, alpha=0.5)\n",
|
| 328 |
+
" ax.legend(frameon=False, fontsize=10)\n",
|
| 329 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 330 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 331 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 332 |
+
" fig.tight_layout()\n",
|
| 333 |
+
" plt.show()\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"def plot_stage_comparison(all_evals, metric='terminal_reward', title='Model Stage Comparison'):\n",
|
| 337 |
+
" \"\"\"Compare multiple evaluation stages side-by-side.\"\"\"\n",
|
| 338 |
+
" if not all_evals:\n",
|
| 339 |
+
" return\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" stages = list(all_evals.keys())\n",
|
| 342 |
+
" all_tasks = sorted({r['task_id'] for rows in all_evals.values() for r in rows})\n",
|
| 343 |
+
" task_labels = [TASK_SHORT.get(t, t) for t in all_tasks]\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" x = np.arange(len(all_tasks))\n",
|
| 346 |
+
" width = 0.8 / max(len(stages), 1)\n",
|
| 347 |
+
"\n",
|
| 348 |
+
" fig, ax = plt.subplots(figsize=(max(10, len(all_tasks) * 3), 6))\n",
|
| 349 |
+
" for idx, stage in enumerate(stages):\n",
|
| 350 |
+
" rows = all_evals[stage]\n",
|
| 351 |
+
" by_task = defaultdict(list)\n",
|
| 352 |
+
" for r in rows:\n",
|
| 353 |
+
" by_task[r['task_id']].append(float(r.get(metric, 0)))\n",
|
| 354 |
+
" vals = [np.mean(by_task.get(t, [0])) for t in all_tasks]\n",
|
| 355 |
+
" offsets = x - 0.4 + width/2 + idx * width\n",
|
| 356 |
+
" color = PALETTE.get(stage, f'C{idx}')\n",
|
| 357 |
+
" bars = ax.bar(offsets, vals, width, label=stage.upper(), color=color,\n",
|
| 358 |
+
" edgecolor='#30363d', linewidth=0.8)\n",
|
| 359 |
+
" for bar, val in zip(bars, vals):\n",
|
| 360 |
+
" if val > 0:\n",
|
| 361 |
+
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.015,\n",
|
| 362 |
+
" f'{val:.2f}', ha='center', va='bottom', fontsize=9,\n",
|
| 363 |
+
" fontweight='bold', color='#f0f6fc')\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" ax.set_title(title, fontsize=16, fontweight='bold', color='#f0f6fc')\n",
|
| 366 |
+
" ax.set_xticks(x)\n",
|
| 367 |
+
" ax.set_xticklabels(task_labels)\n",
|
| 368 |
+
" ax.set_ylabel(metric.replace('_', ' ').title())\n",
|
| 369 |
+
" ax.set_ylim(0, 1.15)\n",
|
| 370 |
+
" ax.legend(frameon=False, fontsize=10, loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=len(stages))\n",
|
| 371 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 372 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 373 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 374 |
+
" fig.tight_layout()\n",
|
| 375 |
+
" plt.show()\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"def plot_rlvr_stats(stats_file):\n",
|
| 379 |
+
" \"\"\"Plot RLVR training stats per round.\"\"\"\n",
|
| 380 |
+
" p = Path(stats_file)\n",
|
| 381 |
+
" if not p.exists():\n",
|
| 382 |
+
" print(f\"⚠️ Stats file not found: {stats_file}\")\n",
|
| 383 |
+
" return\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" stats = [json.loads(line) for line in p.read_text().strip().splitlines() if line.strip()]\n",
|
| 386 |
+
" if not stats:\n",
|
| 387 |
+
" return\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" rounds = [s['round'] for s in stats]\n",
|
| 390 |
+
" keep_rates = [s['keep_rate'] for s in stats]\n",
|
| 391 |
+
" mean_best = [s['mean_best_reward'] for s in stats]\n",
|
| 392 |
+
" mean_any = [s['mean_sample_reward'] for s in stats]\n",
|
| 393 |
+
" kept_counts = [int(s['prompts_kept']) for s in stats]\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
| 396 |
+
" fig.suptitle('RLVR Training Progress', fontsize=16, fontweight='bold', color='#f0f6fc', y=1.02)\n",
|
| 397 |
+
"\n",
|
| 398 |
+
" # Keep rate\n",
|
| 399 |
+
" ax = axes[0]\n",
|
| 400 |
+
" ax.plot(rounds, keep_rates, marker='o', linewidth=2.5, color='#3fb950', markersize=8)\n",
|
| 401 |
+
" ax.fill_between(rounds, keep_rates, alpha=0.15, color='#3fb950')\n",
|
| 402 |
+
" ax.set_title('Keep Rate per Round', fontweight='bold')\n",
|
| 403 |
+
" ax.set_xlabel('Round')\n",
|
| 404 |
+
" ax.set_ylabel('Keep Rate')\n",
|
| 405 |
+
" ax.set_ylim(0, 1.05)\n",
|
| 406 |
+
" ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n",
|
| 407 |
+
" ax.grid(alpha=0.3)\n",
|
| 408 |
+
"\n",
|
| 409 |
+
" # Reward progression\n",
|
| 410 |
+
" ax = axes[1]\n",
|
| 411 |
+
" ax.plot(rounds, mean_best, marker='s', linewidth=2.5, color='#f0883e', markersize=8, label='Best')\n",
|
| 412 |
+
" ax.plot(rounds, mean_any, marker='D', linewidth=2.5, color='#58a6ff', markersize=7, label='Any sample')\n",
|
| 413 |
+
" ax.set_title('Mean Reward per Round', fontweight='bold')\n",
|
| 414 |
+
" ax.set_xlabel('Round')\n",
|
| 415 |
+
" ax.set_ylabel('Reward')\n",
|
| 416 |
+
" ax.legend(frameon=False)\n",
|
| 417 |
+
" ax.grid(alpha=0.3)\n",
|
| 418 |
+
"\n",
|
| 419 |
+
" # Prompts kept\n",
|
| 420 |
+
" ax = axes[2]\n",
|
| 421 |
+
" ax.bar(rounds, kept_counts, color='#a371f7', edgecolor='#30363d', linewidth=0.8)\n",
|
| 422 |
+
" for r, c in zip(rounds, kept_counts):\n",
|
| 423 |
+
" ax.text(r, c + 0.5, str(c), ha='center', fontweight='bold', fontsize=11, color='#f0f6fc')\n",
|
| 424 |
+
" ax.set_title('Prompts Kept (Winners)', fontweight='bold')\n",
|
| 425 |
+
" ax.set_xlabel('Round')\n",
|
| 426 |
+
" ax.set_ylabel('Count')\n",
|
| 427 |
+
" ax.grid(axis='y', alpha=0.3)\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" for ax in axes:\n",
|
| 430 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 431 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 432 |
+
" ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" fig.tight_layout()\n",
|
| 435 |
+
" plt.show()\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" # Print per-round summary\n",
|
| 438 |
+
" display(Markdown(\"### 📋 RLVR Round Summary\"))\n",
|
| 439 |
+
" header = \"| Round | Keep Rate | Mean Best Reward | Mean Any Reward | Prompts Kept | Time (s) |\"\n",
|
| 440 |
+
" sep = \"|-------|-----------|-----------------|----------------|-------------|----------|\"\n",
|
| 441 |
+
" lines = [header, sep]\n",
|
| 442 |
+
" for s in stats:\n",
|
| 443 |
+
" lines.append(f\"| {s['round']} | {s['keep_rate']:.1%} | {s['mean_best_reward']:.3f} | {s['mean_sample_reward']:.3f} | {int(s['prompts_kept'])} | {s['seconds']:.0f} |\")\n",
|
| 444 |
+
" display(Markdown('\\n'.join(lines)))\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"def gpu_status():\n",
|
| 448 |
+
" \"\"\"Print current GPU memory usage.\"\"\"\n",
|
| 449 |
+
" if torch.cuda.is_available():\n",
|
| 450 |
+
" alloc = torch.cuda.memory_allocated() / 1e9\n",
|
| 451 |
+
" cached = torch.cuda.memory_reserved() / 1e9\n",
|
| 452 |
+
" total = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 453 |
+
" pct = alloc / total * 100\n",
|
| 454 |
+
" bar_len, filled = 20, int(pct / 5)\n",
|
| 455 |
+
" bar = '█' * filled + '░' * (bar_len - filled)\n",
|
| 456 |
+
" print(f\"🖥️ GPU Memory: [{bar}] {alloc:.1f}/{total:.1f} GB ({pct:.0f}%) | Cached: {cached:.1f} GB\")\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"# Collect all eval results for final comparison\n",
|
| 460 |
+
"ALL_EVALS = {}\n",
|
| 461 |
+
"print(\"✅ Visualization utilities loaded.\")"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"cell_type": "markdown",
|
| 466 |
+
"id": "43b92bf5",
|
| 467 |
"metadata": {},
|
| 468 |
"source": [
|
| 469 |
"## 3️⃣ Environment Validation\n",
|
|
|
|
| 474 |
{
|
| 475 |
"cell_type": "code",
|
| 476 |
"execution_count": null,
|
| 477 |
+
"id": "71cfe355",
|
| 478 |
"metadata": {},
|
| 479 |
"outputs": [],
|
| 480 |
"source": [
|
|
|
|
| 484 |
},
|
| 485 |
{
|
| 486 |
"cell_type": "markdown",
|
| 487 |
+
"id": "0275c763",
|
| 488 |
"metadata": {},
|
| 489 |
"source": [
|
| 490 |
"## 4️⃣ Data Preparation\n",
|
|
|
|
| 495 |
{
|
| 496 |
"cell_type": "code",
|
| 497 |
"execution_count": null,
|
| 498 |
+
"id": "85901039",
|
| 499 |
"metadata": {},
|
| 500 |
"outputs": [],
|
| 501 |
"source": [
|
|
|
|
| 524 |
{
|
| 525 |
"cell_type": "code",
|
| 526 |
"execution_count": null,
|
| 527 |
+
"id": "eb6a7997",
|
| 528 |
"metadata": {},
|
| 529 |
"outputs": [],
|
| 530 |
"source": [
|
| 531 |
+
"# Check data stats & visualize\n",
|
|
|
|
|
|
|
|
|
|
| 532 |
"sft_path = Path(\"data/sft/e1_m1_h1_examples.jsonl\")\n",
|
| 533 |
"if sft_path.exists():\n",
|
| 534 |
" lines = [json.loads(l) for l in sft_path.read_text().strip().splitlines() if l.strip()]\n",
|
| 535 |
" print(f\"\\n✅ SFT dataset: {len(lines)} examples\")\n",
|
|
|
|
| 536 |
" turn_counts = [len(ex['messages']) for ex in lines]\n",
|
| 537 |
" print(f\" Avg turns per example: {sum(turn_counts)/len(turn_counts):.1f}\")\n",
|
| 538 |
" print(f\" Min/Max turns: {min(turn_counts)} / {max(turn_counts)}\")\n",
|
| 539 |
+
"\n",
|
| 540 |
+
" # Visualize data distribution\n",
|
| 541 |
+
" fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
|
| 542 |
+
" fig.suptitle('SFT Dataset Overview', fontsize=14, fontweight='bold', color='#f0f6fc')\n",
|
| 543 |
+
"\n",
|
| 544 |
+
" # Turns histogram\n",
|
| 545 |
+
" axes[0].hist(turn_counts, bins=range(min(turn_counts), max(turn_counts)+2),\n",
|
| 546 |
+
" color='#58a6ff', edgecolor='#30363d', alpha=0.85)\n",
|
| 547 |
+
" axes[0].set_title('Message Turns per Example', fontweight='bold')\n",
|
| 548 |
+
" axes[0].set_xlabel('Number of Turns')\n",
|
| 549 |
+
" axes[0].set_ylabel('Count')\n",
|
| 550 |
+
" axes[0].grid(axis='y', alpha=0.3)\n",
|
| 551 |
+
"\n",
|
| 552 |
+
" # Role distribution\n",
|
| 553 |
+
" role_counts = defaultdict(int)\n",
|
| 554 |
+
" for ex in lines:\n",
|
| 555 |
+
" for msg in ex['messages']:\n",
|
| 556 |
+
" role_counts[msg['role']] += 1\n",
|
| 557 |
+
" roles = list(role_counts.keys())\n",
|
| 558 |
+
" counts = list(role_counts.values())\n",
|
| 559 |
+
" role_colors = ['#a371f7', '#3fb950', '#58a6ff', '#f0883e'][:len(roles)]\n",
|
| 560 |
+
" axes[1].barh(roles, counts, color=role_colors, edgecolor='#30363d')\n",
|
| 561 |
+
" axes[1].set_title('Messages by Role', fontweight='bold')\n",
|
| 562 |
+
" axes[1].set_xlabel('Count')\n",
|
| 563 |
+
" axes[1].grid(axis='x', alpha=0.3)\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" for ax in axes:\n",
|
| 566 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 567 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 568 |
+
" fig.tight_layout()\n",
|
| 569 |
+
" plt.show()\n",
|
| 570 |
"else:\n",
|
| 571 |
" print(\"❌ SFT dataset not found. Check data preparation above.\")"
|
| 572 |
]
|
| 573 |
},
|
| 574 |
{
|
| 575 |
"cell_type": "markdown",
|
| 576 |
+
"id": "8c529b78",
|
| 577 |
"metadata": {},
|
| 578 |
"source": [
|
| 579 |
"## 5️⃣ Baseline Evaluation\n",
|
|
|
|
| 584 |
{
|
| 585 |
"cell_type": "code",
|
| 586 |
"execution_count": null,
|
| 587 |
+
"id": "9c5db0c1",
|
| 588 |
"metadata": {},
|
| 589 |
"outputs": [],
|
| 590 |
"source": [
|
|
|
|
| 592 |
"!python eval.py --policy oracle --label oracle --episodes {EVAL_EPISODES}"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "code",
|
| 597 |
+
"execution_count": null,
|
| 598 |
+
"id": "f106aaed",
|
| 599 |
+
"metadata": {},
|
| 600 |
+
"outputs": [],
|
| 601 |
+
"source": [
|
| 602 |
+
"# 📊 Visualize baseline results\n",
|
| 603 |
+
"display(Markdown(\"## 📊 Baseline Results\"))\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"baseline_rows = load_eval_jsonl(\"results/runs\")\n",
|
| 606 |
+
"baseline_only = [r for r in baseline_rows if r.get('model_stage') == 'baseline']\n",
|
| 607 |
+
"oracle_only = [r for r in baseline_rows if r.get('model_stage') == 'oracle']\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"if baseline_only:\n",
|
| 610 |
+
" display(Markdown(\"### 🔹 Scripted Weak Baseline\"))\n",
|
| 611 |
+
" plot_eval_dashboard(baseline_only, title=\"Scripted Weak Baseline\")\n",
|
| 612 |
+
" plot_reward_traces(baseline_only, title=\"Baseline Reward Traces\")\n",
|
| 613 |
+
" ALL_EVALS['baseline'] = baseline_only\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"if oracle_only:\n",
|
| 616 |
+
" display(Markdown(\"### 🔹 Oracle Policy\"))\n",
|
| 617 |
+
" plot_eval_dashboard(oracle_only, title=\"Oracle Policy\")\n",
|
| 618 |
+
" plot_reward_traces(oracle_only, title=\"Oracle Reward Traces\")\n",
|
| 619 |
+
" ALL_EVALS['oracle'] = oracle_only\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"# Side-by-side comparison if both exist\n",
|
| 622 |
+
"if baseline_only and oracle_only:\n",
|
| 623 |
+
" plot_stage_comparison(\n",
|
| 624 |
+
" {'baseline': baseline_only, 'oracle': oracle_only},\n",
|
| 625 |
+
" metric='terminal_reward',\n",
|
| 626 |
+
" title='Baseline vs Oracle — Terminal Reward'\n",
|
| 627 |
+
" )\n",
|
| 628 |
+
"gpu_status()"
|
| 629 |
+
]
|
| 630 |
+
},
|
| 631 |
{
|
| 632 |
"cell_type": "markdown",
|
| 633 |
+
"id": "3011f739",
|
| 634 |
"metadata": {},
|
| 635 |
"source": [
|
| 636 |
"## 6️⃣ SFT Training (Unsloth + TRL)\n",
|
| 637 |
"\n",
|
| 638 |
+
"Fine-tune Qwen 2.5-3B-Instruct with LoRA using verified CORP-ENV trajectories.\n",
|
| 639 |
"\n",
|
| 640 |
"- Uses `unsloth.FastLanguageModel` for 4-bit QLoRA\n",
|
| 641 |
"- Uses `trl.SFTTrainer` with messages-format conversational SFT\n",
|
| 642 |
+
"- LoRA `r=32`, targets all attention + MLP projections\n",
|
| 643 |
+
"- **FP16 on T4** (auto-detected), BF16 on Ampere+ GPUs"
|
| 644 |
]
|
| 645 |
},
|
| 646 |
{
|
| 647 |
"cell_type": "code",
|
| 648 |
"execution_count": null,
|
| 649 |
+
"id": "cb76d631",
|
| 650 |
"metadata": {},
|
| 651 |
"outputs": [],
|
| 652 |
"source": [
|
| 653 |
+
"gpu_status()\n",
|
| 654 |
+
"print(f\"\\n🚀 Starting SFT training ({FP16_FLAG or 'bf16'} precision)...\\n\")\n",
|
| 655 |
+
"\n",
|
| 656 |
"!python training/train_sft.py \\\n",
|
| 657 |
" --model {BASE_MODEL} \\\n",
|
| 658 |
" --data data/sft/e1_m1_h1_examples.jsonl \\\n",
|
| 659 |
" --output outputs/sft_adapter \\\n",
|
| 660 |
+
" --max-seq-length {SFT_MAX_SEQ_LEN} \\\n",
|
| 661 |
" --max-steps {SFT_MAX_STEPS} \\\n",
|
| 662 |
" --epochs {SFT_EPOCHS} \\\n",
|
| 663 |
" --lr {SFT_LR} \\\n",
|
| 664 |
" --batch-size {SFT_BATCH_SIZE} \\\n",
|
| 665 |
" --grad-accum {SFT_GRAD_ACCUM} \\\n",
|
| 666 |
+
" {FP16_FLAG}\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"gpu_status()\n",
|
| 669 |
+
"print(\"\\n✅ SFT training complete!\")"
|
| 670 |
+
]
|
| 671 |
+
},
|
| 672 |
+
{
|
| 673 |
+
"cell_type": "code",
|
| 674 |
+
"execution_count": null,
|
| 675 |
+
"id": "3755df03",
|
| 676 |
+
"metadata": {},
|
| 677 |
+
"outputs": [],
|
| 678 |
+
"source": [
|
| 679 |
+
"# 📊 Visualize SFT training logs\n",
|
| 680 |
+
"display(Markdown(\"## 📊 SFT Training Summary\"))\n",
|
| 681 |
+
"\n",
|
| 682 |
+
"# Check for trainer_state.json\n",
|
| 683 |
+
"state_file = Path(\"outputs/sft_adapter/trainer_state.json\")\n",
|
| 684 |
+
"if state_file.exists():\n",
|
| 685 |
+
" state = json.loads(state_file.read_text())\n",
|
| 686 |
+
" log_history = state.get('log_history', [])\n",
|
| 687 |
+
" if log_history:\n",
|
| 688 |
+
" steps = [l['step'] for l in log_history if 'loss' in l]\n",
|
| 689 |
+
" losses = [l['loss'] for l in log_history if 'loss' in l]\n",
|
| 690 |
+
" lrs = [l.get('learning_rate', 0) for l in log_history if 'loss' in l]\n",
|
| 691 |
+
"\n",
|
| 692 |
+
" fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 693 |
+
" fig.suptitle('SFT Training Curves', fontsize=16, fontweight='bold', color='#f0f6fc')\n",
|
| 694 |
+
"\n",
|
| 695 |
+
" # Loss curve\n",
|
| 696 |
+
" axes[0].plot(steps, losses, linewidth=2.5, color='#f0883e', marker='o', markersize=5)\n",
|
| 697 |
+
" axes[0].set_title('Training Loss', fontweight='bold')\n",
|
| 698 |
+
" axes[0].set_xlabel('Step')\n",
|
| 699 |
+
" axes[0].set_ylabel('Loss')\n",
|
| 700 |
+
" axes[0].grid(alpha=0.3)\n",
|
| 701 |
+
"\n",
|
| 702 |
+
" # Learning rate schedule\n",
|
| 703 |
+
" axes[1].plot(steps, lrs, linewidth=2.5, color='#3fb950', marker='s', markersize=4)\n",
|
| 704 |
+
" axes[1].set_title('Learning Rate Schedule', fontweight='bold')\n",
|
| 705 |
+
" axes[1].set_xlabel('Step')\n",
|
| 706 |
+
" axes[1].set_ylabel('Learning Rate')\n",
|
| 707 |
+
" axes[1].ticklabel_format(axis='y', style='scientific', scilimits=(-4, -4))\n",
|
| 708 |
+
" axes[1].grid(alpha=0.3)\n",
|
| 709 |
+
"\n",
|
| 710 |
+
" for ax in axes:\n",
|
| 711 |
+
" ax.spines['top'].set_visible(False)\n",
|
| 712 |
+
" ax.spines['right'].set_visible(False)\n",
|
| 713 |
+
" fig.tight_layout()\n",
|
| 714 |
+
" plt.show()\n",
|
| 715 |
+
"\n",
|
| 716 |
+
" print(f\"\\n📈 Final loss: {losses[-1]:.4f} at step {steps[-1]}\")\n",
|
| 717 |
+
"else:\n",
|
| 718 |
+
" print(\"⚠️ No trainer_state.json found; training logs unavailable.\")\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"# Check adapter files\n",
|
| 721 |
+
"adapter_dir = Path(\"outputs/sft_adapter\")\n",
|
| 722 |
+
"if adapter_dir.exists():\n",
|
| 723 |
+
" files = list(adapter_dir.glob(\"*\"))\n",
|
| 724 |
+
" total_mb = sum(f.stat().st_size for f in files if f.is_file()) / 1e6\n",
|
| 725 |
+
" print(f\"💾 Adapter saved: {len(files)} files, {total_mb:.1f} MB total\")"
|
| 726 |
]
|
| 727 |
},
|
| 728 |
{
|
| 729 |
"cell_type": "markdown",
|
| 730 |
+
"id": "cd078c28",
|
| 731 |
"metadata": {},
|
| 732 |
"source": [
|
| 733 |
"## 7️⃣ Evaluate SFT Adapter"
|
|
|
|
| 736 |
{
|
| 737 |
"cell_type": "code",
|
| 738 |
"execution_count": null,
|
| 739 |
+
"id": "50594aef",
|
| 740 |
"metadata": {},
|
| 741 |
"outputs": [],
|
| 742 |
"source": [
|
| 743 |
+
"# Clear GPU memory before loading eval model\n",
|
| 744 |
+
"import gc\n",
|
| 745 |
+
"gc.collect()\n",
|
| 746 |
+
"torch.cuda.empty_cache()\n",
|
| 747 |
+
"gpu_status()\n",
|
| 748 |
+
"\n",
|
| 749 |
"!python eval.py \\\n",
|
| 750 |
" --policy hf \\\n",
|
| 751 |
" --label sft \\\n",
|
| 752 |
" --model {BASE_MODEL} \\\n",
|
| 753 |
" --adapter outputs/sft_adapter \\\n",
|
| 754 |
" --episodes {EVAL_EPISODES} \\\n",
|
| 755 |
+
" --max-steps {EVAL_MAX_STEPS}\n",
|
| 756 |
+
"\n",
|
| 757 |
+
"gpu_status()"
|
| 758 |
+
]
|
| 759 |
+
},
|
| 760 |
+
{
|
| 761 |
+
"cell_type": "code",
|
| 762 |
+
"execution_count": null,
|
| 763 |
+
"id": "37bc9dd8",
|
| 764 |
+
"metadata": {},
|
| 765 |
+
"outputs": [],
|
| 766 |
+
"source": [
|
| 767 |
+
"# 📊 Visualize SFT evaluation results\n",
|
| 768 |
+
"display(Markdown(\"## 📊 SFT Evaluation Results\"))\n",
|
| 769 |
+
"\n",
|
| 770 |
+
"sft_rows = [r for r in load_eval_jsonl(\"results/runs\") if r.get('model_stage') == 'sft']\n",
|
| 771 |
+
"if sft_rows:\n",
|
| 772 |
+
" plot_eval_dashboard(sft_rows, title=\"SFT Adapter Evaluation\")\n",
|
| 773 |
+
" plot_reward_traces(sft_rows, title=\"SFT Reward Traces\")\n",
|
| 774 |
+
" ALL_EVALS['sft'] = sft_rows\n",
|
| 775 |
+
"\n",
|
| 776 |
+
" # Compare baseline → SFT\n",
|
| 777 |
+
" display(Markdown(\"### 📈 Improvement: Baseline → SFT\"))\n",
|
| 778 |
+
" comparison = {k: v for k, v in ALL_EVALS.items() if k in ('baseline', 'oracle', 'sft')}\n",
|
| 779 |
+
" if len(comparison) > 1:\n",
|
| 780 |
+
" plot_stage_comparison(comparison, metric='terminal_reward',\n",
|
| 781 |
+
" title='Baseline → SFT — Terminal Reward Comparison')\n",
|
| 782 |
+
" plot_stage_comparison(comparison, metric='verifier_pass_rate',\n",
|
| 783 |
+
" title='Baseline → SFT — Verifier Pass Rate')\n",
|
| 784 |
+
"else:\n",
|
| 785 |
+
" print(\"⚠️ No SFT eval results found.\")"
|
| 786 |
]
|
| 787 |
},
|
| 788 |
{
|
| 789 |
"cell_type": "markdown",
|
| 790 |
+
"id": "d9671fe6",
|
| 791 |
"metadata": {},
|
| 792 |
"source": [
|
| 793 |
"## 8️⃣ RLVR Training (Rejection-Sampling FT)\n",
|
|
|
|
| 799 |
"4. SFT on that curated set\n",
|
| 800 |
"5. Repeating for multiple outer rounds\n",
|
| 801 |
"\n",
|
| 802 |
+
"This avoids the zero-variance gradient problem seen with GRPO on CORP-ENV.\n",
|
| 803 |
+
"\n",
|
| 804 |
+
"> ⚡ **T4 Note**: Using `--fp16` and reduced `--n-samples` / `--max-prompts` to fit in 16 GB VRAM."
|
| 805 |
]
|
| 806 |
},
|
| 807 |
{
|
| 808 |
"cell_type": "code",
|
| 809 |
"execution_count": null,
|
| 810 |
+
"id": "5be0f8be",
|
| 811 |
"metadata": {},
|
| 812 |
"outputs": [],
|
| 813 |
"source": [
|
| 814 |
+
"# Clear GPU memory\n",
|
| 815 |
+
"gc.collect()\n",
|
| 816 |
+
"torch.cuda.empty_cache()\n",
|
| 817 |
+
"gpu_status()\n",
|
| 818 |
+
"\n",
|
| 819 |
+
"print(f\"\\n🚀 Starting RLVR training ({RLVR_ROUNDS} rounds, {RLVR_N_SAMPLES} samples/prompt)...\\n\")\n",
|
| 820 |
+
"\n",
|
| 821 |
"!python training/train_rlvr.py \\\n",
|
| 822 |
" --model {BASE_MODEL} \\\n",
|
| 823 |
" --adapter outputs/sft_adapter \\\n",
|
|
|
|
| 827 |
" --n-samples {RLVR_N_SAMPLES} \\\n",
|
| 828 |
" --temperature {RLVR_TEMPERATURE} \\\n",
|
| 829 |
" --max-prompts {RLVR_MAX_PROMPTS} \\\n",
|
| 830 |
+
" --max-prompt-length {RLVR_MAX_PROMPT_LEN} \\\n",
|
| 831 |
+
" --max-completion-length {RLVR_MAX_COMPLETION_LEN} \\\n",
|
| 832 |
" --strict-json \\\n",
|
| 833 |
" --use-stub-workers \\\n",
|
| 834 |
" --disable-llm-judge \\\n",
|
| 835 |
+
" --stats-file results/runs/rlvr_qwen2.5_3b_stats.jsonl \\\n",
|
| 836 |
+
" {FP16_FLAG}\n",
|
| 837 |
+
"\n",
|
| 838 |
+
"gpu_status()\n",
|
| 839 |
+
"print(\"\\n✅ RLVR training complete!\")"
|
| 840 |
+
]
|
| 841 |
+
},
|
| 842 |
+
{
|
| 843 |
+
"cell_type": "code",
|
| 844 |
+
"execution_count": null,
|
| 845 |
+
"id": "f71e3401",
|
| 846 |
+
"metadata": {},
|
| 847 |
+
"outputs": [],
|
| 848 |
+
"source": [
|
| 849 |
+
"# 📊 Visualize RLVR training progress\n",
|
| 850 |
+
"display(Markdown(\"## 📊 RLVR Training Progress\"))\n",
|
| 851 |
+
"plot_rlvr_stats(\"results/runs/rlvr_qwen2.5_3b_stats.jsonl\")\n",
|
| 852 |
+
"\n",
|
| 853 |
+
"# Check adapter files\n",
|
| 854 |
+
"rlvr_dir = Path(\"outputs/rlvr_adapter\")\n",
|
| 855 |
+
"if rlvr_dir.exists():\n",
|
| 856 |
+
" files = list(rlvr_dir.glob(\"*\"))\n",
|
| 857 |
+
" total_mb = sum(f.stat().st_size for f in files if f.is_file()) / 1e6\n",
|
| 858 |
+
" print(f\"\\n💾 RLVR adapter saved: {len(files)} files, {total_mb:.1f} MB total\")"
|
| 859 |
]
|
| 860 |
},
|
| 861 |
{
|
| 862 |
"cell_type": "markdown",
|
| 863 |
+
"id": "32503cf5",
|
| 864 |
"metadata": {},
|
| 865 |
"source": [
|
| 866 |
"## 9️⃣ Evaluate RLVR Adapter"
|
|
|
|
| 869 |
{
|
| 870 |
"cell_type": "code",
|
| 871 |
"execution_count": null,
|
| 872 |
+
"id": "a756f408",
|
| 873 |
"metadata": {},
|
| 874 |
"outputs": [],
|
| 875 |
"source": [
|
| 876 |
+
"# Clear GPU memory\n",
|
| 877 |
+
"gc.collect()\n",
|
| 878 |
+
"torch.cuda.empty_cache()\n",
|
| 879 |
+
"gpu_status()\n",
|
| 880 |
+
"\n",
|
| 881 |
"!python eval.py \\\n",
|
| 882 |
" --policy hf \\\n",
|
| 883 |
" --label rlvr \\\n",
|
| 884 |
" --model {BASE_MODEL} \\\n",
|
| 885 |
" --adapter outputs/rlvr_adapter \\\n",
|
| 886 |
" --episodes {EVAL_EPISODES} \\\n",
|
| 887 |
+
" --max-steps {EVAL_MAX_STEPS}\n",
|
| 888 |
+
"\n",
|
| 889 |
+
"gpu_status()"
|
| 890 |
+
]
|
| 891 |
+
},
|
| 892 |
+
{
|
| 893 |
+
"cell_type": "code",
|
| 894 |
+
"execution_count": null,
|
| 895 |
+
"id": "daf5526c",
|
| 896 |
+
"metadata": {},
|
| 897 |
+
"outputs": [],
|
| 898 |
+
"source": [
|
| 899 |
+
"# 📊 Visualize RLVR evaluation results\n",
|
| 900 |
+
"display(Markdown(\"## 📊 RLVR Evaluation Results\"))\n",
|
| 901 |
+
"\n",
|
| 902 |
+
"rlvr_rows = [r for r in load_eval_jsonl(\"results/runs\") if r.get('model_stage') == 'rlvr']\n",
|
| 903 |
+
"if rlvr_rows:\n",
|
| 904 |
+
" plot_eval_dashboard(rlvr_rows, title=\"RLVR Adapter Evaluation\")\n",
|
| 905 |
+
" plot_reward_traces(rlvr_rows, title=\"RLVR Reward Traces\")\n",
|
| 906 |
+
" ALL_EVALS['rlvr'] = rlvr_rows\n",
|
| 907 |
+
"\n",
|
| 908 |
+
" # Compare SFT → RLVR\n",
|
| 909 |
+
" display(Markdown(\"### 📈 Improvement: SFT → RLVR\"))\n",
|
| 910 |
+
" if 'sft' in ALL_EVALS:\n",
|
| 911 |
+
" plot_stage_comparison(\n",
|
| 912 |
+
" {'sft': ALL_EVALS['sft'], 'rlvr': rlvr_rows},\n",
|
| 913 |
+
" metric='terminal_reward',\n",
|
| 914 |
+
" title='SFT → RLVR — Terminal Reward'\n",
|
| 915 |
+
" )\n",
|
| 916 |
+
"else:\n",
|
| 917 |
+
" print(\"⚠️ No RLVR eval results found.\")"
|
| 918 |
]
|
| 919 |
},
|
| 920 |
{
|
| 921 |
"cell_type": "markdown",
|
| 922 |
+
"id": "e96ce765",
|
| 923 |
+
"metadata": {},
|
| 924 |
+
"source": [
|
| 925 |
+
"## 📊 Final Comparison: All Model Stages\n",
|
| 926 |
+
"\n",
|
| 927 |
+
"Side-by-side comparison of all evaluated model stages."
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"cell_type": "code",
|
| 932 |
+
"execution_count": null,
|
| 933 |
+
"id": "d92ae920",
|
| 934 |
"metadata": {},
|
| 935 |
+
"outputs": [],
|
| 936 |
"source": [
|
| 937 |
+
"display(Markdown(\"## 📊 Full Pipeline Comparison: Baseline → Oracle → SFT → RLVR\"))\n",
|
| 938 |
+
"\n",
|
| 939 |
+
"if ALL_EVALS:\n",
|
| 940 |
+
" # Terminal Reward comparison\n",
|
| 941 |
+
" plot_stage_comparison(ALL_EVALS, metric='terminal_reward',\n",
|
| 942 |
+
" title='Terminal Reward — All Model Stages')\n",
|
| 943 |
"\n",
|
| 944 |
+
" # Verifier Pass Rate comparison\n",
|
| 945 |
+
" plot_stage_comparison(ALL_EVALS, metric='verifier_pass_rate',\n",
|
| 946 |
+
" title='Verifier Pass Rate — All Model Stages')\n",
|
| 947 |
+
"\n",
|
| 948 |
+
" # Build final comparison table\n",
|
| 949 |
+
" display(Markdown(\"### 📋 Final Results Table\"))\n",
|
| 950 |
+
" header = \"| Stage | Task | Terminal Reward | Verifier Pass | Success Rate |\"\n",
|
| 951 |
+
" sep = \"|-------|------|---------------|--------------|-------------|\"\n",
|
| 952 |
+
" lines = [header, sep]\n",
|
| 953 |
+
" for stage_name, stage_rows in ALL_EVALS.items():\n",
|
| 954 |
+
" by_task = defaultdict(list)\n",
|
| 955 |
+
" for r in stage_rows:\n",
|
| 956 |
+
" by_task[r['task_id']].append(r)\n",
|
| 957 |
+
" for task_id in sorted(by_task.keys()):\n",
|
| 958 |
+
" task_rows = by_task[task_id]\n",
|
| 959 |
+
" avg_r = np.mean([r['terminal_reward'] for r in task_rows])\n",
|
| 960 |
+
" avg_p = np.mean([r['verifier_pass_rate'] for r in task_rows])\n",
|
| 961 |
+
" succ = np.mean([1 if r.get('success') else 0 for r in task_rows])\n",
|
| 962 |
+
" lines.append(f\"| {stage_name.upper()} | {TASK_SHORT.get(task_id, task_id)} | {avg_r:.3f} | {avg_p:.3f} | {succ:.0%} |\")\n",
|
| 963 |
+
" display(Markdown('\\n'.join(lines)))\n",
|
| 964 |
+
"else:\n",
|
| 965 |
+
" print(\"⚠️ No evaluation data collected. Run the evaluation cells above.\")"
|
| 966 |
]
|
| 967 |
},
|
| 968 |
{
|
| 969 |
"cell_type": "code",
|
| 970 |
"execution_count": null,
|
| 971 |
+
"id": "b37b7da9",
|
| 972 |
"metadata": {},
|
| 973 |
"outputs": [],
|
| 974 |
"source": [
|
| 975 |
+
"# Also generate plots via plot_results.py for file-based output\n",
|
| 976 |
"!python plot_results.py \\\n",
|
| 977 |
" --inputs results/runs \\\n",
|
| 978 |
+
" --output-dir results/model_compare_qwen25_3b"
|
| 979 |
]
|
| 980 |
},
|
| 981 |
{
|
| 982 |
"cell_type": "code",
|
| 983 |
"execution_count": null,
|
| 984 |
+
"id": "3313ec66",
|
| 985 |
"metadata": {},
|
| 986 |
"outputs": [],
|
| 987 |
"source": [
|
| 988 |
"from IPython.display import Image, display, Markdown\n",
|
|
|
|
| 989 |
"\n",
|
| 990 |
+
"plot_dir = Path(\"results/model_compare_qwen25_3b\")\n",
|
| 991 |
"if not plot_dir.exists():\n",
|
| 992 |
" plot_dir = Path(\"results/model_compare_qwen25_fresh_no_grpo_ep5rlvr\")\n",
|
| 993 |
"\n",
|
| 994 |
+
"if plot_dir.exists():\n",
|
| 995 |
+
" for png in sorted(plot_dir.glob(\"*.png\")):\n",
|
| 996 |
+
" display(Markdown(f\"### {png.stem.replace('_', ' ').title()}\"))\n",
|
| 997 |
+
" display(Image(filename=str(png), width=800))\n",
|
| 998 |
"\n",
|
| 999 |
+
" # Show summary table\n",
|
| 1000 |
+
" summary_md = plot_dir / \"comparison_summary.md\"\n",
|
| 1001 |
+
" if summary_md.exists():\n",
|
| 1002 |
+
" display(Markdown(summary_md.read_text()))\n",
|
| 1003 |
+
"else:\n",
|
| 1004 |
+
" print(\"⚠️ No plot directory found.\")"
|
| 1005 |
]
|
| 1006 |
},
|
| 1007 |
{
|
| 1008 |
"cell_type": "markdown",
|
| 1009 |
+
"id": "a638d546",
|
| 1010 |
"metadata": {},
|
| 1011 |
"source": [
|
| 1012 |
"## 📋 Results Summary\n",
|
| 1013 |
"\n",
|
| 1014 |
+
"Expected progression for Qwen 2.5-3B-Instruct on CORP-ENV:\n",
|
| 1015 |
"\n",
|
| 1016 |
"| Stage | E1 Terminal Reward | M1 Terminal Reward | H1 Terminal Reward | M1 Success |\n",
|
| 1017 |
"|-------|-------------------|-------------------|-------------------|------------|\n",
|
|
|
|
| 1020 |
"| SFT | 0.910 | 0.943 | 0.889 | 100% |\n",
|
| 1021 |
"| RLVR | 0.910 | 0.932 | 0.779 | 80% |\n",
|
| 1022 |
"\n",
|
| 1023 |
+
"> **Key takeaway**: SFT dramatically improves M1 (budget reallocation) from 0% to 100% success rate. RLVR maintains strong performance while reducing reliance on fixed trajectories.\n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
"> **T4 Note**: Results may differ slightly on T4 due to FP16 precision (vs BF16) and reduced RLVR sampling. For best results, use the full hyperparameters on an A100/H100."
|
| 1026 |
]
|
| 1027 |
}
|
| 1028 |
],
|
|
|
|
| 1045 |
},
|
| 1046 |
"nbformat": 4,
|
| 1047 |
"nbformat_minor": 5
|
| 1048 |
+
}
|
training/train_rlvr.py
CHANGED
|
@@ -236,6 +236,7 @@ def sft_on_winners(
|
|
| 236 |
epochs: float,
|
| 237 |
max_steps: int,
|
| 238 |
max_seq_length: int,
|
|
|
|
| 239 |
) -> None:
|
| 240 |
"""Run a single SFT pass over the curated (prompt, best_completion) set."""
|
| 241 |
from datasets import Dataset
|
|
@@ -267,7 +268,8 @@ def sft_on_winners(
|
|
| 267 |
"save_steps": 10_000,
|
| 268 |
"save_total_limit": 1,
|
| 269 |
"optim": "adamw_8bit",
|
| 270 |
-
"bf16":
|
|
|
|
| 271 |
"report_to": "none",
|
| 272 |
"dataset_text_field": "text",
|
| 273 |
"push_to_hub": False,
|
|
@@ -376,6 +378,11 @@ def main() -> None:
|
|
| 376 |
action="store_true",
|
| 377 |
help="Disable LLM judge scoring for deterministic verifier-only runs.",
|
| 378 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
args = parser.parse_args()
|
| 380 |
|
| 381 |
if args.use_stub_workers:
|
|
@@ -405,10 +412,11 @@ def main() -> None:
|
|
| 405 |
print(f"Built {len(full_rows)} prompts from {args.examples}")
|
| 406 |
|
| 407 |
max_seq_len = args.max_prompt_length + args.max_completion_length
|
|
|
|
| 408 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 409 |
model_name=args.model,
|
| 410 |
max_seq_length=max_seq_len,
|
| 411 |
-
dtype=
|
| 412 |
load_in_4bit=True,
|
| 413 |
)
|
| 414 |
if getattr(tokenizer, "pad_token", None) is None and getattr(
|
|
@@ -438,9 +446,10 @@ def main() -> None:
|
|
| 438 |
random_state=args.seed,
|
| 439 |
)
|
| 440 |
|
|
|
|
| 441 |
for p in model.parameters():
|
| 442 |
if p.requires_grad and p.dtype == torch.float32:
|
| 443 |
-
p.data = p.data.to(
|
| 444 |
|
| 445 |
stats_path = Path(args.stats_file) if args.stats_file else None
|
| 446 |
if stats_path:
|
|
@@ -490,6 +499,7 @@ def main() -> None:
|
|
| 490 |
epochs=args.inner_epochs,
|
| 491 |
max_steps=args.inner_max_steps,
|
| 492 |
max_seq_length=max_seq_len,
|
|
|
|
| 493 |
)
|
| 494 |
|
| 495 |
Path(args.output).mkdir(parents=True, exist_ok=True)
|
|
|
|
| 236 |
epochs: float,
|
| 237 |
max_steps: int,
|
| 238 |
max_seq_length: int,
|
| 239 |
+
use_fp16: bool = False,
|
| 240 |
) -> None:
|
| 241 |
"""Run a single SFT pass over the curated (prompt, best_completion) set."""
|
| 242 |
from datasets import Dataset
|
|
|
|
| 268 |
"save_steps": 10_000,
|
| 269 |
"save_total_limit": 1,
|
| 270 |
"optim": "adamw_8bit",
|
| 271 |
+
"bf16": (not use_fp16) and torch.cuda.is_available(),
|
| 272 |
+
"fp16": use_fp16 and torch.cuda.is_available(),
|
| 273 |
"report_to": "none",
|
| 274 |
"dataset_text_field": "text",
|
| 275 |
"push_to_hub": False,
|
|
|
|
| 378 |
action="store_true",
|
| 379 |
help="Disable LLM judge scoring for deterministic verifier-only runs.",
|
| 380 |
)
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--fp16",
|
| 383 |
+
action="store_true",
|
| 384 |
+
help="Use fp16 instead of bf16 (required for T4 GPUs which lack bf16 support).",
|
| 385 |
+
)
|
| 386 |
args = parser.parse_args()
|
| 387 |
|
| 388 |
if args.use_stub_workers:
|
|
|
|
| 412 |
print(f"Built {len(full_rows)} prompts from {args.examples}")
|
| 413 |
|
| 414 |
max_seq_len = args.max_prompt_length + args.max_completion_length
|
| 415 |
+
load_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
| 416 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 417 |
model_name=args.model,
|
| 418 |
max_seq_length=max_seq_len,
|
| 419 |
+
dtype=load_dtype,
|
| 420 |
load_in_4bit=True,
|
| 421 |
)
|
| 422 |
if getattr(tokenizer, "pad_token", None) is None and getattr(
|
|
|
|
| 446 |
random_state=args.seed,
|
| 447 |
)
|
| 448 |
|
| 449 |
+
cast_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
| 450 |
for p in model.parameters():
|
| 451 |
if p.requires_grad and p.dtype == torch.float32:
|
| 452 |
+
p.data = p.data.to(cast_dtype)
|
| 453 |
|
| 454 |
stats_path = Path(args.stats_file) if args.stats_file else None
|
| 455 |
if stats_path:
|
|
|
|
| 499 |
epochs=args.inner_epochs,
|
| 500 |
max_steps=args.inner_max_steps,
|
| 501 |
max_seq_length=max_seq_len,
|
| 502 |
+
use_fp16=args.fp16,
|
| 503 |
)
|
| 504 |
|
| 505 |
Path(args.output).mkdir(parents=True, exist_ok=True)
|
training/train_sft.py
CHANGED
|
@@ -210,10 +210,11 @@ def main() -> None:
|
|
| 210 |
if args.dataset_num_proc == 0 and "dataset_num_proc" in allowed:
|
| 211 |
args = argparse.Namespace(**{**vars(args), "dataset_num_proc": None})
|
| 212 |
|
|
|
|
| 213 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 214 |
model_name=args.model,
|
| 215 |
max_seq_length=args.max_seq_length,
|
| 216 |
-
dtype=
|
| 217 |
load_in_4bit=True,
|
| 218 |
)
|
| 219 |
if getattr(tokenizer, "pad_token", None) is None and getattr(
|
|
|
|
| 210 |
if args.dataset_num_proc == 0 and "dataset_num_proc" in allowed:
|
| 211 |
args = argparse.Namespace(**{**vars(args), "dataset_num_proc": None})
|
| 212 |
|
| 213 |
+
load_dtype = torch.float16 if args.fp16 else None
|
| 214 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 215 |
model_name=args.model,
|
| 216 |
max_seq_length=args.max_seq_length,
|
| 217 |
+
dtype=load_dtype,
|
| 218 |
load_in_4bit=True,
|
| 219 |
)
|
| 220 |
if getattr(tokenizer, "pad_token", None) is None and getattr(
|