Spaces:
Sleeping
Sleeping
| """Reward / grading module. | |
| Implements the hybrid reward used by the ESC environment: | |
| step_reward = clip(immediate + future_oriented - penalties, 0, 1) | |
| - immediate : stage-appropriate empathy/validation/open-question signal | |
| - future_oriented : RLFF-ESC style lookahead — projects the oracle policy | |
| k steps forward from the *post-action* state and | |
| compares the projected resolution_score against the | |
| pre-action ceiling. Rewards actions that *preserve or | |
| advance* the attainable resolution, not just ones | |
| that look good this turn. | |
| - penalties : dismissive language, premature advice, repetitive | |
| bare replies, interrogation. | |
| This shaping gives the agent dense, varying signal across the trajectory | |
| (required by the rubric: "signal over the full trajectory, not just | |
| binary end-of-episode"). | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| from .seeker import ( | |
| Features, | |
| SeekerState, | |
| Stage, | |
| resolution_score, | |
| simulate_oracle_rollout, | |
| stage_progress, | |
| ) | |
| # Hyper-parameters — tuned to keep step reward in [0, 1] under normal play. | |
| LOOKAHEAD_K = 3 | |
| W_IMMEDIATE = 0.45 | |
| W_FUTURE = 0.55 | |
| DISMISSIVE_PENALTY = 0.6 | |
| PREMATURE_ADVICE_PENALTY = 0.25 | |
| BARE_PENALTY = 0.15 | |
| INTERROGATION_PENALTY = 0.15 | |
| REPETITION_PENALTY = 0.18 | |
| class GradeBreakdown: | |
| value: float | |
| immediate: float | |
| future_oriented: float | |
| penalties: float | |
| components: Dict[str, float] | |
| def _stage_fit_score(stage: Stage, f: Features) -> float: | |
| """How appropriate are the agent's features for the current stage?""" | |
| if stage in (Stage.OPENING, Stage.EXPLORING): | |
| # Reward empathy + open questions; punish early advice strongly. | |
| fit = 0.5 * min(1.0, f.empathy) + 0.3 * min(1.0, f.open_question) + 0.2 * min(1.0, f.validation) | |
| fit -= 0.4 * min(1.0, f.advice) | |
| elif stage == Stage.REFLECTING: | |
| fit = 0.5 * min(1.0, f.validation) + 0.4 * min(1.0, f.empathy) + 0.1 * min(1.0, f.open_question) | |
| fit -= 0.2 * min(1.0, f.advice) | |
| elif stage == Stage.PLANNING: | |
| # Advice is finally okay here. | |
| fit = 0.4 * min(1.0, f.open_question) + 0.3 * min(1.0, f.advice) + 0.3 * min(1.0, f.empathy) | |
| else: # CLOSING | |
| fit = 0.5 * min(1.0, f.empathy) + 0.3 * min(1.0, f.safety) + 0.2 * min(1.0, f.validation) | |
| return max(0.0, min(1.0, fit)) | |
| def _immediate_reward(pre_state: SeekerState, post_state: SeekerState, f: Features) -> float: | |
| """Turn-level reward: stage fit + trust delta + distress delta.""" | |
| stage_fit = _stage_fit_score(pre_state.stage, f) | |
| trust_delta = max(0.0, post_state.trust - pre_state.trust) | |
| distress_relief = max(0.0, pre_state.distress - post_state.distress) | |
| stage_advance = max( | |
| 0.0, stage_progress(post_state.stage) - stage_progress(pre_state.stage) | |
| ) | |
| reveal_bonus = 0.2 if (post_state.revealed and not pre_state.revealed) else 0.0 | |
| return max( | |
| 0.0, | |
| min( | |
| 1.0, | |
| 0.45 * stage_fit | |
| + 0.20 * trust_delta * 2.0 # scale small deltas | |
| + 0.20 * distress_relief * 2.0 | |
| + 0.10 * stage_advance | |
| + 0.05 * 1.0 # small baseline for any non-destructive turn | |
| + reveal_bonus, | |
| ), | |
| ) | |
| def _future_oriented_reward(pre_state: SeekerState, post_state: SeekerState) -> float: | |
| """RLFF-ESC style: does this action *preserve / advance* future resolution? | |
| We roll the oracle policy k steps from both the pre- and post-action states | |
| and take the (clipped) delta. Positive delta = the action moved the | |
| attainable future forward; negative = the agent damaged trajectory | |
| potential and must recover. | |
| """ | |
| pre_ceiling = simulate_oracle_rollout(pre_state.snapshot(), LOOKAHEAD_K) | |
| post_ceiling = simulate_oracle_rollout(post_state.snapshot(), LOOKAHEAD_K) | |
| delta = post_ceiling - pre_ceiling | |
| # Map delta in roughly [-0.4, +0.4] to [0, 1] with 0 at delta=0. | |
| return max(0.0, min(1.0, 0.5 + 1.25 * delta)) | |
| def _penalties(flags: Dict[str, bool], f: Features) -> float: | |
| p = 0.0 | |
| if flags.get("dismissed"): | |
| p += DISMISSIVE_PENALTY | |
| if flags.get("advice_too_early"): | |
| p += PREMATURE_ADVICE_PENALTY | |
| if flags.get("bare_reply"): | |
| p += BARE_PENALTY | |
| if flags.get("interrogated"): | |
| p += INTERROGATION_PENALTY | |
| if flags.get("repetitive"): | |
| p += REPETITION_PENALTY | |
| return p | |
| def grade_step( | |
| pre_state: SeekerState, | |
| post_state: SeekerState, | |
| features: Features, | |
| flags: Dict[str, bool], | |
| ) -> GradeBreakdown: | |
| imm = _immediate_reward(pre_state, post_state, features) | |
| fut = _future_oriented_reward(pre_state, post_state) | |
| pen = _penalties(flags, features) | |
| combined = W_IMMEDIATE * imm + W_FUTURE * fut - pen | |
| value = max(0.0, min(1.0, combined)) | |
| components = { | |
| "stage_fit": _stage_fit_score(pre_state.stage, features), | |
| "trust_delta": post_state.trust - pre_state.trust, | |
| "distress_delta": pre_state.distress - post_state.distress, | |
| "resolution_score_post": resolution_score(post_state), | |
| "pre_oracle_ceiling": simulate_oracle_rollout(pre_state.snapshot(), LOOKAHEAD_K), | |
| "post_oracle_ceiling": simulate_oracle_rollout(post_state.snapshot(), LOOKAHEAD_K), | |
| } | |
| return GradeBreakdown( | |
| value=value, | |
| immediate=imm, | |
| future_oriented=fut, | |
| penalties=pen, | |
| components=components, | |
| ) | |
| def final_task_score( | |
| cumulative_reward: float, | |
| steps_taken: int, | |
| max_turns: int, | |
| final_state: SeekerState, | |
| success_threshold: float, | |
| completed: bool, | |
| ) -> Dict[str, float]: | |
| """Compute the final [0,1] task score used by the grader.""" | |
| # Component 1: average shaped reward over the trajectory (already in [0,1]). | |
| avg_reward = cumulative_reward / max(1, steps_taken) | |
| # Component 2: final resolution_score. | |
| final_res = resolution_score(final_state) | |
| # Component 3: efficiency — finishing sooner is slightly better, but never | |
| # negative. Flat 1.0 if used ≤ 60% of budget, linearly decays to 0.7 at max. | |
| usage = steps_taken / max_turns | |
| efficiency = 1.0 if usage <= 0.6 else max(0.7, 1.0 - 0.75 * (usage - 0.6)) | |
| completion = 1.0 if completed else 0.0 | |
| score = ( | |
| 0.30 * avg_reward | |
| + 0.45 * final_res | |
| + 0.10 * efficiency | |
| + 0.15 * completion | |
| ) | |
| score = max(0.0, min(1.0, score)) | |
| return { | |
| "score": score, | |
| "avg_reward": avg_reward, | |
| "final_resolution": final_res, | |
| "efficiency": efficiency, | |
| "completion": completion, | |
| "success": 1.0 if (completed and score >= success_threshold) else 0.0, | |
| } | |