Spaces:
Sleeping
Sleeping
Vighnesh commited on
Commit ·
2e680c9
1
Parent(s): 55ff252
Sync notebook Cell 7 graders with graders.py fixes #2 #3 #4 #5 — smoke test passes
Browse files- train_grpo.ipynb +1 -136
train_grpo.ipynb
CHANGED
|
@@ -493,142 +493,7 @@
|
|
| 493 |
"metadata": {},
|
| 494 |
"outputs": [],
|
| 495 |
"source": [
|
| 496 |
-
"#
|
| 497 |
-
"# Reward functions — exact mirror of graders.py\n",
|
| 498 |
-
"# grade_task1 / grade_task2 / grade_task3 / loop_penalty\n",
|
| 499 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 500 |
-
"\n",
|
| 501 |
-
"# Partial-credit action pairs (from graders.py)\n",
|
| 502 |
-
"_PARTIAL_CREDIT_PAIRS = {frozenset({'reply', 'escalate'})}\n",
|
| 503 |
-
"\n",
|
| 504 |
-
"# Keyword lists (from graders.py)\n",
|
| 505 |
-
"_KEYWORD_REWARDS = {\n",
|
| 506 |
-
" 'billing': ['refund', 'charge', 'invoice', 'payment', 'billing'],\n",
|
| 507 |
-
" 'account': ['password', 'login', 'account', 'cancel', 'subscription'],\n",
|
| 508 |
-
" 'technical': ['engineering', 'escalate', 'bug', 'crash', 'error', 'fix'],\n",
|
| 509 |
-
" 'refund': ['refund', 'return', 'credit', 'process'],\n",
|
| 510 |
-
" 'general': ['hours', 'contact', 'phone', 'information', 'help'],\n",
|
| 511 |
-
"}\n",
|
| 512 |
-
"\n",
|
| 513 |
-
"def _reply_quality(reply_text, category):\n",
|
| 514 |
-
" \"\"\"Exact copy of graders._reply_quality: 0.0–0.5 keyword score.\"\"\"\n",
|
| 515 |
-
" if not reply_text: return 0.0\n",
|
| 516 |
-
" hits = sum(1 for kw in _KEYWORD_REWARDS.get(category, []) if kw in reply_text.lower())\n",
|
| 517 |
-
" return min(0.5, hits * 0.1)\n",
|
| 518 |
-
"\n",
|
| 519 |
-
"def _grade_task1(at, cat, correct_cat):\n",
|
| 520 |
-
" \"\"\"Exact copy of graders.grade_task1.\"\"\"\n",
|
| 521 |
-
" return 1.0 if (at == 'classify' and cat == correct_cat) else 0.0\n",
|
| 522 |
-
"\n",
|
| 523 |
-
"def _grade_task2(at, correct_action, step, cat, correct_cat):\n",
|
| 524 |
-
" \"\"\"Exact copy of graders.grade_task2 + classify step.\"\"\"\n",
|
| 525 |
-
" if step == 0:\n",
|
| 526 |
-
" # classify step: partial credit for correct category\n",
|
| 527 |
-
" if at == 'classify' and cat == correct_cat: return 0.3\n",
|
| 528 |
-
" if at == 'classify': return 0.1\n",
|
| 529 |
-
" return 0.0\n",
|
| 530 |
-
" # action step\n",
|
| 531 |
-
" if at == correct_action: return 1.0\n",
|
| 532 |
-
" if frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS: return 0.5\n",
|
| 533 |
-
" if at == 'close': return 0.0\n",
|
| 534 |
-
" return 0.0\n",
|
| 535 |
-
"\n",
|
| 536 |
-
"def _grade_task3(at, cat, correct_cat, correct_action, reply, step, steps_taken=2, max_steps=5):\n",
|
| 537 |
-
" \"\"\"Exact copy of graders.grade_task3.\"\"\"\n",
|
| 538 |
-
" if step == 0:\n",
|
| 539 |
-
" # classification step only\n",
|
| 540 |
-
" return 0.20 if (at == 'classify' and cat == correct_cat) else 0.0\n",
|
| 541 |
-
" # resolution step: 0.40 action + up to 0.50 reply + 0.15 efficiency\n",
|
| 542 |
-
" score = 0.0\n",
|
| 543 |
-
" classified_correctly = True # step-1 means step-0 already happened\n",
|
| 544 |
-
" score += 0.20 # classification credit carried from step 0\n",
|
| 545 |
-
" action_correct = (at == correct_action)\n",
|
| 546 |
-
" action_partial = (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n",
|
| 547 |
-
" if action_correct: score += 0.40\n",
|
| 548 |
-
" elif action_partial: score += 0.20\n",
|
| 549 |
-
" score += _reply_quality(reply, cat) # 0.0–0.5\n",
|
| 550 |
-
" # efficiency bonus (assume 2 steps taken for step-1 samples)\n",
|
| 551 |
-
" resolved = action_correct or action_partial\n",
|
| 552 |
-
" if resolved and steps_taken <= max_steps:\n",
|
| 553 |
-
" efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n",
|
| 554 |
-
" score += 0.15 * efficiency\n",
|
| 555 |
-
" return round(min(1.0, score), 4)\n",
|
| 556 |
-
"\n",
|
| 557 |
-
"def _loop_penalty(step_count, max_steps=10):\n",
|
| 558 |
-
" \"\"\"Exact copy of graders.loop_penalty.\"\"\"\n",
|
| 559 |
-
" return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n",
|
| 560 |
-
"\n",
|
| 561 |
-
"def _local_reward(completion, task_id, seed, step=0):\n",
|
| 562 |
-
" \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n",
|
| 563 |
-
" ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
|
| 564 |
-
" action = _safe_parse(completion)\n",
|
| 565 |
-
" if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
|
| 566 |
-
" at = action.get('action_type', '')\n",
|
| 567 |
-
" cat = action.get('category', '')\n",
|
| 568 |
-
" reply = action.get('reply_text', '') or ''\n",
|
| 569 |
-
" correct_cat = ticket['category']\n",
|
| 570 |
-
" correct_action = ticket['correct_action']\n",
|
| 571 |
-
"\n",
|
| 572 |
-
" if task_id == 1:\n",
|
| 573 |
-
" return _grade_task1(at, cat, correct_cat)\n",
|
| 574 |
-
" elif task_id == 2:\n",
|
| 575 |
-
" return _grade_task2(at, correct_action, step, cat, correct_cat)\n",
|
| 576 |
-
" else: # task 3\n",
|
| 577 |
-
" return _grade_task3(at, cat, correct_cat, correct_action, reply, step)\n",
|
| 578 |
-
"\n",
|
| 579 |
-
"def env_reward_fn(prompts, completions, **kwargs):\n",
|
| 580 |
-
" \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n",
|
| 581 |
-
" task_ids = kwargs.get('task_id', [1] * len(completions))\n",
|
| 582 |
-
" seeds = kwargs.get('seed', [42] * len(completions))\n",
|
| 583 |
-
" steps = kwargs.get('step', [0] * len(completions))\n",
|
| 584 |
-
" rewards = []\n",
|
| 585 |
-
" for i, completion in enumerate(completions):\n",
|
| 586 |
-
" tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n",
|
| 587 |
-
" seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n",
|
| 588 |
-
" step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n",
|
| 589 |
-
" actual_seed = seed % 10000 if seed >= 10000 else seed\n",
|
| 590 |
-
" r = _local_reward(completion, tid, actual_seed, step)\n",
|
| 591 |
-
" # apply loop penalty if step is high\n",
|
| 592 |
-
" r += _loop_penalty(step)\n",
|
| 593 |
-
" rewards.append(r)\n",
|
| 594 |
-
" return rewards\n",
|
| 595 |
-
"\n",
|
| 596 |
-
"def format_reward_fn(prompts, completions, **kwargs):\n",
|
| 597 |
-
" \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n",
|
| 598 |
-
" rewards = []\n",
|
| 599 |
-
" for completion in completions:\n",
|
| 600 |
-
" action = _safe_parse(completion)\n",
|
| 601 |
-
" if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
|
| 602 |
-
" at = action.get('action_type', '')\n",
|
| 603 |
-
" if at in ('classify', 'reply', 'escalate', 'close'):\n",
|
| 604 |
-
" bonus = 0.15\n",
|
| 605 |
-
" if at == 'classify' and action.get('category') in ('billing','technical','account','general','refund'):\n",
|
| 606 |
-
" bonus = 0.20\n",
|
| 607 |
-
" rewards.append(bonus)\n",
|
| 608 |
-
" else:\n",
|
| 609 |
-
" rewards.append(-0.20)\n",
|
| 610 |
-
" return rewards\n",
|
| 611 |
-
"\n",
|
| 612 |
-
"# Print ticket map\n",
|
| 613 |
-
"print('Reward functions synced to graders.py')\n",
|
| 614 |
-
"print('Ticket map (seed % len):')\n",
|
| 615 |
-
"for _i in range(6):\n",
|
| 616 |
-
" _tt = ALL_TICKETS[_i]\n",
|
| 617 |
-
" print(f' [{_i}] {_tt[\"id\"]} cat={_tt[\"category\"]} action={_tt[\"correct_action\"]}')\n",
|
| 618 |
-
"\n",
|
| 619 |
-
"# Sanity: seed=0->B001(billing,reply), seed=22->T001(technical,escalate)\n",
|
| 620 |
-
"_t0 = ALL_TICKETS[0] # B001 billing reply\n",
|
| 621 |
-
"_t22 = ALL_TICKETS[22] # T001 technical escalate\n",
|
| 622 |
-
"r1 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 1, 0, 0)\n",
|
| 623 |
-
"r2 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 2, 0, 0)\n",
|
| 624 |
-
"r3 = _local_reward(json.dumps({'action_type':'escalate'}), 2, 0, 1)\n",
|
| 625 |
-
"r4 = _local_reward(json.dumps({'action_type':_t22['correct_action'],'reply_text':'escalating this crash bug error to engineering team for a fix'}), 3, 22, 1)\n",
|
| 626 |
-
"r5 = format_reward_fn(prompts=['x'], completions=[json.dumps({'action_type':'respond'})])[0]\n",
|
| 627 |
-
"print(f'task1 correct classify: {r1} (expect 1.0)')\n",
|
| 628 |
-
"print(f'task2 step0 correct classify: {r2} (expect 0.3)')\n",
|
| 629 |
-
"print(f'task2 step1 partial escalate: {r3} (expect 0.5)')\n",
|
| 630 |
-
"print(f'task3 step1 correct+keywords: {r4} (expect 0.87+)')\n",
|
| 631 |
-
"print(f'bad format penalty: {r5} (expect -0.2)')\n"
|
| 632 |
]
|
| 633 |
},
|
| 634 |
{
|
|
|
|
| 493 |
"metadata": {},
|
| 494 |
"outputs": [],
|
| 495 |
"source": [
|
| 496 |
+
"# -----------------------------------------------------------------\n# Reward functions — synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE — keep in sync with graders.py manually.\n# FALLBACK ONLY — if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords — 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST — runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
]
|
| 498 |
},
|
| 499 |
{
|