File size: 7,560 Bytes
08b82d0
 
 
 
 
 
 
 
 
54a5bf9
08b82d0
 
 
 
 
 
54a5bf9
 
 
 
 
 
 
 
08b82d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54a5bf9
08b82d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54a5bf9
08b82d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
Multi-component grading system for SQL query evaluation.

Scores agent queries against ground truth with partial credit:
  - syntax_score  (0.1): Query parses and executes without error
  - column_score  (0.2): Fraction of expected columns present
  - row_score     (0.3): Fraction of expected rows matching
  - exact_score   (0.4): Full result set matches ground truth exactly

Total reward per question is in (0.0, 1.0) — strictly between 0 and 1.
"""

from typing import Any, List, Optional, Tuple

from .database import Database, QueryResult

# Epsilon to ensure scores are strictly between 0 and 1 (never exactly 0.0 or 1.0)
_EPS = 0.001


def _clamp_reward(reward: float) -> float:
    """Clamp reward to be strictly within (0, 1)."""
    return min(max(reward, _EPS), 1.0 - _EPS)


def _normalize_value(val: Any) -> Any:
    """Normalize a value for comparison (handle float/int equivalence, None)."""
    if val is None:
        return None
    if isinstance(val, float):
        if val == int(val):
            return int(val)
        return round(val, 2)
    if isinstance(val, str):
        return val.strip()
    return val


def _normalize_row(row: Tuple) -> Tuple:
    """Normalize all values in a row."""
    return tuple(_normalize_value(v) for v in row)


def _normalize_column_name(col: str) -> str:
    """Normalize column name for comparison (lowercase, strip)."""
    return col.strip().lower()


def grade_query(
    db: Database,
    agent_sql: str,
    expected_columns: List[str],
    expected_rows: List[List],
    order_matters: bool = True,
) -> dict:
    """
    Grade an agent's SQL query against expected results.

    Args:
        db: Active Database instance.
        agent_sql: The SQL query submitted by the agent.
        expected_columns: List of expected column names.
        expected_rows: List of expected row values (list of lists).
        order_matters: Whether row order affects scoring.

    Returns:
        Dictionary with:
          - reward: float in [0.0, 1.0]
          - syntax_score: float
          - column_score: float
          - row_score: float
          - exact_score: float
          - query_result: QueryResult object
          - feedback: str describing what was right/wrong
    """
    result = db.execute_query(agent_sql)

    # Component weights
    W_SYNTAX = 0.1
    W_COLUMN = 0.2
    W_ROW = 0.3
    W_EXACT = 0.4

    # --- Syntax Score ---
    if not result.success:
        return {
            "reward": _clamp_reward(0.0),
            "syntax_score": 0.0,
            "column_score": 0.0,
            "row_score": 0.0,
            "exact_score": 0.0,
            "query_result": result,
            "feedback": f"SQL error: {result.error}",
        }

    syntax_score = 1.0

    # --- Column Score ---
    expected_cols_normalized = [_normalize_column_name(c) for c in expected_columns]
    actual_cols_normalized = [_normalize_column_name(c) for c in result.columns]

    if not expected_cols_normalized:
        column_score = 1.0 if not actual_cols_normalized else 0.0
    else:
        matched_cols = sum(
            1 for c in expected_cols_normalized if c in actual_cols_normalized
        )
        column_score = matched_cols / len(expected_cols_normalized)

    # --- Row Score ---
    expected_rows_normalized = [
        _normalize_row(tuple(row)) for row in expected_rows
    ]
    actual_rows_normalized = [_normalize_row(row) for row in result.rows]

    if not expected_rows_normalized:
        row_score = 1.0 if not actual_rows_normalized else 0.0
    else:
        if order_matters:
            # For ordered results, match position-by-position
            matched_rows = 0
            for i, expected_row in enumerate(expected_rows_normalized):
                if i < len(actual_rows_normalized):
                    if _rows_match(expected_row, actual_rows_normalized[i], expected_cols_normalized, actual_cols_normalized):
                        matched_rows += 1
            row_score = matched_rows / len(expected_rows_normalized)
        else:
            # For unordered results, check set membership
            matched_rows = 0
            remaining_actual = list(actual_rows_normalized)
            for expected_row in expected_rows_normalized:
                for j, actual_row in enumerate(remaining_actual):
                    if _rows_match(expected_row, actual_row, expected_cols_normalized, actual_cols_normalized):
                        matched_rows += 1
                        remaining_actual.pop(j)
                        break
            row_score = matched_rows / len(expected_rows_normalized)

    # --- Exact Score ---
    exact_score = 0.0
    if column_score == 1.0 and row_score == 1.0:
        # Check exact match: same number of rows and all matched
        if len(actual_rows_normalized) == len(expected_rows_normalized):
            exact_score = 1.0
        else:
            # Extra rows returned — partial exact credit
            exact_score = 0.5

    # --- Total Reward ---
    reward = (
        W_SYNTAX * syntax_score
        + W_COLUMN * column_score
        + W_ROW * row_score
        + W_EXACT * exact_score
    )
    reward = round(_clamp_reward(reward), 4)

    # --- Feedback ---
    feedback_parts = []
    if syntax_score == 1.0:
        feedback_parts.append("Query executed successfully.")
    if column_score < 1.0:
        missing = [c for c in expected_cols_normalized if c not in actual_cols_normalized]
        feedback_parts.append(f"Missing columns: {missing}. Expected: {expected_cols_normalized}, Got: {actual_cols_normalized}")
    if row_score < 1.0:
        feedback_parts.append(
            f"Row match: {row_score:.0%} ({int(row_score * len(expected_rows_normalized))}/{len(expected_rows_normalized)} rows correct). "
            f"Expected {len(expected_rows_normalized)} rows, got {len(actual_rows_normalized)}."
        )
    if exact_score == 1.0:
        feedback_parts.append("Perfect match!")
    elif exact_score == 0.5:
        feedback_parts.append(f"All expected rows found but got {len(actual_rows_normalized)} rows instead of {len(expected_rows_normalized)} (extra rows).")

    return {
        "reward": reward,
        "syntax_score": syntax_score,
        "column_score": column_score,
        "row_score": row_score,
        "exact_score": exact_score,
        "query_result": result,
        "feedback": " ".join(feedback_parts),
    }


def _rows_match(
    expected_row: Tuple,
    actual_row: Tuple,
    expected_cols: List[str],
    actual_cols: List[str],
) -> bool:
    """
    Check if an actual row matches an expected row.

    Handles column reordering: maps expected columns to actual column positions.
    """
    if len(expected_cols) != len(expected_row):
        return False

    # Build a mapping from expected column index to actual column index
    col_map = {}
    for i, ec in enumerate(expected_cols):
        if ec in actual_cols:
            col_map[i] = actual_cols.index(ec)
        else:
            return False  # Missing column

    for exp_idx, act_idx in col_map.items():
        if act_idx >= len(actual_row):
            return False
        exp_val = expected_row[exp_idx]
        act_val = _normalize_value(actual_row[act_idx])
        if exp_val != act_val:
            # Try numeric comparison with tolerance
            try:
                if abs(float(exp_val) - float(act_val)) < 0.01:
                    continue
            except (TypeError, ValueError):
                pass
            return False

    return True