Melika Kheirieh commited on
Commit
454d146
Β·
1 Parent(s): 8103714

fix(grafana): move nl2sql.json into provisioning folder and fix dashboard mount path

Browse files
benchmarks/evaluate_spider.py CHANGED
@@ -1,73 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
3
  import csv
4
  import json
5
  import os
6
  import time
7
  from pathlib import Path
8
  from typing import Any, Dict, List, Optional
 
 
 
 
9
 
10
- # Reuse existing factories from your FastAPI router (no new DI needed)
11
- from app.routers.nl2sql import ( # type: ignore
12
- _pipeline as DEFAULT_PIPELINE,
13
- _build_pipeline,
14
- _select_adapter,
15
- )
16
 
17
- # -------------------- Config --------------------
 
 
 
18
 
19
- DATASET: List[str] = [
20
  "list all customers",
21
  "show total invoices per country",
22
  "top 3 albums by total sales",
23
  "artists with more than 3 albums",
24
  "number of employees per city",
25
  ]
 
 
26
 
27
- # DB id/mode follows your router convention; adjust if needed
28
- DB_ID: str = os.getenv("DB_MODE", "sqlite")
29
-
30
- # Results directory with timestamped subfolder (keeps previous runs)
31
  RESULT_ROOT = Path("benchmarks") / "results"
32
  TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
33
  RESULT_DIR = RESULT_ROOT / TIMESTAMP
34
 
35
 
36
- # -------------------- Helpers --------------------
37
-
38
-
39
  def _int_ms(start: float) -> int:
 
40
  return int((time.perf_counter() - start) * 1000)
41
 
42
 
43
  def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
44
- """
45
- Try to derive schema preview from the adapter/executor if such a method exists.
46
- Kept intentionally permissive to avoid tight coupling.
47
- """
48
  try:
49
- # common places the adapter might live
50
- candidates: List[Any] = [
51
  getattr(pipeline_obj, "executor", None),
52
  getattr(pipeline_obj, "adapter", None),
53
  ]
54
  for c in candidates:
55
  if c and hasattr(c, "derive_schema_preview"):
56
- return c.derive_schema_preview() # type: ignore[no-any-return, call-arg]
57
  except Exception:
58
  pass
59
  return None
60
 
61
 
62
  def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
63
- """
64
- Normalize pipeline trace (list of dataclass or dict) to a list of dicts:
65
- [{ "stage": str, "ms": int }, ...]
66
- """
67
- stages: List[Dict[str, Any]] = []
68
  if not isinstance(trace_obj, list):
69
- return stages
70
-
71
  for t in trace_obj:
72
  if isinstance(t, dict):
73
  stage = t.get("stage", "?")
@@ -76,126 +96,303 @@ def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
76
  stage = getattr(t, "stage", "?")
77
  ms = getattr(t, "duration_ms", 0)
78
  try:
79
- stages.append({"stage": str(stage), "ms": int(ms)})
80
  except Exception:
81
- stages.append({"stage": str(stage), "ms": 0})
82
- return stages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
- # -------------------- Main --------------------
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
- def main() -> None:
 
89
  RESULT_DIR.mkdir(parents=True, exist_ok=True)
90
 
91
- # Build pipeline from router factories (no new DI required)
92
- try:
93
- adapter = _select_adapter(DB_ID) # e.g., "sqlite" / "postgres"
94
- pipeline = _build_pipeline(adapter)
95
- using_default = False
96
- except Exception:
97
- pipeline = DEFAULT_PIPELINE
98
- using_default = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  print(
101
- f"βœ… Pipeline ready "
102
- f"(db_id={DB_ID}, source={'default' if using_default else 'custom adapter'})"
 
 
 
 
103
  )
104
 
105
- # Optional schema preview
 
 
 
 
 
106
  schema_preview = _derive_schema_preview_safe(pipeline)
107
  if schema_preview:
108
  print("πŸ“„ Derived schema preview βœ“")
109
  else:
110
  print("ℹ️ No schema preview (adapter does not expose it or not needed)")
111
 
112
- # Evaluate
113
- records: List[Dict[str, Any]] = []
114
- for q in DATASET:
115
  print(f"\n🧠 Query: {q}")
116
  t0 = time.perf_counter()
117
  try:
