Melika Kheirieh commited on
Commit
ed681b1
Β·
1 Parent(s): 598536c

feat(benchmarks): align Spider eval with config-driven Pipeline and native Safety; log per-stage trace; add CSV summary

Browse files
Files changed (1) hide show
  1. benchmarks/evaluate_spider.py +180 -103
benchmarks/evaluate_spider.py CHANGED
@@ -1,17 +1,22 @@
1
- """
2
- Evaluate NL2SQL pipeline performance on Spider-like queries.
3
- Uses config-driven Pipeline, native Safety checks, and per-stage latency tracing.
4
- Outputs: JSONL (detailed logs), JSON (metrics summary), and CSV (for README).
5
- """
6
 
7
- import json
8
  import csv
 
 
9
  import time
10
  from pathlib import Path
11
- from nl2sql.pipeline import Pipeline
 
 
 
 
 
 
 
12
 
13
- # ---------- Config ----------
14
- DATASET = [
 
15
  "list all customers",
16
  "show total invoices per country",
17
  "top 3 albums by total sales",
@@ -19,107 +24,179 @@ DATASET = [
19
  "number of employees per city",
20
  ]
21
 
22
- CONFIG_PATH = "configs/pipeline.yaml"
23
- RESULT_DIR = Path("benchmarks/results")
24
- RESULT_DIR.mkdir(parents=True, exist_ok=True)
25
-
26
- # ---------- Initialize pipeline ----------
27
- pipeline = Pipeline.from_config(CONFIG_PATH)
28
- print(f"βœ… Loaded pipeline from {CONFIG_PATH}")
29
-
30
- # Optional: schema preview if adapter supports it
31
- schema_preview = None
32
- try:
33
- adapter = getattr(pipeline, "executor", None)
34
- if adapter and hasattr(adapter, "derive_schema_preview"):
35
- schema_preview = adapter.derive_schema_preview()
36
- print("πŸ“„ Derived schema preview successfully.")
37
- except Exception as e:
38
- print(f"⚠️ Could not derive schema preview: {e}")
39
-
40
- # ---------- Evaluation ----------
41
- records = []
42
- for q in DATASET:
43
- print(f"\n🧠 Query: {q}")
44
- start = time.perf_counter()
45
  try:
46
- result = pipeline.run(user_query=q, schema_preview=schema_preview)
47
- latency = int((time.perf_counter() - start) * 1000)
48
-
49
- trace = getattr(result, "trace", None)
50
- stages = []
51
- if trace:
52
- # trace might be list of StageTrace or dicts
53
- try:
54
- for t in trace:
55
- stages.append(
56
- {"stage": t.get("stage", "?"), "ms": t.get("duration_ms", 0)}
57
- if isinstance(t, dict)
58
- else {
59
- "stage": getattr(t, "stage", "?"),
60
- "ms": getattr(t, "duration_ms", 0),
61
- }
62
- )
63
- except Exception:
64
- pass
65
-
66
- records.append(
67
- {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  "query": q,
69
- "ok": True,
70
- "latency_ms": latency,
71
  "trace": stages,
72
  "error": None,
73
  }
74
- )
75
- print(f"βœ… Success ({latency} ms)")
76
- except Exception as e:
77
- latency = int((time.perf_counter() - start) * 1000)
78
- records.append(
79
- {
80
  "query": q,
81
  "ok": False,
82
- "latency_ms": latency,
83
  "trace": [],
84
- "error": str(e),
85
  }
86
- )
87
- print(f"❌ Failed: {e} ({latency} ms)")
88
-
89
- # ---------- Aggregate metrics ----------
90
- avg_latency = round(sum(r["latency_ms"] for r in records) / len(records), 1)
91
- success_rate = sum(1 for r in records if r["ok"]) / len(records)
92
- print(f"\nπŸ“Š Average latency: {avg_latency} ms | Success rate: {success_rate:.0%}")
93
-
94
- summary = {
95
- "queries_total": len(records),
96
- "success_rate": success_rate,
97
- "avg_latency_ms": avg_latency,
98
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
99
- }
100
-
101
- # ---------- Save outputs ----------
102
- jsonl_path = RESULT_DIR / "spider_eval.jsonl"
103
- with open(jsonl_path, "w", encoding="utf-8") as f:
104
- for r in records:
105
- json.dump(r, f, ensure_ascii=False)
106
- f.write("\n")
107
-
108
- summary_path = RESULT_DIR / "metrics_summary.json"
109
- with open(summary_path, "w", encoding="utf-8") as f:
110
- json.dump(summary, f, indent=2)
111
-
112
- csv_path = RESULT_DIR / "results.csv"
113
- with open(csv_path, "w", newline="", encoding="utf-8") as f:
114
- writer = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
115
- writer.writeheader()
116
- for r in records:
117
- writer.writerow(
118
- {
119
- "query": r["query"],
120
- "ok": "βœ…" if r["ok"] else "❌",
121
- "latency_ms": r["latency_ms"],
122
- }
123
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- print(f"\nπŸ’Ύ Saved logs to:\n- {jsonl_path}\n- {summary_path}\n- {csv_path}")
 
 
1
+ from __future__ import annotations
 
 
 
 
2
 
 
3
  import csv
4
+ import json
5
+ import os
6
  import time
7
  from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ # Reuse existing factories from your FastAPI router (no new DI needed)
11
+ from app.routers.nl2sql import ( # type: ignore
12
+ _pipeline as DEFAULT_PIPELINE,
13
+ _build_pipeline,
14
+ _select_adapter,
15
+ )
16
 
17
+ # -------------------- Config --------------------
18
+
19
+ DATASET: List[str] = [
20
  "list all customers",
21
  "show total invoices per country",
22
  "top 3 albums by total sales",
 
24
  "number of employees per city",
25
  ]
26
 
27
+ # DB id/mode follows your router convention; adjust if needed
28
+ DB_ID: str = os.getenv("DB_MODE", "sqlite")
29
+
30
+ # Results directory with timestamped subfolder (keeps previous runs)
31
+ RESULT_ROOT = Path("benchmarks") / "results"
32
+ TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
33
+ RESULT_DIR = RESULT_ROOT / TIMESTAMP
34
+
35
+
36
+ # -------------------- Helpers --------------------
37
+
38
+
39
+ def _int_ms(start: float) -> int:
40
+ return int((time.perf_counter() - start) * 1000)
41
+
42
+
43
+ def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
44
+ """
45
+ Try to derive schema preview from the adapter/executor if such a method exists.
46
+ Kept intentionally permissive to avoid tight coupling.
47
+ """
 
 
48
  try:
49
+ # common places the adapter might live
50
+ candidates: List[Any] = [
51
+ getattr(pipeline_obj, "executor", None),
52
+ getattr(pipeline_obj, "adapter", None),
53
+ ]
54
+ for c in candidates:
55
+ if c and hasattr(c, "derive_schema_preview"):
56
+ return c.derive_schema_preview() # type: ignore[no-any-return, call-arg]
57
+ except Exception:
58
+ pass
59
+ return None
60
+
61
+
62
+ def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
63
+ """
64
+ Normalize pipeline trace (list of dataclass or dict) to a list of dicts:
65
+ [{ "stage": str, "ms": int }, ...]
66
+ """
67
+ stages: List[Dict[str, Any]] = []
68
+ if not isinstance(trace_obj, list):
69
+ return stages
70
+
71
+ for t in trace_obj:
72
+ if isinstance(t, dict):
73
+ stage = t.get("stage", "?")
74
+ ms = t.get("duration_ms", 0)
75
+ else:
76
+ stage = getattr(t, "stage", "?")
77
+ ms = getattr(t, "duration_ms", 0)
78
+ try:
79
+ stages.append({"stage": str(stage), "ms": int(ms)})
80
+ except Exception:
81
+ stages.append({"stage": str(stage), "ms": 0})
82
+ return stages
83
+
84
+
85
+ # -------------------- Main --------------------
86
+
87
+
88
+ def main() -> None:
89
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
90
+
91
+ # Build pipeline from router factories (no new DI required)
92
+ try:
93
+ adapter = _select_adapter(DB_ID) # e.g., "sqlite" / "postgres"
94
+ pipeline = _build_pipeline(adapter)
95
+ using_default = False
96
+ except Exception:
97
+ pipeline = DEFAULT_PIPELINE
98
+ using_default = True
99
+
100
+ print(
101
+ f"βœ… Pipeline ready "
102
+ f"(db_id={DB_ID}, source={'default' if using_default else 'custom adapter'})"
103
+ )
104
+
105
+ # Optional schema preview
106
+ schema_preview = _derive_schema_preview_safe(pipeline)
107
+ if schema_preview:
108
+ print("πŸ“„ Derived schema preview βœ“")
109
+ else:
110
+ print("ℹ️ No schema preview (adapter does not expose it or not needed)")
111
+
112
+ # Evaluate
113
+ records: List[Dict[str, Any]] = []
114
+ for q in DATASET:
115
+ print(f"\n🧠 Query: {q}")
116
+ t0 = time.perf_counter()
117
+ try:
118
+ result = pipeline.run(
119
+ user_query=q,
120
+ schema_preview=schema_preview or "", # <- force str
121
+ )
122
+ latency_ms = _int_ms(t0)
123
+
124
+ # ok flag -> coerce to bool for mypy and consistency
125
+ ok_flag = bool(getattr(result, "ok", True))
126
+ stages = _to_stage_list(getattr(result, "trace", None))
127
+
128
+ rec: Dict[str, Any] = {
129
  "query": q,
130
+ "ok": ok_flag,
131
+ "latency_ms": latency_ms,
132
  "trace": stages,
133
  "error": None,
134
  }
135
+ records.append(rec)
136
+ print(f"βœ… Success ({latency_ms} ms)")
137
+ except Exception as exc:
138
+ latency_ms = _int_ms(t0)
139
+ rec = {
 
140
  "query": q,
141
  "ok": False,
142
+ "latency_ms": latency_ms,
143
  "trace": [],
144
+ "error": str(exc),
145
  }
146
+ records.append(rec)
147
+ print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
148
+
149
+ # Aggregate metrics
150
+ avg_latency = (
151
+ round(sum(r["latency_ms"] for r in records) / max(len(records), 1), 1)
152
+ if records
153
+ else 0.0
154
+ )
155
+ success_rate = (
156
+ sum(1 for r in records if bool(r.get("ok"))) / max(len(records), 1)
157
+ if records
158
+ else 0.0
159
+ )
160
+
161
+ summary: Dict[str, Any] = {
162
+ "queries_total": len(records),
163
+ "success_rate": success_rate,
164
+ "avg_latency_ms": avg_latency,
165
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
166
+ "db_id": DB_ID,
167
+ "pipeline_source": "default" if using_default else "adapter",
168
+ }
169
+
170
+ # Persist outputs
171
+ jsonl_path = RESULT_DIR / "spider_eval.jsonl"
172
+ with jsonl_path.open("w", encoding="utf-8") as f:
173
+ for r in records:
174
+ json.dump(r, f, ensure_ascii=False)
175
+ f.write("\n")
176
+
177
+ summary_path = RESULT_DIR / "metrics_summary.json"
178
+ with summary_path.open("w", encoding="utf-8") as f:
179
+ json.dump(summary, f, indent=2)
180
+
181
+ csv_path = RESULT_DIR / "results.csv"
182
+ with csv_path.open("w", newline="", encoding="utf-8") as f:
183
+ writer = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
184
+ writer.writeheader()
185
+ for r in records:
186
+ writer.writerow(
187
+ {
188
+ "query": r["query"],
189
+ "ok": "βœ…" if bool(r["ok"]) else "❌",
190
+ "latency_ms": int(r["latency_ms"]),
191
+ }
192
+ )
193
+
194
+ print(
195
+ "\nπŸ’Ύ Saved outputs:\n"
196
+ f"- {jsonl_path}\n- {summary_path}\n- {csv_path}\n"
197
+ f"πŸ“Š Avg latency: {avg_latency} ms | Success rate: {success_rate:.0%}"
198
+ )
199
+
200
 
201
+ if __name__ == "__main__":
202
+ main()