""" Shaped reward function for the SQL debug RL environment. All rewards are clamped to strictly (0, 1) before being returned. """ from __future__ import annotations from typing import Optional from dataclasses import dataclass from rl.types import ErrorClass from rl.error_classifier import error_severity _SCORE_MIN = 0.05 _SCORE_MAX = 0.95 def _clamp(x: float) -> float: """Clamp to strictly (0, 1).""" if x is None or x != x: return _SCORE_MIN return max(_SCORE_MIN, min(_SCORE_MAX, float(x))) @dataclass class GraderInput: success: bool attempt_number: int # 1-indexed current_error_class: Optional[ErrorClass] # None if success previous_error_class: Optional[ErrorClass] # None on first attempt @dataclass class RewardBreakdown: base: float attempt_penalty: float severity_bonus: float change_bonus: float @dataclass class GraderOutput: reward: float breakdown: RewardBreakdown def compute_reward(inp: GraderInput) -> GraderOutput: if inp.success: base = 1.0 attempt_penalty = -0.1 * (inp.attempt_number - 1) raw = base + attempt_penalty return GraderOutput( reward=_clamp(raw), breakdown=RewardBreakdown( base=base, attempt_penalty=attempt_penalty, severity_bonus=0.0, change_bonus=0.0, ), ) # Failed step — base penalty + potential shaping base = -0.1 attempt_penalty = -0.05 * inp.attempt_number severity_bonus = 0.0 change_bonus = 0.0 if inp.previous_error_class is not None and inp.current_error_class is not None: prev_sev = error_severity(inp.previous_error_class) curr_sev = error_severity(inp.current_error_class) if curr_sev < prev_sev: severity_bonus = 0.2 elif curr_sev > prev_sev: severity_bonus = -0.1 if inp.current_error_class != inp.previous_error_class: change_bonus = 0.1 raw = base + attempt_penalty + severity_bonus + change_bonus return GraderOutput( reward=_clamp(raw), breakdown=RewardBreakdown( base=base, attempt_penalty=attempt_penalty, severity_bonus=severity_bonus, change_bonus=change_bonus, ), ) def compute_episode_reward(step_rewards: list[float], success: bool) -> float: """Compute total episode reward, clamped to (0, 1).""" if not step_rewards: return _SCORE_MIN avg = sum(step_rewards) / len(step_rewards) return _clamp(avg)