File size: 2,632 Bytes
3c665d2
 
 
719c147
3c665d2
 
 
 
 
 
 
 
 
 
 
719c147
 
 
 
 
 
 
 
 
 
 
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c147
3c665d2
719c147
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c147
3c665d2
719c147
3c665d2
 
719c147
3c665d2
719c147
3c665d2
 
719c147
3c665d2
 
 
 
 
 
 
 
 
 
719c147
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Shaped reward function for the SQL debug RL environment.

All rewards are clamped to strictly (0, 1) before being returned.
"""

from __future__ import annotations

from typing import Optional
from dataclasses import dataclass

from rl.types import ErrorClass
from rl.error_classifier import error_severity


_SCORE_MIN = 0.05
_SCORE_MAX = 0.95


def _clamp(x: float) -> float:
    """Clamp to strictly (0, 1)."""
    if x is None or x != x:
        return _SCORE_MIN
    return max(_SCORE_MIN, min(_SCORE_MAX, float(x)))


@dataclass
class GraderInput:
    success: bool
    attempt_number: int                        # 1-indexed
    current_error_class: Optional[ErrorClass]  # None if success
    previous_error_class: Optional[ErrorClass] # None on first attempt


@dataclass
class RewardBreakdown:
    base: float
    attempt_penalty: float
    severity_bonus: float
    change_bonus: float


@dataclass
class GraderOutput:
    reward: float
    breakdown: RewardBreakdown


def compute_reward(inp: GraderInput) -> GraderOutput:
    if inp.success:
        base = 1.0
        attempt_penalty = -0.1 * (inp.attempt_number - 1)
        raw = base + attempt_penalty
        return GraderOutput(
            reward=_clamp(raw),
            breakdown=RewardBreakdown(
                base=base,
                attempt_penalty=attempt_penalty,
                severity_bonus=0.0,
                change_bonus=0.0,
            ),
        )

    # Failed step — base penalty + potential shaping
    base = -0.1
    attempt_penalty = -0.05 * inp.attempt_number

    severity_bonus = 0.0
    change_bonus = 0.0

    if inp.previous_error_class is not None and inp.current_error_class is not None:
        prev_sev = error_severity(inp.previous_error_class)
        curr_sev = error_severity(inp.current_error_class)

        if curr_sev < prev_sev:
            severity_bonus = 0.2
        elif curr_sev > prev_sev:
            severity_bonus = -0.1

        if inp.current_error_class != inp.previous_error_class:
            change_bonus = 0.1

    raw = base + attempt_penalty + severity_bonus + change_bonus

    return GraderOutput(
        reward=_clamp(raw),
        breakdown=RewardBreakdown(
            base=base,
            attempt_penalty=attempt_penalty,
            severity_bonus=severity_bonus,
            change_bonus=change_bonus,
        ),
    )


def compute_episode_reward(step_rewards: list[float], success: bool) -> float:
    """Compute total episode reward, clamped to (0, 1)."""
    if not step_rewards:
        return _SCORE_MIN
    avg = sum(step_rewards) / len(step_rewards)
    return _clamp(avg)