Melika Kheirieh commited on
Commit
ebc7457
Β·
1 Parent(s): 1615704

feat(benchmarks): add pro evaluator with EM, structural match, execution accuracy, and safety consistency metrics

Browse files
Files changed (1) hide show
  1. benchmarks/evaluate_spider_pro.py +309 -0
benchmarks/evaluate_spider_pro.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full benchmark for NL2SQL pipeline.
3
+
4
+ Metrics:
5
+ - EM (exact match)
6
+ - Structural Match (sqlglot AST)
7
+ - Execution Accuracy
8
+ - Safety consistency (pipeline vs AST)
9
+ - Latency (end-to-end) + per-stage trace (via pipeline if available)
10
+
11
+ Outputs:
12
+ JSONL (logs), JSON (summary), CSV (compact table)
13
+
14
+ Run example:
15
+ python benchmarks/evaluate_spider_pro.py --limit 10 --sleep 0.1 --db sqlite --adapter data/chinook.db
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import csv
22
+ import json
23
+ import sqlite3
24
+ import time
25
+ from pathlib import Path
26
+ from typing import Any, Dict, List, Optional, cast
27
+
28
+ import sqlglot
29
+ from sqlglot.errors import ParseError
30
+
31
+ # Reuse existing factories from FastAPI router (no new DI needed)
32
+ from app.routers.nl2sql import ( # type: ignore
33
+ _pipeline as DEFAULT_PIPELINE,
34
+ _build_pipeline,
35
+ _select_adapter,
36
+ )
37
+ from nl2sql.safety import Safety
38
+
39
+
40
+ # -------------------- Helpers --------------------
41
+
42
+
43
+ def _int_ms(start: float) -> int:
44
+ return int((time.perf_counter() - start) * 1000)
45
+
46
+
47
+ def _parse_sql(sql: str) -> Optional[sqlglot.Expression]:
48
+ try:
49
+ return sqlglot.parse_one(sql, read="sqlite")
50
+ except ParseError:
51
+ return None
52
+
53
+
54
+ def _is_structural_match(sql1: str, sql2: str) -> bool:
55
+ a, b = _parse_sql(sql1), _parse_sql(sql2)
56
+ return (a == b) if (a is not None and b is not None) else False
57
+
58
+
59
+ def _exec_sql(conn: sqlite3.Connection, sql: str) -> List[tuple]:
60
+ try:
61
+ cur = conn.execute(sql)
62
+ return [tuple(r) for r in cur.fetchall()]
63
+ except Exception:
64
+ return []
65
+
66
+
67
+ def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
68
+ for attr in ("executor", "adapter"):
69
+ obj = getattr(pipeline_obj, attr, None)
70
+ if obj is not None and hasattr(obj, "derive_schema_preview"):
71
+ try:
72
+ # type: ignore[no-any-return]
73
+ return obj.derive_schema_preview() # pragma: no cover
74
+ except Exception:
75
+ pass
76
+ return None
77
+
78
+
79
+ def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
80
+ """
81
+ Normalize pipeline trace (list of dataclass or dict) to:
82
+ [{'stage': str, 'ms': int}, ...]
83
+ """
84
+ stages: List[Dict[str, Any]] = []
85
+ if not isinstance(trace_obj, list):
86
+ return stages
87
+
88
+ for t in trace_obj:
89
+ if isinstance(t, dict):
90
+ stage = t.get("stage", "?")
91
+ ms = t.get("duration_ms", 0)
92
+ else:
93
+ stage = getattr(t, "stage", "?")
94
+ ms = getattr(t, "duration_ms", 0)
95
+ try:
96
+ stages.append({"stage": str(stage), "ms": int(ms)})
97
+ except Exception:
98
+ stages.append({"stage": str(stage), "ms": 0})
99
+ return stages
100
+
101
+
102
+ # -------------------- Main --------------------
103
+
104
+
105
+ def main() -> None:
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--limit", type=int, default=10, help="Max number of examples")
108
+ parser.add_argument("--resume", type=int, default=0, help="Skip first N examples")
109
+ parser.add_argument(
110
+ "--sleep", type=float, default=0.0, help="Delay (seconds) between queries"
111
+ )
112
+ parser.add_argument(
113
+ "--split", type=str, default="test", help="Dataset split (placeholder)"
114
+ )
115
+ parser.add_argument(
116
+ "--db", type=str, default="sqlite", help="Database ID (e.g., sqlite/postgres)"
117
+ )
118
+ parser.add_argument(
119
+ "--adapter",
120
+ type=str,
121
+ default="data/chinook.db",
122
+ help="SQLite file path for local eval",
123
+ )
124
+ args = parser.parse_args()
125
+
126
+ # SQLite connection for execution-accuracy
127
+ conn = sqlite3.connect(args.adapter)
128
+
129
+ # Build pipeline from router factories
130
+ try:
131
+ adapter = _select_adapter(args.db)
132
+ pipeline = _build_pipeline(adapter)
133
+ using_default = False
134
+ except Exception:
135
+ pipeline = DEFAULT_PIPELINE
136
+ using_default = True
137
+
138
+ safety = Safety()
139
+ schema_preview = _derive_schema_preview_safe(pipeline)
140
+ print(f"βœ… Pipeline ready (db={args.db}, default={using_default})")
141
+
142
+ # Minimal sample dataset for demonstration; replace with real Spider subset if available
143
+ DATASET: List[Dict[str, Any]] = [
144
+ {
145
+ "id": 1,
146
+ "question": "list all customers",
147
+ "gold_sql": "SELECT * FROM customers;",
148
+ },
149
+ {
150
+ "id": 2,
151
+ "question": "top 3 albums by total sales",
152
+ "gold_sql": """
153
+ SELECT a.Title, SUM(i.Quantity * i.UnitPrice) AS total
154
+ FROM albums a
155
+ JOIN tracks t ON a.AlbumId = t.AlbumId
156
+ JOIN invoice_items i ON t.TrackId = i.TrackId
157
+ GROUP BY a.AlbumId
158
+ ORDER BY total DESC
159
+ LIMIT 3;
160
+ """,
161
+ },
162
+ {
163
+ "id": 3,
164
+ "question": "number of employees per city",
165
+ "gold_sql": """
166
+ SELECT City, COUNT(*) AS cnt
167
+ FROM employees
168
+ GROUP BY City
169
+ ORDER BY cnt DESC;
170
+ """,
171
+ },
172
+ ]
173
+
174
+ sliced = DATASET[args.resume : args.resume + args.limit]
175
+
176
+ # Eval loop
177
+ results: List[Dict[str, Any]] = []
178
+ for idx, ex in enumerate(sliced, start=1):
179
+ qid = cast(int, ex.get("id", idx))
180
+ q: str = cast(str, ex.get("question", ""))
181
+ gold_sql: str = cast(str, ex.get("gold_sql", "")).strip()
182
+ print(f"\n[{idx}] {q}")
183
+
184
+ t0 = time.perf_counter()
185
+ try:
186
+ out = pipeline.run(user_query=q, schema_preview=(schema_preview or "")) # type: ignore[misc]
187
+ latency = _int_ms(t0)
188
+
189
+ # Safely extract predicted SQL:
190
+ sql_pred_obj = getattr(out, "sql", None)
191
+ if sql_pred_obj is None:
192
+ data_obj = getattr(out, "data", None)
193
+ if data_obj is not None:
194
+ sql_pred_obj = getattr(data_obj, "sql", None)
195
+
196
+ sql_pred: str = str(sql_pred_obj) if sql_pred_obj is not None else ""
197
+ if not sql_pred.strip():
198
+ raise ValueError("No SQL generated")
199
+
200
+ # Metrics
201
+ em = sql_pred.strip().lower() == gold_sql.strip().lower()
202
+ sm = _is_structural_match(sql_pred, gold_sql)
203
+
204
+ safe_ast = safety.check(sql_pred) # pipeline has its own safety as well
205
+ safe_pipeline = bool(getattr(out, "ok", True))
206
+ safety_consistent = safe_ast.ok == safe_pipeline
207
+
208
+ gold_exec = _exec_sql(conn, gold_sql)
209
+ pred_exec = _exec_sql(conn, sql_pred)
210
+ exec_acc = gold_exec == pred_exec
211
+
212
+ stages = _to_stage_list(getattr(out, "trace", None))
213
+
214
+ results.append(
215
+ {
216
+ "id": qid,
217
+ "question": q,
218
+ "sql_pred": sql_pred,
219
+ "sql_gold": gold_sql,
220
+ "em": em,
221
+ "sm": sm,
222
+ "exec_acc": exec_acc,
223
+ "safety_consistent": safety_consistent,
224
+ "latency_ms": latency,
225
+ "trace": stages,
226
+ "error": None,
227
+ }
228
+ )
229
+ print(f"βœ… OK | EM={em} | SM={sm} | Exec={exec_acc} | {latency} ms")
230
+
231
+ except Exception as e:
232
+ latency = _int_ms(t0)
233
+ results.append(
234
+ {
235
+ "id": qid,
236
+ "question": q,
237
+ "sql_pred": None,
238
+ "sql_gold": gold_sql,
239
+ "em": False,
240
+ "sm": False,
241
+ "exec_acc": False,
242
+ "safety_consistent": None,
243
+ "latency_ms": latency,
244
+ "trace": [],
245
+ "error": str(e),
246
+ }
247
+ )
248
+ print(f"❌ Fail ({latency} ms): {e}")
249
+ time.sleep(args.sleep)
250
+
251
+ # Summary
252
+ total = len(results)
253
+ avg_latency = round(sum(r["latency_ms"] for r in results) / max(total, 1), 1)
254
+ em_rate = (sum(1 for r in results if r["em"]) / max(total, 1)) if total else 0.0
255
+ sm_rate = (sum(1 for r in results if r["sm"]) / max(total, 1)) if total else 0.0
256
+ exec_acc_rate = (
257
+ (sum(1 for r in results if r["exec_acc"]) / max(total, 1)) if total else 0.0
258
+ )
259
+
260
+ summary: Dict[str, Any] = {
261
+ "total": total,
262
+ "avg_latency_ms": avg_latency,
263
+ "EM": em_rate,
264
+ "SM": sm_rate,
265
+ "ExecAcc": exec_acc_rate,
266
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
267
+ "db": args.db,
268
+ "using_default_pipeline": using_default,
269
+ }
270
+
271
+ # Persist outputs (timestamped dir)
272
+ out_dir = Path("benchmarks") / "results_pro" / time.strftime("%Y%m%d-%H%M%S")
273
+ out_dir.mkdir(parents=True, exist_ok=True)
274
+
275
+ jsonl_path = out_dir / "spider_eval_pro.jsonl"
276
+ with jsonl_path.open("w", encoding="utf-8") as f:
277
+ for r in results:
278
+ json.dump(r, f, ensure_ascii=False)
279
+ f.write("\n")
280
+
281
+ json_path = out_dir / "summary.json"
282
+ with json_path.open("w", encoding="utf-8") as f:
283
+ json.dump(summary, f, indent=2)
284
+
285
+ csv_path = out_dir / "summary.csv"
286
+ with csv_path.open("w", newline="", encoding="utf-8") as f:
287
+ writer = csv.DictWriter(
288
+ f,
289
+ fieldnames=["id", "question", "em", "sm", "exec_acc", "latency_ms"],
290
+ )
291
+ writer.writeheader()
292
+ for r in results:
293
+ writer.writerow(
294
+ {
295
+ "id": r["id"],
296
+ "question": r["question"],
297
+ "em": "βœ…" if r["em"] else "❌",
298
+ "sm": "βœ…" if r["sm"] else "❌",
299
+ "exec_acc": "βœ…" if r["exec_acc"] else "❌",
300
+ "latency_ms": r["latency_ms"],
301
+ }
302
+ )
303
+
304
+ print("\nπŸ“Š Summary:", json.dumps(summary, indent=2))
305
+ print(f"πŸ’Ύ Saved to:\n- {jsonl_path}\n- {json_path}\n- {csv_path}")
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()