Navigam commited on
Commit
5c8287c
·
1 Parent(s): 978c5b4

feat: add training pipeline with SFT and RLVR support for Qwen 2.5-3B-Instruct

Browse files
notebooks/training.ipynb CHANGED
@@ -2,27 +2,37 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
- "# 🏢 CORP-ENV · Qwen 2.5-7B-Instruct — SFT + RLVR Training\n",
8
  "\n",
9
- "**End-to-end reproducible notebook** for training a Qwen 2.5-7B-Instruct agent on CORP-ENV using Supervised Fine-Tuning (SFT) followed by Rejection-Sampling RL on Verifiable Rewards (RLVR).\n",
 
 
 
 
 
 
 
 
10
  "\n",
11
  "CORP-ENV is a multi-agent corporate decision environment where a Master Agent governs a **Shared Workspace Document (SWD)** across long-horizon planning episodes, coordinating frozen worker agents. Rewards measure SWD integrity, task completion, milestone adherence, reasoning density, and LLM-judge scores.\n",
12
  "\n",
13
  "| Component | Detail |\n",
14
  "|---|---|\n",
15
- "| **Base model** | `Qwen/Qwen2.5-7B-Instruct` |\n",
16
  "| **SFT script** | `training/train_sft.py` |\n",
17
  "| **RLVR script** | `training/train_rlvr.py` |\n",
18
  "| **Tasks** | E1 Launch Readiness, M1 Budget Reallocation, H1 Acquisition Defence |\n",
19
- "| **Runtime** | Colab GPU / Lightning AI H100 / Any CUDA GPU |\n",
20
  "\n",
21
  "---"
22
  ]
23
  },
