Melika Kheirieh commited on
Commit
598536c
·
1 Parent(s): b0bec17

feat(benchmarks): align Spider eval with config-driven Pipeline and native Safety; log per-stage trace; add CSV summary

Browse files
Files changed (2) hide show
  1. .coverage +0 -0
  2. benchmarks/evaluate_spider.py +115 -442
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
benchmarks/evaluate_spider.py CHANGED
@@ -1,452 +1,125 @@
1
- from __future__ import annotations
 
 
 
 
2
 
3
  import json
4
- import subprocess
5
  import time
6
  from pathlib import Path
7
- from typing import Any, Iterable, Optional, Tuple, cast
8
-
9
- from tqdm import tqdm
10
- from langchain_community.utilities import SQLDatabase
11
- from sqlglot import parse_one, exp
12
- from sqlglot.errors import ParseError
13
- from sqlalchemy import create_engine, inspect
14
- from spider_loader import load_spider_sqlite
15
-
16
-
17
- def _try_import_pipeline():
18
- """
19
- Try multiple plausible entrypoints from nl2sql.
20
- Returns a tuple of callables or None:
21
- (make_pipeline | None, run_function | None, PipelineClass | None)
22
- """
23
- make_pipeline = None
24
- run_fn = None
25
- PipelineCls = None
26
- try:
27
- from nl2sql.pipeline import make_pipeline as _mk # type: ignore
28
-
29
- make_pipeline = _mk
30
- except Exception:
31
- pass
32
- try:
33
- from nl2sql.pipeline import run_nl2sql as _run # type: ignore
34
-
35
- run_fn = _run
36
- except Exception:
37
- pass
38
- try:
39
- from nl2sql.pipeline import Pipeline as _P # type: ignore
40
-
41
- PipelineCls = _P
42
- except Exception:
43
- pass
44
- return make_pipeline, run_fn, PipelineCls
45
-
46
-
47
- LOG_DIR = Path("logs/spider_eval")
48
- LOG_DIR.mkdir(parents=True, exist_ok=True)
49
-
50
- FORBIDDEN_NODES: Tuple[type, ...] = (
51
- exp.Insert,
52
- exp.Delete,
53
- exp.Update,
54
- exp.Drop,
55
- exp.Alter,
56
- exp.Attach,
57
- exp.Pragma,
58
- exp.Create,
59
- )
60
-
61
-
62
- def normalize_sql(sql: str) -> str:
63
- return " ".join(sql.lower().strip().split())
64
-
65
-
66
- def compare_results(
67
- pred_rows: Optional[Iterable[Any]], gold_rows: Optional[Iterable[Any]]
68
- ) -> bool:
69
- if pred_rows is None or gold_rows is None:
70
- return False
71
- return set(pred_rows) == set(gold_rows)
72
-
73
-
74
- def try_execute_sql(
75
- sql_db: SQLDatabase,
76
- sql: str,
77
- timeout: Optional[float] = None, # kept for API compatibility
78
- ) -> tuple[Optional[list[tuple[Any, ...]]], float, Optional[str]]:
79
- start = time.time()
80
  try:
81
- raw_rows = sql_db.run(sql)
82
-
83
- # Normalize result shape for MyPy and downstream code
84
- if isinstance(raw_rows, list):
85
- rows = [tuple(r) for r in raw_rows]
86
- elif isinstance(raw_rows, tuple):
87
- rows = [tuple(raw_rows)]
88
- else:
89
- # Fallback cast — if library returns ResultSet or something similar
90
- rows = cast(list[tuple[Any, ...]], raw_rows)
91
-
92
- return rows, time.time() - start, None
93
-
94
- except Exception as e:
95
- return None, time.time() - start, str(e)
96
-
97
-
98
- def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
99
- try:
100
- ast_pred = parse_one(sql_pred)
101
- ast_gold = parse_one(sql_gold)
102
- except Exception:
103
- return False
104
-
105
- def normalize_ast(node: exp.Expression) -> exp.Expression:
106
- for name, arg in node.args.items():
107
- if isinstance(arg, list):
108
- arg.sort(key=lambda x: str(x))
109
- for child in arg:
110
- normalize_ast(child)
111
- elif isinstance(arg, exp.Expression):
112
- normalize_ast(arg)
113
- if isinstance(node, exp.Alias):
114
- return normalize_ast(node.this)
115
- return node
116
-
117
- norm_prd = normalize_ast(ast_pred)
118
- norm_gold = normalize_ast(ast_gold)
119
- return norm_prd == norm_gold
120
-
121
-
122
- def get_git_commit_hash() -> str:
123
- try:
124
- out = (
125
- subprocess.check_output(["git", "rev-parse", "HEAD"])
126
- .strip()
127
- .decode("ascii")
128
- )
129
- return out
130
- except Exception:
131
- return "UNKNOWN"
132
-
133
-
134
- def is_safe_sql(sql: str, dialect: Optional[str] = None) -> bool:
135
- try:
136
- ast = parse_one(sql, read=dialect)
137
- except ParseError:
138
- return False
139
- if not isinstance(ast, exp.Select):
140
- return False
141
- for node in ast.walk():
142
- if isinstance(node, FORBIDDEN_NODES):
143
- return False
144
- return True
145
-
146
 