118
- result = pipeline.run(
119
- user_query=q,
120
- schema_preview=schema_preview or "", # <- force str
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
- latency_ms = _int_ms(t0)
123
-
124
- # ok flag -> coerce to bool for mypy and consistency
125
- ok_flag = bool(getattr(result, "ok", True))
126
- stages = _to_stage_list(getattr(result, "trace", None))
127
-
128
- rec: Dict[str, Any] = {
129
- "query": q,
130
- "ok": ok_flag,
131
- "latency_ms": latency_ms,
132
- "trace": stages,
133
- "error": None,
134
- }
135
- records.append(rec)
136
  print(f"βœ… Success ({latency_ms} ms)")
137
  except Exception as exc:
138
- latency_ms = _int_ms(t0)
139
- rec = {
140
- "query": q,
141
- "ok": False,
142
- "latency_ms": latency_ms,
143
- "trace": [],
144
- "error": str(exc),
145
- }
146
- records.append(rec)
 
 
 
147
  print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
148
 
149
- # Aggregate metrics
150
- avg_latency = (
151
- round(sum(r["latency_ms"] for r in records) / max(len(records), 1), 1)
152
- if records
153
- else 0.0
154
- )
155
- success_rate = (
156
- sum(1 for r in records if bool(r.get("ok"))) / max(len(records), 1)
157
- if records
158
- else 0.0
159
- )
160
-
161
- summary: Dict[str, Any] = {
162
- "queries_total": len(records),
163
- "success_rate": success_rate,
164
- "avg_latency_ms": avg_latency,
165
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
166
- "db_id": DB_ID,
167
- "pipeline_source": "default" if using_default else "adapter",
168
  }
 
169
 
170
- # Persist outputs
171
- jsonl_path = RESULT_DIR / "spider_eval.jsonl"
172
- with jsonl_path.open("w", encoding="utf-8") as f:
173
- for r in records:
174
- json.dump(r, f, ensure_ascii=False)
175
- f.write("\n")
176
 
177
- summary_path = RESULT_DIR / "metrics_summary.json"
178
- with summary_path.open("w", encoding="utf-8") as f:
179
- json.dump(summary, f, indent=2)
 
 
 
180
 
181
- csv_path = RESULT_DIR / "results.csv"
182
- with csv_path.open("w", newline="", encoding="utf-8") as f:
183
- writer = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
184
- writer.writeheader()
185
- for r in records:
186
- writer.writerow(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  {
188
- "query": r["query"],
189
- "ok": "βœ…" if bool(r["ok"]) else "❌",
190
- "latency_ms": int(r["latency_ms"]),
 
 
 
 
191
  }
192
  )
 
193
 
194
- print(
195
- "\nπŸ’Ύ Saved outputs:\n"
196
- f"- {jsonl_path}\n- {summary_path}\n- {csv_path}\n"
197
- f"πŸ“Š Avg latency: {avg_latency} ms | Success rate: {success_rate:.0%}"
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
 
201
  if __name__ == "__main__":
 
1
+ """
2
+ Lightweight eval runner for two modes:
3
+ 1) Single-DB demo mode (default): run a list of questions against one SQLite DB.
4
+ 2) Spider mode (--spider): load a subset of the Spider dataset and run each question
5
+ against its own database (resolved via SPIDER_ROOT).
6
+
7
+ - Uses your official pipeline factory (no app/router imports).
8
+ - Works with real LLM (OPENAI_API_KEY) or stub mode (PYTEST_CURRENT_TEST=1).
9
+ - Produces JSONL + JSON summary + CSV under benchmarks/results/<timestamp>/
10
+
11
+ Examples:
12
+ # Demo (single DB), stub mode
13
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
14
+ python benchmarks/evaluate_spider.py --db-path demo.db
15
+
16
+ # Spider subset (20 items), stub mode
17
+ export SPIDER_ROOT=$PWD/data/spider
18
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
19
+ python benchmarks/evaluate_spider.py --spider --split dev --limit 20
20
+ Notes:
21
+ - In stub mode, all LLM calls are mocked for offline evaluation.
22
+ - Results are saved under benchmarks/results/<timestamp>/.
23
+ """
24
+
25
  from __future__ import annotations
26
 
27
+ import argparse
28
  import csv
29
  import json
30
  import os
31
  import time
32
  from pathlib import Path
33
  from typing import Any, Dict, List, Optional
34
+ import sqlite3
35
+
36
+ from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
37
+ from adapters.db.sqlite_adapter import SQLiteAdapter
38
 
39
+ # Only needed in --spider mode
40
+ try:
41
+ from benchmarks.spider_loader import load_spider_sqlite, open_readonly_connection
42
+ except Exception:
43
+ load_spider_sqlite = None # type: ignore[assignment]
44
+ open_readonly_connection = None # type: ignore[assignment]
45
 
46
+ # Resolve repo root and default config path relative to this file (not CWD)
47
+ THIS_DIR = Path(__file__).resolve().parent # .../benchmarks
48
+ REPO_ROOT = THIS_DIR.parent # repo root
49
+ CONFIG_PATH = str(REPO_ROOT / "configs" / "sqlite_pipeline.yaml")
50
 
51
+ DEFAULT_DATASET: List[str] = [
52
  "list all customers",
53
  "show total invoices per country",
54
  "top 3 albums by total sales",
55
  "artists with more than 3 albums",
56
  "number of employees per city",
57
  ]
58
+ # Back-compat for tests: monkeypatchable dataset at module level
59
+ DATASET: List[str] = list(DEFAULT_DATASET)
60
 
 
 
 
 
61
  RESULT_ROOT = Path("benchmarks") / "results"
62
  TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
63
  RESULT_DIR = RESULT_ROOT / TIMESTAMP
64
 
65
 
 
 
 
66
  def _int_ms(start: float) -> int:
67
+ """Convert elapsed seconds to integer milliseconds."""
68
  return int((time.perf_counter() - start) * 1000)
69
 
70
 
71
  def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
72
+ """Safely call derive_schema_preview() if available on adapter/executor."""
 
 
 
73
  try:
74
+ candidates = [
 
75
  getattr(pipeline_obj, "executor", None),
76
  getattr(pipeline_obj, "adapter", None),
77
  ]
78
  for c in candidates:
79
  if c and hasattr(c, "derive_schema_preview"):
80
+ return c.derive_schema_preview() # type: ignore[no-any-return]
81
  except Exception:
82
  pass
83
  return None
84
 
85
 
86
  def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
87
+ """Normalize pipeline trace into a list of dicts for logging/CSV export."""
88
+ out: List[Dict[str, Any]] = []
 
 
 
89
  if not isinstance(trace_obj, list):
90
+ return out
 
91
  for t in trace_obj:
92
  if isinstance(t, dict):
93
  stage = t.get("stage", "?")
 
96
  stage = getattr(t, "stage", "?")
97
  ms = getattr(t, "duration_ms", 0)
98
  try:
99
+ out.append({"stage": str(stage), "ms": int(ms)})
100
  except Exception:
101
+ out.append({"stage": str(stage), "ms": 0})
102
+ return out
103
+
104
+
105
+ def _load_dataset_from_file(path: Optional[str]) -> List[str]:
106
+ """
107
+ Load dataset questions.
108
+ Accepts either a list of strings or a list of {"question": "..."} objects.
109
+ """
110
+ if not path:
111
+ # Use module-level DATASET so tests can monkeypatch it
112
+ return list(DATASET)
113
+
114
+ p = Path(path)
115
+ if not p.exists():
116
+ raise FileNotFoundError(f"dataset file not found: {p}")
117
+ data = json.loads(p.read_text(encoding="utf-8"))
118
+ if isinstance(data, list):
119
+ if all(isinstance(x, str) for x in data):
120
+ return list(data)
121
+ if all(isinstance(x, dict) and "question" in x for x in data):
122
+ return [str(x["question"]) for x in data]
123
+ raise ValueError(
124
+ "Dataset file must be a JSON array of strings or objects with 'question' field."
125
+ )
126
 
127
 
128
+ def _ensure_demo_db(db_path: Path) -> None:
129
+ """Create an empty SQLite DB for demo runs if it doesn't exist."""
130
+ if db_path.exists():
131
+ return
132
+ db_path.parent.mkdir(parents=True, exist_ok=True)
133
+ conn = sqlite3.connect(str(db_path))
134
+ try:
135
+ # Keep it minimal; SELECT 1 works without any tables.
136
+ conn.execute("SELECT 1;")
137
+ finally:
138
+ conn.close()
139
 
140
 
141
+ def _save_outputs(rows: List[Dict[str, Any]], meta: Dict[str, Any]) -> None:
142
+ """Persist JSONL + JSON summary + CSV (write both new and legacy filenames)."""
143
  RESULT_DIR.mkdir(parents=True, exist_ok=True)
144
 
145
+ # Filenames (new + legacy for back-compat with tests)
146
+ jsonl_path = RESULT_DIR / "eval.jsonl"
147
+ summary_path = RESULT_DIR / "summary.json"
148
+ csv_path = RESULT_DIR / "results.csv"
149
+
150
+ jsonl_path_legacy = RESULT_DIR / "spider_eval.jsonl"
151
+ summary_path_legacy = RESULT_DIR / "metrics_summary.json"
152
+
153
+ # --- Write JSONL (both names) ---
154
+ with jsonl_path.open("w", encoding="utf-8") as f:
155
+ for r in rows:
156
+ json.dump(r, f, ensure_ascii=False)
157
+ f.write("\n")
158
+ # duplicate for legacy name
159
+ with jsonl_path_legacy.open("w", encoding="utf-8") as f:
160
+ for r in rows:
161
+ json.dump(r, f, ensure_ascii=False)
162
+ f.write("\n")
163
+
164
+ # --- Build summary dict ---
165
+ summary = {
166
+ # keep both for compatibility with old tests/consumers
167
+ "queries_total": len(rows),
168
+ "total": len(rows),
169
+ "pipeline_source": meta.get(
170
+ "pipeline_source", "adapter"
171
+ ), # for backward-compat with tests
172
+ "success_rate": (sum(1 for r in rows if r.get("ok")) / max(len(rows), 1))
173
+ if rows
174
+ else 0.0,
175
+ "avg_latency_ms": (
176
+ round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(len(rows), 1), 1)
177
+ )
178
+ if rows
179
+ else 0.0,
180
+ **meta,
181
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
182
+ }
183
+
184
+ # --- Write summary (both names) ---
185
+ with summary_path.open("w", encoding="utf-8") as f:
186
+ json.dump(summary, f, indent=2)
187
+ with summary_path_legacy.open("w", encoding="utf-8") as f:
188
+ json.dump(summary, f, indent=2)
189
+
190
+ # --- Write CSV (single name) ---
191
+ with csv_path.open("w", newline="", encoding="utf-8") as f:
192
+ writer = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
193
+ writer.writeheader()
194
+ for r in rows:
195
+ writer.writerow(
196
+ {
197
+ "query": r.get("query", ""),
198
+ "ok": "βœ…" if r.get("ok") else "❌",
199
+ "latency_ms": int(r.get("latency_ms", 0)),
200
+ }
201
+ )
202
 
203
  print(
204
+ "\nπŸ’Ύ Saved outputs:\n"
205
+ f"- {jsonl_path} (and {jsonl_path_legacy})\n"
206
+ f"- {summary_path} (and {summary_path_legacy})\n"
207
+ f"- {csv_path}\n"
208
+ f"πŸ“Š Avg latency: {summary['avg_latency_ms']} ms | "
209
+ f"Success rate: {summary['success_rate']:.0%}\n"
210
  )
211
 
212
+
213
+ def _run_single_db_mode(db_path: Path, questions: List[str], config_path: str) -> None:
214
+ """Evaluate a list of questions against a single SQLite DB."""
215
+ adapter = SQLiteAdapter(str(db_path))
216
+ pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
217
+
218
  schema_preview = _derive_schema_preview_safe(pipeline)
219
  if schema_preview:
220
  print("πŸ“„ Derived schema preview βœ“")
221
  else:
222
  print("ℹ️ No schema preview (adapter does not expose it or not needed)")
223
 
224
+ rows: List[Dict[str, Any]] = []
225
+ for q in questions:
 
226
  print(f"\n🧠 Query: {q}")
227
  t0 = time.perf_counter()
228
  try:
229
+ result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
230
+ latency_ms = _int_ms(t0) or 1 # clamp to 1ms for nicer CSV in stub mode
231
+ stages = _to_stage_list(
232
+ getattr(result, "traces", getattr(result, "trace", []))
233
+ )
234
+ rows.append(
235
+ {
236
+ "source": "demo",
237
+ "db_id": Path(db_path).stem,
238
+ "query": q,
239
+ "ok": bool(getattr(result, "ok", True)),
240
+ "latency_ms": latency_ms,
241
+ "trace": stages,
242
+ "error": None,
243
+ }
244
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  print(f"βœ… Success ({latency_ms} ms)")
246
  except Exception as exc:
247
+ latency_ms = _int_ms(t0) or 1
248
+ rows.append(
249
+ {
250
+ "source": "demo",
251
+ "db_id": Path(db_path).stem,
252
+ "query": q,
253
+ "ok": False,
254
+ "latency_ms": latency_ms,
255
+ "trace": [],
256
+ "error": str(exc),
257
+ }
258
+ )
259
  print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
260
 
261
+ meta = {
262
+ "mode": "single-db",
263
+ "db_path": str(db_path),
264
+ "config": config_path,
265
+ "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  }
267
+ _save_outputs(rows, meta)
268
 
 
 
 
 
 
 
269
 
270
+ def _run_spider_mode(split: str, limit: int, config_path: str) -> None:
271
+ """Evaluate a Spider subset. Each example points to its own DB under SPIDER_ROOT."""
272
+ if load_spider_sqlite is None or open_readonly_connection is None:
273
+ raise RuntimeError(
274
+ "Spider utilities are not available. Ensure benchmarks/spider_loader.py exists."
275
+ )
276
 
277
+ items = load_spider_sqlite(split=split, limit=limit)
278
+ print(f"πŸ—‚ Loaded {len(items)} Spider items (split={split}).")
279
+
280
+ rows: List[Dict[str, Any]] = []
281
+
282
+ for i, ex in enumerate(items, 1):
283
+ print(f"\n[{i}] {ex.db_id} :: {ex.question}")
284
+ adapter = SQLiteAdapter(ex.db_path)
285
+ pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
286
+
287
+ # derive schema per-DB (optional)
288
+ schema_preview = _derive_schema_preview_safe(pipeline)
289
+
290
+ t0 = time.perf_counter()
291
+ try:
292
+ result = pipeline.run(
293
+ user_query=ex.question, schema_preview=schema_preview or ""
294
+ )
295
+ latency_ms = _int_ms(t0) or 1
296
+ stages = _to_stage_list(
297
+ getattr(result, "traces", getattr(result, "trace", []))
298
+ )
299
+ rows.append(
300
+ {
301
+ "source": "spider",
302
+ "db_id": ex.db_id,
303
+ "query": ex.question,
304
+ "ok": bool(getattr(result, "ok", True)),
305
+ "latency_ms": latency_ms,
306
+ "trace": stages,
307
+ "error": None,
308
+ }
309
+ )
310
+ print(f"βœ… Success ({latency_ms} ms)")
311
+ except Exception as exc:
312
+ latency_ms = _int_ms(t0) or 1
313
+ rows.append(
314
  {
315
+ "source": "spider",
316
+ "db_id": ex.db_id,
317
+ "query": ex.question,
318
+ "ok": False,
319
+ "latency_ms": latency_ms,
320
+ "trace": [],
321
+ "error": str(exc),
322
  }
323
  )
324
+ print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
325
 
326
+ meta = {
327
+ "mode": "spider",
328
+ "split": split,
329
+ "limit": limit,
330
+ "config": config_path,
331
+ "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
332
+ "spider_root": os.getenv("SPIDER_ROOT", ""),
333
+ }
334
+ _save_outputs(rows, meta)
335
+
336
+
337
+ def main() -> None:
338
+ ap = argparse.ArgumentParser()
339
+ ap.add_argument(
340
+ "--spider",
341
+ action="store_true",
342
+ help="Enable Spider mode (reads from SPIDER_ROOT; ignores --db-path).",
343
  )
344
+ ap.add_argument(
345
+ "--split",
346
+ type=str,
347
+ default="dev",
348
+ choices=["dev", "train"],
349
+ help="Spider split to use (default: dev).",
350
+ )
351
+ ap.add_argument(
352
+ "--limit",
353
+ type=int,
354
+ default=20,
355
+ help="Number of Spider items to evaluate (default: 20).",
356
+ )
357
+
358
+ ap.add_argument(
359
+ "--db-path",
360
+ type=str,
361
+ default="demo.db",
362
+ help="Path to SQLite database file (single-DB mode).",
363
+ )
364
+ ap.add_argument(
365
+ "--dataset-file",
366
+ type=str,
367
+ default=None,
368
+ help="Optional JSON file with questions (single-DB mode).",
369
+ )
370
+ ap.add_argument(
371
+ "--config",
372
+ type=str,
373
+ default=CONFIG_PATH,
374
+ help=f"Pipeline YAML config (default: {CONFIG_PATH})",
375
+ )
376
+ args, _unknown = ap.parse_known_args()
377
+
378
+ if args.spider:
379
+ # Spider mode: read items from SPIDER_ROOT and evaluate per-DB
380
+ if not os.getenv("SPIDER_ROOT"):
381
+ raise RuntimeError(
382
+ "SPIDER_ROOT is not set. It must point to the folder that contains "
383
+ "dev.json/train_spider.json and the database/ directory."
384
+ )
385
+ _run_spider_mode(args.split, args.limit, args.config)
386
+ else:
387
+ # Single-DB demo mode
388
+ db_path = Path(args.db_path).resolve()
389
+ # Auto-create demo DB for test/smoke runs; otherwise keep strict check
390
+ if db_path.name == "demo.db":
391
+ _ensure_demo_db(db_path)
392
+ elif not db_path.exists():
393
+ raise FileNotFoundError(f"SQLite DB not found: {db_path}")
394
+ questions = _load_dataset_from_file(args.dataset_file)
395
+ _run_single_db_mode(db_path, questions, args.config)
396
 
397
 
398
  if __name__ == "__main__":
benchmarks/evaluate_spider_pro.py CHANGED
@@ -1,18 +1,38 @@
1
  """
2
- Full benchmark for NL2SQL pipeline.
 
3
 
4
- Metrics:
5
- - EM (exact match)
6
- - Structural Match (sqlglot AST)
7
- - Execution Accuracy
8
- - Safety consistency (pipeline vs AST)
9
- - Latency (end-to-end) + per-stage trace (via pipeline if available)
10
 
11
- Outputs:
12
- JSONL (logs), JSON (summary), CSV (compact table)
 
 
 
 
 
 
 
 
 
13
 
14
- Run example:
15
- python benchmarks/evaluate_spider_pro.py --limit 10 --sleep 0.1 --db sqlite --adapter data/chinook.db
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
 
18
  from __future__ import annotations
@@ -20,71 +40,71 @@ from __future__ import annotations
20
  import argparse
21
  import csv
22
  import json
23
- import sqlite3
24
  import time
25
  from pathlib import Path
26
- from typing import Any, Dict, List, Optional, cast
27
 
28
  import sqlglot
29
  from sqlglot.errors import ParseError
30
 
31
- # Reuse existing factories from FastAPI router (no new DI needed)
32
- from app.routers.nl2sql import ( # type: ignore
33
- _pipeline as DEFAULT_PIPELINE,
34
- _build_pipeline,
35
- _select_adapter,
36
- )
37
- from nl2sql.safety import Safety
38
 
 
 
 
 
 
 
39
 
40
- # -------------------- Helpers --------------------
 
 
 
41
 
42
 
43
- def _int_ms(start: float) -> int:
44
- return int((time.perf_counter() - start) * 1000)
 
 
 
 
 
 
45
 
46
-
47
- def _parse_sql(sql: str) -> Optional[sqlglot.Expression]:
48
- try:
49
- return sqlglot.parse_one(sql, read="sqlite")
50
- except ParseError:
51
- return None
52
 
53
 
54
- def _is_structural_match(sql1: str, sql2: str) -> bool:
55
- a, b = _parse_sql(sql1), _parse_sql(sql2)
56
- return (a == b) if (a is not None and b is not None) else False
57
 
58
 
59
- def _exec_sql(conn: sqlite3.Connection, sql: str) -> List[tuple]:
60
- try:
61
- cur = conn.execute(sql)
62
- return [tuple(r) for r in cur.fetchall()]
63
- except Exception:
64
- return []
65
 
66
 
67
  def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
68
- for attr in ("executor", "adapter"):
69
- obj = getattr(pipeline_obj, attr, None)
70
- if obj is not None and hasattr(obj, "derive_schema_preview"):
71
- try:
72
- # type: ignore[no-any-return]
73
- return obj.derive_schema_preview() # pragma: no cover
74
- except Exception:
75
- pass
 
 
76
  return None
77
 
78
 
79
  def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
80
- """
81
- Normalize pipeline trace (list of dataclass or dict) to:
82
- [{'stage': str, 'ms': int}, ...]
83
- """
84
- stages: List[Dict[str, Any]] = []
85
  if not isinstance(trace_obj, list):
86
- return stages
87
-
88
  for t in trace_obj:
89
  if isinstance(t, dict):
90
  stage = t.get("stage", "?")
@@ -93,216 +113,377 @@ def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
93
  stage = getattr(t, "stage", "?")
94
  ms = getattr(t, "duration_ms", 0)
95
  try:
96
- stages.append({"stage": str(stage), "ms": int(ms)})
97
  except Exception:
98
- stages.append({"stage": str(stage), "ms": 0})
99
- return stages
100
 
101
 
102
- # -------------------- Main --------------------
 
 
 
 
103
 
104
 
105
- def main() -> None:
106
- parser = argparse.ArgumentParser()
107
- parser.add_argument("--limit", type=int, default=10, help="Max number of examples")
108
- parser.add_argument("--resume", type=int, default=0, help="Skip first N examples")
109
- parser.add_argument(
110
- "--sleep", type=float, default=0.0, help="Delay (seconds) between queries"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- parser.add_argument(
113
- "--split", type=str, default="test", help="Dataset split (placeholder)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
- parser.add_argument(
116
- "--db", type=str, default="sqlite", help="Database ID (e.g., sqlite/postgres)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- parser.add_argument(
119
- "--adapter",
120
- type=str,
121
- default="data/chinook.db",
122
- help="SQLite file path for local eval",
123
  )
124
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # SQLite connection for execution-accuracy
127
- conn = sqlite3.connect(args.adapter)
128
 
129
- # Build pipeline from router factories
130
- try:
131
- adapter = _select_adapter(args.db)
132
- pipeline = _build_pipeline(adapter)
133
- using_default = False
134
- except Exception:
135
- pipeline = DEFAULT_PIPELINE
136
- using_default = True
 
137
 
138
- safety = Safety()
139
- schema_preview = _derive_schema_preview_safe(pipeline)
140
- print(f"βœ… Pipeline ready (db={args.db}, default={using_default})")
141
-
142
- # Minimal sample dataset for demonstration; replace with real Spider subset if available
143
- DATASET: List[Dict[str, Any]] = [
144
- {
145
- "id": 1,
146
- "question": "list all customers",
147
- "gold_sql": "SELECT * FROM customers;",
148
- },
149
- {
150
- "id": 2,
151
- "question": "top 3 albums by total sales",
152
- "gold_sql": """
153
- SELECT a.Title, SUM(i.Quantity * i.UnitPrice) AS total
154
- FROM albums a
155
- JOIN tracks t ON a.AlbumId = t.AlbumId
156
- JOIN invoice_items i ON t.TrackId = i.TrackId
157
- GROUP BY a.AlbumId
158
- ORDER BY total DESC
159
- LIMIT 3;
160
- """,
161
- },
162
- {
163
- "id": 3,
164
- "question": "number of employees per city",
165
- "gold_sql": """
166
- SELECT City, COUNT(*) AS cnt
167
- FROM employees
168
- GROUP BY City
169
- ORDER BY cnt DESC;
170
- """,
171
- },
172
- ]
173
 
174
- sliced = DATASET[args.resume : args.resume + args.limit]
 
 
 
175
 
176
- # Eval loop
177
- results: List[Dict[str, Any]] = []
178
- for idx, ex in enumerate(sliced, start=1):
179
- qid = cast(int, ex.get("id", idx))
180
- q: str = cast(str, ex.get("question", ""))
181
- gold_sql: str = cast(str, ex.get("gold_sql", "")).strip()
182
- print(f"\n[{idx}] {q}")
183
 
184
  t0 = time.perf_counter()
185
  try:
186
- out = pipeline.run(user_query=q, schema_preview=(schema_preview or "")) # type: ignore[misc]
187
- latency = _int_ms(t0)
188
-
189
- # Safely extract predicted SQL:
190
- sql_pred_obj = getattr(out, "sql", None)
191
- if sql_pred_obj is None:
192
- data_obj = getattr(out, "data", None)
193
- if data_obj is not None:
194
- sql_pred_obj = getattr(data_obj, "sql", None)
195
-
196
- sql_pred: str = str(sql_pred_obj) if sql_pred_obj is not None else ""
197
- if not sql_pred.strip():
198
- raise ValueError("No SQL generated")
199
-
200
- # Metrics
201
- em = sql_pred.strip().lower() == gold_sql.strip().lower()
202
- sm = _is_structural_match(sql_pred, gold_sql)
203
-
204
- safe_ast = safety.check(sql_pred) # pipeline has its own safety as well
205
- safe_pipeline = bool(getattr(out, "ok", True))
206
- safety_consistent = safe_ast.ok == safe_pipeline
207
-
208
- gold_exec = _exec_sql(conn, gold_sql)
209
- pred_exec = _exec_sql(conn, sql_pred)
210
- exec_acc = gold_exec == pred_exec
211
 
212
- stages = _to_stage_list(getattr(out, "trace", None))
 
 
 
 
 
 
 
 
213
 
214
- results.append(
215
  {
216
- "id": qid,
217
- "question": q,
 
218
  "sql_pred": sql_pred,
219
  "sql_gold": gold_sql,
220
  "em": em,
221
  "sm": sm,
222
  "exec_acc": exec_acc,
223
- "safety_consistent": safety_consistent,
224
- "latency_ms": latency,
225
  "trace": stages,
226
  "error": None,
227
  }
228
  )
229
- print(f"βœ… OK | EM={em} | SM={sm} | Exec={exec_acc} | {latency} ms")
230
-
231
- except Exception as e:
232
- latency = _int_ms(t0)
233
- results.append(
234
  {
235
- "id": qid,
236
- "question": q,
 
237
  "sql_pred": None,
238
- "sql_gold": gold_sql,
239
  "em": False,
240
  "sm": False,
241
  "exec_acc": False,
242
- "safety_consistent": None,
243
- "latency_ms": latency,
244
  "trace": [],
245
- "error": str(e),
246
  }
247
  )
248
- print(f"❌ Fail ({latency} ms): {e}")
249
- time.sleep(args.sleep)
250
-
251
- # Summary
252
- total = len(results)
253
- avg_latency = round(sum(r["latency_ms"] for r in results) / max(total, 1), 1)
254
- em_rate = (sum(1 for r in results if r["em"]) / max(total, 1)) if total else 0.0
255
- sm_rate = (sum(1 for r in results if r["sm"]) / max(total, 1)) if total else 0.0
256
- exec_acc_rate = (
257
- (sum(1 for r in results if r["exec_acc"]) / max(total, 1)) if total else 0.0
 
 
 
 
 
 
 
 
 
 
 
258
  )
259
 
260
- summary: Dict[str, Any] = {
 
 
 
 
 
 
261
  "total": total,
 
 
 
 
262
  "avg_latency_ms": avg_latency,
263
- "EM": em_rate,
264
- "SM": sm_rate,
265
- "ExecAcc": exec_acc_rate,
266
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
267
- "db": args.db,
268
- "using_default_pipeline": using_default,
269
  }
 
270
 
271
- # Persist outputs (timestamped dir)
272
- out_dir = Path("benchmarks") / "results_pro" / time.strftime("%Y%m%d-%H%M%S")
273
- out_dir.mkdir(parents=True, exist_ok=True)
274
 
275
- jsonl_path = out_dir / "spider_eval_pro.jsonl"
276
- with jsonl_path.open("w", encoding="utf-8") as f:
277
- for r in results:
278
- json.dump(r, f, ensure_ascii=False)
279
- f.write("\n")
280
 
281
- json_path = out_dir / "summary.json"
282
- with json_path.open("w", encoding="utf-8") as f:
283
- json.dump(summary, f, indent=2)
284
 
285
- csv_path = out_dir / "summary.csv"
286
- with csv_path.open("w", newline="", encoding="utf-8") as f:
287
- writer = csv.DictWriter(
288
- f,
289
- fieldnames=["id", "question", "em", "sm", "exec_acc", "latency_ms"],
290
- )
291
- writer.writeheader()
292
- for r in results:
293
- writer.writerow(
294
- {
295
- "id": r["id"],
296
- "question": r["question"],
297
- "em": "βœ…" if r["em"] else "❌",
298
- "sm": "βœ…" if r["sm"] else "❌",
299
- "exec_acc": "βœ…" if r["exec_acc"] else "❌",
300
- "latency_ms": r["latency_ms"],
301
- }
302
- )
 
 
303
 
304
- print("\nπŸ“Š Summary:", json.dumps(summary, indent=2))
305
- print(f"πŸ’Ύ Saved to:\n- {jsonl_path}\n- {json_path}\n- {csv_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
 
308
  if __name__ == "__main__":
 
1
  """
2
+ Pro evaluation runner with two modes:
3
+ Extension of `evaluate_spider.py` with additional metrics (EM, SM, ExecAcc) and richer logging for research-style benchmarking.
4
 
5
+ 1) Single-DB demo mode (default)
6
+ - Runs a list of questions against one SQLite DB
7
+ - Reports latency/ok (no EM/SM/ExecAcc because there's no gold SQL)
 
 
 
8
 
9
+ 2) Spider mode (--spider)
10
+ - Loads a subset of the Spider dataset via SPIDER_ROOT
11
+ - For each item, builds a per-DB pipeline and computes:
12
+ * EM (exact SQL string match, case-insensitive)
13
+ * SM (structural match via sqlglot AST)
14
+ * ExecAcc (result equivalence by executing gold vs. predicted SQL)
15
+ - Also logs latency, (optional) traces, and aggregates a summary
16
+
17
+ Works with:
18
+ - Real LLM (OPENAI_API_KEY set)
19
+ - Stub mode (PYTEST_CURRENT_TEST=1) for zero-cost offline runs
20
 
21
+ Outputs:
22
+ benchmarks/results_pro/<timestamp>/
23
+ - eval.jsonl # per-sample rows
24
+ - summary.json # aggregate metrics
25
+ - results.csv # human-friendly table
26
+
27
+ Examples:
28
+ # Demo (single DB), stub mode
29
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
30
+ python benchmarks/evaluate_spider_pro.py --db-path demo.db
31
+
32
+ # Spider subset (20 items), stub mode
33
+ export SPIDER_ROOT=$PWD/data/spider
34
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
35
+ python benchmarks/evaluate_spider_pro.py --spider --split dev --limit 20
36
  """
37
 
38
  from __future__ import annotations
 
40
  import argparse
41
  import csv
42
  import json
43
+ import os
44
  import time
45
  from pathlib import Path
46
+ from typing import Any, Dict, List, Optional
47
 
48
  import sqlglot
49
  from sqlglot.errors import ParseError
50
 
51
+ from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
52
+ from adapters.db.sqlite_adapter import SQLiteAdapter
 
 
 
 
 
53
 
54
+ # Only needed for Spider mode
55
+ try:
56
+ from benchmarks.spider_loader import load_spider_sqlite, open_readonly_connection
57
+ except Exception:
58
+ load_spider_sqlite = None # type: ignore[assignment]
59
+ open_readonly_connection = None # type: ignore[assignment]
60
 
61
+ # Resolve repo root and default config path relative to this file (not CWD)
62
+ THIS_DIR = Path(__file__).resolve().parent # .../benchmarks
63
+ REPO_ROOT = THIS_DIR.parent # repo root
64
+ CONFIG_PATH = str(REPO_ROOT / "configs" / "sqlite_pipeline.yaml")
65
 
66
 
67
+ # Default demo questions for single-DB mode
68
+ DEFAULT_DATASET: List[str] = [
69
+ "list all customers",
70
+ "show total invoices per country",
71
+ "top 3 albums by total sales",
72
+ "artists with more than 3 albums",
73
+ "number of employees per city",
74
+ ]
75
 
76
+ RESULT_ROOT = Path("benchmarks") / "results_pro"
77
+ TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
78
+ RESULT_DIR = RESULT_ROOT / TIMESTAMP
 
 
 
79
 
80
 
81
+ # -------------------- Utilities --------------------
 
 
82
 
83
 
84
+ def _int_ms(start: float) -> int:
85
+ """Convert elapsed seconds to integer milliseconds."""
86
+ return int((time.perf_counter() - start) * 1000)
 
 
 
87
 
88
 
89
  def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
90
+ """Safely call derive_schema_preview() if available on adapter/executor."""
91
+ try:
92
+ for c in (
93
+ getattr(pipeline_obj, "executor", None),
94
+ getattr(pipeline_obj, "adapter", None),
95
+ ):
96
+ if c and hasattr(c, "derive_schema_preview"):
97
+ return c.derive_schema_preview() # type: ignore[no-any-return]
98
+ except Exception:
99
+ pass
100
  return None
101
 
102
 
103
  def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
104
+ """Normalize pipeline trace into a list of dicts for logging/export."""
105
+ out: List[Dict[str, Any]] = []
 
 
 
106
  if not isinstance(trace_obj, list):
107
+ return out
 
108
  for t in trace_obj:
109
  if isinstance(t, dict):
110
  stage = t.get("stage", "?")
 
113
  stage = getattr(t, "stage", "?")
114
  ms = getattr(t, "duration_ms", 0)
115
  try:
116
+ out.append({"stage": str(stage), "ms": int(ms)})
117
  except Exception:
118
+ out.append({"stage": str(stage), "ms": 0})
119
+ return out
120
 
121
 
122
+ def _parse_sql(sql: str):
123
+ try:
124
+ return sqlglot.parse_one(sql, read="sqlite")
125
+ except ParseError:
126
+ return None
127
 
128
 
129
+ def _structural_match(pred: str, gold: str) -> bool:
130
+ """AST-level equality via sqlglot; returns False if either side can't be parsed."""
131
+ a, b = _parse_sql(pred), _parse_sql(gold)
132
+ return (a == b) if (a is not None and b is not None) else False
133
+
134
+
135
+ def _load_dataset_from_file(path: Optional[str]) -> List[str]:
136
+ """Load questions from a JSON file: list[str] or list[{question: str}]."""
137
+ if not path:
138
+ return DEFAULT_DATASET
139
+ p = Path(path)
140
+ if not p.exists():
141
+ raise FileNotFoundError(f"dataset file not found: {p}")
142
+ data = json.loads(p.read_text(encoding="utf-8"))
143
+ if isinstance(data, list):
144
+ if all(isinstance(x, str) for x in data):
145
+ return list(data)
146
+ if all(isinstance(x, dict) and "question" in x for x in data):
147
+ return [str(x["question"]) for x in data]
148
+ raise ValueError(
149
+ "Dataset file must be a JSON array of strings or objects with 'question' field."
150
  )
151
+
152
+
153
+ def _extract_sql(result: Any) -> str:
154
+ """
155
+ Extract SQL from pipeline result in a mypy-friendly way.
156
+ Supports both result.sql and result.data.sql shapes.
157
+ """
158
+ sql_pred: Optional[str] = getattr(result, "sql", None)
159
+ if not sql_pred:
160
+ data = getattr(result, "data", None)
161
+ if data is not None:
162
+ sql_pred = getattr(data, "sql", None)
163
+ return (sql_pred or "").strip()
164
+
165
+
166
+ def _save_outputs(rows: List[Dict[str, Any]], summary: Dict[str, Any]) -> None:
167
+ """Persist JSONL + JSON summary + CSV for pro runner."""
168
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
169
+
170
+ jsonl_path = RESULT_DIR / "eval.jsonl"
171
+ with jsonl_path.open("w", encoding="utf-8") as f:
172
+ for r in rows:
173
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
174
+
175
+ with (RESULT_DIR / "summary.json").open("w", encoding="utf-8") as f:
176
+ json.dump(summary, f, indent=2)
177
+
178
+ csv_path = RESULT_DIR / "results.csv"
179
+ # For pro, include pro columns when present (Spider mode)
180
+ fieldnames = [
181
+ "source",
182
+ "db_id",
183
+ "query",
184
+ "em",
185
+ "sm",
186
+ "exec_acc",
187
+ "ok",
188
+ "latency_ms",
189
+ ]
190
+ with csv_path.open("w", newline="", encoding="utf-8") as f:
191
+ wr = csv.DictWriter(f, fieldnames=fieldnames)
192
+ wr.writeheader()
193
+ for r in rows:
194
+ wr.writerow(
195
+ {
196
+ "source": r.get("source", "demo"),
197
+ "db_id": r.get("db_id", ""),
198
+ "query": r.get("query", ""),
199
+ "em": "βœ…" if r.get("em") else "❌" if "em" in r else "",
200
+ "sm": "βœ…" if r.get("sm") else "❌" if "sm" in r else "",
201
+ "exec_acc": "βœ…"
202
+ if r.get("exec_acc")
203
+ else "❌"
204
+ if "exec_acc" in r
205
+ else "",
206
+ "ok": "βœ…" if r.get("ok") else "❌",
207
+ "latency_ms": int(r.get("latency_ms", 0)),
208
+ }
209
+ )
210
+
211
+ print(
212
+ "\nπŸ’Ύ Saved outputs:\n"
213
+ f"- {jsonl_path}\n- {RESULT_DIR / 'summary.json'}\n- {csv_path}\n"
214
+ f"πŸ“Š Avg latency: {summary.get('avg_latency_ms', 0.0)} ms "
215
+ f"| EM: {summary.get('EM', 0.0):.3f} "
216
+ f"| SM: {summary.get('SM', 0.0):.3f} "
217
+ f"| ExecAcc: {summary.get('ExecAcc', 0.0):.3f} "
218
+ f"| Success: {summary.get('success_rate', 0.0):.0%}\n"
219
  )
220
+
221
+
222
+ # -------------------- Runners --------------------
223
+
224
+
225
+ def _run_single_db_mode(db_path: Path, questions: List[str], config_path: str) -> None:
226
+ """
227
+ Single-DB demo mode.
228
+ Only latency/ok is reported (no EM/SM/ExecAcc, because we don't have gold SQL).
229
+ """
230
+ adapter = SQLiteAdapter(str(db_path))
231
+ pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
232
+
233
+ schema_preview = _derive_schema_preview_safe(pipeline)
234
+ if schema_preview:
235
+ print("πŸ“„ Derived schema preview βœ“")
236
+ else:
237
+ print("ℹ️ No schema preview (adapter does not expose it or not needed)")
238
+
239
+ rows: List[Dict[str, Any]] = []
240
+ for q in questions:
241
+ print(f"\n🧠 Query: {q}")
242
+ t0 = time.perf_counter()
243
+ try:
244
+ result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
245
+ latency_ms = _int_ms(t0) or 1 # clamp to 1ms for nicer CSV in stub mode
246
+ stages = _to_stage_list(
247
+ getattr(result, "traces", getattr(result, "trace", []))
248
+ )
249
+ rows.append(
250
+ {
251
+ "source": "demo",
252
+ "db_id": Path(db_path).stem,
253
+ "query": q,
254
+ "ok": bool(getattr(result, "ok", True)),
255
+ "latency_ms": latency_ms,
256
+ "trace": stages,
257
+ "error": None,
258
+ }
259
+ )
260
+ print(f"βœ… Success ({latency_ms} ms)")
261
+ except Exception as exc:
262
+ latency_ms = _int_ms(t0) or 1
263
+ rows.append(
264
+ {
265
+ "source": "demo",
266
+ "db_id": Path(db_path).stem,
267
+ "query": q,
268
+ "ok": False,
269
+ "latency_ms": latency_ms,
270
+ "trace": [],
271
+ "error": str(exc),
272
+ }
273
+ )
274
+ print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
275
+
276
+ success_rate = (
277
+ (sum(1 for r in rows if r.get("ok")) / max(len(rows), 1)) if rows else 0.0
278
  )
279
+ avg_latency = (
280
+ round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(len(rows), 1), 1)
281
+ if rows
282
+ else 0.0
 
283
  )
284
+ summary = {
285
+ "mode": "single-db",
286
+ "db_path": str(db_path),
287
+ "config": config_path,
288
+ "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
289
+ "total": len(rows),
290
+ "EM": 0.0,
291
+ "SM": 0.0,
292
+ "ExecAcc": 0.0, # not applicable in demo
293
+ "success_rate": success_rate,
294
+ "avg_latency_ms": avg_latency,
295
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
296
+ }
297
+ _save_outputs(rows, summary)
298
 
 
 
299
 
300
+ def _run_spider_mode(split: str, limit: int, config_path: str) -> None:
301
+ """
302
+ Spider mode: compute EM/SM/ExecAcc with per-DB pipelines.
303
+ Requires SPIDER_ROOT pointing to a folder that contains dev.json/train_spider.json and database/.
304
+ """
305
+ if load_spider_sqlite is None or open_readonly_connection is None:
306
+ raise RuntimeError(
307
+ "Spider utilities are not available. Ensure benchmarks/spider_loader.py exists."
308
+ )
309
 
310
+ items = load_spider_sqlite(split=split, limit=limit)
311
+ print(f"πŸ—‚ Loaded {len(items)} Spider items (split={split}).")
312
+
313
+ rows: List[Dict[str, Any]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ for i, ex in enumerate(items, 1):
316
+ print(f"\n[{i}] {ex.db_id} :: {ex.question}")
317
+ adapter = SQLiteAdapter(ex.db_path)
318
+ pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
319
 
320
+ # Optional schema preview per DB
321
+ schema_preview = _derive_schema_preview_safe(pipeline)
322
+
323
+ # Open read-only connection for ExecAcc computation
324
+ conn = open_readonly_connection(ex.db_path)
 
 
325
 
326
  t0 = time.perf_counter()
327
  try:
328
+ result = pipeline.run(
329
+ user_query=ex.question, schema_preview=schema_preview or ""
330
+ )
331
+ latency_ms = _int_ms(t0) or 1
332
+ stages = _to_stage_list(
333
+ getattr(result, "traces", getattr(result, "trace", []))
334
+ )
335
+
336
+ # Extract predicted SQL from result (support both .sql and .data.sql)
337
+ sql_pred = _extract_sql(result)
338
+
339
+ # Pro metrics
340
+ gold_sql = ex.gold_sql.strip()
341
+ em = (sql_pred.lower() == gold_sql.lower()) if sql_pred else False
342
+ sm = _structural_match(sql_pred, gold_sql) if sql_pred else False
 
 
 
 
 
 
 
 
 
 
343
 
344
+ try:
345
+ gold_exec = conn.execute(gold_sql).fetchall()
346
+ except Exception:
347
+ gold_exec = []
348
+ try:
349
+ pred_exec = conn.execute(sql_pred).fetchall() if sql_pred else []
350
+ except Exception:
351
+ pred_exec = []
352
+ exec_acc = gold_exec == pred_exec
353
 
354
+ rows.append(
355
  {
356
+ "source": "spider",
357
+ "db_id": ex.db_id,
358
+ "query": ex.question,
359
  "sql_pred": sql_pred,
360
  "sql_gold": gold_sql,
361
  "em": em,
362
  "sm": sm,
363
  "exec_acc": exec_acc,
364
+ "ok": bool(getattr(result, "ok", True)),
365
+ "latency_ms": latency_ms,
366
  "trace": stages,
367
  "error": None,
368
  }
369
  )
370
+ print(f"βœ… OK | EM={em} | SM={sm} | Exec={exec_acc} | {latency_ms} ms")
371
+ except Exception as exc:
372
+ latency_ms = _int_ms(t0) or 1
373
+ rows.append(
 
374
  {
375
+ "source": "spider",
376
+ "db_id": ex.db_id,
377
+ "query": ex.question,
378
  "sql_pred": None,
379
+ "sql_gold": ex.gold_sql,
380
  "em": False,
381
  "sm": False,
382
  "exec_acc": False,
383
+ "ok": False,
384
+ "latency_ms": latency_ms,
385
  "trace": [],
386
+ "error": str(exc),
387
  }
388
  )
389
+ print(f"❌ Fail: {exc!s} ({latency_ms} ms)")
390
+ finally:
391
+ try:
392
+ conn.close()
393
+ except Exception:
394
+ pass
395
+
396
+ # Aggregate pro metrics
397
+ total = len(rows)
398
+ em_rate = (sum(1 for r in rows if r.get("em")) / max(total, 1)) if rows else 0.0
399
+ sm_rate = (sum(1 for r in rows if r.get("sm")) / max(total, 1)) if rows else 0.0
400
+ exec_rate = (
401
+ (sum(1 for r in rows if r.get("exec_acc")) / max(total, 1)) if rows else 0.0
402
+ )
403
+ success_rate = (
404
+ (sum(1 for r in rows if r.get("ok")) / max(total, 1)) if rows else 0.0
405
+ )
406
+ avg_latency = (
407
+ round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(total, 1), 1)
408
+ if rows
409
+ else 0.0
410
  )
411
 
412
+ summary = {
413
+ "mode": "spider",
414
+ "split": split,
415
+ "limit": limit,
416
+ "config": config_path,
417
+ "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
418
+ "spider_root": os.getenv("SPIDER_ROOT", ""),
419
  "total": total,
420
+ "EM": round(em_rate, 3),
421
+ "SM": round(sm_rate, 3),
422
+ "ExecAcc": round(exec_rate, 3),
423
+ "success_rate": success_rate,
424
  "avg_latency_ms": avg_latency,
 
 
 
425
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
 
 
426
  }
427
+ _save_outputs(rows, summary)
428
 
 
 
 
429
 
430
+ # -------------------- CLI --------------------
 
 
 
 
431
 
 
 
 
432
 
433
+ def main() -> None:
434
+ ap = argparse.ArgumentParser()
435
+ ap.add_argument(
436
+ "--spider",
437
+ action="store_true",
438
+ help="Enable Spider mode (reads from SPIDER_ROOT; ignores --db-path).",
439
+ )
440
+ ap.add_argument(
441
+ "--split",
442
+ type=str,
443
+ default="dev",
444
+ choices=["dev", "train"],
445
+ help="Spider split to use (default: dev).",
446
+ )
447
+ ap.add_argument(
448
+ "--limit",
449
+ type=int,
450
+ default=20,
451
+ help="Number of Spider items to evaluate (default: 20).",
452
+ )
453
 
454
+ ap.add_argument(
455
+ "--db-path",
456
+ type=str,
457
+ default="demo.db",
458
+ help="Path to SQLite database file (single-DB mode).",
459
+ )
460
+ ap.add_argument(
461
+ "--dataset-file",
462
+ type=str,
463
+ default=None,
464
+ help="Optional JSON file with questions (single-DB mode).",
465
+ )
466
+ ap.add_argument(
467
+ "--config",
468
+ type=str,
469
+ default=CONFIG_PATH,
470
+ help=f"Pipeline YAML config (default: {CONFIG_PATH})",
471
+ )
472
+ args = ap.parse_args()
473
+
474
+ if args.spider:
475
+ if not os.getenv("SPIDER_ROOT"):
476
+ raise RuntimeError(
477
+ "SPIDER_ROOT is not set. It must point to the folder that directly contains "
478
+ "dev.json/train_spider.json and the database/ directory."
479
+ )
480
+ _run_spider_mode(args.split, args.limit, args.config)
481
+ else:
482
+ db_path = Path(args.db_path).resolve()
483
+ if not db_path.exists():
484
+ raise FileNotFoundError(f"SQLite DB not found: {db_path}")
485
+ questions = _load_dataset_from_file(args.dataset_file)
486
+ _run_single_db_mode(db_path, questions, args.config)
487
 
488
 
489
  if __name__ == "__main__":
benchmarks/results/20251108-110451/eval.jsonl ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
6
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age for all French singers?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
7
+ {"source": "spider", "db_id": "concert_singer", "query": "Show the name and the release year of the song by the youngest singer.", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
8
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names and release years for all the songs of the youngest singer?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
9
+ {"source": "spider", "db_id": "concert_singer", "query": "What are all distinct countries where singers above age 20 are from?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
10
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the different countries with singers above age 20?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
11
+ {"source": "spider", "db_id": "concert_singer", "query": "Show all countries and the number of singers in each country.", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
12
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers are from each country?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
13
+ {"source": "spider", "db_id": "concert_singer", "query": "List all song names by singers above the average age.", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
14
+ {"source": "spider", "db_id": "concert_singer", "query": "What are all the song names by singers who are older than average?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
15
+ {"source": "spider", "db_id": "concert_singer", "query": "Show location and name for all stadiums with a capacity between 5000 and 10000.", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
16
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the locations and names of all stations with capacity between 5000 and 10000?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
17
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the maximum capacity and the average of all stadiums ?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
18
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average and maximum capacities for all stadiums ?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
19
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the name and capacity for the stadium with highest average attendance?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
20
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the name and capacity for the stadium with the highest average attendance?", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
benchmarks/results/20251108-110451/results.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source,db_id,query,ok,latency_ms
2
+ spider,concert_singer,How many singers do we have?,βœ…,1
3
+ spider,concert_singer,What is the total number of singers?,βœ…,1
4
+ spider,concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",βœ…,1
5
+ spider,concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",βœ…,1
6
+ spider,concert_singer,"What is the average, minimum, and maximum age of all singers from France?",βœ…,1
7
+ spider,concert_singer,"What is the average, minimum, and maximum age for all French singers?",βœ…,1
8
+ spider,concert_singer,Show the name and the release year of the song by the youngest singer.,βœ…,1
9
+ spider,concert_singer,What are the names and release years for all the songs of the youngest singer?,βœ…,1
10
+ spider,concert_singer,What are all distinct countries where singers above age 20 are from?,βœ…,1
11
+ spider,concert_singer,What are the different countries with singers above age 20?,βœ…,1
12
+ spider,concert_singer,Show all countries and the number of singers in each country.,βœ…,1
13
+ spider,concert_singer,How many singers are from each country?,βœ…,1
14
+ spider,concert_singer,List all song names by singers above the average age.,βœ…,1
15
+ spider,concert_singer,What are all the song names by singers who are older than average?,βœ…,1
16
+ spider,concert_singer,Show location and name for all stadiums with a capacity between 5000 and 10000.,βœ…,1
17
+ spider,concert_singer,What are the locations and names of all stations with capacity between 5000 and 10000?,βœ…,1
18
+ spider,concert_singer,What is the maximum capacity and the average of all stadiums ?,βœ…,1
19
+ spider,concert_singer,What is the average and maximum capacities for all stadiums ?,βœ…,1
20
+ spider,concert_singer,What is the name and capacity for the stadium with highest average attendance?,βœ…,1
21
+ spider,concert_singer,What is the name and capacity for the stadium with the highest average attendance?,βœ…,1
benchmarks/results/20251108-110451/summary.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total": 20,
3
+ "success_rate": 1.0,
4
+ "avg_latency_ms": 1.0,
5
+ "mode": "spider",
6
+ "split": "dev",
7
+ "limit": 20,
8
+ "config": "configs/sqlite_pipeline.yaml",
9
+ "provider_hint": "STUBS",
10
+ "spider_root": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/data/spider",
11
+ "timestamp": "2025-11-08 11:04:51"
12
+ }
benchmarks/results_demo/20251108-111403/demo.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"query": "list all customers", "ok": true, "latency_ms": 12, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
2
+ {"query": "show total invoices per country", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
3
+ {"query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
4
+ {"query": "artists with more than 3 albums", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
5
+ {"query": "number of employees per city", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 0}, {"stage": "generator", "ms": 0}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
benchmarks/results_demo/20251108-111403/results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ query,ok,latency_ms
2
+ list all customers,βœ…,12
3
+ show total invoices per country,βœ…,1
4
+ top 3 albums by total sales,βœ…,1
5
+ artists with more than 3 albums,βœ…,1
6
+ number of employees per city,βœ…,1
benchmarks/results_demo/20251108-111403/summary.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "avg_latency_ms": 3.2,
3
+ "success_rate": 1.0,
4
+ "db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
5
+ "config": "configs/sqlite_pipeline.yaml",
6
+ "provider_hint": "STUBS",
7
+ "timestamp": "2025-11-08 11:14:03"
8
+ }
benchmarks/results_pro/20251108-105442/spider_eval_pro.jsonl ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"id": 1, "db_id": "concert_singer", "question": "How many singers do we have?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT count(*) FROM singer", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
2
+ {"id": 2, "db_id": "concert_singer", "question": "What is the total number of singers?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT count(*) FROM singer", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
3
+ {"id": 3, "db_id": "concert_singer", "question": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "sql_pred": "SELECT 1;", "sql_gold": "SELECT name , country , age FROM singer ORDER BY age DESC", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
4
+ {"id": 4, "db_id": "concert_singer", "question": "What are the names, countries, and ages for every singer in descending order of age?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT name , country , age FROM singer ORDER BY age DESC", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
5
+ {"id": 5, "db_id": "concert_singer", "question": "What is the average, minimum, and maximum age of all singers from France?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
6
+ {"id": 6, "db_id": "concert_singer", "question": "What is the average, minimum, and maximum age for all French singers?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
7
+ {"id": 7, "db_id": "concert_singer", "question": "Show the name and the release year of the song by the youngest singer.", "sql_pred": "SELECT 1;", "sql_gold": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
8
+ {"id": 8, "db_id": "concert_singer", "question": "What are the names and release years for all the songs of the youngest singer?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
9
+ {"id": 9, "db_id": "concert_singer", "question": "What are all distinct countries where singers above age 20 are from?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT DISTINCT country FROM singer WHERE age > 20", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
10
+ {"id": 10, "db_id": "concert_singer", "question": "What are the different countries with singers above age 20?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT DISTINCT country FROM singer WHERE age > 20", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
11
+ {"id": 11, "db_id": "concert_singer", "question": "Show all countries and the number of singers in each country.", "sql_pred": "SELECT 1;", "sql_gold": "SELECT country , count(*) FROM singer GROUP BY country", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
12
+ {"id": 12, "db_id": "concert_singer", "question": "How many singers are from each country?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT country , count(*) FROM singer GROUP BY country", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
13
+ {"id": 13, "db_id": "concert_singer", "question": "List all song names by singers above the average age.", "sql_pred": "SELECT 1;", "sql_gold": "SELECT song_name FROM singer WHERE age > (SELECT avg(age) FROM singer)", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
14
+ {"id": 14, "db_id": "concert_singer", "question": "What are all the song names by singers who are older than average?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT song_name FROM singer WHERE age > (SELECT avg(age) FROM singer)", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
15
+ {"id": 15, "db_id": "concert_singer", "question": "Show location and name for all stadiums with a capacity between 5000 and 10000.", "sql_pred": "SELECT 1;", "sql_gold": "SELECT LOCATION , name FROM stadium WHERE capacity BETWEEN 5000 AND 10000", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
16
+ {"id": 16, "db_id": "concert_singer", "question": "What are the locations and names of all stations with capacity between 5000 and 10000?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT LOCATION , name FROM stadium WHERE capacity BETWEEN 5000 AND 10000", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
17
+ {"id": 17, "db_id": "concert_singer", "question": "What is the maximum capacity and the average of all stadiums ?", "sql_pred": "SELECT 1;", "sql_gold": "select max(capacity), average from stadium", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
18
+ {"id": 18, "db_id": "concert_singer", "question": "What is the average and maximum capacities for all stadiums ?", "sql_pred": "SELECT 1;", "sql_gold": "select avg(capacity) , max(capacity) from stadium", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
19
+ {"id": 19, "db_id": "concert_singer", "question": "What is the name and capacity for the stadium with highest average attendance?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT name , capacity FROM stadium ORDER BY average DESC LIMIT 1", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
20
+ {"id": 20, "db_id": "concert_singer", "question": "What is the name and capacity for the stadium with the highest average attendance?", "sql_pred": "SELECT 1;", "sql_gold": "SELECT name , capacity FROM stadium ORDER BY average DESC LIMIT 1", "em": false, "sm": false, "exec_acc": false, "latency_ms": 0, "error": null}
benchmarks/results_pro/20251108-105442/summary.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,db_id,question,em,sm,exec_acc,latency_ms
2
+ 1,concert_singer,How many singers do we have?,❌,❌,❌,0
3
+ 2,concert_singer,What is the total number of singers?,❌,❌,❌,0
4
+ 3,concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",❌,❌,❌,0
5
+ 4,concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",❌,❌,❌,0
6
+ 5,concert_singer,"What is the average, minimum, and maximum age of all singers from France?",❌,❌,❌,0
7
+ 6,concert_singer,"What is the average, minimum, and maximum age for all French singers?",❌,❌,❌,0
8
+ 7,concert_singer,Show the name and the release year of the song by the youngest singer.,❌,❌,❌,0
9
+ 8,concert_singer,What are the names and release years for all the songs of the youngest singer?,❌,❌,❌,0
10
+ 9,concert_singer,What are all distinct countries where singers above age 20 are from?,❌,❌,❌,0
11
+ 10,concert_singer,What are the different countries with singers above age 20?,❌,❌,❌,0
12
+ 11,concert_singer,Show all countries and the number of singers in each country.,❌,❌,❌,0
13
+ 12,concert_singer,How many singers are from each country?,❌,❌,❌,0
14
+ 13,concert_singer,List all song names by singers above the average age.,❌,❌,❌,0
15
+ 14,concert_singer,What are all the song names by singers who are older than average?,❌,❌,❌,0
16
+ 15,concert_singer,Show location and name for all stadiums with a capacity between 5000 and 10000.,❌,❌,❌,0
17
+ 16,concert_singer,What are the locations and names of all stations with capacity between 5000 and 10000?,❌,❌,❌,0
18
+ 17,concert_singer,What is the maximum capacity and the average of all stadiums ?,❌,❌,❌,0
19
+ 18,concert_singer,What is the average and maximum capacities for all stadiums ?,❌,❌,❌,0
20
+ 19,concert_singer,What is the name and capacity for the stadium with highest average attendance?,❌,❌,❌,0
21
+ 20,concert_singer,What is the name and capacity for the stadium with the highest average attendance?,❌,❌,❌,0
benchmarks/results_pro/20251108-105442/summary.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total": 20,
3
+ "EM": 0.0,
4
+ "SM": 0.0,
5
+ "ExecAcc": 0.0,
6
+ "avg_latency_ms": 0.0,
7
+ "timestamp": "2025-11-08 10:54:42"
8
+ }
benchmarks/run.py DELETED
@@ -1,214 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import os
5
- import json
6
- import time
7
- from pathlib import Path
8
- from typing import Iterable, List, Dict, Any, Protocol, Tuple, Optional
9
-
10
- # ---- app imports
11
- from nl2sql.pipeline import Pipeline, FinalResult
12
- from nl2sql.ambiguity_detector import AmbiguityDetector
13
- from nl2sql.planner import Planner
14
- from nl2sql.generator import Generator
15
- from nl2sql.safety import Safety
16
- from nl2sql.executor import Executor
17
- from nl2sql.verifier import Verifier
18
- from nl2sql.repair import Repair
19
-
20
- # ---- adapters
21
- from adapters.db.sqlite_adapter import SQLiteAdapter
22
- from adapters.llm.openai_provider import OpenAIProvider
23
-
24
-
25
- # ---- LLM protocol (unifies OpenAIProvider and DummyLLM for mypy)
26
- class LLMProvider(Protocol):
27
- """Minimal interface required by Planner/Generator/Repair stages."""
28
-
29
- provider_id: str
30
-
31
- def plan(
32
- self, *, user_query: str, schema_preview: str
33
- ) -> Tuple[str, int, int, float]: ...
34
-
35
- def generate_sql(
36
- self,
37
- *,
38
- user_query: str,
39
- schema_preview: str,
40
- plan_text: str,
41
- clarify_answers: Optional[Any] = None,
42
- ) -> Tuple[str, str, int, int, float]: ...
43
-
44
- def repair(
45
- self, *, sql: str, error_msg: str, schema_preview: str
46
- ) -> Tuple[str, int, int, float]: ...
47
-
48
-
49
- # ---- fallback: Dummy LLM (so it runs without API keys)
50
- class DummyLLM:
51
- provider_id = "dummy-llm"
52
-
53
- def plan(
54
- self, *, user_query: str, schema_preview: str
55
- ) -> Tuple[str, int, int, float]:
56
- text = (
57
- f"- understand question: {user_query}\n"
58
- "- identify tables\n- join if needed\n- filter\n- order/limit"
59
- )
60
- return text, 0, 0, 0.0
61
-
62
- def generate_sql(
63
- self,
64
- *,
65
- user_query: str,
66
- schema_preview: str,
67
- plan_text: str,
68
- clarify_answers: Optional[Any] = None,
69
- ) -> Tuple[str, str, int, int, float]:
70
- # naive demo SQL (so pipeline flows end-to-end)
71
- sql = "SELECT 1 AS one;"
72
- rationale = "Demo SQL from DummyLLM"
73
- return sql, rationale, 0, 0, 0.0
74
-
75
- def repair(
76
- self, *, sql: str, error_msg: str, schema_preview: str
77
- ) -> Tuple[str, int, int, float]:
78
- return sql, 0, 0, 0.0
79
-
80
-
81
- def ensure_demo_db(path: Path) -> None:
82
- """Create a tiny SQLite db if missing, so executor has something to run."""
83
- if path.exists():
84
- return
85
- import sqlite3
86
-
87
- path.parent.mkdir(parents=True, exist_ok=True)
88
- con = sqlite3.connect(path)
89
- cur = con.cursor()
90
- cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
91
- cur.executemany(
92
- "INSERT INTO users(id,name,spend) VALUES(?,?,?)",
93
- [(1, "Alice", 120.5), (2, "Bob", 80.0), (3, "Carol", 155.0)],
94
- )
95
- con.commit()
96
- con.close()
97
-
98
-
99
- def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
100
- # DB adapter
101
- db = SQLiteAdapter(str(db_path))
102
- executor = Executor(db)
103
-
104
- # LLM provider (typed to the Protocol so mypy accepts either provider)
105
- llm: LLMProvider
106
- if use_openai and os.getenv("OPENAI_API_KEY"):
107
- llm = OpenAIProvider() # conforms to LLMProvider
108
- else:
109
- llm = DummyLLM() # conforms to LLMProvider
110
-
111
- # stages
112
- detector = AmbiguityDetector()
113
- planner = Planner(llm)
114
- generator = Generator(llm)
115
- safety = Safety()
116
- verifier = Verifier()
117
- repair = Repair(llm)
118
-
119
- # pipeline
120
- return Pipeline(
121
- detector=detector,
122
- planner=planner,
123
- generator=generator,
124
- safety=safety,
125
- executor=executor,
126
- verifier=verifier,
127
- repair=repair,
128
- )
129
-
130
-
131
- def _sum_cost(traces: Iterable[Dict[str, Any]]) -> float:
132
- total = 0.0
133
- for tr in traces:
134
- try:
135
- total += float(tr.get("cost_usd", 0.0))
136
- except Exception:
137
- # ignore bad values
138
- pass
139
- return total
140
-
141
-
142
- def _is_safe_fail(ok: bool, details: List[str] | None) -> float:
143
- """Return 1.0 when pipeline failed due to unsafe SQL (heuristic)."""
144
- if ok:
145
- return 0.0
146
- txt = " ".join(details or []).lower()
147
- return 1.0 if "unsafe" in txt else 0.0
148
-
149
-
150
- def run_benchmark(
151
- queries: List[str], schema_preview: str, pipeline: Pipeline, outfile: Path
152
- ) -> None:
153
- results: List[Dict[str, Any]] = []
154
- for q in queries:
155
- t0 = time.perf_counter()
156
- res: FinalResult = pipeline.run(user_query=q, schema_preview=schema_preview)
157
- latency_ms = (time.perf_counter() - t0) * 1000.0
158
-
159
- ok = (not res.ambiguous) and (not res.error) and bool(res.ok)
160
- traces = res.traces or []
161
- cost_sum = _sum_cost(traces)
162
-
163
- results.append(
164
- {
165
- "query": q,
166
- "exec_acc": 1.0 if ok else 0.0,
167
- "safe_fail": _is_safe_fail(ok, res.details),
168
- "latency_ms": latency_ms,
169
- "cost_usd": cost_sum,
170
- "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
171
- "provider": getattr(
172
- getattr(pipeline.generator, "llm", None), "provider_id", "unknown"
173
- ),
174
- }
175
- )
176
-
177
- outfile.parent.mkdir(parents=True, exist_ok=True)
178
- with open(outfile, "w") as f:
179
- for row in results:
180
- f.write(json.dumps(row) + "\n")
181
- print(f"[OK] wrote {len(results)} rows β†’ {outfile}")
182
-
183
-
184
- def main() -> None:
185
- parser = argparse.ArgumentParser()
186
- parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
187
- parser.add_argument("--db", default="data/bench_demo.db")
188
- parser.add_argument(
189
- "--use-openai",
190
- action="store_true",
191
- help="Use OpenAI provider if API key present",
192
- )
193
- args = parser.parse_args()
194
-
195
- root = Path(__file__).resolve().parents[1] # project root
196
- outfile = (root / args.outfile).resolve()
197
- db_path = (root / args.db).resolve()
198
-
199
- ensure_demo_db(db_path)
200
- pipe = build_pipeline(db_path, use_openai=args.use_openai)
201
-
202
- # a small demo set; replace with Spider when ready
203
- queries = [
204
- "show all users",
205
- "top spenders",
206
- "sum of spend",
207
- ]
208
- schema_preview = "CREATE TABLE users(id INT, name TEXT, spend REAL);"
209
-
210
- run_benchmark(queries, schema_preview, pipe, outfile)
211
-
212
-
213
- if __name__ == "__main__":
214
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/spider_loader.py CHANGED
@@ -1,12 +1,11 @@
1
  from __future__ import annotations
 
2
  import json
3
- import pathlib
4
  import sqlite3
5
  from dataclasses import dataclass
 
6
  from typing import List, Optional
7
- import os
8
-
9
- SPIDER_ROOT = pathlib.Path(os.getenv("SPIDER_ROOT", "data/spider"))
10
 
11
 
12
  @dataclass
@@ -14,40 +13,150 @@ class SpiderItem:
14
  db_id: str
15
  question: str
16
  gold_sql: str
17
- db_path: pathlib.Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def load_spider_sqlite(
21
- split: str = "dev", limit: Optional[int] = None
22
  ) -> List[SpiderItem]:
23
- fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
24
- json_path = SPIDER_ROOT / fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
  items = json.loads(json_path.read_text(encoding="utf-8"))
27
  except Exception as e:
28
  raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
29
 
30
- out: list[SpiderItem] = []
31
- for ex in items[: (limit or len(items))]:
32
- db_id = ex["db_id"]
33
- db_path = SPIDER_ROOT / "database" / db_id / f"{db_id}.sqlite"
34
- if not db_path.exists():
35
- raise FileNotFoundError(f"Missing SQLite DB for {db_id}: {db_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  out.append(
37
- SpiderItem(
38
- db_id=db_id,
39
- question=ex["question"],
40
- gold_sql=ex["query"],
41
- db_path=db_path,
42
- )
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  return out
45
 
46
 
47
- def open_readonly_connection(
48
- db_path: pathlib.Path, timeout: float = 5.0
49
- ) -> sqlite3.Connection:
50
- uri = f"file:{db_path}?mode=ro&uri=true"
51
- conn = sqlite3.connect(uri, uri=True, timeout=timeout)
52
- conn.row_factory = sqlite3.Row
53
- return conn
 
1
  from __future__ import annotations
2
+
3
  import json
4
+ import os
5
  import sqlite3
6
  from dataclasses import dataclass
7
+ from pathlib import Path
8
  from typing import List, Optional
 
 
 
9
 
10
 
11
  @dataclass
 
13
  db_id: str
14
  question: str
15
  gold_sql: str
16
+ db_path: str # absolute path to the sqlite file
17
+
18
+
19
+ # ---------- helpers ----------
20
+
21
+
22
+ def _candidate_roots(env_root: Optional[str]) -> List[Path]:
23
+ """
24
+ Build a small list of candidate Spider roots to tolerate common layouts:
25
+ - $SPIDER_ROOT
26
+ - data/spider
27
+ - data/spider/spider (when the repo was cloned into data/spider/spider)
28
+ - <env>/spider (when SPIDER_ROOT points to the parent directory)
29
+ """
30
+ cands: List[Path] = []
31
+ if env_root:
32
+ p = Path(env_root).expanduser().resolve()
33
+ cands.append(p)
34
+ cands.append((p / "spider").resolve())
35
+ # project-local defaults
36
+ here = Path.cwd().resolve()
37
+ cands.append((here / "data" / "spider").resolve())
38
+ cands.append((here / "data" / "spider" / "spider").resolve())
39
+ # de-dup
40
+ seen, uniq = set(), []
41
+ for x in cands:
42
+ if str(x) not in seen:
43
+ uniq.append(x)
44
+ seen.add(str(x))
45
+ return uniq
46
+
47
+
48
+ def _resolve_split_json(root: Path, split: str) -> Path:
49
+ """
50
+ Map split name to file name and return full path under `root`.
51
+ Spider uses:
52
+ - dev.json
53
+ - train_spider.json
54
+ """
55
+ fname = "dev.json" if split == "dev" else "train_spider.json"
56
+ return (root / fname).resolve()
57
+
58
+
59
+ def _resolve_database_dir(root: Path) -> Path:
60
+ return (root / "database").resolve()
61
+
62
+
63
+ def _ensure_exists(path: Path, kind: str) -> None:
64
+ if not path.exists():
65
+ raise FileNotFoundError(f"{kind} not found: {path}")
66
+
67
+
68
+ # ---------- public API ----------
69
 
70
 
71
  def load_spider_sqlite(
72
+ *, split: str = "dev", limit: Optional[int] = None
73
  ) -> List[SpiderItem]:
74
+ """
75
+ Load a subset of Spider (dev/train) and attach absolute sqlite db paths.
76
+ Looks under:
77
+ - $SPIDER_ROOT (if set)
78
+ - ./data/spider
79
+ - ./data/spider/spider
80
+ - $SPIDER_ROOT/spider
81
+ """
82
+ env_root = os.getenv("SPIDER_ROOT")
83
+ roots = _candidate_roots(env_root)
84
+
85
+ # find a root that actually contains the split file & database/
86
+ json_path: Optional[Path] = None
87
+ database_dir: Optional[Path] = None
88
+ chosen_root: Optional[Path] = None
89
+
90
+ for r in roots:
91
+ jp = _resolve_split_json(r, split)
92
+ dbd = _resolve_database_dir(r)
93
+ if jp.exists() and dbd.exists():
94
+ json_path, database_dir, chosen_root = jp, dbd, r
95
+ break
96
+
97
+ if json_path is None or database_dir is None:
98
+ debug = "\n".join(
99
+ f"- {str(_resolve_split_json(r, split))} | {str(_resolve_database_dir(r))}"
100
+ for r in roots
101
+ )
102
+ raise RuntimeError(
103
+ "Failed to locate Spider dataset.\n"
104
+ f"Checked candidates for split='{split}':\n{debug}\n"
105
+ "Tip: export SPIDER_ROOT=/absolute/path/to/spider "
106
+ "(the folder that directly contains dev.json/train_spider.json and database/)"
107
+ )
108
+
109
+ # read split
110
  try:
111
  items = json.loads(json_path.read_text(encoding="utf-8"))
112
  except Exception as e:
113
  raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
114
 
115
+ # build rows with absolute sqlite path
116
+ out: List[SpiderItem] = []
117
+ for obj in items:
118
+ db_id: str = obj.get("db_id", "")
119
+ q: str = obj.get("question", "").strip()
120
+ gold: str = obj.get("query", obj.get("sql", "")).strip() # Spider uses 'query'
121
+ if not (db_id and q and gold):
122
+ continue
123
+
124
+ # <root>/database/<db_id>/<db_id>.sqlite
125
+ db_file = (database_dir / db_id / f"{db_id}.sqlite").resolve()
126
+ if not db_file.exists():
127
+ # some mirrors use .db ; try a fallback
128
+ alt = (database_dir / db_id / f"{db_id}.db").resolve()
129
+ if alt.exists():
130
+ db_file = alt
131
+ else:
132
+ # skip if DB file missing
133
+ # (you could also raise here if you prefer strict behavior)
134
+ continue
135
+
136
  out.append(
137
+ SpiderItem(db_id=db_id, question=q, gold_sql=gold, db_path=str(db_file))
 
 
 
 
 
138
  )
139
+
140
+ if limit is not None and len(out) >= limit:
141
+ break
142
+
143
+ if not out:
144
+ raise RuntimeError(
145
+ f"No usable items from {json_path} (limit={limit}). "
146
+ "Check db files under database/<db_id>/<db_id>.sqlite"
147
+ )
148
+
149
+ # small info for sanity
150
+ print(
151
+ f"βœ” Spider root: {chosen_root}\n"
152
+ f"βœ” Split file: {json_path.name} ({len(out)} items)"
153
+ )
154
  return out
155
 
156
 
157
+ def open_readonly_connection(db_path: str) -> sqlite3.Connection:
158
+ """
159
+ Open SQLite in read-only mode (URI).
160
+ """
161
+ uri = f"file:{Path(db_path).resolve()}?mode=ro"
162
+ return sqlite3.connect(uri, uri=True, check_same_thread=False)
 
scripts/smoke_run.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal smoke/demo runner for the NL2SQL pipeline.
3
+
4
+ - Builds the pipeline via the official factory (no app/router imports).
5
+ - Runs a small set of demo questions against a SQLite DB.
6
+ - Works in two modes:
7
+ * Stub mode (set PYTEST_CURRENT_TEST=1) β†’ no API key needed.
8
+ * Real mode (set OPENAI_API_KEY=...) β†’ uses actual LLM provider.
9
+
10
+ Outputs:
11
+ benchmarks/results_demo/<timestamp>/
12
+ - demo.jsonl # one JSON record per query
13
+ - summary.json # latency & success overview
14
+ - results.csv # compact table for quick inspection
15
+
16
+ Usage examples:
17
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
18
+ python scripts/smoke_run.py --db-path demo.db
19
+
20
+ # With a custom dataset file (JSON: list[str] or list[{question: "..."}])
21
+ PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
22
+ python scripts/smoke_run.py --db-path demo.db --dataset-file benchmarks/demo.json
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import csv
29
+ import json
30
+ import os
31
+ import time
32
+ from pathlib import Path
33
+ from typing import Any, Dict, List, Optional
34
+ import sqlite3
35
+
36
+ from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
37
+ from adapters.db.sqlite_adapter import SQLiteAdapter
38
+
39
+ CONFIG_PATH = "configs/sqlite_pipeline.yaml"
40
+ DEFAULT_QUESTIONS: List[str] = [
41
+ "list all customers",
42
+ "show total invoices per country",
43
+ "top 3 albums by total sales",
44
+ "artists with more than 3 albums",
45
+ "number of employees per city",
46
+ ]
47
+
48
+ RESULT_ROOT = Path("benchmarks") / "results_demo"
49
+ TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
50
+ RESULT_DIR = RESULT_ROOT / TIMESTAMP
51
+
52
+
53
+ def ensure_demo_db(db_path: Path) -> None:
54
+ """Create a tiny demo SQLite DB if it doesn't exist."""
55
+ if db_path.exists():
56
+ return
57
+ db_path.parent.mkdir(parents=True, exist_ok=True)
58
+ conn = sqlite3.connect(str(db_path))
59
+ cur = conn.cursor()
60
+
61
+ # Minimal schema that matches our default demo questions
62
+ cur.executescript("""
63
+ DROP TABLE IF EXISTS customers;
64
+ DROP TABLE IF EXISTS invoices;
65
+ DROP TABLE IF EXISTS employees;
66
+ DROP TABLE IF EXISTS artists;
67
+ DROP TABLE IF EXISTS albums;
68
+
69
+ CREATE TABLE customers (
70
+ id INTEGER PRIMARY KEY,
71
+ name TEXT,
72
+ country TEXT
73
+ );
74
+
75
+ CREATE TABLE invoices (
76
+ id INTEGER PRIMARY KEY,
77
+ customer_id INTEGER,
78
+ total REAL,
79
+ country TEXT,
80
+ FOREIGN KEY (customer_id) REFERENCES customers(id)
81
+ );
82
+
83
+ CREATE TABLE employees (
84
+ id INTEGER PRIMARY KEY,
85
+ name TEXT,
86
+ city TEXT
87
+ );
88
+
89
+ CREATE TABLE artists (
90
+ id INTEGER PRIMARY KEY,
91
+ name TEXT
92
+ );
93
+
94
+ CREATE TABLE albums (
95
+ id INTEGER PRIMARY KEY,
96
+ artist_id INTEGER,
97
+ title TEXT,
98
+ sales REAL DEFAULT 0,
99
+ FOREIGN KEY (artist_id) REFERENCES artists(id)
100
+ );
101
+ """)
102
+
103
+ # Seed a bit of data
104
+ cur.executemany(
105
+ "INSERT INTO customers (id, name, country) VALUES (?, ?, ?)",
106
+ [
107
+ (1, "Alice", "USA"),
108
+ (2, "Bob", "Germany"),
109
+ (3, "Carlos", "Brazil"),
110
+ (4, "Darya", "Iran"),
111
+ ],
112
+ )
113
+ cur.executemany(
114
+ "INSERT INTO invoices (id, customer_id, total, country) VALUES (?, ?, ?, ?)",
115
+ [
116
+ (1, 1, 120.5, "USA"),
117
+ (2, 2, 75.0, "Germany"),
118
+ (3, 1, 33.2, "USA"),
119
+ (4, 3, 48.0, "Brazil"),
120
+ (5, 4, 90.0, "Iran"),
121
+ ],
122
+ )
123
+ cur.executemany(
124
+ "INSERT INTO employees (id, name, city) VALUES (?, ?, ?)",
125
+ [
126
+ (1, "Eve", "New York"),
127
+ (2, "Frank", "Berlin"),
128
+ (3, "Gita", "Tehran"),
129
+ ],
130
+ )
131
+ cur.executemany(
132
+ "INSERT INTO artists (id, name) VALUES (?, ?)",
133
+ [
134
+ (1, "ABand"),
135
+ (2, "BGroup"),
136
+ (3, "CEnsemble"),
137
+ ],
138
+ )
139
+ cur.executemany(
140
+ "INSERT INTO albums (id, artist_id, title, sales) VALUES (?, ?, ?, ?)",
141
+ [
142
+ (1, 1, "First Light", 500.0),
143
+ (2, 1, "Second Wind", 300.0),
144
+ (3, 2, "Blue Lines", 900.0),
145
+ (4, 3, "Echoes", 150.0),
146
+ ],
147
+ )
148
+
149
+ conn.commit()
150
+ conn.close()
151
+
152
+
153
+ def _ms(start_s: float) -> int:
154
+ """Convert elapsed seconds to integer milliseconds."""
155
+ return int((time.perf_counter() - start_s) * 1000)
156
+
157
+
158
+ def _derive_schema_preview(pipeline_obj: Any) -> Optional[str]:
159
+ """Try to derive schema preview from adapter/executor if available."""
160
+ for attr in ("executor", "adapter"):
161
+ obj = getattr(pipeline_obj, attr, None)
162
+ if obj and hasattr(obj, "derive_schema_preview"):
163
+ try:
164
+ return obj.derive_schema_preview() # type: ignore[no-any-return]
165
+ except Exception:
166
+ pass
167
+ return None
168
+
169
+
170
+ def _normalize_trace(trace_obj: Any) -> List[Dict[str, Any]]:
171
+ """Convert trace to a list of {stage, ms} dicts for logging/export."""
172
+ out: List[Dict[str, Any]] = []
173
+ if not isinstance(trace_obj, list):
174
+ return out
175
+ for t in trace_obj:
176
+ if isinstance(t, dict):
177
+ stage = t.get("stage", "?")
178
+ ms = t.get("duration_ms", 0)
179
+ else:
180
+ stage = getattr(t, "stage", "?")
181
+ ms = getattr(t, "duration_ms", 0)
182
+ try:
183
+ out.append({"stage": str(stage), "ms": int(ms)})
184
+ except Exception:
185
+ out.append({"stage": str(stage), "ms": 0})
186
+ return out
187
+
188
+
189
+ def _load_questions(path: Optional[str]) -> List[str]:
190
+ """Load questions from a JSON file or return defaults."""
191
+ if not path:
192
+ return DEFAULT_QUESTIONS
193
+ p = Path(path)
194
+ if not p.exists():
195
+ raise FileNotFoundError(f"dataset file not found: {p}")
196
+ data = json.loads(p.read_text(encoding="utf-8"))
197
+ if isinstance(data, list):
198
+ if all(isinstance(x, str) for x in data):
199
+ return list(data)
200
+ if all(isinstance(x, dict) and "question" in x for x in data):
201
+ return [str(x["question"]) for x in data]
202
+ raise ValueError(
203
+ "Dataset must be a JSON array of strings or objects with a 'question' field."
204
+ )
205
+
206
+
207
+ def main() -> None:
208
+ ap = argparse.ArgumentParser()
209
+ ap.add_argument(
210
+ "--db-path",
211
+ type=str,
212
+ default="demo.db",
213
+ help="Path to SQLite DB (default: demo.db)",
214
+ )
215
+ ap.add_argument(
216
+ "--dataset-file",
217
+ type=str,
218
+ default=None,
219
+ help="Optional JSON file: list[str] or list[{question: str}]",
220
+ )
221
+ ap.add_argument(
222
+ "--config",
223
+ type=str,
224
+ default=CONFIG_PATH,
225
+ help=f"Pipeline YAML (default: {CONFIG_PATH})",
226
+ )
227
+ args = ap.parse_args()
228
+
229
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
230
+
231
+ # Resolve DB path and ensure demo DB exists for quick smoke runs
232
+ db_path = Path(args.db_path).resolve()
233
+ ensure_demo_db(db_path)
234
+
235
+ # Build pipeline via the official factory (factory decides real vs stub by env)
236
+ adapter = SQLiteAdapter(str(db_path))
237
+ pipeline = pipeline_from_config_with_adapter(args.config, adapter=adapter)
238
+
239
+ schema_preview = _derive_schema_preview(pipeline)
240
+ print(f"βœ… Pipeline ready (db={db_path.name}, config={args.config})")
241
+ print(
242
+ "πŸ“„ Schema preview:",
243
+ "yes" if schema_preview else "no",
244
+ "| provider:",
245
+ "STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL",
246
+ )
247
+
248
+ questions = _load_questions(args.dataset_file)
249
+ print(f"πŸ—‚ Loaded {len(questions)} questions.")
250
+
251
+ rows: List[Dict[str, Any]] = []
252
+ for q in questions:
253
+ print(f"\n🧠 Query: {q}")
254
+ t0 = time.perf_counter()
255
+ try:
256
+ result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
257
+ latency_ms = _ms(t0) or 1 # clamp to 1ms when stubs are instant
258
+ stages = _normalize_trace(
259
+ getattr(result, "traces", getattr(result, "trace", []))
260
+ )
261
+ rows.append(
262
+ {
263
+ "query": q,
264
+ "ok": bool(getattr(result, "ok", True)),
265
+ "latency_ms": latency_ms,
266
+ "trace": stages,
267
+ "error": None,
268
+ }
269
+ )
270
+ print(f"βœ… Success ({latency_ms} ms)")
271
+ except Exception as exc:
272
+ latency_ms = _ms(t0) or 1
273
+ rows.append(
274
+ {
275
+ "query": q,
276
+ "ok": False,
277
+ "latency_ms": latency_ms,
278
+ "trace": [],
279
+ "error": str(exc),
280
+ }
281
+ )
282
+ print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
283
+
284
+ # Aggregate and persist
285
+ avg_latency = (
286
+ round(sum(r["latency_ms"] for r in rows) / max(len(rows), 1), 1)
287
+ if rows
288
+ else 0.0
289
+ )
290
+ success_rate = (
291
+ (sum(1 for r in rows if r["ok"]) / max(len(rows), 1)) if rows else 0.0
292
+ )
293
+ meta = {
294
+ "db_path": str(db_path),
295
+ "config": args.config,
296
+ "provider_hint": "STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL",
297
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
298
+ }
299
+
300
+ jsonl_path = RESULT_DIR / "demo.jsonl"
301
+ with jsonl_path.open("w", encoding="utf-8") as f:
302
+ for r in rows:
303
+ json.dump(r, f, ensure_ascii=False)
304
+ f.write("\n")
305
+
306
+ summary_path = RESULT_DIR / "summary.json"
307
+ with summary_path.open("w", encoding="utf-8") as f:
308
+ json.dump(
309
+ {"avg_latency_ms": avg_latency, "success_rate": success_rate, **meta},
310
+ f,
311
+ indent=2,
312
+ )
313
+
314
+ csv_path = RESULT_DIR / "results.csv"
315
+ with csv_path.open("w", newline="", encoding="utf-8") as f:
316
+ wr = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
317
+ wr.writeheader()
318
+ for r in rows:
319
+ wr.writerow(
320
+ {
321
+ "query": r["query"],
322
+ "ok": "βœ…" if r["ok"] else "❌",
323
+ "latency_ms": int(r["latency_ms"]),
324
+ }
325
+ )
326
+
327
+ print(
328
+ "\nπŸ’Ύ Saved outputs:\n"
329
+ f"- {jsonl_path}\n- {summary_path}\n- {csv_path}\n"
330
+ f"πŸ“Š Avg latency: {avg_latency} ms | Success rate: {success_rate:.0%}\n"
331
+ )
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()