Spaces:
Paused
Paused
| """ | |
| Grader for the PayOps environment. | |
| Reward design (v2 — trajectory-based) | |
| -------------------------------------- | |
| The grader rewards *correct intermediate reasoning* as well as the final | |
| call, so agents receive a dense learning signal across the full trajectory. | |
| Terminal action credit is now split: | |
| Correct final action → +1.00 | |
| Partial-credit adjacent action → fraction × 1.00 | |
| approve when should be reject/escalate → −1.00 (worst mistake) | |
| approve when should be flag/hold → −0.50 | |
| reject when should be approve → −0.50 | |
| any other wrong terminal action → −0.25 | |
| Skip-investigation penalty (hard / critical tasks only): | |
| Agent issued zero investigation sub-actions on a task that has | |
| requires_investigation: | |
| • Wrong terminal action → credit × 0.50 | |
| • Correct terminal action → credit × 0.80 | |
| Correct actions that skip investigation still earn partial credit, | |
| but the full reward requires proper investigation first. | |
| Investigation sub-action bonuses (per eligible, first use only): | |
| Used one of task.requires_investigation → +0.15 | |
| Flag identification: agent used inspect AND task.key_flags ⊆ obs.flags → +0.20 | |
| (Both bonuses are independent and stackable.) | |
| Duplicate investigation penalty: | |
| Same sub-action on same task more than once → −0.05 | |
| Modifiers: | |
| Difficulty weight: easy=1.0, medium=1.2, hard=1.5, critical=2.0 | |
| Confidence (≥0.8) AND correct → +0.10 | |
| Confidence (≥0.8) AND wrong → −0.10 | |
| Regulatory bonus (file_sar before terminal on regulatory task) → +0.20 | |
| Budget overspend penalty: (spent − limit) × 0.10 | |
| Normalised episode score: [0, 1], strictly clamped. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Set, Tuple | |
| from payops_env.tasks import ACTION_COSTS, PayOpsTask | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| # Terminal-action credit: correct action earns full credit | |
| TERMINAL_CORRECT = 1.0 | |
| FULL_CREDIT = TERMINAL_CORRECT # alias for backward compat | |
| WRONG_APPROVE_FRAUD = -1.0 | |
| WRONG_APPROVE_CAUTION = -0.5 | |
| WRONG_REJECT_GOOD = -0.5 | |
| WRONG_DEFAULT = -0.25 | |
| # Investigation trajectory bonuses | |
| INVESTIGATION_BONUS = 0.15 # per eligible sub-action used (first use) | |
| FLAG_IDENTIFICATION_BONUS = 0.20 # agent ran inspect AND all key_flags are in obs | |
| TIME_PENALTY_PER_EXTRA_STEP = 0.05 # duplicate investigation on same task | |
| CONFIDENCE_CORRECT_BONUS = 0.10 | |
| CONFIDENCE_WRONG_PENALTY = -0.10 | |
| REGULATORY_BONUS = 0.20 | |
| BUDGET_OVERSPEND_PENALTY = 0.10 | |
| # Skip-investigation penalty for hard/critical tasks with requires_investigation. | |
| # Applied when the agent issued ZERO investigation sub-actions for that task. | |
| # Wrong terminal: halved. Correct terminal: 20% reduction (still well above minimum). | |
| SKIP_INVESTIGATION_MULTIPLIER = 0.50 # applied to wrong terminals | |
| SKIP_INV_CORRECT_MULTIPLIER = 0.80 # applied to correct terminals | |
| DIFFICULTY_WEIGHT: Dict[str, float] = { | |
| "easy": 1.0, | |
| "medium": 1.2, | |
| "hard": 1.5, | |
| "critical": 2.0, | |
| } | |
| INVESTIGATION_ACTIONS: Set[str] = { | |
| "inspect", "request_docs", "verify_kyc", "contact_sender", "file_sar" | |
| } | |
| # Maximum achievable reward per task at weight=1.0 (used for normalisation) | |
| # correct terminal (1.0); investigation bonuses are extra and clamped at 1.0 | |
| _MAX_TASK_RAW = TERMINAL_CORRECT | |
| # --------------------------------------------------------------------------- | |
| # Per-step helpers | |
| # --------------------------------------------------------------------------- | |
| def _is_investigation(action_type: str) -> bool: | |
| return action_type in INVESTIGATION_ACTIONS | |
| def _base_terminal_reward(action_type: str, task: PayOpsTask) -> float: | |
| """Return the base reward for a terminal action against a task.""" | |
| if action_type == task.correct_action: | |
| return TERMINAL_CORRECT | |
| if action_type in task.partial_credit_actions: | |
| return TERMINAL_CORRECT * task.partial_credit_actions[action_type] | |
| if action_type == "approve" and task.correct_action in ("reject", "escalate"): | |
| return WRONG_APPROVE_FRAUD | |
| if action_type == "approve" and task.correct_action in ("flag", "hold"): | |
| return WRONG_APPROVE_CAUTION | |
| if action_type == "reject" and task.correct_action == "approve": | |
| return WRONG_REJECT_GOOD | |
| return WRONG_DEFAULT | |
| def step_reward( | |
| action_type: str, | |
| task: PayOpsTask, | |
| inspected_already: bool = False, | |
| investigation_done: bool = True, | |
| ) -> float: | |
| """ | |
| Single-step reward used by the real-time environment (step_async). | |
| ``investigation_done`` must be False when the agent issued zero | |
| investigation sub-actions for this task — the same skip-investigation | |
| penalty applied by grade_episode is then applied here so the per-step | |
| reward the agent sees during training matches the final episode score. | |
| """ | |
| if _is_investigation(action_type): | |
| return 0.0 if inspected_already else INVESTIGATION_BONUS | |
| base = _base_terminal_reward(action_type, task) | |
| requires_inv = getattr(task, "requires_investigation", set()) | |
| if requires_inv and not investigation_done and task.difficulty in ("hard", "critical"): | |
| correct = action_type == task.correct_action | |
| base = base * (SKIP_INV_CORRECT_MULTIPLIER if correct else SKIP_INVESTIGATION_MULTIPLIER) | |
| return base | |
| # --------------------------------------------------------------------------- | |
| # Extended per-task grader (used by grade_episode) | |
| # --------------------------------------------------------------------------- | |
| class TaskGradeDetail: | |
| task_id: str | |
| difficulty: str | |
| weight: float | |
| correct_action: str | |
| terminal_action: str | |
| investigation_actions_used: List[str] | |
| base_reward: float | |
| investigation_bonus: float | |
| flag_id_bonus: float | |
| time_penalty: float | |
| confidence_modifier: float | |
| regulatory_bonus: float | |
| total_reward: float | |
| correct: bool | |
| reward_breakdown: Dict[str, float] = field(default_factory=dict) | |
| def _grade_single_task( | |
| terminal_action: str, | |
| investigation_actions: List[str], # sub-actions used BEFORE terminal | |
| task: PayOpsTask, | |
| agent_confidence: Optional[float] = None, | |
| ) -> TaskGradeDetail: | |
| weight = DIFFICULTY_WEIGHT.get(task.difficulty, 1.0) | |
| base = _base_terminal_reward(terminal_action, task) | |
| correct = terminal_action == task.correct_action | |
| # ── skip-investigation penalty ─────────────────────────────────────────── | |
| # Hard/critical tasks with requires_investigation penalise agents that skip | |
| # all investigation sub-actions before making the terminal call. | |
| # Wrong terminal → halve the credit (existing behaviour). | |
| # Correct terminal → 20% reduction; full reward requires investigation first. | |
| requires_inv = getattr(task, "requires_investigation", set()) | |
| if requires_inv and not investigation_actions and task.difficulty in ("hard", "critical"): | |
| if not correct: | |
| base = base * SKIP_INVESTIGATION_MULTIPLIER | |
| else: | |
| base = base * SKIP_INV_CORRECT_MULTIPLIER | |
| # ── investigation trajectory bonus & time penalty ──────────────────────── | |
| inv_bonus = 0.0 | |
| time_pen = 0.0 | |
| eligible = getattr(task, "requires_investigation", set()) | |
| seen_counts: Dict[str, int] = {} | |
| for inv_action in investigation_actions: | |
| seen_counts[inv_action] = seen_counts.get(inv_action, 0) + 1 | |
| if inv_action in eligible and seen_counts[inv_action] == 1: | |
| inv_bonus += INVESTIGATION_BONUS | |
| elif seen_counts[inv_action] > 1: | |
| time_pen += TIME_PENALTY_PER_EXTRA_STEP | |
| # ── flag-identification bonus ──────────────────────────────────────────── | |
| # Awarded when: agent used 'inspect' AND the task has key_flags AND all | |
| # key_flags are present in the task's flag list (they are always present | |
| # as the randomised episode preserves the original flags). | |
| flag_id = 0.0 | |
| key_flags = getattr(task, "key_flags", []) | |
| if key_flags and "inspect" in investigation_actions: | |
| # key_flags on the task are guaranteed to be in task.flags; reward | |
| # the agent for using inspect (which reveals them) when they matter. | |
| flag_id = FLAG_IDENTIFICATION_BONUS | |
| # ── confidence modifier ────────────────────────────────────────────────── | |
| conf_mod = 0.0 | |
| if agent_confidence is not None and agent_confidence >= 0.8: | |
| conf_mod = CONFIDENCE_CORRECT_BONUS if correct else CONFIDENCE_WRONG_PENALTY | |
| # ── regulatory bonus ───────────────────────────────────────────────────── | |
| reg_bonus = 0.0 | |
| if getattr(task, "regulatory_action", False) and "file_sar" in investigation_actions: | |
| reg_bonus = REGULATORY_BONUS | |
| raw_total = base + inv_bonus + flag_id - time_pen + conf_mod + reg_bonus | |
| total = weight * raw_total | |
| return TaskGradeDetail( | |
| task_id=task.task_id, | |
| difficulty=task.difficulty, | |
| weight=weight, | |
| correct_action=task.correct_action, | |
| terminal_action=terminal_action, | |
| investigation_actions_used=investigation_actions, | |
| base_reward=round(base, 4), | |
| investigation_bonus=round(inv_bonus, 4), | |
| flag_id_bonus=round(flag_id, 4), | |
| time_penalty=round(time_pen, 4), | |
| confidence_modifier=round(conf_mod, 4), | |
| regulatory_bonus=round(reg_bonus, 4), | |
| total_reward=round(total, 4), | |
| correct=correct, | |
| reward_breakdown={ | |
| "base": round(base, 4), | |
| "weight": weight, | |
| "investigation": round(inv_bonus, 4), | |
| "flag_id": round(flag_id, 4), | |
| "time_penalty": round(-time_pen, 4), | |
| "confidence": round(conf_mod, 4), | |
| "regulatory": round(reg_bonus, 4), | |
| "weighted_total": round(total, 4), | |
| }, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Episode grader | |
| # --------------------------------------------------------------------------- | |
| class EpisodeResult: | |
| total_reward: float | |
| max_possible_reward: float | |
| normalised_score: float # strictly 0.0 – 1.0 | |
| per_task_rewards: List[dict] | |
| budget_spent: float | |
| budget_overspend: float | |
| budget_penalty: float | |
| passed: bool # normalised_score >= 0.5 | |
| def grade_episode( | |
| actions: List[str], | |
| tasks: List[PayOpsTask], | |
| confidences: Optional[List[Optional[float]]] = None, | |
| budget_limit: float = 5.0, | |
| ) -> EpisodeResult: | |
| """ | |
| Grade a complete episode. | |
| ``actions`` is the flat list of all actions taken (including investigation | |
| sub-actions interspersed between terminal decisions). | |
| Returns EpisodeResult with normalised_score strictly in [0.0, 1.0]. | |
| """ | |
| if confidences is None: | |
| confidences = [None] * len(actions) | |
| per_task_details: List[TaskGradeDetail] = [] | |
| budget_spent = 0.0 | |
| task_idx = 0 | |
| pending_inv: List[str] = [] | |
| pending_conf: List[Optional[float]] = [] | |
| for action, conf in zip(actions, confidences): | |
| budget_spent += ACTION_COSTS.get(action, 0.0) | |
| if _is_investigation(action): | |
| pending_inv.append(action) | |
| pending_conf.append(conf) | |
| else: | |
| if task_idx >= len(tasks): | |
| break | |
| task = tasks[task_idx] | |
| detail = _grade_single_task(action, pending_inv, task, agent_confidence=conf) | |
| per_task_details.append(detail) | |
| pending_inv = [] | |
| pending_conf = [] | |
| task_idx += 1 | |
| # Tasks the agent never reached get a small default penalty | |
| while task_idx < len(tasks): | |
| task = tasks[task_idx] | |
| weight = DIFFICULTY_WEIGHT.get(task.difficulty, 1.0) | |
| detail = _grade_single_task("hold", [], task, agent_confidence=None) | |
| # Override to a neutral miss (no severe penalty for unreached tasks) | |
| detail.base_reward = 0.0 | |
| detail.total_reward = 0.0 | |
| per_task_details.append(detail) | |
| task_idx += 1 | |
| # ── budget overspend penalty ───────────────────────────────────────────── | |
| budget_overspend = max(0.0, budget_spent - budget_limit) | |
| budget_penalty = round(budget_overspend * BUDGET_OVERSPEND_PENALTY, 4) | |
| raw_total = sum(d.total_reward for d in per_task_details) | |
| total = raw_total - budget_penalty | |
| # Max possible = each task at full trajectory credit × difficulty weight | |
| # (terminal 0.6 + one inv 0.2 + flag_id 0.2) × weight | |
| max_possible = sum( | |
| DIFFICULTY_WEIGHT.get(t.difficulty, 1.0) * _MAX_TASK_RAW | |
| for t in tasks | |
| ) | |
| # Strict open interval (0, 1) — platform rejects exactly 0.0 and 1.0 | |
| if max_possible > 0: | |
| normalised = total / max_possible | |
| normalised = max(0.001, min(0.999, normalised)) | |
| else: | |
| normalised = 0.001 | |
| # Build per-task rewards with grader config included. | |
| # zip is safe because per_task_details always has exactly len(tasks) entries | |
| # (the while-loop above fills in any tasks the agent never reached). | |
| per_task_rewards = [ | |
| { | |
| "task_id": d.task_id, | |
| "difficulty": d.difficulty, | |
| "weight": d.weight, | |
| "terminal_action": d.terminal_action, | |
| "correct_action": d.correct_action, | |
| "investigation_used":d.investigation_actions_used, | |
| "correct": d.correct, | |
| "reward_breakdown": d.reward_breakdown, | |
| "weighted_reward": d.total_reward, | |
| # Grader config: lets platform validators (and server/app.py) find | |
| # grader definitions per task without needing a separate API call. | |
| "grader": t.grader, | |
| } | |
| for d, t in zip(per_task_details, tasks) | |
| ] | |
| return EpisodeResult( | |
| total_reward=round(total, 4), | |
| max_possible_reward=round(max_possible, 4), | |
| normalised_score=round(normalised, 4), | |
| per_task_rewards=per_task_rewards, | |
| budget_spent=round(budget_spent, 4), | |
| budget_overspend=round(budget_overspend, 4), | |
| budget_penalty=budget_penalty, | |
| passed=normalised >= 0.5, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Convenience wrapper used by the environment | |
| # --------------------------------------------------------------------------- | |
| def grade( | |
| action_type: str, | |
| task: PayOpsTask, | |
| inspected_already: bool = False, | |
| investigation_done: bool = True, | |
| ) -> float: | |
| """Single-step reward used inside environment.step_async.""" | |
| return step_reward( | |
| action_type, task, | |
| inspected_already=inspected_already, | |
| investigation_done=investigation_done, | |
| ) | |