Vighnesh commited on
Commit
2e680c9
·
1 Parent(s): 55ff252

Sync notebook Cell 7 graders with graders.py fixes #2 #3 #4 #5 — smoke test passes

Browse files
Files changed (1) hide show
  1. train_grpo.ipynb +1 -136
train_grpo.ipynb CHANGED
@@ -493,142 +493,7 @@
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
- "# ─────────────────────────────────────────────────────────────────\n",
497
- "# Reward functions — exact mirror of graders.py\n",
498
- "# grade_task1 / grade_task2 / grade_task3 / loop_penalty\n",
499
- "# ─────────────────────────────────────────────────────────────────\n",
500
- "\n",
501
- "# Partial-credit action pairs (from graders.py)\n",
502
- "_PARTIAL_CREDIT_PAIRS = {frozenset({'reply', 'escalate'})}\n",
503
- "\n",
504
- "# Keyword lists (from graders.py)\n",
505
- "_KEYWORD_REWARDS = {\n",
506
- " 'billing': ['refund', 'charge', 'invoice', 'payment', 'billing'],\n",
507
- " 'account': ['password', 'login', 'account', 'cancel', 'subscription'],\n",
508
- " 'technical': ['engineering', 'escalate', 'bug', 'crash', 'error', 'fix'],\n",
509
- " 'refund': ['refund', 'return', 'credit', 'process'],\n",
510
- " 'general': ['hours', 'contact', 'phone', 'information', 'help'],\n",
511
- "}\n",
512
- "\n",
513
- "def _reply_quality(reply_text, category):\n",
514
- " \"\"\"Exact copy of graders._reply_quality: 0.0–0.5 keyword score.\"\"\"\n",
515
- " if not reply_text: return 0.0\n",
516
- " hits = sum(1 for kw in _KEYWORD_REWARDS.get(category, []) if kw in reply_text.lower())\n",
517
- " return min(0.5, hits * 0.1)\n",
518
- "\n",
519
- "def _grade_task1(at, cat, correct_cat):\n",
520
- " \"\"\"Exact copy of graders.grade_task1.\"\"\"\n",
521
- " return 1.0 if (at == 'classify' and cat == correct_cat) else 0.0\n",
522
- "\n",
523
- "def _grade_task2(at, correct_action, step, cat, correct_cat):\n",
524
- " \"\"\"Exact copy of graders.grade_task2 + classify step.\"\"\"\n",
525
- " if step == 0:\n",
526
- " # classify step: partial credit for correct category\n",
527
- " if at == 'classify' and cat == correct_cat: return 0.3\n",
528
- " if at == 'classify': return 0.1\n",
529
- " return 0.0\n",
530
- " # action step\n",
531
- " if at == correct_action: return 1.0\n",
532
- " if frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS: return 0.5\n",
533
- " if at == 'close': return 0.0\n",
534
- " return 0.0\n",
535
- "\n",
536
- "def _grade_task3(at, cat, correct_cat, correct_action, reply, step, steps_taken=2, max_steps=5):\n",
537
- " \"\"\"Exact copy of graders.grade_task3.\"\"\"\n",
538
- " if step == 0:\n",
539
- " # classification step only\n",
540
- " return 0.20 if (at == 'classify' and cat == correct_cat) else 0.0\n",
541
- " # resolution step: 0.40 action + up to 0.50 reply + 0.15 efficiency\n",
542
- " score = 0.0\n",
543
- " classified_correctly = True # step-1 means step-0 already happened\n",
544
- " score += 0.20 # classification credit carried from step 0\n",
545
- " action_correct = (at == correct_action)\n",
546
- " action_partial = (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n",
547
- " if action_correct: score += 0.40\n",
548
- " elif action_partial: score += 0.20\n",
549
- " score += _reply_quality(reply, cat) # 0.0–0.5\n",
550
- " # efficiency bonus (assume 2 steps taken for step-1 samples)\n",
551
- " resolved = action_correct or action_partial\n",
552
- " if resolved and steps_taken <= max_steps:\n",
553
- " efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n",
554
- " score += 0.15 * efficiency\n",
555
- " return round(min(1.0, score), 4)\n",
556
- "\n",
557
- "def _loop_penalty(step_count, max_steps=10):\n",
558
- " \"\"\"Exact copy of graders.loop_penalty.\"\"\"\n",
559
- " return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n",
560
- "\n",
561
- "def _local_reward(completion, task_id, seed, step=0):\n",
562
- " \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n",
563
- " ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
564
- " action = _safe_parse(completion)\n",
565
- " if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
566
- " at = action.get('action_type', '')\n",
567
- " cat = action.get('category', '')\n",
568
- " reply = action.get('reply_text', '') or ''\n",
569
- " correct_cat = ticket['category']\n",
570
- " correct_action = ticket['correct_action']\n",
571
- "\n",
572
- " if task_id == 1:\n",
573
- " return _grade_task1(at, cat, correct_cat)\n",
574
- " elif task_id == 2:\n",
575
- " return _grade_task2(at, correct_action, step, cat, correct_cat)\n",
576
- " else: # task 3\n",
577
- " return _grade_task3(at, cat, correct_cat, correct_action, reply, step)\n",
578
- "\n",
579
- "def env_reward_fn(prompts, completions, **kwargs):\n",
580
- " \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n",
581
- " task_ids = kwargs.get('task_id', [1] * len(completions))\n",
582
- " seeds = kwargs.get('seed', [42] * len(completions))\n",
583
- " steps = kwargs.get('step', [0] * len(completions))\n",
584
- " rewards = []\n",
585
- " for i, completion in enumerate(completions):\n",
586
- " tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n",
587
- " seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n",
588
- " step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n",
589
- " actual_seed = seed % 10000 if seed >= 10000 else seed\n",
590
- " r = _local_reward(completion, tid, actual_seed, step)\n",
591
- " # apply loop penalty if step is high\n",
592
- " r += _loop_penalty(step)\n",
593
- " rewards.append(r)\n",
594
- " return rewards\n",
595
- "\n",
596
- "def format_reward_fn(prompts, completions, **kwargs):\n",
597
- " \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n",
598
- " rewards = []\n",
599
- " for completion in completions:\n",
600
- " action = _safe_parse(completion)\n",
601
- " if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
602
- " at = action.get('action_type', '')\n",
603
- " if at in ('classify', 'reply', 'escalate', 'close'):\n",
604
- " bonus = 0.15\n",
605
- " if at == 'classify' and action.get('category') in ('billing','technical','account','general','refund'):\n",
606
- " bonus = 0.20\n",
607
- " rewards.append(bonus)\n",
608
- " else:\n",
609
- " rewards.append(-0.20)\n",
610
- " return rewards\n",
611
- "\n",
612
- "# Print ticket map\n",
613
- "print('Reward functions synced to graders.py')\n",
614
- "print('Ticket map (seed % len):')\n",
615
- "for _i in range(6):\n",
616
- " _tt = ALL_TICKETS[_i]\n",
617
- " print(f' [{_i}] {_tt[\"id\"]} cat={_tt[\"category\"]} action={_tt[\"correct_action\"]}')\n",
618
- "\n",
619
- "# Sanity: seed=0->B001(billing,reply), seed=22->T001(technical,escalate)\n",
620
- "_t0 = ALL_TICKETS[0] # B001 billing reply\n",
621
- "_t22 = ALL_TICKETS[22] # T001 technical escalate\n",
622
- "r1 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 1, 0, 0)\n",
623
- "r2 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 2, 0, 0)\n",
624
- "r3 = _local_reward(json.dumps({'action_type':'escalate'}), 2, 0, 1)\n",
625
- "r4 = _local_reward(json.dumps({'action_type':_t22['correct_action'],'reply_text':'escalating this crash bug error to engineering team for a fix'}), 3, 22, 1)\n",
626
- "r5 = format_reward_fn(prompts=['x'], completions=[json.dumps({'action_type':'respond'})])[0]\n",
627
- "print(f'task1 correct classify: {r1} (expect 1.0)')\n",
628
- "print(f'task2 step0 correct classify: {r2} (expect 0.3)')\n",
629
- "print(f'task2 step1 partial escalate: {r3} (expect 0.5)')\n",
630
- "print(f'task3 step1 correct+keywords: {r4} (expect 0.87+)')\n",
631
- "print(f'bad format penalty: {r5} (expect -0.2)')\n"
632
  ]
633
  },
634
  {
 
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
+ "# -----------------------------------------------------------------\n# Reward functions — synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE — keep in sync with graders.py manually.\n# FALLBACK ONLY — if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords — 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST — runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  ]
498
  },
499
  {