ar9avg's picture
Nuclear clamp: every reward source in the codebase now returns (0.05, 0.95)
719c147
"""
Shaped reward function for the SQL debug RL environment.
All rewards are clamped to strictly (0, 1) before being returned.
"""
from __future__ import annotations
from typing import Optional
from dataclasses import dataclass
from rl.types import ErrorClass
from rl.error_classifier import error_severity
_SCORE_MIN = 0.05
_SCORE_MAX = 0.95
def _clamp(x: float) -> float:
"""Clamp to strictly (0, 1)."""
if x is None or x != x:
return _SCORE_MIN
return max(_SCORE_MIN, min(_SCORE_MAX, float(x)))
@dataclass
class GraderInput:
success: bool
attempt_number: int # 1-indexed
current_error_class: Optional[ErrorClass] # None if success
previous_error_class: Optional[ErrorClass] # None on first attempt
@dataclass
class RewardBreakdown:
base: float
attempt_penalty: float
severity_bonus: float
change_bonus: float
@dataclass
class GraderOutput:
reward: float
breakdown: RewardBreakdown
def compute_reward(inp: GraderInput) -> GraderOutput:
if inp.success:
base = 1.0
attempt_penalty = -0.1 * (inp.attempt_number - 1)
raw = base + attempt_penalty
return GraderOutput(
reward=_clamp(raw),
breakdown=RewardBreakdown(
base=base,
attempt_penalty=attempt_penalty,
severity_bonus=0.0,
change_bonus=0.0,
),
)
# Failed step — base penalty + potential shaping
base = -0.1
attempt_penalty = -0.05 * inp.attempt_number
severity_bonus = 0.0
change_bonus = 0.0
if inp.previous_error_class is not None and inp.current_error_class is not None:
prev_sev = error_severity(inp.previous_error_class)
curr_sev = error_severity(inp.current_error_class)
if curr_sev < prev_sev:
severity_bonus = 0.2
elif curr_sev > prev_sev:
severity_bonus = -0.1
if inp.current_error_class != inp.previous_error_class:
change_bonus = 0.1
raw = base + attempt_penalty + severity_bonus + change_bonus
return GraderOutput(
reward=_clamp(raw),
breakdown=RewardBreakdown(
base=base,
attempt_penalty=attempt_penalty,
severity_bonus=severity_bonus,
change_bonus=change_bonus,
),
)
def compute_episode_reward(step_rewards: list[float], success: bool) -> float:
"""Compute total episode reward, clamped to (0, 1)."""
if not step_rewards:
return _SCORE_MIN
avg = sum(step_rewards) / len(step_rewards)
return _clamp(avg)