File size: 1,758 Bytes
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Shaped reward function for the SQL Query Optimizer environment.

Design:
  - Partial credit every step based on grader improvement delta
  - Completion bonus when agent signals is_done and score ≥ threshold
  - Step penalty for unnecessary steps beyond task minimum
  - Invalid action penalty for empty / unparseable queries
"""
from __future__ import annotations

import math


_COMPLETION_THRESHOLD = 0.80
_COMPLETION_BONUS = 0.50
_STEP_PENALTY = 0.02
_INVALID_PENALTY = 0.10
_DELTA_WEIGHT = 0.50   # weight for grader improvement delta in step reward


def compute_step_reward(
    *,
    grader_score: float,
    prev_grader_score: float,
    step_number: int,
    max_steps: int,
    is_done: bool,
    is_invalid: bool,
) -> float:
    """
    Returns a reward in [-0.10, 1.0] for a single step.

    Components (all summed then clamped to [0, 1]):
      1. delta_reward   = _DELTA_WEIGHT * max(0, grader_score - prev_grader_score)
      2. completion_bonus (only if is_done and grader_score >= threshold)
      3. step_penalty   (only if step > min_steps_expected and not done-early)
      4. invalid_penalty (if query is empty / not parseable)
    """
    if is_invalid:
        return -_INVALID_PENALTY

    delta = max(0.0, grader_score - prev_grader_score)
    reward = _DELTA_WEIGHT * delta

    if is_done:
        if grader_score >= _COMPLETION_THRESHOLD:
            reward += _COMPLETION_BONUS
        # proportional partial completion signal even without bonus
        reward += grader_score * 0.30

    # Step penalty starts after half of max_steps used
    halfway = math.ceil(max_steps / 2)
    if step_number > halfway and not is_done:
        reward -= _STEP_PENALTY

    return round(min(max(reward, -_INVALID_PENALTY), 1.0), 4)