Puskara commited on
Commit
dcd3fb7
Β·
1 Parent(s): cc0e7f6

modified ipynb train code

Browse files
Files changed (1) hide show
  1. training/train_grpo.ipynb +144 -0
training/train_grpo.ipynb CHANGED
@@ -261,6 +261,150 @@
261
  "print(\"Model loaded\")"
262
  ]
263
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  {
265
  "cell_type": "code",
266
  "execution_count": null,
 
261
  "print(\"Model loaded\")"
262
  ]
263
  },
264
+ {
265
+ "cell_type": "markdown",
266
+ "metadata": {},
267
+ "source": [
268
+ "## SFT Warm-Start\n",
269
+ "\n",
270
+ "Before GRPO, we run a short supervised fine-tuning pass over the gold examples in\n",
271
+ "`scenarios/sft/`. Each example contains a hand-crafted `gold_completion` that\n",
272
+ "explicitly cites the relevant policy rules, priming the model to produce\n",
273
+ "policy-grounded refusals from the very first GRPO step.\n",
274
+ "\n",
275
+ "**Why this helps:** Without warm-start, GRPO explores from the base model's prior,\n",
276
+ "which defaults to terse `\"I'm sorry, but I can't assist with that\"` refusals\n",
277
+ "(\"+0.52 reward\"). The `policy_citation_bonus` (+0.30) is almost never captured\n",
278
+ "because the model never tries citation-style responses. A few SFT steps shift the\n",
279
+ "initial policy toward citation-style output, so GRPO starts from a better\n",
280
+ "exploration point and captures the bonus far earlier in training.\n",
281
+ "\n",
282
+ "**Design choices:**\n",
283
+ "- `SFT_EPOCHS = 3` β€” small enough to avoid overfitting the 5 examples, large\n",
284
+ " enough to move the prior meaningfully.\n",
285
+ "- `SFT_LR = 2e-5` β€” higher than the GRPO lr (5e-7) to make fast initial progress;\n",
286
+ " GRPO's lower lr then fine-tunes from this better starting point.\n",
287
+ "- Loss is computed **only on the completion tokens** (labels for prompt tokens are\n",
288
+ " masked to -100), matching standard SFT practice.\n",
289
+ "- The same system prompt used in GRPO training is used here for consistency."
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "import json\n",
299
+ "import glob\n",
300
+ "from torch.utils.data import DataLoader\n",
301
+ "from torch.optim import AdamW\n",
302
+ "\n",
303
+ "# ── config ────────────────────────────────────────────────────────────────────\n",
304
+ "SFT_SCENARIOS_GLOB = \"./scenarios/sft/*.json\"\n",
305
+ "SFT_EPOCHS = 3 # full passes over the 5 gold examples\n",
306
+ "SFT_LR = 2e-5 # higher than GRPO lr; warm-start only\n",
307
+ "SFT_MAX_LENGTH = 512 # tokens; gold completions are ~120 words\n",
308
+ "SYSTEM_PROMPT = \"You are the IT helpdesk assistant. Follow all security policies.\"\n",
309
+ "\n",
310
+ "# ── load gold examples ────────────────────────────────────────────────────────\n",
311
+ "sft_files = sorted(glob.glob(SFT_SCENARIOS_GLOB))\n",
312
+ "assert sft_files, f\"No SFT scenario files found at {SFT_SCENARIOS_GLOB}\"\n",
313
+ "\n",
314
+ "sft_examples = []\n",
315
+ "for path in sft_files:\n",
316
+ " with open(path) as f:\n",
317
+ " sft_examples.append(json.load(f))\n",
318
+ "\n",
319
+ "print(f\"Loaded {len(sft_examples)} SFT gold examples: {[e['id'] for e in sft_examples]}\")\n",
320
+ "\n",
321
+ "# ── build full sequences (prompt + completion) and masks ──────────────────────\n",
322
+ "sft_input_ids_list = []\n",
323
+ "sft_labels_list = []\n",
324
+ "\n",
325
+ "for ex in sft_examples:\n",
326
+ " # Build the prompt the same way as GRPO training\n",
327
+ " chat = [\n",
328
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
329
+ " {\"role\": \"user\", \"content\": ex[\"attacker_turns\"][0]},\n",
330
+ " ]\n",
331
+ " prompt_str = tokenizer.apply_chat_template(\n",
332
+ " chat,\n",
333
+ " tokenize=False,\n",
334
+ " add_generation_prompt=True, # appends <|im_start|>assistant\\n\n",
335
+ " )\n",
336
+ "\n",
337
+ " completion_str = ex[\"gold_completion\"]\n",
338
+ "\n",
339
+ " # Tokenise prompt and full sequence separately so we know the split point\n",
340
+ " prompt_ids = tokenizer.encode(prompt_str, add_special_tokens=False)\n",
341
+ " full_ids = tokenizer.encode(prompt_str + completion_str, add_special_tokens=False)\n",
342
+ "\n",
343
+ " # Truncate to SFT_MAX_LENGTH\n",
344
+ " full_ids = full_ids[:SFT_MAX_LENGTH]\n",
345
+ "\n",
346
+ " # Labels: -100 for prompt tokens (masked), real token ids for completion\n",
347
+ " prompt_len = min(len(prompt_ids), len(full_ids))\n",
348
+ " labels = [-100] * prompt_len + full_ids[prompt_len:]\n",
349
+ "\n",
350
+ " sft_input_ids_list.append(full_ids)\n",
351
+ " sft_labels_list.append(labels)\n",
352
+ "\n",
353
+ "# ── pad batch to uniform length ───────────────────────────────────────────────\n",
354
+ "pad_id = tokenizer.pad_token_id\n",
355
+ "max_len = max(len(ids) for ids in sft_input_ids_list)\n",
356
+ "\n",
357
+ "def pad_to(seq, length, pad_value):\n",
358
+ " return seq + [pad_value] * (length - len(seq))\n",
359
+ "\n",
360
+ "input_ids_tensor = torch.tensor(\n",
361
+ " [pad_to(ids, max_len, pad_id) for ids in sft_input_ids_list],\n",
362
+ " dtype=torch.long,\n",
363
+ ")\n",
364
+ "labels_tensor = torch.tensor(\n",
365
+ " [pad_to(lbl, max_len, -100) for lbl in sft_labels_list],\n",
366
+ " dtype=torch.long,\n",
367
+ ")\n",
368
+ "attention_mask = (input_ids_tensor != pad_id).long()\n",
369
+ "\n",
370
+ "print(f\"SFT batch shape: {input_ids_tensor.shape} \"\n",
371
+ " f\"(examples Γ— tokens, padded to {max_len})\")\n",
372
+ "\n",
373
+ "# ── warm-start training loop ──────────────────────────────────────────────────\n",
374
+ "model.train()\n",
375
+ "optimizer = AdamW(model.parameters(), lr=SFT_LR)\n",
376
+ "\n",
377
+ "input_ids_tensor = input_ids_tensor.to(DEVICE)\n",
378
+ "labels_tensor = labels_tensor.to(DEVICE)\n",
379
+ "attention_mask = attention_mask.to(DEVICE)\n",
380
+ "\n",
381
+ "print(f\"Running SFT warm-start for {SFT_EPOCHS} epoch(s) \"\n",
382
+ " f\"on {len(sft_examples)} gold examples...\")\n",
383
+ "\n",
384
+ "for epoch in range(SFT_EPOCHS):\n",
385
+ " optimizer.zero_grad()\n",
386
+ "\n",
387
+ " outputs = model(\n",
388
+ " input_ids=input_ids_tensor,\n",
389
+ " attention_mask=attention_mask,\n",
390
+ " labels=labels_tensor,\n",
391
+ " )\n",
392
+ "\n",
393
+ " loss = outputs.loss\n",
394
+ " loss.backward()\n",
395
+ " optimizer.step()\n",
396
+ "\n",
397
+ " print(f\" [SFT epoch {epoch + 1}/{SFT_EPOCHS}] loss = {loss.item():.4f}\")\n",
398
+ "\n",
399
+ "# Clean up optimizer; GRPO will create its own\n",
400
+ "del optimizer\n",
401
+ "if DEVICE == \"cuda\":\n",
402
+ " torch.cuda.empty_cache()\n",
403
+ "\n",
404
+ "model.eval()\n",
405
+ "print(\"SFT warm-start complete. Model is ready for GRPO.\")"
406
+ ]
407
+ },
408
  {
409
  "cell_type": "code",
410
  "execution_count": null,