File size: 4,617 Bytes
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b71d1b
 
 
 
 
 
 
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b71d1b
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b71d1b
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Reward function for the SQL Debug Environment.

Reward is computed at every step (not just end of episode).
This provides dense, meaningful signal for RL training.

Reward components:
- correctness:      0.0–0.6  (row-level match vs expected)
- efficiency:       0.0–0.2  (bonus for solving quickly)  
- syntax_progress:  0.0–0.1  (valid SQL even if wrong content)
- schema_bonus:     0.0–0.1  (correct tables/columns referenced)
- penalty:          0.0 to 0.2  (deduction for bad actions)

Total range: 0.0 to 1.0 (clamped to [0.0, 1.0])
"""
from typing import Optional, List, Dict, Any
from .models import SQLDebugReward

MIN_STRICT_SCORE = 0.001
MAX_STRICT_SCORE = 0.999


def _strict_score(value: float) -> float:
    return round(min(MAX_STRICT_SCORE, max(MIN_STRICT_SCORE, value)), 4)


def compute_reward(
    action_type: str,
    query_result: Optional[Dict[str, Any]],
    grade_score: float,
    steps_taken: int,
    max_steps: int,
    previous_best_score: float,
    schema_tables: List[str],
    submitted_query: Optional[str] = None,
) -> SQLDebugReward:
    """
    Compute the full reward for a step.

    Args:
    action_type: The action taken this step
    query_result: Result dict from EpisodeDatabase.execute_query()
    grade_score: strict (0, 1) score from task grader
    steps_taken: How many steps have been used (1-indexed)
    max_steps: Maximum steps for this task
    previous_best_score: Best grade score seen so far
    schema_tables: List of valid table names in this task's DB
    submitted_query: The SQL query string (if action was submit_query)
    """

    correctness = 0.0
    efficiency = 0.0
    syntax_progress = 0.0
    schema_bonus = 0.0
    penalty = 0.0  # deduction magnitude (non-negative)

    if action_type == "submit_query":
        # Correctness: primary signal
        correctness = min(0.6, grade_score * 0.6)

        # Syntax progress: reward for at least getting a valid query
        if query_result and query_result.get("success"):
            syntax_progress = 0.1
        elif query_result and not query_result.get("success"):
            # Partially reward if it's getting closer (fewer errors)
            error = query_result.get("error_message", "")
            if "no such column" in error.lower():
                syntax_progress = 0.03  # Structure is right but wrong column
            elif "no such table" in error.lower():
                syntax_progress = 0.01
            else:
                syntax_progress = 0.0

        # Schema bonus: correct table references
        if submitted_query and schema_tables:
            query_upper = submitted_query.upper()
            tables_referenced = sum(
                1 for t in schema_tables if t.upper() in query_upper
            )
            schema_bonus = min(0.1, (tables_referenced / len(schema_tables)) * 0.1)

        # Efficiency bonus: reward solving with fewer steps
        if grade_score >= 0.95:  # Near-perfect solution
            steps_fraction = steps_taken / max_steps
            if steps_fraction <= 0.3:
                efficiency = 0.2
            elif steps_fraction <= 0.5:
                efficiency = 0.15
            elif steps_fraction <= 0.7:
                efficiency = 0.1
            else:
                efficiency = 0.05

        # Penalty: if score went DOWN from previous best (regressed)
        if grade_score < previous_best_score - 0.1:
            penalty = 0.05

    elif action_type == "reset_query":
        # Penalize resetting — agent should be making progress
        penalty = 0.05

    elif action_type in ("inspect_schema", "inspect_error", "inspect_sample"):
        # Free information actions — small positive for using schema info
        # (encourages agents to explore rather than blindly guess)
        syntax_progress = 0.01

    # Penalty: approaching step limit (urgency signal)
    steps_remaining = max_steps - steps_taken
    if steps_remaining <= 2 and grade_score < 0.5:
        penalty += 0.03

    total_raw = correctness + efficiency + syntax_progress + schema_bonus - penalty
    total = _strict_score(total_raw)

    breakdown = (
        f"correctness={correctness:.3f} + "
        f"efficiency={efficiency:.3f} + "
        f"syntax={syntax_progress:.3f} + "
        f"schema={schema_bonus:.3f} + "
        f"penalty={penalty:.3f} = {total:.4f}"
    )

    return SQLDebugReward(
        value=total,
        correctness=correctness,
        efficiency=efficiency,
        syntax_progress=syntax_progress,
        schema_bonus=schema_bonus,
        penalty=penalty,
        breakdown=breakdown
    )