Spaces:
Sleeping
Sleeping
File size: 7,560 Bytes
08b82d0 54a5bf9 08b82d0 54a5bf9 08b82d0 54a5bf9 08b82d0 54a5bf9 08b82d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """
Multi-component grading system for SQL query evaluation.
Scores agent queries against ground truth with partial credit:
- syntax_score (0.1): Query parses and executes without error
- column_score (0.2): Fraction of expected columns present
- row_score (0.3): Fraction of expected rows matching
- exact_score (0.4): Full result set matches ground truth exactly
Total reward per question is in (0.0, 1.0) — strictly between 0 and 1.
"""
from typing import Any, List, Optional, Tuple
from .database import Database, QueryResult
# Epsilon to ensure scores are strictly between 0 and 1 (never exactly 0.0 or 1.0)
_EPS = 0.001
def _clamp_reward(reward: float) -> float:
"""Clamp reward to be strictly within (0, 1)."""
return min(max(reward, _EPS), 1.0 - _EPS)
def _normalize_value(val: Any) -> Any:
"""Normalize a value for comparison (handle float/int equivalence, None)."""
if val is None:
return None
if isinstance(val, float):
if val == int(val):
return int(val)
return round(val, 2)
if isinstance(val, str):
return val.strip()
return val
def _normalize_row(row: Tuple) -> Tuple:
"""Normalize all values in a row."""
return tuple(_normalize_value(v) for v in row)
def _normalize_column_name(col: str) -> str:
"""Normalize column name for comparison (lowercase, strip)."""
return col.strip().lower()
def grade_query(
db: Database,
agent_sql: str,
expected_columns: List[str],
expected_rows: List[List],
order_matters: bool = True,
) -> dict:
"""
Grade an agent's SQL query against expected results.
Args:
db: Active Database instance.
agent_sql: The SQL query submitted by the agent.
expected_columns: List of expected column names.
expected_rows: List of expected row values (list of lists).
order_matters: Whether row order affects scoring.
Returns:
Dictionary with:
- reward: float in [0.0, 1.0]
- syntax_score: float
- column_score: float
- row_score: float
- exact_score: float
- query_result: QueryResult object
- feedback: str describing what was right/wrong
"""
result = db.execute_query(agent_sql)
# Component weights
W_SYNTAX = 0.1
W_COLUMN = 0.2
W_ROW = 0.3
W_EXACT = 0.4
# --- Syntax Score ---
if not result.success:
return {
"reward": _clamp_reward(0.0),
"syntax_score": 0.0,
"column_score": 0.0,
"row_score": 0.0,
"exact_score": 0.0,
"query_result": result,
"feedback": f"SQL error: {result.error}",
}
syntax_score = 1.0
# --- Column Score ---
expected_cols_normalized = [_normalize_column_name(c) for c in expected_columns]
actual_cols_normalized = [_normalize_column_name(c) for c in result.columns]
if not expected_cols_normalized:
column_score = 1.0 if not actual_cols_normalized else 0.0
else:
matched_cols = sum(
1 for c in expected_cols_normalized if c in actual_cols_normalized
)
column_score = matched_cols / len(expected_cols_normalized)
# --- Row Score ---
expected_rows_normalized = [
_normalize_row(tuple(row)) for row in expected_rows
]
actual_rows_normalized = [_normalize_row(row) for row in result.rows]
if not expected_rows_normalized:
row_score = 1.0 if not actual_rows_normalized else 0.0
else:
if order_matters:
# For ordered results, match position-by-position
matched_rows = 0
for i, expected_row in enumerate(expected_rows_normalized):
if i < len(actual_rows_normalized):
if _rows_match(expected_row, actual_rows_normalized[i], expected_cols_normalized, actual_cols_normalized):
matched_rows += 1
row_score = matched_rows / len(expected_rows_normalized)
else:
# For unordered results, check set membership
matched_rows = 0
remaining_actual = list(actual_rows_normalized)
for expected_row in expected_rows_normalized:
for j, actual_row in enumerate(remaining_actual):
if _rows_match(expected_row, actual_row, expected_cols_normalized, actual_cols_normalized):
matched_rows += 1
remaining_actual.pop(j)
break
row_score = matched_rows / len(expected_rows_normalized)
# --- Exact Score ---
exact_score = 0.0
if column_score == 1.0 and row_score == 1.0:
# Check exact match: same number of rows and all matched
if len(actual_rows_normalized) == len(expected_rows_normalized):
exact_score = 1.0
else:
# Extra rows returned — partial exact credit
exact_score = 0.5
# --- Total Reward ---
reward = (
W_SYNTAX * syntax_score
+ W_COLUMN * column_score
+ W_ROW * row_score
+ W_EXACT * exact_score
)
reward = round(_clamp_reward(reward), 4)
# --- Feedback ---
feedback_parts = []
if syntax_score == 1.0:
feedback_parts.append("Query executed successfully.")
if column_score < 1.0:
missing = [c for c in expected_cols_normalized if c not in actual_cols_normalized]
feedback_parts.append(f"Missing columns: {missing}. Expected: {expected_cols_normalized}, Got: {actual_cols_normalized}")
if row_score < 1.0:
feedback_parts.append(
f"Row match: {row_score:.0%} ({int(row_score * len(expected_rows_normalized))}/{len(expected_rows_normalized)} rows correct). "
f"Expected {len(expected_rows_normalized)} rows, got {len(actual_rows_normalized)}."
)
if exact_score == 1.0:
feedback_parts.append("Perfect match!")
elif exact_score == 0.5:
feedback_parts.append(f"All expected rows found but got {len(actual_rows_normalized)} rows instead of {len(expected_rows_normalized)} (extra rows).")
return {
"reward": reward,
"syntax_score": syntax_score,
"column_score": column_score,
"row_score": row_score,
"exact_score": exact_score,
"query_result": result,
"feedback": " ".join(feedback_parts),
}
def _rows_match(
expected_row: Tuple,
actual_row: Tuple,
expected_cols: List[str],
actual_cols: List[str],
) -> bool:
"""
Check if an actual row matches an expected row.
Handles column reordering: maps expected columns to actual column positions.
"""
if len(expected_cols) != len(expected_row):
return False
# Build a mapping from expected column index to actual column index
col_map = {}
for i, ec in enumerate(expected_cols):
if ec in actual_cols:
col_map[i] = actual_cols.index(ec)
else:
return False # Missing column
for exp_idx, act_idx in col_map.items():
if act_idx >= len(actual_row):
return False
exp_val = expected_row[exp_idx]
act_val = _normalize_value(actual_row[act_idx])
if exp_val != act_val:
# Try numeric comparison with tolerance
try:
if abs(float(exp_val) - float(act_val)) < 0.01:
continue
except (TypeError, ValueError):
pass
return False
return True
|