payops_env / grader.py
padmapriyagosakan's picture
fix: clamp all grader scores to strict (0,1) open interval
71ea0d8
Raw
History Blame Contribute Delete
15.6 kB
"""
Grader for the PayOps environment.
Reward design (v2 — trajectory-based)
--------------------------------------
The grader rewards *correct intermediate reasoning* as well as the final
call, so agents receive a dense learning signal across the full trajectory.
Terminal action credit is now split:
Correct final action → +1.00
Partial-credit adjacent action → fraction × 1.00
approve when should be reject/escalate → −1.00 (worst mistake)
approve when should be flag/hold → −0.50
reject when should be approve → −0.50
any other wrong terminal action → −0.25
Skip-investigation penalty (hard / critical tasks only):
Agent issued zero investigation sub-actions on a task that has
requires_investigation:
• Wrong terminal action → credit × 0.50
• Correct terminal action → credit × 0.80
Correct actions that skip investigation still earn partial credit,
but the full reward requires proper investigation first.
Investigation sub-action bonuses (per eligible, first use only):
Used one of task.requires_investigation → +0.15
Flag identification: agent used inspect AND task.key_flags ⊆ obs.flags → +0.20
(Both bonuses are independent and stackable.)
Duplicate investigation penalty:
Same sub-action on same task more than once → −0.05
Modifiers:
Difficulty weight: easy=1.0, medium=1.2, hard=1.5, critical=2.0
Confidence (≥0.8) AND correct → +0.10
Confidence (≥0.8) AND wrong → −0.10
Regulatory bonus (file_sar before terminal on regulatory task) → +0.20
Budget overspend penalty: (spent − limit) × 0.10
Normalised episode score: [0, 1], strictly clamped.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
from payops_env.tasks import ACTION_COSTS, PayOpsTask
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Terminal-action credit: correct action earns full credit
TERMINAL_CORRECT = 1.0
FULL_CREDIT = TERMINAL_CORRECT # alias for backward compat
WRONG_APPROVE_FRAUD = -1.0
WRONG_APPROVE_CAUTION = -0.5
WRONG_REJECT_GOOD = -0.5
WRONG_DEFAULT = -0.25
# Investigation trajectory bonuses
INVESTIGATION_BONUS = 0.15 # per eligible sub-action used (first use)
FLAG_IDENTIFICATION_BONUS = 0.20 # agent ran inspect AND all key_flags are in obs
TIME_PENALTY_PER_EXTRA_STEP = 0.05 # duplicate investigation on same task
CONFIDENCE_CORRECT_BONUS = 0.10
CONFIDENCE_WRONG_PENALTY = -0.10
REGULATORY_BONUS = 0.20
BUDGET_OVERSPEND_PENALTY = 0.10
# Skip-investigation penalty for hard/critical tasks with requires_investigation.
# Applied when the agent issued ZERO investigation sub-actions for that task.
# Wrong terminal: halved. Correct terminal: 20% reduction (still well above minimum).
SKIP_INVESTIGATION_MULTIPLIER = 0.50 # applied to wrong terminals
SKIP_INV_CORRECT_MULTIPLIER = 0.80 # applied to correct terminals
DIFFICULTY_WEIGHT: Dict[str, float] = {
"easy": 1.0,
"medium": 1.2,
"hard": 1.5,
"critical": 2.0,
}
INVESTIGATION_ACTIONS: Set[str] = {
"inspect", "request_docs", "verify_kyc", "contact_sender", "file_sar"
}
# Maximum achievable reward per task at weight=1.0 (used for normalisation)
# correct terminal (1.0); investigation bonuses are extra and clamped at 1.0
_MAX_TASK_RAW = TERMINAL_CORRECT
# ---------------------------------------------------------------------------
# Per-step helpers
# ---------------------------------------------------------------------------
def _is_investigation(action_type: str) -> bool:
return action_type in INVESTIGATION_ACTIONS
def _base_terminal_reward(action_type: str, task: PayOpsTask) -> float:
"""Return the base reward for a terminal action against a task."""
if action_type == task.correct_action:
return TERMINAL_CORRECT
if action_type in task.partial_credit_actions:
return TERMINAL_CORRECT * task.partial_credit_actions[action_type]
if action_type == "approve" and task.correct_action in ("reject", "escalate"):
return WRONG_APPROVE_FRAUD
if action_type == "approve" and task.correct_action in ("flag", "hold"):
return WRONG_APPROVE_CAUTION
if action_type == "reject" and task.correct_action == "approve":
return WRONG_REJECT_GOOD
return WRONG_DEFAULT
def step_reward(
action_type: str,
task: PayOpsTask,
inspected_already: bool = False,
investigation_done: bool = True,
) -> float:
"""
Single-step reward used by the real-time environment (step_async).
``investigation_done`` must be False when the agent issued zero
investigation sub-actions for this task — the same skip-investigation
penalty applied by grade_episode is then applied here so the per-step
reward the agent sees during training matches the final episode score.
"""
if _is_investigation(action_type):
return 0.0 if inspected_already else INVESTIGATION_BONUS
base = _base_terminal_reward(action_type, task)
requires_inv = getattr(task, "requires_investigation", set())
if requires_inv and not investigation_done and task.difficulty in ("hard", "critical"):
correct = action_type == task.correct_action
base = base * (SKIP_INV_CORRECT_MULTIPLIER if correct else SKIP_INVESTIGATION_MULTIPLIER)
return base
# ---------------------------------------------------------------------------
# Extended per-task grader (used by grade_episode)
# ---------------------------------------------------------------------------
@dataclass
class TaskGradeDetail:
task_id: str
difficulty: str
weight: float
correct_action: str
terminal_action: str
investigation_actions_used: List[str]
base_reward: float
investigation_bonus: float
flag_id_bonus: float
time_penalty: float
confidence_modifier: float
regulatory_bonus: float
total_reward: float
correct: bool
reward_breakdown: Dict[str, float] = field(default_factory=dict)
def _grade_single_task(
terminal_action: str,
investigation_actions: List[str], # sub-actions used BEFORE terminal
task: PayOpsTask,
agent_confidence: Optional[float] = None,
) -> TaskGradeDetail:
weight = DIFFICULTY_WEIGHT.get(task.difficulty, 1.0)
base = _base_terminal_reward(terminal_action, task)
correct = terminal_action == task.correct_action
# ── skip-investigation penalty ───────────────────────────────────────────
# Hard/critical tasks with requires_investigation penalise agents that skip
# all investigation sub-actions before making the terminal call.
# Wrong terminal → halve the credit (existing behaviour).
# Correct terminal → 20% reduction; full reward requires investigation first.
requires_inv = getattr(task, "requires_investigation", set())
if requires_inv and not investigation_actions and task.difficulty in ("hard", "critical"):
if not correct:
base = base * SKIP_INVESTIGATION_MULTIPLIER
else:
base = base * SKIP_INV_CORRECT_MULTIPLIER
# ── investigation trajectory bonus & time penalty ────────────────────────
inv_bonus = 0.0
time_pen = 0.0
eligible = getattr(task, "requires_investigation", set())
seen_counts: Dict[str, int] = {}
for inv_action in investigation_actions:
seen_counts[inv_action] = seen_counts.get(inv_action, 0) + 1
if inv_action in eligible and seen_counts[inv_action] == 1:
inv_bonus += INVESTIGATION_BONUS
elif seen_counts[inv_action] > 1:
time_pen += TIME_PENALTY_PER_EXTRA_STEP
# ── flag-identification bonus ────────────────────────────────────────────
# Awarded when: agent used 'inspect' AND the task has key_flags AND all
# key_flags are present in the task's flag list (they are always present
# as the randomised episode preserves the original flags).
flag_id = 0.0
key_flags = getattr(task, "key_flags", [])
if key_flags and "inspect" in investigation_actions:
# key_flags on the task are guaranteed to be in task.flags; reward
# the agent for using inspect (which reveals them) when they matter.
flag_id = FLAG_IDENTIFICATION_BONUS
# ── confidence modifier ──────────────────────────────────────────────────
conf_mod = 0.0
if agent_confidence is not None and agent_confidence >= 0.8:
conf_mod = CONFIDENCE_CORRECT_BONUS if correct else CONFIDENCE_WRONG_PENALTY
# ── regulatory bonus ─────────────────────────────────────────────────────
reg_bonus = 0.0
if getattr(task, "regulatory_action", False) and "file_sar" in investigation_actions:
reg_bonus = REGULATORY_BONUS
raw_total = base + inv_bonus + flag_id - time_pen + conf_mod + reg_bonus
total = weight * raw_total
return TaskGradeDetail(
task_id=task.task_id,
difficulty=task.difficulty,
weight=weight,
correct_action=task.correct_action,
terminal_action=terminal_action,
investigation_actions_used=investigation_actions,
base_reward=round(base, 4),
investigation_bonus=round(inv_bonus, 4),
flag_id_bonus=round(flag_id, 4),
time_penalty=round(time_pen, 4),
confidence_modifier=round(conf_mod, 4),
regulatory_bonus=round(reg_bonus, 4),
total_reward=round(total, 4),
correct=correct,
reward_breakdown={
"base": round(base, 4),
"weight": weight,
"investigation": round(inv_bonus, 4),
"flag_id": round(flag_id, 4),
"time_penalty": round(-time_pen, 4),
"confidence": round(conf_mod, 4),
"regulatory": round(reg_bonus, 4),
"weighted_total": round(total, 4),
},
)
# ---------------------------------------------------------------------------
# Episode grader
# ---------------------------------------------------------------------------
@dataclass
class EpisodeResult:
total_reward: float
max_possible_reward: float
normalised_score: float # strictly 0.0 – 1.0
per_task_rewards: List[dict]
budget_spent: float
budget_overspend: float
budget_penalty: float
passed: bool # normalised_score >= 0.5
def grade_episode(
actions: List[str],
tasks: List[PayOpsTask],
confidences: Optional[List[Optional[float]]] = None,
budget_limit: float = 5.0,
) -> EpisodeResult:
"""
Grade a complete episode.
``actions`` is the flat list of all actions taken (including investigation
sub-actions interspersed between terminal decisions).
Returns EpisodeResult with normalised_score strictly in [0.0, 1.0].
"""
if confidences is None:
confidences = [None] * len(actions)
per_task_details: List[TaskGradeDetail] = []
budget_spent = 0.0
task_idx = 0
pending_inv: List[str] = []
pending_conf: List[Optional[float]] = []
for action, conf in zip(actions, confidences):
budget_spent += ACTION_COSTS.get(action, 0.0)
if _is_investigation(action):
pending_inv.append(action)
pending_conf.append(conf)
else:
if task_idx >= len(tasks):
break
task = tasks[task_idx]
detail = _grade_single_task(action, pending_inv, task, agent_confidence=conf)
per_task_details.append(detail)
pending_inv = []
pending_conf = []
task_idx += 1
# Tasks the agent never reached get a small default penalty
while task_idx < len(tasks):
task = tasks[task_idx]
weight = DIFFICULTY_WEIGHT.get(task.difficulty, 1.0)
detail = _grade_single_task("hold", [], task, agent_confidence=None)
# Override to a neutral miss (no severe penalty for unreached tasks)
detail.base_reward = 0.0
detail.total_reward = 0.0
per_task_details.append(detail)
task_idx += 1
# ── budget overspend penalty ─────────────────────────────────────────────
budget_overspend = max(0.0, budget_spent - budget_limit)
budget_penalty = round(budget_overspend * BUDGET_OVERSPEND_PENALTY, 4)
raw_total = sum(d.total_reward for d in per_task_details)
total = raw_total - budget_penalty
# Max possible = each task at full trajectory credit × difficulty weight
# (terminal 0.6 + one inv 0.2 + flag_id 0.2) × weight
max_possible = sum(
DIFFICULTY_WEIGHT.get(t.difficulty, 1.0) * _MAX_TASK_RAW
for t in tasks
)
# Strict open interval (0, 1) — platform rejects exactly 0.0 and 1.0
if max_possible > 0:
normalised = total / max_possible
normalised = max(0.001, min(0.999, normalised))
else:
normalised = 0.001
# Build per-task rewards with grader config included.
# zip is safe because per_task_details always has exactly len(tasks) entries
# (the while-loop above fills in any tasks the agent never reached).
per_task_rewards = [
{
"task_id": d.task_id,
"difficulty": d.difficulty,
"weight": d.weight,
"terminal_action": d.terminal_action,
"correct_action": d.correct_action,
"investigation_used":d.investigation_actions_used,
"correct": d.correct,
"reward_breakdown": d.reward_breakdown,
"weighted_reward": d.total_reward,
# Grader config: lets platform validators (and server/app.py) find
# grader definitions per task without needing a separate API call.
"grader": t.grader,
}
for d, t in zip(per_task_details, tasks)
]
return EpisodeResult(
total_reward=round(total, 4),
max_possible_reward=round(max_possible, 4),
normalised_score=round(normalised, 4),
per_task_rewards=per_task_rewards,
budget_spent=round(budget_spent, 4),
budget_overspend=round(budget_overspend, 4),
budget_penalty=budget_penalty,
passed=normalised >= 0.5,
)
# ---------------------------------------------------------------------------
# Convenience wrapper used by the environment
# ---------------------------------------------------------------------------
def grade(
action_type: str,
task: PayOpsTask,
inspected_already: bool = False,
investigation_done: bool = True,
) -> float:
"""Single-step reward used inside environment.step_async."""
return step_reward(
action_type, task,
inspected_already=inspected_already,
investigation_done=investigation_done,
)