147
- # --- جایگزین get_schema_preview از app.routers ---
148
- def get_schema_preview_sqlalchemy(db_path: str, max_cols: int = 0) -> str:
149
- """
150
- Lightweight schema preview using SQLAlchemy inspector.
151
- max_cols=0 => unlimited
152
- """
153
- engine = create_engine(f"sqlite:///{db_path}")
154
- insp = inspect(engine)
155
- lines: list[str] = []
156
- for tbl in sorted(insp.get_table_names()):
157
- cols = insp.get_columns(tbl)
158
- if max_cols > 0:
159
- cols = cols[:max_cols]
160
- col_str = ", ".join(f"{c['name']}:{c.get('type')}" for c in cols)
161
- pks = insp.get_pk_constraint(tbl).get("constrained_columns") or []
162
- pk_str = f" | PK: {', '.join(pks)}" if pks else ""
163
- fks = insp.get_foreign_keys(tbl)
164
- fk_str = ""
165
- if fks:
166
- fks_desc = []
167
- for fk in fks:
168
- ref = fk.get("referred_table")
169
- cols_fk = ", ".join(fk.get("constrained_columns") or [])
170
- ref_cols = ", ".join(fk.get("referred_columns") or [])
171
- fks_desc.append(f"{cols_fk} -> {ref}({ref_cols})")
172
- fk_str = " | FK: " + " ; ".join(fks_desc)
173
- lines.append(f"{tbl}({col_str}){pk_str}{fk_str}")
174
- engine.dispose()
175
- return "\n".join(lines)
176
-
177
-
178
- def _generate_sql(
179
- question: str, sql_db: SQLDatabase, schema_text: str, max_output_tokens: int = 1000
180
- ) -> tuple[str, str, dict[str, Any]]:
181
- """
182
- Returns: (status_msg, sql_text, extra_output)
183
- Strategy:
184
- 1) If nl2sql.pipeline.run_nl2sql exists: call it.
185
- 2) Else if nl2sql.pipeline.make_pipeline exists: build and run.
186
- 3) Else if nl2sql.pipeline.Pipeline exists: instantiate minimal pipeline and run.
187
- 4) Else: raise NotImplementedError.
188
- """
189
- make_pipeline, run_fn, PipelineCls = _try_import_pipeline()
190
-
191
- # Case 1: direct run function
192
- if run_fn is not None:
193
- res = run_fn(
194
- question=question,
195
- schema_text=schema_text,
196
- sql_db=sql_db,
197
- max_output_tokens=max_output_tokens,
198
  )
