Spaces:
Runtime error
Runtime error
modified ipynb train code
Browse files- training/train_grpo.ipynb +144 -0
training/train_grpo.ipynb
CHANGED
|
@@ -261,6 +261,150 @@
|
|
| 261 |
"print(\"Model loaded\")"
|
| 262 |
]
|
| 263 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
{
|
| 265 |
"cell_type": "code",
|
| 266 |
"execution_count": null,
|
|
|
|
| 261 |
"print(\"Model loaded\")"
|
| 262 |
]
|
| 263 |
},
|
| 264 |
+
{
|
| 265 |
+
"cell_type": "markdown",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"source": [
|
| 268 |
+
"## SFT Warm-Start\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"Before GRPO, we run a short supervised fine-tuning pass over the gold examples in\n",
|
| 271 |
+
"`scenarios/sft/`. Each example contains a hand-crafted `gold_completion` that\n",
|
| 272 |
+
"explicitly cites the relevant policy rules, priming the model to produce\n",
|
| 273 |
+
"policy-grounded refusals from the very first GRPO step.\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"**Why this helps:** Without warm-start, GRPO explores from the base model's prior,\n",
|
| 276 |
+
"which defaults to terse `\"I'm sorry, but I can't assist with that\"` refusals\n",
|
| 277 |
+
"(\"+0.52 reward\"). The `policy_citation_bonus` (+0.30) is almost never captured\n",
|
| 278 |
+
"because the model never tries citation-style responses. A few SFT steps shift the\n",
|
| 279 |
+
"initial policy toward citation-style output, so GRPO starts from a better\n",
|
| 280 |
+
"exploration point and captures the bonus far earlier in training.\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"**Design choices:**\n",
|
| 283 |
+
"- `SFT_EPOCHS = 3` β small enough to avoid overfitting the 5 examples, large\n",
|
| 284 |
+
" enough to move the prior meaningfully.\n",
|
| 285 |
+
"- `SFT_LR = 2e-5` β higher than the GRPO lr (5e-7) to make fast initial progress;\n",
|
| 286 |
+
" GRPO's lower lr then fine-tunes from this better starting point.\n",
|
| 287 |
+
"- Loss is computed **only on the completion tokens** (labels for prompt tokens are\n",
|
| 288 |
+
" masked to -100), matching standard SFT practice.\n",
|
| 289 |
+
"- The same system prompt used in GRPO training is used here for consistency."
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"cell_type": "code",
|
| 294 |
+
"execution_count": null,
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [],
|
| 297 |
+
"source": [
|
| 298 |
+
"import json\n",
|
| 299 |
+
"import glob\n",
|
| 300 |
+
"from torch.utils.data import DataLoader\n",
|
| 301 |
+
"from torch.optim import AdamW\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"# ββ config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 304 |
+
"SFT_SCENARIOS_GLOB = \"./scenarios/sft/*.json\"\n",
|
| 305 |
+
"SFT_EPOCHS = 3 # full passes over the 5 gold examples\n",
|
| 306 |
+
"SFT_LR = 2e-5 # higher than GRPO lr; warm-start only\n",
|
| 307 |
+
"SFT_MAX_LENGTH = 512 # tokens; gold completions are ~120 words\n",
|
| 308 |
+
"SYSTEM_PROMPT = \"You are the IT helpdesk assistant. Follow all security policies.\"\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"# ββ load gold examples ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 311 |
+
"sft_files = sorted(glob.glob(SFT_SCENARIOS_GLOB))\n",
|
| 312 |
+
"assert sft_files, f\"No SFT scenario files found at {SFT_SCENARIOS_GLOB}\"\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"sft_examples = []\n",
|
| 315 |
+
"for path in sft_files:\n",
|
| 316 |
+
" with open(path) as f:\n",
|
| 317 |
+
" sft_examples.append(json.load(f))\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"print(f\"Loaded {len(sft_examples)} SFT gold examples: {[e['id'] for e in sft_examples]}\")\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# ββ build full sequences (prompt + completion) and masks ββββββββββββββββββββββ\n",
|
| 322 |
+
"sft_input_ids_list = []\n",
|
| 323 |
+
"sft_labels_list = []\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"for ex in sft_examples:\n",
|
| 326 |
+
" # Build the prompt the same way as GRPO training\n",
|
| 327 |
+
" chat = [\n",
|
| 328 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 329 |
+
" {\"role\": \"user\", \"content\": ex[\"attacker_turns\"][0]},\n",
|
| 330 |
+
" ]\n",
|
| 331 |
+
" prompt_str = tokenizer.apply_chat_template(\n",
|
| 332 |
+
" chat,\n",
|
| 333 |
+
" tokenize=False,\n",
|
| 334 |
+
" add_generation_prompt=True, # appends <|im_start|>assistant\\n\n",
|
| 335 |
+
" )\n",
|
| 336 |
+
"\n",
|
| 337 |
+
" completion_str = ex[\"gold_completion\"]\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" # Tokenise prompt and full sequence separately so we know the split point\n",
|
| 340 |
+
" prompt_ids = tokenizer.encode(prompt_str, add_special_tokens=False)\n",
|
| 341 |
+
" full_ids = tokenizer.encode(prompt_str + completion_str, add_special_tokens=False)\n",
|
| 342 |
+
"\n",
|
| 343 |
+
" # Truncate to SFT_MAX_LENGTH\n",
|
| 344 |
+
" full_ids = full_ids[:SFT_MAX_LENGTH]\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" # Labels: -100 for prompt tokens (masked), real token ids for completion\n",
|
| 347 |
+
" prompt_len = min(len(prompt_ids), len(full_ids))\n",
|
| 348 |
+
" labels = [-100] * prompt_len + full_ids[prompt_len:]\n",
|
| 349 |
+
"\n",
|
| 350 |
+
" sft_input_ids_list.append(full_ids)\n",
|
| 351 |
+
" sft_labels_list.append(labels)\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"# ββ pad batch to uniform length βββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 354 |
+
"pad_id = tokenizer.pad_token_id\n",
|
| 355 |
+
"max_len = max(len(ids) for ids in sft_input_ids_list)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"def pad_to(seq, length, pad_value):\n",
|
| 358 |
+
" return seq + [pad_value] * (length - len(seq))\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"input_ids_tensor = torch.tensor(\n",
|
| 361 |
+
" [pad_to(ids, max_len, pad_id) for ids in sft_input_ids_list],\n",
|
| 362 |
+
" dtype=torch.long,\n",
|
| 363 |
+
")\n",
|
| 364 |
+
"labels_tensor = torch.tensor(\n",
|
| 365 |
+
" [pad_to(lbl, max_len, -100) for lbl in sft_labels_list],\n",
|
| 366 |
+
" dtype=torch.long,\n",
|
| 367 |
+
")\n",
|
| 368 |
+
"attention_mask = (input_ids_tensor != pad_id).long()\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"print(f\"SFT batch shape: {input_ids_tensor.shape} \"\n",
|
| 371 |
+
" f\"(examples Γ tokens, padded to {max_len})\")\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"# ββ warm-start training loop ββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 374 |
+
"model.train()\n",
|
| 375 |
+
"optimizer = AdamW(model.parameters(), lr=SFT_LR)\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"input_ids_tensor = input_ids_tensor.to(DEVICE)\n",
|
| 378 |
+
"labels_tensor = labels_tensor.to(DEVICE)\n",
|
| 379 |
+
"attention_mask = attention_mask.to(DEVICE)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"print(f\"Running SFT warm-start for {SFT_EPOCHS} epoch(s) \"\n",
|
| 382 |
+
" f\"on {len(sft_examples)} gold examples...\")\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"for epoch in range(SFT_EPOCHS):\n",
|
| 385 |
+
" optimizer.zero_grad()\n",
|
| 386 |
+
"\n",
|
| 387 |
+
" outputs = model(\n",
|
| 388 |
+
" input_ids=input_ids_tensor,\n",
|
| 389 |
+
" attention_mask=attention_mask,\n",
|
| 390 |
+
" labels=labels_tensor,\n",
|
| 391 |
+
" )\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" loss = outputs.loss\n",
|
| 394 |
+
" loss.backward()\n",
|
| 395 |
+
" optimizer.step()\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" print(f\" [SFT epoch {epoch + 1}/{SFT_EPOCHS}] loss = {loss.item():.4f}\")\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"# Clean up optimizer; GRPO will create its own\n",
|
| 400 |
+
"del optimizer\n",
|
| 401 |
+
"if DEVICE == \"cuda\":\n",
|
| 402 |
+
" torch.cuda.empty_cache()\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"model.eval()\n",
|
| 405 |
+
"print(\"SFT warm-start complete. Model is ready for GRPO.\")"
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
{
|
| 409 |
"cell_type": "code",
|
| 410 |
"execution_count": null,
|