Spaces:
Sleeping
Sleeping
File size: 2,632 Bytes
3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 3c665d2 719c147 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | """
Shaped reward function for the SQL debug RL environment.
All rewards are clamped to strictly (0, 1) before being returned.
"""
from __future__ import annotations
from typing import Optional
from dataclasses import dataclass
from rl.types import ErrorClass
from rl.error_classifier import error_severity
_SCORE_MIN = 0.05
_SCORE_MAX = 0.95
def _clamp(x: float) -> float:
"""Clamp to strictly (0, 1)."""
if x is None or x != x:
return _SCORE_MIN
return max(_SCORE_MIN, min(_SCORE_MAX, float(x)))
@dataclass
class GraderInput:
success: bool
attempt_number: int # 1-indexed
current_error_class: Optional[ErrorClass] # None if success
previous_error_class: Optional[ErrorClass] # None on first attempt
@dataclass
class RewardBreakdown:
base: float
attempt_penalty: float
severity_bonus: float
change_bonus: float
@dataclass
class GraderOutput:
reward: float
breakdown: RewardBreakdown
def compute_reward(inp: GraderInput) -> GraderOutput:
if inp.success:
base = 1.0
attempt_penalty = -0.1 * (inp.attempt_number - 1)
raw = base + attempt_penalty
return GraderOutput(
reward=_clamp(raw),
breakdown=RewardBreakdown(
base=base,
attempt_penalty=attempt_penalty,
severity_bonus=0.0,
change_bonus=0.0,
),
)
# Failed step — base penalty + potential shaping
base = -0.1
attempt_penalty = -0.05 * inp.attempt_number
severity_bonus = 0.0
change_bonus = 0.0
if inp.previous_error_class is not None and inp.current_error_class is not None:
prev_sev = error_severity(inp.previous_error_class)
curr_sev = error_severity(inp.current_error_class)
if curr_sev < prev_sev:
severity_bonus = 0.2
elif curr_sev > prev_sev:
severity_bonus = -0.1
if inp.current_error_class != inp.previous_error_class:
change_bonus = 0.1
raw = base + attempt_penalty + severity_bonus + change_bonus
return GraderOutput(
reward=_clamp(raw),
breakdown=RewardBreakdown(
base=base,
attempt_penalty=attempt_penalty,
severity_bonus=severity_bonus,
change_bonus=change_bonus,
),
)
def compute_episode_reward(step_rewards: list[float], success: bool) -> float:
"""Compute total episode reward, clamped to (0, 1)."""
if not step_rewards:
return _SCORE_MIN
avg = sum(step_rewards) / len(step_rewards)
return _clamp(avg)
|