Spaces:
Running
Running
File size: 7,792 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
nl2sql-bench/server/grader.py
==============================
Deterministic, programmatic reward grader.
No LLM-as-judge. Every reward is computed by comparing the agent's SQL
execution results against a ground-truth result set.
Reward decomposition (sums to 1.0 for a perfect first-attempt answer):
+0.10 syntax_ok β query runs without SQLite error
+0.20 columns_match β returned column names match ground truth exactly
+0.20 row_count_match β number of returned rows matches
+0.50 exact_match β full result set equals ground truth (order-aware
for ORDER BY queries, order-agnostic otherwise)
Step penalty:
-0.05 per step beyond the first (encourages solving in fewer steps),
clamped so the minimum is always 0.0.
All rewards are floats in [0.0, 1.0].
"""
from __future__ import annotations
import sqlite3
from typing import Any, Dict, List, Optional, Tuple
# ββ Result normalisation βββββββββββββββββββββββββββββββββββββββββββββββββββ
def _normalise_value(v: Any) -> Any:
"""Round floats for comparison so 1.2300000001 == 1.23."""
if isinstance(v, float):
return round(v, 4)
if isinstance(v, str):
return v.strip()
return v
def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]:
return {k: _normalise_value(v) for k, v in row.items()}
def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [_normalise_row(r) for r in rows]
# ββ SQL execution ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def execute_query(
conn: sqlite3.Connection,
query: str,
max_rows: int = 200,
) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
"""
Execute a SQL query safely.
Returns (rows, error_string).
rows is None on error.
"""
query = query.strip().rstrip(";")
if not query:
return None, "Empty query."
# Block write operations β the environment is read-only from the agent's view.
forbidden = ("insert", "update", "delete", "drop", "alter",
"create", "replace", "truncate", "pragma")
first_word = query.split()[0].lower() if query.split() else ""
if first_word in forbidden:
return None, (
f"Operation '{first_word.upper()}' is not allowed. "
"Only SELECT queries are permitted."
)
try:
cur = conn.execute(query)
cols = [d[0] for d in cur.description] if cur.description else []
rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)]
return rows, None
except sqlite3.Error as exc:
return None, str(exc)
# ββ Grading logic ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class GradeResult:
__slots__ = (
"reward", "syntax_ok", "columns_match",
"row_count_match", "exact_match", "step_penalty",
"breakdown",
)
def __init__(
self,
reward: float,
syntax_ok: bool,
columns_match: bool,
row_count_match: bool,
exact_match: bool,
step_penalty: float,
) -> None:
self.reward = reward
self.syntax_ok = syntax_ok
self.columns_match = columns_match
self.row_count_match = row_count_match
self.exact_match = exact_match
self.step_penalty = step_penalty
self.breakdown = {
"syntax_ok": 0.10 if syntax_ok else 0.0,
"columns_match": 0.20 if (syntax_ok and columns_match) else 0.0,
"row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0,
"exact_match": 0.50 if (syntax_ok and exact_match) else 0.0,
"step_penalty": -step_penalty,
}
def __repr__(self) -> str: # pragma: no cover
return (
f"GradeResult(reward={self.reward:.3f}, "
f"exact={self.exact_match}, cols={self.columns_match}, "
f"rows={self.row_count_match}, syntax={self.syntax_ok})"
)
def grade(
actual_rows: Optional[List[Dict[str, Any]]],
ground_truth_rows: List[Dict[str, Any]],
error: Optional[str],
step: int,
order_sensitive: bool = False,
) -> GradeResult:
"""
Grade the agent's query result against ground truth.
Parameters
----------
actual_rows : Rows returned by the agent's query (None on error).
ground_truth_rows : Expected rows (pre-computed at task load time).
error : SQLite error string (None if query ran successfully).
step : Current step number (1-indexed) for penalty calculation.
order_sensitive : If True, row order matters (queries with ORDER BY).
"""
# ββ Syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
syntax_ok = error is None and actual_rows is not None
if not syntax_ok:
return GradeResult(
reward=0.0,
syntax_ok=False,
columns_match=False,
row_count_match=False,
exact_match=False,
step_penalty=0.0,
)
gt_norm = _normalise_rows(ground_truth_rows)
act_norm = _normalise_rows(actual_rows)
gt_cols = set(gt_norm[0].keys()) if gt_norm else set()
act_cols = set(act_norm[0].keys()) if act_norm else set()
columns_match = act_cols == gt_cols
row_count_match = len(act_norm) == len(gt_norm)
# Exact match: if order matters, compare list; otherwise compare sorted sets
if columns_match and row_count_match:
if order_sensitive:
exact_match = act_norm == gt_norm
else:
# Sort rows by their string representation for order-agnostic compare
def _sort_key(r: Dict) -> str:
return str(sorted(r.items()))
exact_match = (
sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key)
)
else:
exact_match = False
# ββ Score assembly ββββββββββββββββββββββββββββββββββββββββββββββββ
raw = (
0.10 # syntax
+ (0.20 if columns_match else 0.0)
+ (0.20 if row_count_match else 0.0)
+ (0.50 if exact_match else 0.0)
)
penalty = max(0.0, step - 1) * 0.05
reward = float(max(0.0, min(1.0, raw - penalty)))
return GradeResult(
reward=reward,
syntax_ok=syntax_ok,
columns_match=columns_match,
row_count_match=row_count_match,
exact_match=exact_match,
step_penalty=penalty,
)
# ββ Convenience: pre-compute ground truth rows βββββββββββββββββββββββββββββ
def compute_ground_truth(
conn: sqlite3.Connection,
sql: str,
) -> List[Dict[str, Any]]:
"""Execute the ground-truth SQL and return normalised rows."""
rows, error = execute_query(conn, sql)
if error or rows is None:
raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}")
return _normalise_rows(rows)
def has_order_by(sql: str) -> bool:
"""Heuristic: does the top-level query have an ORDER BY?"""
# Simple check sufficient for our controlled task SQL
return "ORDER BY" in sql.upper()
|