24
  {
25
  "cell_type": "markdown",
 
26
  "metadata": {},
27
  "source": [
28
  "## 1️⃣ Setup & Installation"
@@ -31,49 +41,92 @@
31
  {
32
  "cell_type": "code",
33
  "execution_count": null,
 
34
  "metadata": {},
35
  "outputs": [],
36
  "source": [
37
  "import os\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  "\n",
39
  "# ===== Configuration =====\n",
40
  "REPO_URL = \"https://huggingface.co/spaces/Navigam/corp-env\" # Change to your repo\n",
41
- "BASE_MODEL = \"Qwen/Qwen2.5-7B-Instruct\"\n",
42
  "HF_ORG_OR_USER = \"Navigam\" # Your HF username/org\n",
43
  "\n",
44
- "# SFT hyperparameters\n",
45
- "SFT_MAX_STEPS = 30 # Quick judge smoke; set -1 for full-epoch training\n",
46
  "SFT_EPOCHS = 2.0\n",
47
  "SFT_LR = 2e-4\n",
48
  "SFT_BATCH_SIZE = 1\n",
49
  "SFT_GRAD_ACCUM = 8\n",
 
50
  "\n",
51
- "# RLVR hyperparameters\n",
52
  "RLVR_ROUNDS = 3\n",
53
- "RLVR_MAX_PROMPTS = 128\n",
54
- "RLVR_N_SAMPLES = 8\n",
55
  "RLVR_TEMPERATURE = 0.7\n",
 
 
56
  "\n",
57
  "# Eval\n",
58
  "EVAL_EPISODES = 3\n",
59
- "EVAL_MAX_STEPS = 30"
 
 
 
 
 
 
60
  ]
61
  },
62
  {
63
  "cell_type": "code",
64
  "execution_count": null,
 
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
- "# Clone and install\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  "!git clone {REPO_URL} corp_gym 2>/dev/null || echo 'Repo already cloned'\n",
70
  "%cd corp_gym\n",
71
- "!pip install -U pip\n",
72
- "!pip install -e \".[training,plots]\""
73
  ]
74
  },
75
  {
76
  "cell_type": "markdown",
 
77
  "metadata": {},
78
  "source": [
79
  "## 2️⃣ Hugging Face Login (optional)"
@@ -82,6 +135,7 @@
82
  {
83
  "cell_type": "code",
84
  "execution_count": null,
 
85
  "metadata": {},
86
  "outputs": [],
87
  "source": [
@@ -91,6 +145,325 @@
91
  },
92
  {
93
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  "metadata": {},
95
  "source": [
96
  "## 3️⃣ Environment Validation\n",
@@ -101,6 +474,7 @@
101
  {
102
  "cell_type": "code",
103
  "execution_count": null,
 
104
  "metadata": {},
105
  "outputs": [],
106
  "source": [
@@ -110,6 +484,7 @@
110
  },
111
  {
112
  "cell_type": "markdown",
 
113
  "metadata": {},
114
  "source": [
115
  "## 4️⃣ Data Preparation\n",
@@ -120,6 +495,7 @@
120
  {
121
  "cell_type": "code",
122
  "execution_count": null,
 
123
  "metadata": {},
124
  "outputs": [],
125
  "source": [
@@ -148,27 +524,56 @@
148
  {
149
  "cell_type": "code",
150
  "execution_count": null,
 
151
  "metadata": {},
152
  "outputs": [],
153
  "source": [
154
- "# Check data stats\n",
155
- "import json\n",
156
- "from pathlib import Path\n",
157
- "\n",
158
  "sft_path = Path(\"data/sft/e1_m1_h1_examples.jsonl\")\n",
159
  "if sft_path.exists():\n",
160
  " lines = [json.loads(l) for l in sft_path.read_text().strip().splitlines() if l.strip()]\n",
161
  " print(f\"\\n✅ SFT dataset: {len(lines)} examples\")\n",
162
- " # Count by number of messages\n",
163
  " turn_counts = [len(ex['messages']) for ex in lines]\n",
164
  " print(f\" Avg turns per example: {sum(turn_counts)/len(turn_counts):.1f}\")\n",
165
  " print(f\" Min/Max turns: {min(turn_counts)} / {max(turn_counts)}\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  "else:\n",
167
  " print(\"❌ SFT dataset not found. Check data preparation above.\")"
168
  ]
169
  },
170
  {
171
  "cell_type": "markdown",
 
172
  "metadata": {},
173
  "source": [
174
  "## 5️⃣ Baseline Evaluation\n",
@@ -179,6 +584,7 @@
179
  {
180
  "cell_type": "code",
181
  "execution_count": null,
 
182
  "metadata": {},
183
  "outputs": [],
184
  "source": [
@@ -186,39 +592,142 @@
186
  "!python eval.py --policy oracle --label oracle --episodes {EVAL_EPISODES}"
187
  ]
188
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  {
190
  "cell_type": "markdown",
 
191
  "metadata": {},
192
  "source": [
193
  "## 6️⃣ SFT Training (Unsloth + TRL)\n",
194
  "\n",
195
- "Fine-tune Qwen 2.5-7B-Instruct with LoRA using verified CORP-ENV trajectories.\n",
196
  "\n",
197
  "- Uses `unsloth.FastLanguageModel` for 4-bit QLoRA\n",
198
  "- Uses `trl.SFTTrainer` with messages-format conversational SFT\n",
199
- "- LoRA `r=32`, targets all attention + MLP projections"
 
200
  ]
201
  },
202
  {
203
  "cell_type": "code",
204
  "execution_count": null,
 
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
 
 
 
208
  "!python training/train_sft.py \\\n",
209
  " --model {BASE_MODEL} \\\n",
210
  " --data data/sft/e1_m1_h1_examples.jsonl \\\n",
211
  " --output outputs/sft_adapter \\\n",
 
212
  " --max-steps {SFT_MAX_STEPS} \\\n",
213
  " --epochs {SFT_EPOCHS} \\\n",
214
  " --lr {SFT_LR} \\\n",
215
  " --batch-size {SFT_BATCH_SIZE} \\\n",
216
  " --grad-accum {SFT_GRAD_ACCUM} \\\n",
217
- " --push-to-hub {HF_ORG_OR_USER}/corp-env-sft-qwen2.5-7b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  ]
219
  },
220
  {
221
  "cell_type": "markdown",
 
222
  "metadata": {},
223
  "source": [
224
  "## 7️⃣ Evaluate SFT Adapter"
@@ -227,20 +736,58 @@
227
  {
228
  "cell_type": "code",
229
  "execution_count": null,
 
230
  "metadata": {},
231
  "outputs": [],
232
  "source": [
 
 
 
 
 
 
233
  "!python eval.py \\\n",
234
  " --policy hf \\\n",
235
  " --label sft \\\n",
236
  " --model {BASE_MODEL} \\\n",
237
  " --adapter outputs/sft_adapter \\\n",
238
  " --episodes {EVAL_EPISODES} \\\n",
239
- " --max-steps {EVAL_MAX_STEPS}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  ]
241
  },
242
  {
243
  "cell_type": "markdown",
 
244
  "metadata": {},
245
  "source": [
246
  "## 8️⃣ RLVR Training (Rejection-Sampling FT)\n",
@@ -252,15 +799,25 @@
252
  "4. SFT on that curated set\n",
253
  "5. Repeating for multiple outer rounds\n",
254
  "\n",
255
- "This avoids the zero-variance gradient problem seen with GRPO on CORP-ENV."
 
 
256
  ]
257
  },
258
  {
259
  "cell_type": "code",
260
  "execution_count": null,
 
261
  "metadata": {},
262
  "outputs": [],
263
  "source": [
 
 
 
 
 
 
 
264
  "!python training/train_rlvr.py \\\n",
265
  " --model {BASE_MODEL} \\\n",
266
  " --adapter outputs/sft_adapter \\\n",
@@ -270,15 +827,40 @@
270
  " --n-samples {RLVR_N_SAMPLES} \\\n",
271
  " --temperature {RLVR_TEMPERATURE} \\\n",
272
  " --max-prompts {RLVR_MAX_PROMPTS} \\\n",
 
 
273
  " --strict-json \\\n",
274
  " --use-stub-workers \\\n",
275
  " --disable-llm-judge \\\n",
276
- " --stats-file results/runs/rlvr_qwen2.5_7b_stats.jsonl \\\n",
277
- " --push-to-hub {HF_ORG_OR_USER}/corp-env-rlvr-qwen2.5-7b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  ]
279
  },
280
  {
281
  "cell_type": "markdown",
 
282
  "metadata": {},
283
  "source": [
284
  "## 9️⃣ Evaluate RLVR Adapter"
@@ -287,72 +869,149 @@
287
  {
288
  "cell_type": "code",
289
  "execution_count": null,
 
290
  "metadata": {},
291
  "outputs": [],
292
  "source": [
 
 
 
 
 
293
  "!python eval.py \\\n",
294
  " --policy hf \\\n",
295
  " --label rlvr \\\n",
296
  " --model {BASE_MODEL} \\\n",
297
  " --adapter outputs/rlvr_adapter \\\n",
298
  " --episodes {EVAL_EPISODES} \\\n",
299
- " --max-steps {EVAL_MAX_STEPS}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  ]
301
  },
302
  {
303
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
 
 
304
  "metadata": {},
 
305
  "source": [
306
- "## 📊 Generate Comparison Plots\n",
 
 
 
 
 
307
  "\n",
308
- "Aggregate all eval runs and generate:\n",
309
- "- Terminal reward comparison (grouped bar chart)\n",
310
- "- Verifier pass rate by task\n",
311
- "- Invalid action rate\n",
312
- "- Reward curve over episode steps"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  ]
314
  },
315
  {
316
  "cell_type": "code",
317
  "execution_count": null,
 
318
  "metadata": {},
319
  "outputs": [],
320
  "source": [
 
321
  "!python plot_results.py \\\n",
322
  " --inputs results/runs \\\n",
323
- " --output-dir results/model_compare_qwen25_7b"
324
  ]
325
  },
326
  {
327
  "cell_type": "code",
328
  "execution_count": null,
 
329
  "metadata": {},
330
  "outputs": [],
331
  "source": [
332
  "from IPython.display import Image, display, Markdown\n",
333
- "from pathlib import Path\n",
334
  "\n",
335
- "plot_dir = Path(\"results/model_compare_qwen25_7b\")\n",
336
  "if not plot_dir.exists():\n",
337
  " plot_dir = Path(\"results/model_compare_qwen25_fresh_no_grpo_ep5rlvr\")\n",
338
  "\n",
339
- "for png in sorted(plot_dir.glob(\"*.png\")):\n",
340
- " display(Markdown(f\"### {png.stem.replace('_', ' ').title()}\"))\n",
341
- " display(Image(filename=str(png), width=800))\n",
 
342
  "\n",
343
- "# Show summary table\n",
344
- "summary_md = plot_dir / \"comparison_summary.md\"\n",
345
- "if summary_md.exists():\n",
346
- " display(Markdown(summary_md.read_text()))"
 
 
347
  ]
348
  },
349
  {
350
  "cell_type": "markdown",
 
351
  "metadata": {},
352
  "source": [
353
  "## 📋 Results Summary\n",
354
  "\n",
355
- "Expected progression for Qwen 2.5-7B-Instruct on CORP-ENV:\n",
356
  "\n",
357
  "| Stage | E1 Terminal Reward | M1 Terminal Reward | H1 Terminal Reward | M1 Success |\n",
358
  "|-------|-------------------|-------------------|-------------------|------------|\n",
@@ -361,7 +1020,9 @@
361
  "| SFT | 0.910 | 0.943 | 0.889 | 100% |\n",
362
  "| RLVR | 0.910 | 0.932 | 0.779 | 80% |\n",
363
  "\n",
364
- "> **Key takeaway**: SFT dramatically improves M1 (budget reallocation) from 0% to 100% success rate. RLVR maintains strong performance while reducing reliance on fixed trajectories."
 
 
365
  ]
366
  }
367
  ],
@@ -384,4 +1045,4 @@
384
  },
385
  "nbformat": 4,
386
  "nbformat_minor": 5
387
- }
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "23a31c02",
6
  "metadata": {},
7
  "source": [
8
+ "# 🏢 CORP-ENV · Qwen 2.5-3B-Instruct — SFT + RLVR Training\n",
9
  "\n",
10
+ "**End-to-end reproducible notebook** for training a Qwen 2.5-3B-Instruct agent on CORP-ENV using Supervised Fine-Tuning (SFT) followed by Rejection-Sampling RL on Verifiable Rewards (RLVR).\n",
11
+ "\n",
12
+ "### ⚡ Optimized for Google Colab T4 (16 GB VRAM)\n",
13
+ "\n",
14
+ "This notebook is configured to run end-to-end on a **free-tier T4 GPU**:\n",
15
+ "- 4-bit QLoRA quantization to fit 7B model in ~4 GB VRAM\n",
16
+ "- **FP16** precision (T4 lacks BF16 hardware support)\n",
17
+ "- Reduced sequence lengths (4096 tokens) and RLVR samples (4 per prompt)\n",
18
+ "- Inline visualizations after every training and evaluation step\n",
19
  "\n",
20
  "CORP-ENV is a multi-agent corporate decision environment where a Master Agent governs a **Shared Workspace Document (SWD)** across long-horizon planning episodes, coordinating frozen worker agents. Rewards measure SWD integrity, task completion, milestone adherence, reasoning density, and LLM-judge scores.\n",
21
  "\n",
22
  "| Component | Detail |\n",
23
  "|---|---|\n",
24
+ "| **Base model** | `Qwen/Qwen2.5-3B-Instruct` |\n",
25
  "| **SFT script** | `training/train_sft.py` |\n",
26
  "| **RLVR script** | `training/train_rlvr.py` |\n",
27
  "| **Tasks** | E1 Launch Readiness, M1 Budget Reallocation, H1 Acquisition Defence |\n",
28
+ "| **Runtime** | ✅ Google Colab T4 / Lightning AI H100 / Any CUDA GPU |\n",
29
  "\n",
30
  "---"
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
35
+ "id": "15d441af",
36
  "metadata": {},
37
  "source": [
38
  "## 1️⃣ Setup & Installation"
 
41
  {
42
  "cell_type": "code",
43
  "execution_count": null,
44
+ "id": "e9394fab",
45
  "metadata": {},
46
  "outputs": [],
47
  "source": [
48
  "import os\n",
49
+ "import torch\n",
50
+ "\n",
51
+ "# ===== GPU Detection & Configuration =====\n",
52
+ "if torch.cuda.is_available():\n",
53
+ " gpu_name = torch.cuda.get_device_name(0)\n",
54
+ " gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
55
+ " has_bf16 = torch.cuda.is_bf16_supported()\n",
56
+ " print(f\"🖥️ GPU: {gpu_name} ({gpu_mem:.1f} GB)\")\n",
57
+ " print(f\" BF16 support: {'✅ Yes' if has_bf16 else '❌ No (using FP16)'}\")\n",
58
+ "else:\n",
59
+ " raise RuntimeError(\"❌ No GPU detected! Enable GPU in Colab: Runtime → Change runtime type → T4 GPU\")\n",
60
+ "\n",
61
+ "# Auto-detect hardware constraints\n",
62
+ "LOW_MEMORY = gpu_mem < 20.0 # e.g., T4 (16GB), RTX 4080 (16GB) need smaller batches/sequences\n",
63
+ "USE_FP16 = not has_bf16 # e.g., T4 and V100 dont support BF16\n",
64
  "\n",
65
  "# ===== Configuration =====\n",
66
  "REPO_URL = \"https://huggingface.co/spaces/Navigam/corp-env\" # Change to your repo\n",
67
+ "BASE_MODEL = \"Qwen/Qwen2.5-3B-Instruct\"\n",
68
  "HF_ORG_OR_USER = \"Navigam\" # Your HF username/org\n",
69
  "\n",
70
+ "# SFT hyperparameters (T4-optimized)\n",
71
+ "SFT_MAX_STEPS = 30 # Quick judge smoke; set -1 for full-epoch training\n",
72
  "SFT_EPOCHS = 2.0\n",
73
  "SFT_LR = 2e-4\n",
74
  "SFT_BATCH_SIZE = 1\n",
75
  "SFT_GRAD_ACCUM = 8\n",
76
+ " \"SFT_MAX_SEQ_LEN = 3072 if LOW_MEMORY else 8192 # Reduced for <20GB VRAM\\n\",\n",
77
  "\n",
78
+ "# RLVR hyperparameters (T4-optimized)\n",
79
  "RLVR_ROUNDS = 3\n",
80
+ "RLVR_MAX_PROMPTS = 32 if LOW_MEMORY else 128 # Fewer prompts to fit in T4 time/memory\n",
81
+ " \"RLVR_N_SAMPLES = 4 if LOW_MEMORY else 8 # Fewer samples per prompt\\n\",\n",
82
  "RLVR_TEMPERATURE = 0.7\n",
83
+ " \"RLVR_MAX_PROMPT_LEN = 3072 if LOW_MEMORY else 8192\\n\",\n",
84
+ "RLVR_MAX_COMPLETION_LEN = 512\n",
85
  "\n",
86
  "# Eval\n",
87
  "EVAL_EPISODES = 3\n",
88
+ "EVAL_MAX_STEPS = 30\n",
89
+ "\n",
90
+ "# FP16 flag for training scripts\n",
91
+ "FP16_FLAG = \"--fp16\" if USE_FP16 else \"\"\n",
92
+ "\n",
93
+ "print(f\"\\n📋 Config: model={BASE_MODEL}, fp16={USE_FP16}, seq_len={SFT_MAX_SEQ_LEN}\")\n",
94
+ "print(f\" RLVR: rounds={RLVR_ROUNDS}, prompts={RLVR_MAX_PROMPTS}, samples={RLVR_N_SAMPLES}\")"
95
  ]
96
  },
97
  {
98
  "cell_type": "code",
99
  "execution_count": null,
100
+ "id": "1fccadd9",
101
  "metadata": {},
102
  "outputs": [],
103
  "source": [
104
+ "# ===== Install dependencies (Colab-optimized) =====\n",
105
+ "# Unsloth requires a specific install path for Colab\n",
106
+ "import subprocess, sys\n",
107
+ "\n",
108
+ "# Check if running in Colab\n",
109
+ "IN_COLAB = 'google.colab' in sys.modules\n",
110
+ "\n",
111
+ "if IN_COLAB:\n",
112
+ " print(\"🔧 Installing Unsloth for Colab...\")\n",
113
+ " !pip install -q --no-deps trl peft accelerate bitsandbytes triton\n",
114
+ " !pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
115
+ " !pip install -q --no-deps unsloth_zoo\n",
116
+ " !pip install -q xformers\n",
117
+ "else:\n",
118
+ " print(\"🔧 Installing from pyproject.toml...\")\n",
119
+ " !pip install -q -U pip\n",
120
+ "\n",
121
+ "# Clone and install CORP-ENV\n",
122
  "!git clone {REPO_URL} corp_gym 2>/dev/null || echo 'Repo already cloned'\n",
123
  "%cd corp_gym\n",
124
+ "!pip install -q -e \".[training,plots]\""
 
125
  ]
126
  },
127
  {
128
  "cell_type": "markdown",
129
+ "id": "076d342b",
130
  "metadata": {},
131
  "source": [
132
  "## 2️⃣ Hugging Face Login (optional)"
 
135
  {
136
  "cell_type": "code",
137
  "execution_count": null,
138
+ "id": "df0904d7",
139
  "metadata": {},
140
  "outputs": [],
141
  "source": [
 
145
  },
146
  {
147
  "cell_type": "markdown",
148
+ "id": "7d4a001c",
149
+ "metadata": {},
150
+ "source": [
151
+ "## 📊 Visualization Utilities\n",
152
+ "\n",
153
+ "Helper functions for inline charts after every eval and training step."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "3930908e",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "import json\n",
164
+ "import matplotlib.pyplot as plt\n",
165
+ "import matplotlib.ticker as mticker\n",
166
+ "import numpy as np\n",
167
+ "from pathlib import Path\n",
168
+ "from collections import defaultdict\n",
169
+ "from IPython.display import display, Markdown, HTML\n",
170
+ "\n",
171
+ "# ---- Plotting style ----\n",
172
+ "plt.rcParams.update({\n",
173
+ " 'figure.facecolor': '#0d1117',\n",
174
+ " 'axes.facecolor': '#161b22',\n",
175
+ " 'axes.edgecolor': '#30363d',\n",
176
+ " 'axes.labelcolor': '#c9d1d9',\n",
177
+ " 'text.color': '#c9d1d9',\n",
178
+ " 'xtick.color': '#8b949e',\n",
179
+ " 'ytick.color': '#8b949e',\n",
180
+ " 'grid.color': '#21262d',\n",
181
+ " 'font.family': 'sans-serif',\n",
182
+ " 'font.size': 11,\n",
183
+ "})\n",
184
+ "\n",
185
+ "PALETTE = {\n",
186
+ " 'baseline': '#8b949e',\n",
187
+ " 'oracle': '#a371f7',\n",
188
+ " 'sft': '#3fb950',\n",
189
+ " 'rlvr': '#f0883e',\n",
190
+ " 'e1_launch_readiness': '#58a6ff',\n",
191
+ " 'm1_budget_reallocation': '#d2a8ff',\n",
192
+ " 'h1_acquisition_defence': '#7ee787',\n",
193
+ "}\n",
194
+ "TASK_SHORT = {\n",
195
+ " 'e1_launch_readiness': 'E1 Launch',\n",
196
+ " 'm1_budget_reallocation': 'M1 Budget',\n",
197
+ " 'h1_acquisition_defence': 'H1 Acquisition',\n",
198
+ "}\n",
199
+ "\n",
200
+ "def load_eval_jsonl(path):\n",
201
+ " \"\"\"Load evaluation JSONL file.\"\"\"\n",
202
+ " rows = []\n",
203
+ " p = Path(path)\n",
204
+ " if p.is_dir():\n",
205
+ " for f in sorted(p.rglob('*_eval.jsonl')):\n",
206
+ " rows.extend(load_eval_jsonl(f))\n",
207
+ " for f in sorted(p.rglob('eval.jsonl')):\n",
208
+ " rows.extend(load_eval_jsonl(f))\n",
209
+ " return rows\n",
210
+ " if p.exists():\n",
211
+ " for line in p.read_text(encoding='utf-8').strip().splitlines():\n",
212
+ " if line.strip():\n",
213
+ " rows.append(json.loads(line))\n",
214
+ " return rows\n",
215
+ "\n",
216
+ "def plot_eval_dashboard(rows, title=\"Evaluation Results\"):\n",
217
+ " \"\"\"Create a 2x2 dashboard of evaluation metrics.\"\"\"\n",
218
+ " if not rows:\n",
219
+ " print(\"⚠️ No evaluation data to plot.\")\n",
220
+ " return\n",
221
+ "\n",
222
+ " # Group by task\n",
223
+ " by_task = defaultdict(list)\n",
224
+ " for r in rows:\n",
225
+ " by_task[r['task_id']].append(r)\n",
226
+ "\n",
227
+ " tasks = sorted(by_task.keys())\n",
228
+ " task_labels = [TASK_SHORT.get(t, t) for t in tasks]\n",
229
+ "\n",
230
+ " # Compute metrics\n",
231
+ " avg_reward = [np.mean([r['terminal_reward'] for r in by_task[t]]) for t in tasks]\n",
232
+ " avg_pass = [np.mean([r['verifier_pass_rate'] for r in by_task[t]]) for t in tasks]\n",
233
+ " success_rate = [np.mean([1 if r.get('success') else 0 for r in by_task[t]]) for t in tasks]\n",
234
+ " avg_steps = [np.mean([r.get('steps', 0) for r in by_task[t]]) for t in tasks]\n",
235
+ "\n",
236
+ " fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
237
+ " fig.suptitle(title, fontsize=18, fontweight='bold', color='#f0f6fc', y=0.98)\n",
238
+ "\n",
239
+ " # -- Terminal Reward --\n",
240
+ " ax = axes[0, 0]\n",
241
+ " colors = [PALETTE.get(t, '#58a6ff') for t in tasks]\n",
242
+ " bars = ax.bar(task_labels, avg_reward, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
243
+ " for bar, val in zip(bars, avg_reward):\n",
244
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
245
+ " f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
246
+ " ax.set_title('Terminal Reward', fontsize=13, fontweight='bold')\n",
247
+ " ax.set_ylim(0, 1.15)\n",
248
+ " ax.grid(axis='y', alpha=0.3)\n",
249
+ "\n",
250
+ " # -- Verifier Pass Rate --\n",
251
+ " ax = axes[0, 1]\n",
252
+ " bars = ax.bar(task_labels, avg_pass, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
253
+ " for bar, val in zip(bars, avg_pass):\n",
254
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
255
+ " f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
256
+ " ax.set_title('Verifier Pass Rate', fontsize=13, fontweight='bold')\n",
257
+ " ax.set_ylim(0, 1.15)\n",
258
+ " ax.grid(axis='y', alpha=0.3)\n",
259
+ "\n",
260
+ " # -- Success Rate --\n",
261
+ " ax = axes[1, 0]\n",
262
+ " bars = ax.bar(task_labels, success_rate, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
263
+ " for bar, val in zip(bars, success_rate):\n",
264
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
265
+ " f'{val:.0%}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
266
+ " ax.set_title('Success Rate', fontsize=13, fontweight='bold')\n",
267
+ " ax.set_ylim(0, 1.25)\n",
268
+ " ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n",
269
+ " ax.grid(axis='y', alpha=0.3)\n",
270
+ "\n",
271
+ " # -- Avg Steps --\n",
272
+ " ax = axes[1, 1]\n",
273
+ " bars = ax.bar(task_labels, avg_steps, color=colors, edgecolor='#30363d', linewidth=0.8)\n",
274
+ " for bar, val in zip(bars, avg_steps):\n",
275
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,\n",
276
+ " f'{val:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='#f0f6fc')\n",
277
+ " ax.set_title('Average Steps per Episode', fontsize=13, fontweight='bold')\n",
278
+ " ax.grid(axis='y', alpha=0.3)\n",
279
+ "\n",
280
+ " for ax in axes.flat:\n",
281
+ " ax.spines['top'].set_visible(False)\n",
282
+ " ax.spines['right'].set_visible(False)\n",
283
+ "\n",
284
+ " fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
285
+ " plt.show()\n",
286
+ "\n",
287
+ " # Print summary table\n",
288
+ " display(Markdown(\"### 📋 Summary Table\"))\n",
289
+ " header = \"| Task | Terminal Reward | Verifier Pass | Success Rate | Avg Steps |\"\n",
290
+ " sep = \"|------|---------------|--------------|-------------|----------|\"\n",
291
+ " lines = [header, sep]\n",
292
+ " for i, t in enumerate(tasks):\n",
293
+ " lines.append(f\"| {TASK_SHORT.get(t, t)} | {avg_reward[i]:.3f} | {avg_pass[i]:.3f} | {success_rate[i]:.0%} | {avg_steps[i]:.1f} |\")\n",
294
+ " display(Markdown('\\n'.join(lines)))\n",
295
+ "\n",
296
+ "\n",
297
+ "def plot_reward_traces(rows, title=\"Reward Traces\"):\n",
298
+ " \"\"\"Plot reward curves over episode steps.\"\"\"\n",
299
+ " traces_by_task = defaultdict(list)\n",
300
+ " for r in rows:\n",
301
+ " trace = r.get('reward_trace', [])\n",
302
+ " if trace:\n",
303
+ " traces_by_task[r['task_id']].append([float(x) for x in trace])\n",
304
+ "\n",
305
+ " if not traces_by_task:\n",
306
+ " return\n",
307
+ "\n",
308
+ " fig, ax = plt.subplots(figsize=(12, 5))\n",
309
+ " for task_id, traces in sorted(traces_by_task.items()):\n",
310
+ " max_len = max(len(t) for t in traces)\n",
311
+ " means = []\n",
312
+ " for idx in range(max_len):\n",
313
+ " vals = [t[idx] for t in traces if idx < len(t)]\n",
314
+ " means.append(np.mean(vals))\n",
315
+ " xs = range(1, max_len + 1)\n",
316
+ " color = PALETTE.get(task_id, '#58a6ff')\n",
317
+ " ax.plot(xs, means, marker='o', linewidth=2.2, markersize=4,\n",
318
+ " label=TASK_SHORT.get(task_id, task_id), color=color)\n",
319
+ " if len(traces) > 1:\n",
320
+ " mins = [min(t[i] for t in traces if i < len(t)) for i in range(max_len)]\n",
321
+ " maxs = [max(t[i] for t in traces if i < len(t)) for i in range(max_len)]\n",
322
+ " ax.fill_between(xs, mins, maxs, alpha=0.15, color=color)\n",
323
+ "\n",
324
+ " ax.set_title(title, fontsize=15, fontweight='bold')\n",
325
+ " ax.set_xlabel('Environment Step')\n",
326
+ " ax.set_ylabel('Step Reward')\n",
327
+ " ax.axhline(0, color='#484f58', linewidth=0.8, alpha=0.5)\n",
328
+ " ax.legend(frameon=False, fontsize=10)\n",
329
+ " ax.spines['top'].set_visible(False)\n",
330
+ " ax.spines['right'].set_visible(False)\n",
331
+ " ax.grid(axis='y', alpha=0.3)\n",
332
+ " fig.tight_layout()\n",
333
+ " plt.show()\n",
334
+ "\n",
335
+ "\n",
336
+ "def plot_stage_comparison(all_evals, metric='terminal_reward', title='Model Stage Comparison'):\n",
337
+ " \"\"\"Compare multiple evaluation stages side-by-side.\"\"\"\n",
338
+ " if not all_evals:\n",
339
+ " return\n",
340
+ "\n",
341
+ " stages = list(all_evals.keys())\n",
342
+ " all_tasks = sorted({r['task_id'] for rows in all_evals.values() for r in rows})\n",
343
+ " task_labels = [TASK_SHORT.get(t, t) for t in all_tasks]\n",
344
+ "\n",
345
+ " x = np.arange(len(all_tasks))\n",
346
+ " width = 0.8 / max(len(stages), 1)\n",
347
+ "\n",
348
+ " fig, ax = plt.subplots(figsize=(max(10, len(all_tasks) * 3), 6))\n",
349
+ " for idx, stage in enumerate(stages):\n",
350
+ " rows = all_evals[stage]\n",
351
+ " by_task = defaultdict(list)\n",
352
+ " for r in rows:\n",
353
+ " by_task[r['task_id']].append(float(r.get(metric, 0)))\n",
354
+ " vals = [np.mean(by_task.get(t, [0])) for t in all_tasks]\n",
355
+ " offsets = x - 0.4 + width/2 + idx * width\n",
356
+ " color = PALETTE.get(stage, f'C{idx}')\n",
357
+ " bars = ax.bar(offsets, vals, width, label=stage.upper(), color=color,\n",
358
+ " edgecolor='#30363d', linewidth=0.8)\n",
359
+ " for bar, val in zip(bars, vals):\n",
360
+ " if val > 0:\n",
361
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.015,\n",
362
+ " f'{val:.2f}', ha='center', va='bottom', fontsize=9,\n",
363
+ " fontweight='bold', color='#f0f6fc')\n",
364
+ "\n",
365
+ " ax.set_title(title, fontsize=16, fontweight='bold', color='#f0f6fc')\n",
366
+ " ax.set_xticks(x)\n",
367
+ " ax.set_xticklabels(task_labels)\n",
368
+ " ax.set_ylabel(metric.replace('_', ' ').title())\n",
369
+ " ax.set_ylim(0, 1.15)\n",
370
+ " ax.legend(frameon=False, fontsize=10, loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=len(stages))\n",
371
+ " ax.spines['top'].set_visible(False)\n",
372
+ " ax.spines['right'].set_visible(False)\n",
373
+ " ax.grid(axis='y', alpha=0.3)\n",
374
+ " fig.tight_layout()\n",
375
+ " plt.show()\n",
376
+ "\n",
377
+ "\n",
378
+ "def plot_rlvr_stats(stats_file):\n",
379
+ " \"\"\"Plot RLVR training stats per round.\"\"\"\n",
380
+ " p = Path(stats_file)\n",
381
+ " if not p.exists():\n",
382
+ " print(f\"⚠️ Stats file not found: {stats_file}\")\n",
383
+ " return\n",
384
+ "\n",
385
+ " stats = [json.loads(line) for line in p.read_text().strip().splitlines() if line.strip()]\n",
386
+ " if not stats:\n",
387
+ " return\n",
388
+ "\n",
389
+ " rounds = [s['round'] for s in stats]\n",
390
+ " keep_rates = [s['keep_rate'] for s in stats]\n",
391
+ " mean_best = [s['mean_best_reward'] for s in stats]\n",
392
+ " mean_any = [s['mean_sample_reward'] for s in stats]\n",
393
+ " kept_counts = [int(s['prompts_kept']) for s in stats]\n",
394
+ "\n",
395
+ " fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
396
+ " fig.suptitle('RLVR Training Progress', fontsize=16, fontweight='bold', color='#f0f6fc', y=1.02)\n",
397
+ "\n",
398
+ " # Keep rate\n",
399
+ " ax = axes[0]\n",
400
+ " ax.plot(rounds, keep_rates, marker='o', linewidth=2.5, color='#3fb950', markersize=8)\n",
401
+ " ax.fill_between(rounds, keep_rates, alpha=0.15, color='#3fb950')\n",
402
+ " ax.set_title('Keep Rate per Round', fontweight='bold')\n",
403
+ " ax.set_xlabel('Round')\n",
404
+ " ax.set_ylabel('Keep Rate')\n",
405
+ " ax.set_ylim(0, 1.05)\n",
406
+ " ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n",
407
+ " ax.grid(alpha=0.3)\n",
408
+ "\n",
409
+ " # Reward progression\n",
410
+ " ax = axes[1]\n",
411
+ " ax.plot(rounds, mean_best, marker='s', linewidth=2.5, color='#f0883e', markersize=8, label='Best')\n",
412
+ " ax.plot(rounds, mean_any, marker='D', linewidth=2.5, color='#58a6ff', markersize=7, label='Any sample')\n",
413
+ " ax.set_title('Mean Reward per Round', fontweight='bold')\n",
414
+ " ax.set_xlabel('Round')\n",
415
+ " ax.set_ylabel('Reward')\n",
416
+ " ax.legend(frameon=False)\n",
417
+ " ax.grid(alpha=0.3)\n",
418
+ "\n",
419
+ " # Prompts kept\n",
420
+ " ax = axes[2]\n",
421
+ " ax.bar(rounds, kept_counts, color='#a371f7', edgecolor='#30363d', linewidth=0.8)\n",
422
+ " for r, c in zip(rounds, kept_counts):\n",
423
+ " ax.text(r, c + 0.5, str(c), ha='center', fontweight='bold', fontsize=11, color='#f0f6fc')\n",
424
+ " ax.set_title('Prompts Kept (Winners)', fontweight='bold')\n",
425
+ " ax.set_xlabel('Round')\n",
426
+ " ax.set_ylabel('Count')\n",
427
+ " ax.grid(axis='y', alpha=0.3)\n",
428
+ "\n",
429
+ " for ax in axes:\n",
430
+ " ax.spines['top'].set_visible(False)\n",
431
+ " ax.spines['right'].set_visible(False)\n",
432
+ " ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))\n",
433
+ "\n",
434
+ " fig.tight_layout()\n",
435
+ " plt.show()\n",
436
+ "\n",
437
+ " # Print per-round summary\n",
438
+ " display(Markdown(\"### 📋 RLVR Round Summary\"))\n",
439
+ " header = \"| Round | Keep Rate | Mean Best Reward | Mean Any Reward | Prompts Kept | Time (s) |\"\n",
440
+ " sep = \"|-------|-----------|-----------------|----------------|-------------|----------|\"\n",
441
+ " lines = [header, sep]\n",
442
+ " for s in stats:\n",
443
+ " lines.append(f\"| {s['round']} | {s['keep_rate']:.1%} | {s['mean_best_reward']:.3f} | {s['mean_sample_reward']:.3f} | {int(s['prompts_kept'])} | {s['seconds']:.0f} |\")\n",
444
+ " display(Markdown('\\n'.join(lines)))\n",
445
+ "\n",
446
+ "\n",
447
+ "def gpu_status():\n",
448
+ " \"\"\"Print current GPU memory usage.\"\"\"\n",
449
+ " if torch.cuda.is_available():\n",
450
+ " alloc = torch.cuda.memory_allocated() / 1e9\n",
451
+ " cached = torch.cuda.memory_reserved() / 1e9\n",
452
+ " total = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
453
+ " pct = alloc / total * 100\n",
454
+ " bar_len, filled = 20, int(pct / 5)\n",
455
+ " bar = '█' * filled + '░' * (bar_len - filled)\n",
456
+ " print(f\"🖥️ GPU Memory: [{bar}] {alloc:.1f}/{total:.1f} GB ({pct:.0f}%) | Cached: {cached:.1f} GB\")\n",
457
+ "\n",
458
+ "\n",
459
+ "# Collect all eval results for final comparison\n",
460
+ "ALL_EVALS = {}\n",
461
+ "print(\"✅ Visualization utilities loaded.\")"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "markdown",
466
+ "id": "43b92bf5",
467
  "metadata": {},
468
  "source": [
469
  "## 3️⃣ Environment Validation\n",
 
474
  {
475
  "cell_type": "code",
476
  "execution_count": null,
477
+ "id": "71cfe355",
478
  "metadata": {},
479
  "outputs": [],
480
  "source": [
 
484
  },
485
  {
486
  "cell_type": "markdown",
487
+ "id": "0275c763",
488
  "metadata": {},
489
  "source": [
490
  "## 4️⃣ Data Preparation\n",
 
495
  {
496
  "cell_type": "code",
497
  "execution_count": null,
498
+ "id": "85901039",
499
  "metadata": {},
500
  "outputs": [],
501
  "source": [
 
524
  {
525
  "cell_type": "code",
526
  "execution_count": null,
527
+ "id": "eb6a7997",
528
  "metadata": {},
529
  "outputs": [],
530
  "source": [
531
+ "# Check data stats & visualize\n",
 
 
 
532
  "sft_path = Path(\"data/sft/e1_m1_h1_examples.jsonl\")\n",
533
  "if sft_path.exists():\n",
534
  " lines = [json.loads(l) for l in sft_path.read_text().strip().splitlines() if l.strip()]\n",
535
  " print(f\"\\n✅ SFT dataset: {len(lines)} examples\")\n",
 
536
  " turn_counts = [len(ex['messages']) for ex in lines]\n",
537
  " print(f\" Avg turns per example: {sum(turn_counts)/len(turn_counts):.1f}\")\n",
538
  " print(f\" Min/Max turns: {min(turn_counts)} / {max(turn_counts)}\")\n",
539
+ "\n",
540
+ " # Visualize data distribution\n",
541
+ " fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
542
+ " fig.suptitle('SFT Dataset Overview', fontsize=14, fontweight='bold', color='#f0f6fc')\n",
543
+ "\n",
544
+ " # Turns histogram\n",
545
+ " axes[0].hist(turn_counts, bins=range(min(turn_counts), max(turn_counts)+2),\n",
546
+ " color='#58a6ff', edgecolor='#30363d', alpha=0.85)\n",
547
+ " axes[0].set_title('Message Turns per Example', fontweight='bold')\n",
548
+ " axes[0].set_xlabel('Number of Turns')\n",
549
+ " axes[0].set_ylabel('Count')\n",
550
+ " axes[0].grid(axis='y', alpha=0.3)\n",
551
+ "\n",
552
+ " # Role distribution\n",
553
+ " role_counts = defaultdict(int)\n",
554
+ " for ex in lines:\n",
555
+ " for msg in ex['messages']:\n",
556
+ " role_counts[msg['role']] += 1\n",
557
+ " roles = list(role_counts.keys())\n",
558
+ " counts = list(role_counts.values())\n",
559
+ " role_colors = ['#a371f7', '#3fb950', '#58a6ff', '#f0883e'][:len(roles)]\n",
560
+ " axes[1].barh(roles, counts, color=role_colors, edgecolor='#30363d')\n",
561
+ " axes[1].set_title('Messages by Role', fontweight='bold')\n",
562
+ " axes[1].set_xlabel('Count')\n",
563
+ " axes[1].grid(axis='x', alpha=0.3)\n",
564
+ "\n",
565
+ " for ax in axes:\n",
566
+ " ax.spines['top'].set_visible(False)\n",
567
+ " ax.spines['right'].set_visible(False)\n",
568
+ " fig.tight_layout()\n",
569
+ " plt.show()\n",
570
  "else:\n",
571
  " print(\"❌ SFT dataset not found. Check data preparation above.\")"
572
  ]
573
  },
574
  {
575
  "cell_type": "markdown",
576
+ "id": "8c529b78",
577
  "metadata": {},
578
  "source": [
579
  "## 5️⃣ Baseline Evaluation\n",
 
584
  {
585
  "cell_type": "code",
586
  "execution_count": null,
587
+ "id": "9c5db0c1",
588
  "metadata": {},
589
  "outputs": [],
590
  "source": [
 
592
  "!python eval.py --policy oracle --label oracle --episodes {EVAL_EPISODES}"
593
  ]
594
  },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "id": "f106aaed",
599
+ "metadata": {},
600
+ "outputs": [],
601
+ "source": [
602
+ "# 📊 Visualize baseline results\n",
603
+ "display(Markdown(\"## 📊 Baseline Results\"))\n",
604
+ "\n",
605
+ "baseline_rows = load_eval_jsonl(\"results/runs\")\n",
606
+ "baseline_only = [r for r in baseline_rows if r.get('model_stage') == 'baseline']\n",
607
+ "oracle_only = [r for r in baseline_rows if r.get('model_stage') == 'oracle']\n",
608
+ "\n",
609
+ "if baseline_only:\n",
610
+ " display(Markdown(\"### 🔹 Scripted Weak Baseline\"))\n",
611
+ " plot_eval_dashboard(baseline_only, title=\"Scripted Weak Baseline\")\n",
612
+ " plot_reward_traces(baseline_only, title=\"Baseline Reward Traces\")\n",
613
+ " ALL_EVALS['baseline'] = baseline_only\n",
614
+ "\n",
615
+ "if oracle_only:\n",
616
+ " display(Markdown(\"### 🔹 Oracle Policy\"))\n",
617
+ " plot_eval_dashboard(oracle_only, title=\"Oracle Policy\")\n",
618
+ " plot_reward_traces(oracle_only, title=\"Oracle Reward Traces\")\n",
619
+ " ALL_EVALS['oracle'] = oracle_only\n",
620
+ "\n",
621
+ "# Side-by-side comparison if both exist\n",
622
+ "if baseline_only and oracle_only:\n",
623
+ " plot_stage_comparison(\n",
624
+ " {'baseline': baseline_only, 'oracle': oracle_only},\n",
625
+ " metric='terminal_reward',\n",
626
+ " title='Baseline vs Oracle — Terminal Reward'\n",
627
+ " )\n",
628
+ "gpu_status()"
629
+ ]
630
+ },
631
  {
632
  "cell_type": "markdown",
633
+ "id": "3011f739",
634
  "metadata": {},
635
  "source": [
636
  "## 6️⃣ SFT Training (Unsloth + TRL)\n",
637
  "\n",
638
+ "Fine-tune Qwen 2.5-3B-Instruct with LoRA using verified CORP-ENV trajectories.\n",
639
  "\n",
640
  "- Uses `unsloth.FastLanguageModel` for 4-bit QLoRA\n",
641
  "- Uses `trl.SFTTrainer` with messages-format conversational SFT\n",
642
+ "- LoRA `r=32`, targets all attention + MLP projections\n",
643
+ "- **FP16 on T4** (auto-detected), BF16 on Ampere+ GPUs"
644
  ]
645
  },
646
  {
647
  "cell_type": "code",
648
  "execution_count": null,
649
+ "id": "cb76d631",
650
  "metadata": {},
651
  "outputs": [],
652
  "source": [
653
+ "gpu_status()\n",
654
+ "print(f\"\\n🚀 Starting SFT training ({FP16_FLAG or 'bf16'} precision)...\\n\")\n",
655
+ "\n",
656
  "!python training/train_sft.py \\\n",
657
  " --model {BASE_MODEL} \\\n",
658
  " --data data/sft/e1_m1_h1_examples.jsonl \\\n",
659
  " --output outputs/sft_adapter \\\n",
660
+ " --max-seq-length {SFT_MAX_SEQ_LEN} \\\n",
661
  " --max-steps {SFT_MAX_STEPS} \\\n",
662
  " --epochs {SFT_EPOCHS} \\\n",
663
  " --lr {SFT_LR} \\\n",
664
  " --batch-size {SFT_BATCH_SIZE} \\\n",
665
  " --grad-accum {SFT_GRAD_ACCUM} \\\n",
666
+ " {FP16_FLAG}\n",
667
+ "\n",
668
+ "gpu_status()\n",
669
+ "print(\"\\n✅ SFT training complete!\")"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": null,
675
+ "id": "3755df03",
676
+ "metadata": {},
677
+ "outputs": [],
678
+ "source": [
679
+ "# 📊 Visualize SFT training logs\n",
680
+ "display(Markdown(\"## 📊 SFT Training Summary\"))\n",
681
+ "\n",
682
+ "# Check for trainer_state.json\n",
683
+ "state_file = Path(\"outputs/sft_adapter/trainer_state.json\")\n",
684
+ "if state_file.exists():\n",
685
+ " state = json.loads(state_file.read_text())\n",
686
+ " log_history = state.get('log_history', [])\n",
687
+ " if log_history:\n",
688
+ " steps = [l['step'] for l in log_history if 'loss' in l]\n",
689
+ " losses = [l['loss'] for l in log_history if 'loss' in l]\n",
690
+ " lrs = [l.get('learning_rate', 0) for l in log_history if 'loss' in l]\n",
691
+ "\n",
692
+ " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
693
+ " fig.suptitle('SFT Training Curves', fontsize=16, fontweight='bold', color='#f0f6fc')\n",
694
+ "\n",
695
+ " # Loss curve\n",
696
+ " axes[0].plot(steps, losses, linewidth=2.5, color='#f0883e', marker='o', markersize=5)\n",
697
+ " axes[0].set_title('Training Loss', fontweight='bold')\n",
698
+ " axes[0].set_xlabel('Step')\n",
699
+ " axes[0].set_ylabel('Loss')\n",
700
+ " axes[0].grid(alpha=0.3)\n",
701
+ "\n",
702
+ " # Learning rate schedule\n",
703
+ " axes[1].plot(steps, lrs, linewidth=2.5, color='#3fb950', marker='s', markersize=4)\n",
704
+ " axes[1].set_title('Learning Rate Schedule', fontweight='bold')\n",
705
+ " axes[1].set_xlabel('Step')\n",
706
+ " axes[1].set_ylabel('Learning Rate')\n",
707
+ " axes[1].ticklabel_format(axis='y', style='scientific', scilimits=(-4, -4))\n",
708
+ " axes[1].grid(alpha=0.3)\n",
709
+ "\n",
710
+ " for ax in axes:\n",
711
+ " ax.spines['top'].set_visible(False)\n",
712
+ " ax.spines['right'].set_visible(False)\n",
713
+ " fig.tight_layout()\n",
714
+ " plt.show()\n",
715
+ "\n",
716
+ " print(f\"\\n📈 Final loss: {losses[-1]:.4f} at step {steps[-1]}\")\n",
717
+ "else:\n",
718
+ " print(\"⚠️ No trainer_state.json found; training logs unavailable.\")\n",
719
+ "\n",
720
+ "# Check adapter files\n",
721
+ "adapter_dir = Path(\"outputs/sft_adapter\")\n",
722
+ "if adapter_dir.exists():\n",
723
+ " files = list(adapter_dir.glob(\"*\"))\n",
724
+ " total_mb = sum(f.stat().st_size for f in files if f.is_file()) / 1e6\n",
725
+ " print(f\"💾 Adapter saved: {len(files)} files, {total_mb:.1f} MB total\")"
726
  ]
727
  },
728
  {
729
  "cell_type": "markdown",
730
+ "id": "cd078c28",
731
  "metadata": {},
732
  "source": [
733
  "## 7️⃣ Evaluate SFT Adapter"
 
736
  {
737
  "cell_type": "code",
738
  "execution_count": null,
739
+ "id": "50594aef",
740
  "metadata": {},
741
  "outputs": [],
742
  "source": [
743
+ "# Clear GPU memory before loading eval model\n",
744
+ "import gc\n",
745
+ "gc.collect()\n",
746
+ "torch.cuda.empty_cache()\n",
747
+ "gpu_status()\n",
748
+ "\n",
749
  "!python eval.py \\\n",
750
  " --policy hf \\\n",
751
  " --label sft \\\n",
752
  " --model {BASE_MODEL} \\\n",
753
  " --adapter outputs/sft_adapter \\\n",
754
  " --episodes {EVAL_EPISODES} \\\n",
755
+ " --max-steps {EVAL_MAX_STEPS}\n",
756
+ "\n",
757
+ "gpu_status()"
758
+ ]
759
+ },
760
+ {
761
+ "cell_type": "code",
762
+ "execution_count": null,
763
+ "id": "37bc9dd8",
764
+ "metadata": {},
765
+ "outputs": [],
766
+ "source": [
767
+ "# 📊 Visualize SFT evaluation results\n",
768
+ "display(Markdown(\"## 📊 SFT Evaluation Results\"))\n",
769
+ "\n",
770
+ "sft_rows = [r for r in load_eval_jsonl(\"results/runs\") if r.get('model_stage') == 'sft']\n",
771
+ "if sft_rows:\n",
772
+ " plot_eval_dashboard(sft_rows, title=\"SFT Adapter Evaluation\")\n",
773
+ " plot_reward_traces(sft_rows, title=\"SFT Reward Traces\")\n",
774
+ " ALL_EVALS['sft'] = sft_rows\n",
775
+ "\n",
776
+ " # Compare baseline → SFT\n",
777
+ " display(Markdown(\"### 📈 Improvement: Baseline → SFT\"))\n",
778
+ " comparison = {k: v for k, v in ALL_EVALS.items() if k in ('baseline', 'oracle', 'sft')}\n",
779
+ " if len(comparison) > 1:\n",
780
+ " plot_stage_comparison(comparison, metric='terminal_reward',\n",
781
+ " title='Baseline → SFT — Terminal Reward Comparison')\n",
782
+ " plot_stage_comparison(comparison, metric='verifier_pass_rate',\n",
783
+ " title='Baseline → SFT — Verifier Pass Rate')\n",
784
+ "else:\n",
785
+ " print(\"⚠️ No SFT eval results found.\")"
786
  ]
787
  },
788
  {
789
  "cell_type": "markdown",
790
+ "id": "d9671fe6",
791
  "metadata": {},
792
  "source": [
793
  "## 8️⃣ RLVR Training (Rejection-Sampling FT)\n",
 
799
  "4. SFT on that curated set\n",
800
  "5. Repeating for multiple outer rounds\n",
801
  "\n",
802
+ "This avoids the zero-variance gradient problem seen with GRPO on CORP-ENV.\n",
803
+ "\n",
804
+ "> ⚡ **T4 Note**: Using `--fp16` and reduced `--n-samples` / `--max-prompts` to fit in 16 GB VRAM."
805
  ]
806
  },
807
  {
808
  "cell_type": "code",
809
  "execution_count": null,
810
+ "id": "5be0f8be",
811
  "metadata": {},
812
  "outputs": [],
813
  "source": [
814
+ "# Clear GPU memory\n",
815
+ "gc.collect()\n",
816
+ "torch.cuda.empty_cache()\n",
817
+ "gpu_status()\n",
818
+ "\n",
819
+ "print(f\"\\n🚀 Starting RLVR training ({RLVR_ROUNDS} rounds, {RLVR_N_SAMPLES} samples/prompt)...\\n\")\n",
820
+ "\n",
821
  "!python training/train_rlvr.py \\\n",
822
  " --model {BASE_MODEL} \\\n",
823
  " --adapter outputs/sft_adapter \\\n",
 
827
  " --n-samples {RLVR_N_SAMPLES} \\\n",
828
  " --temperature {RLVR_TEMPERATURE} \\\n",
829
  " --max-prompts {RLVR_MAX_PROMPTS} \\\n",
830
+ " --max-prompt-length {RLVR_MAX_PROMPT_LEN} \\\n",
831
+ " --max-completion-length {RLVR_MAX_COMPLETION_LEN} \\\n",
832
  " --strict-json \\\n",
833
  " --use-stub-workers \\\n",
834
  " --disable-llm-judge \\\n",
835
+ " --stats-file results/runs/rlvr_qwen2.5_3b_stats.jsonl \\\n",
836
+ " {FP16_FLAG}\n",
837
+ "\n",
838
+ "gpu_status()\n",
839
+ "print(\"\\n✅ RLVR training complete!\")"
840
+ ]
841
+ },
842
+ {
843
+ "cell_type": "code",
844
+ "execution_count": null,
845
+ "id": "f71e3401",
846
+ "metadata": {},
847
+ "outputs": [],
848
+ "source": [
849
+ "# 📊 Visualize RLVR training progress\n",
850
+ "display(Markdown(\"## 📊 RLVR Training Progress\"))\n",
851
+ "plot_rlvr_stats(\"results/runs/rlvr_qwen2.5_3b_stats.jsonl\")\n",
852
+ "\n",
853
+ "# Check adapter files\n",
854
+ "rlvr_dir = Path(\"outputs/rlvr_adapter\")\n",
855
+ "if rlvr_dir.exists():\n",
856
+ " files = list(rlvr_dir.glob(\"*\"))\n",
857
+ " total_mb = sum(f.stat().st_size for f in files if f.is_file()) / 1e6\n",
858
+ " print(f\"\\n💾 RLVR adapter saved: {len(files)} files, {total_mb:.1f} MB total\")"
859
  ]
860
  },
861
  {
862
  "cell_type": "markdown",
863
+ "id": "32503cf5",
864
  "metadata": {},
865
  "source": [
866
  "## 9️⃣ Evaluate RLVR Adapter"
 
869
  {
870
  "cell_type": "code",
871
  "execution_count": null,
872
+ "id": "a756f408",
873
  "metadata": {},
874
  "outputs": [],
875
  "source": [
876
+ "# Clear GPU memory\n",
877
+ "gc.collect()\n",
878
+ "torch.cuda.empty_cache()\n",
879
+ "gpu_status()\n",
880
+ "\n",
881
  "!python eval.py \\\n",
882
  " --policy hf \\\n",
883
  " --label rlvr \\\n",
884
  " --model {BASE_MODEL} \\\n",
885
  " --adapter outputs/rlvr_adapter \\\n",
886
  " --episodes {EVAL_EPISODES} \\\n",
887
+ " --max-steps {EVAL_MAX_STEPS}\n",
888
+ "\n",
889
+ "gpu_status()"
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "execution_count": null,
895
+ "id": "daf5526c",
896
+ "metadata": {},
897
+ "outputs": [],
898
+ "source": [
899
+ "# 📊 Visualize RLVR evaluation results\n",
900
+ "display(Markdown(\"## 📊 RLVR Evaluation Results\"))\n",
901
+ "\n",
902
+ "rlvr_rows = [r for r in load_eval_jsonl(\"results/runs\") if r.get('model_stage') == 'rlvr']\n",
903
+ "if rlvr_rows:\n",
904
+ " plot_eval_dashboard(rlvr_rows, title=\"RLVR Adapter Evaluation\")\n",
905
+ " plot_reward_traces(rlvr_rows, title=\"RLVR Reward Traces\")\n",
906
+ " ALL_EVALS['rlvr'] = rlvr_rows\n",
907
+ "\n",
908
+ " # Compare SFT → RLVR\n",
909
+ " display(Markdown(\"### 📈 Improvement: SFT → RLVR\"))\n",
910
+ " if 'sft' in ALL_EVALS:\n",
911
+ " plot_stage_comparison(\n",
912
+ " {'sft': ALL_EVALS['sft'], 'rlvr': rlvr_rows},\n",
913
+ " metric='terminal_reward',\n",
914
+ " title='SFT → RLVR — Terminal Reward'\n",
915
+ " )\n",
916
+ "else:\n",
917
+ " print(\"⚠️ No RLVR eval results found.\")"
918
  ]
919
  },
920
  {
921
  "cell_type": "markdown",
922
+ "id": "e96ce765",
923
+ "metadata": {},
924
+ "source": [
925
+ "## 📊 Final Comparison: All Model Stages\n",
926
+ "\n",
927
+ "Side-by-side comparison of all evaluated model stages."
928
+ ]
929
+ },
930
+ {
931
+ "cell_type": "code",
932
+ "execution_count": null,
933
+ "id": "d92ae920",
934
  "metadata": {},
935
+ "outputs": [],
936
  "source": [
937
+ "display(Markdown(\"## 📊 Full Pipeline Comparison: Baseline → Oracle → SFT → RLVR\"))\n",
938
+ "\n",
939
+ "if ALL_EVALS:\n",
940
+ " # Terminal Reward comparison\n",
941
+ " plot_stage_comparison(ALL_EVALS, metric='terminal_reward',\n",
942
+ " title='Terminal Reward — All Model Stages')\n",
943
  "\n",
944
+ " # Verifier Pass Rate comparison\n",
945
+ " plot_stage_comparison(ALL_EVALS, metric='verifier_pass_rate',\n",
946
+ " title='Verifier Pass Rate All Model Stages')\n",
947
+ "\n",
948
+ " # Build final comparison table\n",
949
+ " display(Markdown(\"### 📋 Final Results Table\"))\n",
950
+ " header = \"| Stage | Task | Terminal Reward | Verifier Pass | Success Rate |\"\n",
951
+ " sep = \"|-------|------|---------------|--------------|-------------|\"\n",
952
+ " lines = [header, sep]\n",
953
+ " for stage_name, stage_rows in ALL_EVALS.items():\n",
954
+ " by_task = defaultdict(list)\n",
955
+ " for r in stage_rows:\n",
956
+ " by_task[r['task_id']].append(r)\n",
957
+ " for task_id in sorted(by_task.keys()):\n",
958
+ " task_rows = by_task[task_id]\n",
959
+ " avg_r = np.mean([r['terminal_reward'] for r in task_rows])\n",
960
+ " avg_p = np.mean([r['verifier_pass_rate'] for r in task_rows])\n",
961
+ " succ = np.mean([1 if r.get('success') else 0 for r in task_rows])\n",
962
+ " lines.append(f\"| {stage_name.upper()} | {TASK_SHORT.get(task_id, task_id)} | {avg_r:.3f} | {avg_p:.3f} | {succ:.0%} |\")\n",
963
+ " display(Markdown('\\n'.join(lines)))\n",
964
+ "else:\n",
965
+ " print(\"⚠️ No evaluation data collected. Run the evaluation cells above.\")"
966
  ]
967
  },
968
  {
969
  "cell_type": "code",
970
  "execution_count": null,
971
+ "id": "b37b7da9",
972
  "metadata": {},
973
  "outputs": [],
974
  "source": [
975
+ "# Also generate plots via plot_results.py for file-based output\n",
976
  "!python plot_results.py \\\n",
977
  " --inputs results/runs \\\n",
978
+ " --output-dir results/model_compare_qwen25_3b"
979
  ]
980
  },
981
  {
982
  "cell_type": "code",
983
  "execution_count": null,
984
+ "id": "3313ec66",
985
  "metadata": {},
986
  "outputs": [],
987
  "source": [
988
  "from IPython.display import Image, display, Markdown\n",
 
989
  "\n",
990
+ "plot_dir = Path(\"results/model_compare_qwen25_3b\")\n",
991
  "if not plot_dir.exists():\n",
992
  " plot_dir = Path(\"results/model_compare_qwen25_fresh_no_grpo_ep5rlvr\")\n",
993
  "\n",
994
+ "if plot_dir.exists():\n",
995
+ " for png in sorted(plot_dir.glob(\"*.png\")):\n",
996
+ " display(Markdown(f\"### {png.stem.replace('_', ' ').title()}\"))\n",
997
+ " display(Image(filename=str(png), width=800))\n",
998
  "\n",
999
+ " # Show summary table\n",
1000
+ " summary_md = plot_dir / \"comparison_summary.md\"\n",
1001
+ " if summary_md.exists():\n",
1002
+ " display(Markdown(summary_md.read_text()))\n",
1003
+ "else:\n",
1004
+ " print(\"⚠️ No plot directory found.\")"
1005
  ]
1006
  },
1007
  {
1008
  "cell_type": "markdown",
1009
+ "id": "a638d546",
1010
  "metadata": {},
1011
  "source": [
1012
  "## 📋 Results Summary\n",
1013
  "\n",
1014
+ "Expected progression for Qwen 2.5-3B-Instruct on CORP-ENV:\n",
1015
  "\n",
1016
  "| Stage | E1 Terminal Reward | M1 Terminal Reward | H1 Terminal Reward | M1 Success |\n",
1017
  "|-------|-------------------|-------------------|-------------------|------------|\n",
 
1020
  "| SFT | 0.910 | 0.943 | 0.889 | 100% |\n",
1021
  "| RLVR | 0.910 | 0.932 | 0.779 | 80% |\n",
1022
  "\n",
1023
+ "> **Key takeaway**: SFT dramatically improves M1 (budget reallocation) from 0% to 100% success rate. RLVR maintains strong performance while reducing reliance on fixed trajectories.\n",
1024
+ "\n",
1025
+ "> **T4 Note**: Results may differ slightly on T4 due to FP16 precision (vs BF16) and reduced RLVR sampling. For best results, use the full hyperparameters on an A100/H100."
1026
  ]
1027
  }
1028
  ],
 
1045
  },
1046
  "nbformat": 4,
1047
  "nbformat_minor": 5
1048
+ }
training/train_rlvr.py CHANGED
@@ -236,6 +236,7 @@ def sft_on_winners(
236
  epochs: float,
237
  max_steps: int,
238
  max_seq_length: int,
 
239
  ) -> None:
240
  """Run a single SFT pass over the curated (prompt, best_completion) set."""
241
  from datasets import Dataset
@@ -267,7 +268,8 @@ def sft_on_winners(
267
  "save_steps": 10_000,
268
  "save_total_limit": 1,
269
  "optim": "adamw_8bit",
270
- "bf16": True,
 
271
  "report_to": "none",
272
  "dataset_text_field": "text",
273
  "push_to_hub": False,
@@ -376,6 +378,11 @@ def main() -> None:
376
  action="store_true",
377
  help="Disable LLM judge scoring for deterministic verifier-only runs.",
378
  )
 
 
 
 
 
379
  args = parser.parse_args()
380
 
381
  if args.use_stub_workers:
@@ -405,10 +412,11 @@ def main() -> None:
405
  print(f"Built {len(full_rows)} prompts from {args.examples}")
406
 
407
  max_seq_len = args.max_prompt_length + args.max_completion_length
 
408
  model, tokenizer = FastLanguageModel.from_pretrained(
409
  model_name=args.model,
410
  max_seq_length=max_seq_len,
411
- dtype=torch.bfloat16,
412
  load_in_4bit=True,
413
  )
414
  if getattr(tokenizer, "pad_token", None) is None and getattr(
@@ -438,9 +446,10 @@ def main() -> None:
438
  random_state=args.seed,
439
  )
440
 
 
441
  for p in model.parameters():
442
  if p.requires_grad and p.dtype == torch.float32:
443
- p.data = p.data.to(torch.bfloat16)
444
 
445
  stats_path = Path(args.stats_file) if args.stats_file else None
446
  if stats_path:
@@ -490,6 +499,7 @@ def main() -> None:
490
  epochs=args.inner_epochs,
491
  max_steps=args.inner_max_steps,
492
  max_seq_length=max_seq_len,
 
493
  )
494
 
495
  Path(args.output).mkdir(parents=True, exist_ok=True)
 
236
  epochs: float,
237
  max_steps: int,
238
  max_seq_length: int,
239
+ use_fp16: bool = False,
240
  ) -> None:
241
  """Run a single SFT pass over the curated (prompt, best_completion) set."""
242
  from datasets import Dataset
 
268
  "save_steps": 10_000,
269
  "save_total_limit": 1,
270
  "optim": "adamw_8bit",
271
+ "bf16": (not use_fp16) and torch.cuda.is_available(),
272
+ "fp16": use_fp16 and torch.cuda.is_available(),
273
  "report_to": "none",
274
  "dataset_text_field": "text",
275
  "push_to_hub": False,
 
378
  action="store_true",
379
  help="Disable LLM judge scoring for deterministic verifier-only runs.",
380
  )
381
+ parser.add_argument(
382
+ "--fp16",
383
+ action="store_true",
384
+ help="Use fp16 instead of bf16 (required for T4 GPUs which lack bf16 support).",
385
+ )
386
  args = parser.parse_args()
387
 
388
  if args.use_stub_workers:
 
412
  print(f"Built {len(full_rows)} prompts from {args.examples}")
413
 
414
  max_seq_len = args.max_prompt_length + args.max_completion_length
415
+ load_dtype = torch.float16 if args.fp16 else torch.bfloat16
416
  model, tokenizer = FastLanguageModel.from_pretrained(
417
  model_name=args.model,
418
  max_seq_length=max_seq_len,
419
+ dtype=load_dtype,
420
  load_in_4bit=True,
421
  )
422
  if getattr(tokenizer, "pad_token", None) is None and getattr(
 
446
  random_state=args.seed,
447
  )
448
 
449
+ cast_dtype = torch.float16 if args.fp16 else torch.bfloat16
450
  for p in model.parameters():
451
  if p.requires_grad and p.dtype == torch.float32:
452
+ p.data = p.data.to(cast_dtype)
453
 
454
  stats_path = Path(args.stats_file) if args.stats_file else None
455
  if stats_path:
 
499
  epochs=args.inner_epochs,
500
  max_steps=args.inner_max_steps,
501
  max_seq_length=max_seq_len,
502
+ use_fp16=args.fp16,
503
  )
504
 
505
  Path(args.output).mkdir(parents=True, exist_ok=True)
training/train_sft.py CHANGED
@@ -210,10 +210,11 @@ def main() -> None:
210
  if args.dataset_num_proc == 0 and "dataset_num_proc" in allowed:
211
  args = argparse.Namespace(**{**vars(args), "dataset_num_proc": None})
212
 
 
213
  model, tokenizer = FastLanguageModel.from_pretrained(
214
  model_name=args.model,
215
  max_seq_length=args.max_seq_length,
216
- dtype=None,
217
  load_in_4bit=True,
218
  )
219
  if getattr(tokenizer, "pad_token", None) is None and getattr(
 
210
  if args.dataset_num_proc == 0 and "dataset_num_proc" in allowed:
211
  args = argparse.Namespace(**{**vars(args), "dataset_num_proc": None})
212
 
213
+ load_dtype = torch.float16 if args.fp16 else None
214
  model, tokenizer = FastLanguageModel.from_pretrained(
215
  model_name=args.model,
216
  max_seq_length=args.max_seq_length,
217
+ dtype=load_dtype,
218
  load_in_4bit=True,
219
  )
220
  if getattr(tokenizer, "pad_token", None) is None and getattr(