199
- # Expecting a dict-like or object with attributes; normalize:
200
- if isinstance(res, dict):
201
- msg = res.get("status", "ok")
202
- sql = res.get("sql", "")
203
- return msg, sql, res
204
- # fallback generic
205
- msg = getattr(res, "status", "ok")
206
- sql = getattr(res, "sql", "")
207
- return msg, sql, {"result": res}
208
-
209
- # Case 2: factory + run
210
- if make_pipeline is not None:
211
- pipe = make_pipeline(sql_db=sql_db, schema_text=schema_text) # type: ignore[arg-type]
212
- # Common conventions:
213
- if hasattr(pipe, "run"):
214
- out = pipe.run(question) # type: ignore[call-arg]
215
- elif hasattr(pipe, "execute"):
216
- out = pipe.execute(question) # type: ignore[call-arg]
217
- else:
218
- raise RuntimeError("Pipeline object has no run/execute()")
219
- msg = getattr(out, "status", "ok")
220
- sql = getattr(out, "sql", "")
221
- return msg, sql, {"result": out}
222
-
223
- # Case 3: class-based pipeline
224
- if PipelineCls is not None:
225
- # Try minimal constructor names; adjust to your class signature if needed
226
- # We pass what we have; extra kwargs should be ignored or have defaults.
227
- pipe = PipelineCls(sql_db=sql_db, schema_text=schema_text)
228
- if hasattr(pipe, "run"):
229
- out = pipe.run(question) # type: ignore[call-arg]
230
- else:
231
- raise RuntimeError("Pipeline class has no run()")
232
- msg = getattr(out, "status", "ok")
233
- sql = getattr(out, "sql", "")
234
- return msg, sql, {"result": out}
235
-
236
- raise NotImplementedError(
237
- "Cannot locate a public NL2SQL entrypoint in nl2sql.pipeline. "
238
- "Expose one of: run_nl2sql(), make_pipeline(), or Pipeline.run()."
239
- )
240
-
241
-
242
- def run_eval(
243
- split: str = "dev", limit: int = 100, resume: bool = True, sleep_time: float = 0.01
244
- ) -> None:
245
- data = load_spider_sqlite(split)
246
- if len(data) < limit:
247
- limit = len(data)
248
- data = data[:limit]
249
- print(f"Running eval on {len(data)} examples in split={split}...")
250
-
251
- commit_hash = get_git_commit_hash()
252
- start_ts = int(time.time())
253
-
254
- pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
255
- gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
256
- results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
257
- metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
258
-
259
- done: set[tuple[str, str]] = set()
260
- if resume and results_fn.exists():
261
- with results_fn.open("r", encoding="utf-8") as f:
262
- for line in f:
263
- if line.startswith("#"):
264
- continue
265
- try:
266
- r = json.loads(line)
267
- done.add((r.get("db_id"), r.get("question")))
268
- except Exception:
269
- pass
270
-
271
- write_header = not results_fn.exists()
272
- agg: list[dict[str, Any]] = []
273
-
274
- with (
275
- results_fn.open("a", encoding="utf-8") as fout,
276
- pred_txt.open("a", encoding="utf-8") as fpred,
277
- gold_txt.open("a", encoding="utf-8") as fgold,
278
- ):
279
- if write_header:
280
- header = {
281
- "commit_hash": commit_hash,
282
- "split": split,
283
- "limit": limit,
284
- "start_time": start_ts,
285
  }
286
- fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
287
- fout.flush()
288
-
289
- for ex in tqdm(data):
290
- key = (ex.db_id, ex.question)
291
- if resume and key in done:
292
- continue
293
-
294
- db_path = str(ex.db_path)
295
- schema = get_schema_preview_sqlalchemy(db_path, max_cols=0)
296
- sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
297
-
298
- t0 = time.time()
299
- try:
300
- msg, sql, output = _generate_sql(
301
- ex.question, sql_db, schema, max_output_tokens=1000
302
- )
303
- except NotImplementedError as e:
304
- rec = {
305
- "db_id": ex.db_id,
306
- "question": ex.question,
307
- "gold_sql": ex.gold_sql,
308
- "pred_sql": "",
309
- "status": "no_entrypoint",
310
- "output": {"error": str(e)},
311
- "gen_time": time.time() - t0,
312
- "exec_time": None,
313
- "error": "no_entrypoint",
314
- "gold_error": None,
315
- "pred_rows": None,
316
- "gold_rows": None,
317
- "exact_match": False,
318
- "exact_match_structural": False,
319
- "execution_accuracy": False,
320
- "safe_check_failed": True,
321
- }
322
- fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
323
- fout.flush()
324
- fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
325
- fgold.flush()
326
- agg.append(rec)
327
- if sleep_time > 0:
328
- time.sleep(sleep_time)
329
- continue
330
-
331
- gen_time = time.time() - t0
332
-
333
- safe_flag = is_safe_sql(sql)
334
- if not safe_flag:
335
- rec = {
336
- "db_id": ex.db_id,
337
- "question": ex.question,
338
- "gold_sql": ex.gold_sql,
339
- "pred_sql": sql,
340
- "status": "rejected_safe_check",
341
- "output": output,
342
- "gen_time": gen_time,
343
- "exec_time": None,
344
- "error": "unsafe_sql",
345
- "gold_error": None,
346
- "pred_rows": None,
347
- "gold_rows": None,
348
- "exact_match": False,
349
- "exact_match_structural": False,
350
- "execution_accuracy": False,
351
- "safe_check_failed": True,
352
- }
353
- fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
354
- fout.flush()
355
- fpred.write(f"{sql}\t{ex.db_id}\n")
356
- fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
357
- fpred.flush()
358
- fgold.flush()
359
- agg.append(rec)
360
- if sleep_time > 0:
361
- time.sleep(sleep_time)
362
- continue
363
-
364
- pred_rows, exec_time, error = try_execute_sql(sql_db, sql)
365
- gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
366
-
367
- skip = gold_error is not None
368
- em = normalize_sql(sql) == normalize_sql(ex.gold_sql) if not skip else False
369
- em_struct = exact_match_structural(sql, ex.gold_sql) if not skip else False
370
- exec_acc = compare_results(pred_rows, gold_rows) if not skip else False
371
-
372
- rec = {
373
- "db_id": ex.db_id,
374
- "question": ex.question,
375
- "gold_sql": ex.gold_sql,
376
- "pred_sql": sql,
377
- "status": msg,
378
- "output": output,
379
- "gen_time": gen_time,
380
- "exec_time": exec_time,
381
- "error": error,
382
- "gold_error": gold_error,
383
- "pred_rows": pred_rows,
384
- "gold_rows": gold_rows,
385
- "exact_match": em,
386
- "exact_match_structural": em_struct,
387
- "execution_accuracy": exec_acc,
388
- "safe_check_failed": False,
389
  }
