Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| scripts/validate_rewards.py β Anti-gaming reward validation (Β§4.3). | |
| Three cheating agents, each run against every training ticket. | |
| Asserts each cheater's mean total reward β€ 0.35. | |
| If a cheater exceeds the threshold, prints the offending tickets and why. | |
| Cheaters: | |
| AlwaysEscalate β submit_resolution(escalate=True) on turn 1, no searches. | |
| DumpEverything β fetch every artifact in the corpus, then submit with all | |
| cited (destroys citation precision). | |
| AnswerImmediately β submit_resolution on turn 1, no searches, empty citations. | |
| Usage: | |
| uv run python scripts/validate_rewards.py | |
| """ | |
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| from server.corpus import Corpus # noqa: E402 | |
| from server import rewards as R # noqa: E402 | |
| THRESHOLD = 0.35 | |
| DATA_DIR = ROOT / "data" | |
| # ------------------------------------------------------------------ # | |
| # Minimal mutable state (mirrors _EpisodeState, no env dependency) # | |
| # ------------------------------------------------------------------ # | |
| class _State: | |
| __slots__ = ( | |
| "episode_id", "step_count", | |
| "target_ticket_id", "gold_resolution", "gold_cited_ids", | |
| "difficulty", "is_unanswerable", | |
| "tools_called", "artifacts_viewed", | |
| "searches_made", "fetches_made", | |
| "submitted", "submitted_resolution", "submitted_citations", | |
| "submitted_confidence", "submitted_escalate", | |
| ) | |
| def __init__(self, ticket: dict): | |
| self.episode_id = ticket.get("ticket_id", "") | |
| self.step_count = 0 | |
| self.target_ticket_id = ticket.get("ticket_id", "") | |
| self.gold_resolution = ticket.get("gold_resolution", "") | |
| self.gold_cited_ids = list(ticket.get("gold_cited_ids", [])) | |
| self.difficulty = ticket.get("difficulty", "medium") | |
| self.is_unanswerable = ticket.get("is_unanswerable", False) | |
| self.tools_called = [] | |
| self.artifacts_viewed = [] | |
| self.searches_made = 0 | |
| self.fetches_made = 0 | |
| self.submitted = False | |
| self.submitted_resolution = None | |
| self.submitted_citations = [] | |
| self.submitted_confidence = 0.5 | |
| self.submitted_escalate = False | |
| # ------------------------------------------------------------------ # | |
| # Cheater strategies # | |
| # ------------------------------------------------------------------ # | |
| def always_escalate(ticket: dict, corpus: Corpus) -> dict: | |
| """Turn 1: submit_resolution(escalate=True). No artifact access.""" | |
| s = _State(ticket) | |
| s.step_count = 1 | |
| s.submitted = True | |
| s.submitted_resolution = "" | |
| s.submitted_citations = [] | |
| s.submitted_confidence = 0.5 | |
| s.submitted_escalate = True | |
| return R.reward_breakdown(s) | |
| def dump_everything(ticket: dict, corpus: Corpus) -> dict: | |
| """Fetch every artifact in the corpus, cite all of them, submit gold resolution. | |
| Maximally adversarial: resolution quality is perfect but citation precision | |
| is destroyed (~1-4 gold IDs out of 235 total β F1 < 0.05).""" | |
| s = _State(ticket) | |
| all_ids: list = [] | |
| for a in corpus._kb: | |
| aid = a.get("article_id", "") | |
| if aid: | |
| all_ids.append(aid) | |
| s.fetches_made += 1 | |
| for t in corpus._tickets: | |
| tid = t.get("ticket_id", "") | |
| if tid: | |
| all_ids.append(tid) | |
| s.fetches_made += 1 | |
| for inc in corpus._incidents: | |
| iid = inc.get("incident_id", "") | |
| if iid: | |
| all_ids.append(iid) | |
| s.fetches_made += 1 | |
| s.artifacts_viewed = list(all_ids) | |
| s.step_count = s.fetches_made | |
| s.submitted = True | |
| # Use gold resolution so correctness is maximal β the only thing that | |
| # can defeat this cheater is citation precision. | |
| s.submitted_resolution = ticket.get("gold_resolution", "resolved") | |
| s.submitted_citations = list(all_ids) # cite everything β precision β 0 | |
| s.submitted_confidence = 0.5 | |
| s.submitted_escalate = False | |
| return R.reward_breakdown(s) | |
| def answer_immediately(ticket: dict, corpus: Corpus) -> dict: | |
| """Turn 1: submit with a generic resolution, no searches, empty citations.""" | |
| s = _State(ticket) | |
| s.step_count = 1 | |
| s.submitted = True | |
| s.submitted_resolution = "The issue has been resolved." | |
| s.submitted_citations = [] | |
| s.submitted_confidence = 0.5 | |
| s.submitted_escalate = False | |
| return R.reward_breakdown(s) | |
| # ------------------------------------------------------------------ # | |
| # Runner # | |
| # ------------------------------------------------------------------ # | |
| CHEATERS = [ | |
| ("AlwaysEscalate", always_escalate), | |
| ("DumpEverything", dump_everything), | |
| ("AnswerImmediately", answer_immediately), | |
| ] | |
| SUB_KEYS = ["primary", "grounding", "efficiency", "calibration", "format", "total"] | |
| def run_cheater(name: str, fn, tickets: list, corpus: Corpus) -> dict: | |
| per_ticket = [(t, fn(t, corpus)) for t in tickets] | |
| n = len(per_ticket) | |
| means = {k: sum(r[k] for _, r in per_ticket) / n for k in SUB_KEYS} if n else {k: 0.0 for k in SUB_KEYS} | |
| over = [(t, r) for t, r in per_ticket if r["total"] > THRESHOLD] | |
| return {"name": name, "means": means, "n": n, "over_threshold": over} | |
| def _why(ticket: dict, breakdown: dict) -> str: | |
| parts = [] | |
| if breakdown["primary"] > 0: | |
| if ticket.get("is_unanswerable"): | |
| parts.append("correct escalation on unanswerable ticket") | |
| else: | |
| parts.append(f"primary={breakdown['primary']:.2f}") | |
| if breakdown["grounding"] > 0.05: | |
| parts.append(f"grounding={breakdown['grounding']:.3f}") | |
| if breakdown["efficiency"] > 0.05: | |
| parts.append(f"efficiency={breakdown['efficiency']:.3f}") | |
| if breakdown["calibration"] > 0.05: | |
| parts.append(f"calibration={breakdown['calibration']:.3f}") | |
| return ", ".join(parts) if parts else "sub-rewards accumulated" | |
| def print_result(result: dict) -> bool: | |
| name = result["name"] | |
| means = result["means"] | |
| n = result["n"] | |
| over = result["over_threshold"] | |
| passed = means["total"] <= THRESHOLD | |
| print(f"\n{'β'*62}") | |
| print(f" {name:<22} n={n} threshold={THRESHOLD}") | |
| print(f"{'β'*62}") | |
| for k in SUB_KEYS[:-1]: | |
| print(f" {k:<14}: {means[k]:.4f}") | |
| mark = "β PASS" if passed else "β FAIL" | |
| print(f" {'TOTAL':<14}: {means['total']:.4f} {mark}") | |
| if not passed: | |
| print(f"\n Tickets driving the mean above {THRESHOLD} ({len(over)} total):") | |
| for t, r in over[:30]: | |
| tid = t.get("ticket_id", "?") | |
| diff = t.get("difficulty", "?") | |
| tag = " [unanswerable]" if t.get("is_unanswerable") else "" | |
| why = _why(t, r) | |
| print(f" {tid} [{diff}]{tag} total={r['total']:.3f} β {why}") | |
| if len(over) > 30: | |
| print(f" ... and {len(over) - 30} more") | |
| return passed | |
| def main() -> None: | |
| corpus = Corpus(DATA_DIR) | |
| tickets = corpus.train_tickets | |
| if not tickets: | |
| print("ERROR: No training tickets found. Add data/train_tickets.json first.") | |
| sys.exit(1) | |
| n_unans = sum(1 for t in tickets if t.get("is_unanswerable")) | |
| print(f"Corpus loaded: {len(corpus._kb)} KB articles, " | |
| f"{len(corpus._tickets)} past tickets, " | |
| f"{len(corpus._incidents)} incidents") | |
| print(f"Training tickets: {len(tickets)} total, {n_unans} unanswerable " | |
| f"({100*n_unans/len(tickets):.0f}%)") | |
| print(f"\nRunning anti-gaming validation β threshold: mean β€ {THRESHOLD}") | |
| results = [run_cheater(name, fn, tickets, corpus) for name, fn in CHEATERS] | |
| all_pass = True | |
| for r in results: | |
| if not print_result(r): | |
| all_pass = False | |
| print(f"\n{'β'*62}") | |
| if all_pass: | |
| print(" ALL CHEATERS BELOW THRESHOLD β reward design is anti-gaming.") | |
| print(f"{'β'*62}") | |
| else: | |
| print(" SOME CHEATERS EXCEEDED THRESHOLD.") | |
| print(" Adjust reward weights in server/rewards.py before training.") | |
| print(f"{'β'*62}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |