Spaces:
Sleeping
Sleeping
File size: 6,874 Bytes
807d5cc 868d431 807d5cc 868d431 807d5cc 868d431 807d5cc 868d431 807d5cc 868d431 807d5cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """Reward / grading module.
Implements the hybrid reward used by the ESC environment:
step_reward = clip(immediate + future_oriented - penalties, 0, 1)
- immediate : stage-appropriate empathy/validation/open-question signal
- future_oriented : RLFF-ESC style lookahead — projects the oracle policy
k steps forward from the *post-action* state and
compares the projected resolution_score against the
pre-action ceiling. Rewards actions that *preserve or
advance* the attainable resolution, not just ones
that look good this turn.
- penalties : dismissive language, premature advice, repetitive
bare replies, interrogation.
This shaping gives the agent dense, varying signal across the trajectory
(required by the rubric: "signal over the full trajectory, not just
binary end-of-episode").
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List
from .seeker import (
Features,
SeekerState,
Stage,
resolution_score,
simulate_oracle_rollout,
stage_progress,
)
# Hyper-parameters — tuned to keep step reward in [0, 1] under normal play.
LOOKAHEAD_K = 3
W_IMMEDIATE = 0.45
W_FUTURE = 0.55
DISMISSIVE_PENALTY = 0.6
PREMATURE_ADVICE_PENALTY = 0.25
BARE_PENALTY = 0.15
INTERROGATION_PENALTY = 0.15
REPETITION_PENALTY = 0.18
@dataclass
class GradeBreakdown:
value: float
immediate: float
future_oriented: float
penalties: float
components: Dict[str, float]
def _stage_fit_score(stage: Stage, f: Features) -> float:
"""How appropriate are the agent's features for the current stage?"""
if stage in (Stage.OPENING, Stage.EXPLORING):
# Reward empathy + open questions; punish early advice strongly.
fit = 0.5 * min(1.0, f.empathy) + 0.3 * min(1.0, f.open_question) + 0.2 * min(1.0, f.validation)
fit -= 0.4 * min(1.0, f.advice)
elif stage == Stage.REFLECTING:
fit = 0.5 * min(1.0, f.validation) + 0.4 * min(1.0, f.empathy) + 0.1 * min(1.0, f.open_question)
fit -= 0.2 * min(1.0, f.advice)
elif stage == Stage.PLANNING:
# Advice is finally okay here.
fit = 0.4 * min(1.0, f.open_question) + 0.3 * min(1.0, f.advice) + 0.3 * min(1.0, f.empathy)
else: # CLOSING
fit = 0.5 * min(1.0, f.empathy) + 0.3 * min(1.0, f.safety) + 0.2 * min(1.0, f.validation)
return max(0.0, min(1.0, fit))
def _immediate_reward(pre_state: SeekerState, post_state: SeekerState, f: Features) -> float:
"""Turn-level reward: stage fit + trust delta + distress delta."""
stage_fit = _stage_fit_score(pre_state.stage, f)
trust_delta = max(0.0, post_state.trust - pre_state.trust)
distress_relief = max(0.0, pre_state.distress - post_state.distress)
stage_advance = max(
0.0, stage_progress(post_state.stage) - stage_progress(pre_state.stage)
)
reveal_bonus = 0.2 if (post_state.revealed and not pre_state.revealed) else 0.0
return max(
0.0,
min(
1.0,
0.45 * stage_fit
+ 0.20 * trust_delta * 2.0 # scale small deltas
+ 0.20 * distress_relief * 2.0
+ 0.10 * stage_advance
+ 0.05 * 1.0 # small baseline for any non-destructive turn
+ reveal_bonus,
),
)
def _future_oriented_reward(pre_state: SeekerState, post_state: SeekerState) -> float:
"""RLFF-ESC style: does this action *preserve / advance* future resolution?
We roll the oracle policy k steps from both the pre- and post-action states
and take the (clipped) delta. Positive delta = the action moved the
attainable future forward; negative = the agent damaged trajectory
potential and must recover.
"""
pre_ceiling = simulate_oracle_rollout(pre_state.snapshot(), LOOKAHEAD_K)
post_ceiling = simulate_oracle_rollout(post_state.snapshot(), LOOKAHEAD_K)
delta = post_ceiling - pre_ceiling
# Map delta in roughly [-0.4, +0.4] to [0, 1] with 0 at delta=0.
return max(0.0, min(1.0, 0.5 + 1.25 * delta))
def _penalties(flags: Dict[str, bool], f: Features) -> float:
p = 0.0
if flags.get("dismissed"):
p += DISMISSIVE_PENALTY
if flags.get("advice_too_early"):
p += PREMATURE_ADVICE_PENALTY
if flags.get("bare_reply"):
p += BARE_PENALTY
if flags.get("interrogated"):
p += INTERROGATION_PENALTY
if flags.get("repetitive"):
p += REPETITION_PENALTY
return p
def grade_step(
pre_state: SeekerState,
post_state: SeekerState,
features: Features,
flags: Dict[str, bool],
) -> GradeBreakdown:
imm = _immediate_reward(pre_state, post_state, features)
fut = _future_oriented_reward(pre_state, post_state)
pen = _penalties(flags, features)
combined = W_IMMEDIATE * imm + W_FUTURE * fut - pen
value = max(0.0, min(1.0, combined))
components = {
"stage_fit": _stage_fit_score(pre_state.stage, features),
"trust_delta": post_state.trust - pre_state.trust,
"distress_delta": pre_state.distress - post_state.distress,
"resolution_score_post": resolution_score(post_state),
"pre_oracle_ceiling": simulate_oracle_rollout(pre_state.snapshot(), LOOKAHEAD_K),
"post_oracle_ceiling": simulate_oracle_rollout(post_state.snapshot(), LOOKAHEAD_K),
}
return GradeBreakdown(
value=value,
immediate=imm,
future_oriented=fut,
penalties=pen,
components=components,
)
def final_task_score(
cumulative_reward: float,
steps_taken: int,
max_turns: int,
final_state: SeekerState,
success_threshold: float,
completed: bool,
) -> Dict[str, float]:
"""Compute the final [0,1] task score used by the grader."""
# Component 1: average shaped reward over the trajectory (already in [0,1]).
avg_reward = cumulative_reward / max(1, steps_taken)
# Component 2: final resolution_score.
final_res = resolution_score(final_state)
# Component 3: efficiency — finishing sooner is slightly better, but never
# negative. Flat 1.0 if used ≤ 60% of budget, linearly decays to 0.7 at max.
usage = steps_taken / max_turns
efficiency = 1.0 if usage <= 0.6 else max(0.7, 1.0 - 0.75 * (usage - 0.6))
completion = 1.0 if completed else 0.0
score = (
0.30 * avg_reward
+ 0.45 * final_res
+ 0.10 * efficiency
+ 0.15 * completion
)
score = max(0.0, min(1.0, score))
return {
"score": score,
"avg_reward": avg_reward,
"final_resolution": final_res,
"efficiency": efficiency,
"completion": completion,
"success": 1.0 if (completed and score >= success_threshold) else 0.0,
}
|