390
- fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
391
- fout.flush()
392
- fpred.write(f"{sql}\t{ex.db_id}\n")
393
- fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
394
- fpred.flush()
395
- fgold.flush()
396
- agg.append(rec)
397
-
398
- if sleep_time > 0:
399
- time.sleep(sleep_time)
400
-
401
- valid = [
402
- r
403
- for r in agg
404
- if (not r.get("safe_check_failed", False)) and (r.get("gold_error") is None)
405
- ]
406
- total_valid = len(valid)
407
- total_all = len(agg)
408
- if total_valid == 0:
409
- print("No valid examples to compute metrics")
410
- return
411
-
412
- em_count = sum(1 for r in valid if r["exact_match"])
413
- em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
414
- exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
415
- error_count = sum(
416
- 1
417
- for r in agg
418
- if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
419
- )
420
- safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
421
- avg_gen_time = sum(float(r["gen_time"]) for r in valid) / total_valid
422
- avg_exec_time = sum(float(r["exec_time"]) for r in valid) / total_valid
423
-
424
- metrics = {
425
- "commit_hash": commit_hash,
426
- "split": split,
427
- "limit": limit,
428
- "total_examples": total_all,
429
- "valid_examples": total_valid,
430
- "exact_match_rate": em_count / total_valid,
431
- "exact_match_structural_rate": em_struct_count / total_valid,
432
- "execution_accuracy_rate": exec_acc_count / total_valid,
433
- "error_rate": error_count / total_valid,
434
- "safe_check_fail_rate": safe_fail_count / total_all,
435
- "avg_gen_time": avg_gen_time,
436
- "avg_exec_time": avg_exec_time,
437
- "run_id": start_ts,
438
- }
439
-
440
- metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
441
- with metrics_fn.open("w", encoding="utf-8") as fm:
442
- json.dump(metrics, fm, ensure_ascii=False, indent=2)
443
-
444
- print("Metrics:", metrics)
445
- print(f"Wrote results → {results_fn}")
446
- print(f"Wrote pred file → {pred_txt}")
447
- print(f"Wrote gold file → {gold_txt}")
448
- print(f"Wrote metrics → {metrics_fn}")
449
-
450
 
451
- if __name__ == "__main__":
452
- run_eval("dev", limit=10, resume=True, sleep_time=0.05)
 
1
+ """
2
+ Evaluate NL2SQL pipeline performance on Spider-like queries.
3
+ Uses config-driven Pipeline, native Safety checks, and per-stage latency tracing.
4
+ Outputs: JSONL (detailed logs), JSON (metrics summary), and CSV (for README).
5
+ """
6
 
7
  import json
8
+ import csv
9
  import time
