Vighnesh commited on
Commit
93164f3
Β·
1 Parent(s): 2e680c9

Sync notebook _local_reward: wire resolution_hint + classified_correctly into task3 reward, cls_credit into task2 step-1

Browse files
Files changed (1) hide show
  1. train_grpo.ipynb +1 -1
train_grpo.ipynb CHANGED
@@ -493,7 +493,7 @@
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
- "# -----------------------------------------------------------------\n# Reward functions β€” synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE β€” keep in sync with graders.py manually.\n# FALLBACK ONLY β€” if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords β€” 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST β€” runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n"
497
  ]
498
  },
499
  {
 
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
+ "# -----------------------------------------------------------------\n# Reward functions β€” synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE β€” keep in sync with graders.py manually.\n# FALLBACK ONLY β€” if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords β€” 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST β€” runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n\n\ndef _local_reward(completion, task_id, seed, step=0, cls_credit=0.0):\n \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n cat = action.get('category', '')\n reply = action.get('reply_text', '') or ''\n correct_cat = ticket['category']\n correct_action = ticket['correct_action']\n hint = ticket.get('resolution_hint', '')\n\n if task_id == 1:\n return _grade_task1(at, cat, correct_cat)\n elif task_id == 2:\n return _grade_task2(at, correct_action, step, cat, correct_cat,\n cls_credit=cls_credit)\n else: # task 3\n # step-1 rows are constructed with correct category hardcoded in prompt context\n # (see dataset builder β€” current_category=ticket['category'] always).\n # classified_correctly=True here reflects dataset construction, not agent behaviour.\n # Classification credit (0.20) is awarded for context consistency, not earned accuracy.\n classified_correctly = (step == 1) or (at == \"classify\" and cat == correct_cat)\n return _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=classified_correctly,\n resolution_hint=hint)\n\n\ndef env_reward_fn(prompts, completions, **kwargs):\n \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n task_ids = kwargs.get('task_id', [1] * len(completions))\n seeds = kwargs.get('seed', [42] * len(completions))\n steps = kwargs.get('step', [0] * len(completions))\n rewards = []\n for i, completion in enumerate(completions):\n tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n actual_seed = seed % 10000 if seed >= 10000 else seed\n # For Task 2 step-1, pass the classification credit earned at step-0.\n # Dataset builder hard-codes correct category at step-1 context,\n # so full classify credit (0.3) always applies for task2 step-1.\n cls_credit = 0.3 if (tid == 2 and step == 1) else 0.0\n r = _local_reward(completion, tid, actual_seed, step, cls_credit=cls_credit)\n r += _loop_penalty(step)\n rewards.append(r)\n return rewards\n\n\ndef format_reward_fn(prompts, completions, **kwargs):\n \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n rewards = []\n for completion in completions:\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n if at in ('classify', 'reply', 'escalate', 'close'):\n bonus = 0.15\n if at == 'classify' and action.get('category') in (\n 'billing', 'technical', 'account', 'general', 'refund'):\n bonus = 0.20\n rewards.append(bonus)\n else:\n rewards.append(-0.20)\n return rewards\n\n\nprint(\"_local_reward, env_reward_fn, format_reward_fn ready.\")\n"
497
  ]
498
  },
499
  {