10
  from pathlib import Path
11
+ from nl2sql.pipeline import Pipeline
12
+
13
+ # ---------- Config ----------
14
+ DATASET = [
15
+ "list all customers",
16
+ "show total invoices per country",
17
+ "top 3 albums by total sales",
18
+ "artists with more than 3 albums",
19
+ "number of employees per city",
20
+ ]
21
+
22
+ CONFIG_PATH = "configs/pipeline.yaml"
23
+ RESULT_DIR = Path("benchmarks/results")
24
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
25
+
26
+ # ---------- Initialize pipeline ----------
27
+ pipeline = Pipeline.from_config(CONFIG_PATH)
28
+ print(f"✅ Loaded pipeline from {CONFIG_PATH}")
29
+
30
+ # Optional: schema preview if adapter supports it
31
+ schema_preview = None
32
+ try:
33
+ adapter = getattr(pipeline, "executor", None)
34
+ if adapter and hasattr(adapter, "derive_schema_preview"):
35
+ schema_preview = adapter.derive_schema_preview()
36
+ print("📄 Derived schema preview successfully.")
37
+ except Exception as e:
38
+ print(f"⚠️ Could not derive schema preview: {e}")
39
+
40
+ # ---------- Evaluation ----------
41
+ records = []
42
+ for q in DATASET:
43
+ print(f"\n🧠 Query: {q}")
44
+ start = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
+ result = pipeline.run(user_query=q, schema_preview=schema_preview)
47
+ latency = int((time.perf_counter() - start) * 1000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ trace = getattr(result, "trace", None)
50
+ stages = []
51
+ if trace:
52
+ # trace might be list of StageTrace or dicts
53
+ try:
54
+ for t in trace:
55
+ stages.append(
56
+ {"stage": t.get("stage", "?"), "ms": t.get("duration_ms", 0)}
57
+ if isinstance(t, dict)
58
+ else {
59
+ "stage": getattr(t, "stage", "?"),
60
+ "ms": getattr(t, "duration_ms", 0),
61
+ }
62
+ )
63
+ except Exception:
64
+ pass
65
+
66
+ records.append(
67
+ {
68
+ "query": q,
69
+ "ok": True,
70
+ "latency_ms": latency,
71
+ "trace": stages,
72
+ "error": None,
73
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
+ print(f"✅ Success ({latency} ms)")
76
+ except Exception as e:
77
+ latency = int((time.perf_counter() - start) * 1000)
78
+ records.append(
79
+ {
80
+ "query": q,
81
+ "ok": False,
82
+ "latency_ms": latency,
83
+ "trace": [],
84
+ "error": str(e),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  }
86
+ )
87
+ print(f"❌ Failed: {e} ({latency} ms)")
88
+
89
+ # ---------- Aggregate metrics ----------
90
+ avg_latency = round(sum(r["latency_ms"] for r in records) / len(records), 1)
91
+ success_rate = sum(1 for r in records if r["ok"]) / len(records)
92
+ print(f"\n📊 Average latency: {avg_latency} ms | Success rate: {success_rate:.0%}")
93
+
94
+ summary = {
95
+ "queries_total": len(records),
96
+ "success_rate": success_rate,
97
+ "avg_latency_ms": avg_latency,
98
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
99
+ }
100
+
101
+ # ---------- Save outputs ----------
102
+ jsonl_path = RESULT_DIR / "spider_eval.jsonl"
103
+ with open(jsonl_path, "w", encoding="utf-8") as f:
104
+ for r in records:
105
+ json.dump(r, f, ensure_ascii=False)
106
+ f.write("\n")
107
+
108
+ summary_path = RESULT_DIR / "metrics_summary.json"
109
+ with open(summary_path, "w", encoding="utf-8") as f:
110
+ json.dump(summary, f, indent=2)
111
+
112
+ csv_path = RESULT_DIR / "results.csv"
113
+ with open(csv_path, "w", newline="", encoding="utf-8") as f:
114
+ writer = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
115
+ writer.writeheader()
116
+ for r in records:
117
+ writer.writerow(
118
+ {
119
+ "query": r["query"],
120
+ "ok": "✅" if r["ok"] else "❌",
121
+ "latency_ms": r["latency_ms"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  }
123
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ print(f"\n💾 Saved logs to:\n- {jsonl_path}\n- {summary_path}\n- {csv_path}")