Spaces:
Sleeping
Sleeping
Melika Kheirieh
feat(core): stabilize benchmark pipeline with accurate latency tracking, retry-empty handling, and refined plots
bf06cf7
| """ | |
| Spider benchmark evaluator (pro): | |
| - Computes EM / SM / ExecAcc vs. gold SQL | |
| - Records per-sample latency and (if present) per-stage timings from pipeline traces | |
| - Persists eval.jsonl (per-sample), summary.json (aggregates incl. p50/p95, per-stage means), results.csv | |
| - No external deps; percentile and normalization are implemented locally. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| import sqlite3 | |
| import time | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| from nl2sql.pipeline_factory import pipeline_from_config_with_adapter | |
| from adapters.db.sqlite_adapter import SQLiteAdapter | |
| from benchmarks.spider_loader import load_spider_sqlite | |
| # -------------------------- Config -------------------------- | |
| RESULT_ROOT = Path("benchmarks/results_pro") | |
| TIMESTAMP = time.strftime("%Y%m%d-%H%M%S") | |
| RESULT_DIR = RESULT_ROOT / TIMESTAMP | |
| STAGES = [ | |
| "detector", | |
| "planner", | |
| "generator", | |
| "safety", | |
| "executor", | |
| "verifier", | |
| "repair", | |
| ] | |
| # -------------------------- SQL utils ----------------------- | |
| def extract_clean_sql(text: str | None) -> str: | |
| """Extract a clean SQL string from LLM-ish output (may include fences/JSON).""" | |
| sql = (text or "").strip() | |
| # strip ```sql fences | |
| sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.I) | |
| sql = sql.replace("```", "") | |
| # JSON-like {"sql": "..."} | |
| m = re.search(r'"sql"\s*:\s*"([^"]+)"', sql) | |
| if m: | |
| sql = m.group(1) | |
| # unescape | |
| sql = sql.replace('\\"', '"').replace("\\n", " ").replace("\\t", " ") | |
| # find first SQL-ish keyword | |
| m2 = re.search(r"\b(select|with|insert|update|delete)\b[\s\S]+", sql, re.I) | |
| if m2: | |
| sql = m2.group(0) | |
| sql = re.sub(r"\s+", " ", sql).strip().rstrip(";") | |
| return sql | |
| def normalize_sql(sql: str) -> str: | |
| """ | |
| Conservative normalization for exact-match (EM): | |
| - Trim, collapse spaces, drop trailing ';' | |
| - Drop trailing 'LIMIT n' | |
| - Remove table prefixes only in single-table, no-join queries | |
| - Unquote identifiers like `name` or "name" | |
| - Uppercase common SQL keywords (string literals unaffected) | |
| """ | |
| if not sql: | |
| return "" | |
| s = sql.strip() | |
| # Collapse whitespace early and drop trailing ';' | |
| s = re.sub(r"\s+", " ", s).strip().rstrip(";") | |
| # Drop trailing LIMIT n | |
| s = re.sub(r"(?i)\s+LIMIT\s+\d+\s*$", "", s) | |
| # Remove table prefixes only if single FROM and no JOIN | |
| lower = s.lower() | |
| if lower.count(" from ") == 1 and " join " not in lower: | |
| m = re.search(r"(?i)\bfrom\s+([a-z_][a-z0-9_]*)", s, flags=re.IGNORECASE) | |
| if m: | |
| table = m.group(1) | |
| s = re.sub(rf"\b{re.escape(table)}\.(\w+)\b", r"\1", s) | |
| # Unquote identifiers: `foo` -> foo, "foo" -> foo (strings '...' remain) | |
| s = re.sub(r"`([A-Za-z_]\w*)`", r"\1", s) | |
| s = re.sub(r'"([A-Za-z_]\w*)"', r"\1", s) | |
| # Normalize comma spacing: "a , b" -> "a, b" | |
| s = re.sub(r"\s*,\s*", ", ", s) | |
| # Final whitespace collapse | |
| s = re.sub(r"\s+", " ", s).strip() | |
| # Uppercase common keywords (word-boundary safe) | |
| for kw in [ | |
| "select", | |
| "from", | |
| "where", | |
| "group by", | |
| "order by", | |
| "having", | |
| "limit", | |
| "join", | |
| "on", | |
| "and", | |
| "or", | |
| "asc", | |
| "desc", | |
| ]: | |
| s = re.sub(rf"(?i)\b{kw}\b", kw.upper(), s) | |
| return s | |
| # ---------------------- Schema extraction ------------------- | |
| def get_database_schema(db_path: Path) -> Dict[str, Any]: | |
| """Extract schema from SQLite database (tables, columns, FKs).""" | |
| schema: Dict[str, Any] = {"tables": {}} | |
| if not db_path.exists(): | |
| return schema | |
| conn = sqlite3.connect(str(db_path)) | |
| cur = conn.cursor() | |
| try: | |
| cur.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" | |
| ) | |
| for (table,) in cur.fetchall(): | |
| cur.execute(f"PRAGMA table_info('{table}')") | |
| cols = [ | |
| {"name": c[1], "type": c[2], "primary_key": bool(c[5])} | |
| for c in cur.fetchall() | |
| ] | |
| cur.execute(f"PRAGMA foreign_key_list('{table}')") | |
| fks = [ | |
| {"column": fk[3], "referenced_table": fk[2], "referenced_column": fk[4]} | |
| for fk in cur.fetchall() | |
| ] | |
| schema["tables"][table] = {"columns": cols, "foreign_keys": fks} | |
| finally: | |
| conn.close() | |
| return schema | |
| def format_schema_for_prompt(schema: Dict[str, Any]) -> str: | |
| """Plain-text schema for prompt (minimal but helpful).""" | |
| if not schema.get("tables"): | |
| return "" | |
| lines: List[str] = [] | |
| for t, info in schema["tables"].items(): | |
| cols = [ | |
| f"{c['name']} {c['type']}{' PK' if c.get('primary_key') else ''}" | |
| for c in info.get("columns", []) | |
| ] | |
| lines.append(f"Table: {t}") | |
| lines.append(f"Columns: {', '.join(cols)}") | |
| fks = info.get("foreign_keys") or [] | |
| if fks: | |
| lines.append( | |
| "FKs: " | |
| + ", ".join( | |
| f"{fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}" | |
| for fk in fks | |
| ) | |
| ) | |
| lines.append("") | |
| return "\n".join(lines).strip() | |
| # ---------------------- Exec/eval metrics ------------------- | |
| def _exec_sql(db: Path, sql: str) -> Tuple[bool, List[Tuple]]: | |
| if not sql: | |
| return False, [] | |
| try: | |
| conn = sqlite3.connect(str(db)) | |
| cur = conn.cursor() | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| conn.close() | |
| return True, rows | |
| except Exception: | |
| return False, [] | |
| def _same_rows(a: List[Tuple], b: List[Tuple]) -> bool: | |
| return set(a) == set(b) and len(a) == len(b) | |
| def evaluate_sql(pred: str, gold: str, db: Path) -> Dict[str, float]: | |
| """Return {'em', 'sm', 'exec'} in {0.0,1.0} (sm ~ set-match).""" | |
| em = 1.0 if normalize_sql(pred) == normalize_sql(gold) else 0.0 | |
| gold_ok, gold_rows = _exec_sql(db, gold) | |
| pred_ok, pred_rows = _exec_sql(db, pred) | |
| sm = 0.0 | |
| exec_acc = 0.0 | |
| if gold_ok and pred_ok: | |
| if _same_rows(gold_rows, pred_rows): | |
| sm = 1.0 | |
| exec_acc = 1.0 | |
| else: | |
| exec_acc = 0.5 # partial credit for executing but mismatched rows | |
| return {"em": em, "sm": sm, "exec": exec_acc} | |
| # ---------------------- Trace flatten helpers ------------------- | |
| def _flatten_trace_entry(d: Dict[str, Any]) -> Dict[str, Any]: | |
| out = dict(d or {}) | |
| notes = out.pop("notes", {}) or {} | |
| # promote selected keys to top-level for easier analysis | |
| for k in ( | |
| "tokens_in", | |
| "tokens_out", | |
| "cost_usd", | |
| "sql_length", | |
| "row_count", | |
| "verified", | |
| "error_type", | |
| "repair_attempts", | |
| "skipped", | |
| "col_count", | |
| ): | |
| if k in notes: | |
| out[k] = notes[k] | |
| if notes: | |
| out["notes"] = notes | |
| return out | |
| def _per_stage_ms(trace_list: List[Dict[str, Any]]) -> Dict[str, float]: | |
| acc = {s: 0.0 for s in STAGES} | |
| cnt = {s: 0 for s in STAGES} | |
| for t in trace_list: | |
| s = t.get("stage") | |
| if s in acc: | |
| ms = t.get("duration_ms", t.get("ms", 0.0)) | |
| try: | |
| acc[s] += float(ms) | |
| cnt[s] += 1 | |
| except Exception: | |
| pass | |
| return {s: round(acc[s] / cnt[s], 2) if cnt[s] else 0.0 for s in STAGES} | |
| # ---------------------- Dataclass + runner ------------------ | |
| class SpiderSample: | |
| question: str | |
| db_id: str | |
| db_path: Path | |
| gold_sql: str | |
| def _percentile(values: List[float], p: float) -> float: | |
| """Compute p-th percentile (0..100) without numpy.""" | |
| if not values: | |
| return 0.0 | |
| vals = sorted(values) | |
| k = (len(vals) - 1) * (p / 100.0) | |
| f = int(k) | |
| c = min(f + 1, len(vals) - 1) | |
| if f == c: | |
| return float(vals[int(k)]) | |
| return float(vals[f] * (c - k) + vals[c] * (k - f)) | |
| def _stage_ms_from_trace(trace_item: Dict[str, Any]) -> float: | |
| """Accepts {'stage':..., 'ms':...} OR {'stage':..., 'duration_ms':...}.""" | |
| if not trace_item: | |
| return 0.0 | |
| if "ms" in trace_item: | |
| try: | |
| return float(trace_item["ms"]) | |
| except Exception: | |
| return 0.0 | |
| if "duration_ms" in trace_item: | |
| try: | |
| return float(trace_item["duration_ms"]) | |
| except Exception: | |
| return 0.0 | |
| return 0.0 | |
| def _collect_stage_means(eval_rows: List[Dict[str, Any]]) -> Dict[str, float]: | |
| """Average per-stage ms across all records (0 if absent).""" | |
| totals = {s: 0.0 for s in STAGES} | |
| counts = {s: 0 for s in STAGES} | |
| for r in eval_rows: | |
| trace_list = r.get("trace") or r.get("traces") or [] | |
| for t in trace_list: | |
| s = t.get("stage") | |
| if s in totals: | |
| ms = _stage_ms_from_trace(t) | |
| totals[s] += ms | |
| counts[s] += 1 | |
| return {s: round(totals[s] / counts[s], 2) if counts[s] else 0.0 for s in STAGES} | |
| def run_pipeline_on_sample( | |
| pipeline: Any, | |
| sample: SpiderSample, | |
| schema_cache: Dict[str, str], | |
| debug: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Run pipeline on one sample and extract normalized prediction + traces.""" | |
| # cache schema | |
| if sample.db_id not in schema_cache: | |
| schema_dict = get_database_schema(sample.db_path) | |
| schema_cache[sample.db_id] = format_schema_for_prompt(schema_dict) | |
| if debug: | |
| print( | |
| f" [schema] Loaded {len(schema_cache[sample.db_id])} chars for {sample.db_id}" | |
| ) | |
| schema = schema_cache[sample.db_id] | |
| try: | |
| res = pipeline.run(user_query=sample.question, schema_preview=schema) | |
| # extract SQL | |
| pred_sql = "" | |
| if hasattr(res, "sql") and res.sql: | |
| pred_sql = extract_clean_sql(res.sql) | |
| else: | |
| for attr in ("final_sql", "generated_sql", "answer"): | |
| if getattr(res, attr, None): | |
| pred_sql = extract_clean_sql(str(getattr(res, attr))) | |
| if pred_sql: | |
| break | |
| return { | |
| "ok": bool(getattr(res, "ok", True)), | |
| "sql": pred_sql, | |
| "trace": getattr(res, "traces", []) or getattr(res, "trace", []), | |
| "error": None, | |
| } | |
| except Exception as e: | |
| if debug: | |
| import traceback | |
| traceback.print_exc() | |
| return {"ok": False, "sql": "", "trace": [], "error": str(e)} | |
| # --------------------------- Main -------------------------- | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Evaluate NL2SQL on Spider (pro)") | |
| ap.add_argument("--spider", action="store_true", help="Use Spider dataset loader") | |
| ap.add_argument("--split", default="dev", choices=["dev", "train"]) | |
| ap.add_argument("--limit", type=int, default=20) | |
| ap.add_argument("--debug", action="store_true") | |
| ap.add_argument("--config", default="configs/sqlite_pipeline.yaml") | |
| args = ap.parse_args() | |
| if not args.spider: | |
| print("Use --spider to run Spider evaluation.") | |
| return | |
| # load items | |
| print(f"Loading Spider {args.split} split...") | |
| items = load_spider_sqlite(split=args.split, limit=args.limit) | |
| if not items: | |
| print("❌ No samples loaded. Check SPIDER_ROOT.") | |
| return | |
| print(f"✔ Loaded {len(items)} samples") | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| schema_cache: Dict[str, str] = {} | |
| eval_rows: List[Dict[str, Any]] = [] | |
| for i, it in enumerate(items, 1): | |
| sample = SpiderSample( | |
| question=it.question, | |
| db_id=it.db_id, | |
| db_path=Path(it.db_path), | |
| gold_sql=it.gold_sql, | |
| ) | |
| print(f"\n🧠 [{i}/{len(items)}] [{sample.db_id}] {sample.question}") | |
| adapter = SQLiteAdapter(str(sample.db_path)) | |
| pipeline = pipeline_from_config_with_adapter(args.config, adapter=adapter) | |
| t0 = time.perf_counter() | |
| out = run_pipeline_on_sample(pipeline, sample, schema_cache, args.debug) | |
| latency_ms = int((time.perf_counter() - t0) * 1000) | |
| metrics = evaluate_sql(out["sql"], sample.gold_sql, sample.db_path) | |
| row = { | |
| "source": "spider", | |
| "db_id": sample.db_id, | |
| "query": sample.question, | |
| "gold_sql": sample.gold_sql, | |
| "pred_sql": out["sql"], | |
| "ok": out["ok"], | |
| "latency_ms": latency_ms, | |
| "em": metrics["em"], | |
| "sm": metrics["sm"], | |
| "exec_acc": metrics["exec"], | |
| "error": out.get("error"), | |
| "trace": out.get("trace", []), | |
| } | |
| eval_rows.append(row) | |
| if args.debug: | |
| status = "✅" if row["ok"] and row["em"] == 1.0 else "⚠️" | |
| print( | |
| f"{status} ({latency_ms} ms) | EM={row['em']} SM={row['sm']} ExecAcc={row['exec_acc']}" | |
| ) | |
| if row["em"] < 1.0: | |
| print(f" gold: {sample.gold_sql}") | |
| print(f" pred: {out['sql'] or 'EMPTY'}") | |
| # persist eval.jsonl | |
| RESULT_ROOT.mkdir(parents=True, exist_ok=True) | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| with (RESULT_DIR / "eval.jsonl").open("w", encoding="utf-8") as f: | |
| for r in eval_rows: | |
| json.dump(r, f, ensure_ascii=False) | |
| f.write("\n") | |
| # aggregates | |
| total = len(eval_rows) | |
| success = sum(1 for r in eval_rows if r["ok"]) | |
| avg_em = sum(r["em"] for r in eval_rows) / total if total else 0.0 | |
| avg_sm = sum(r["sm"] for r in eval_rows) / total if total else 0.0 | |
| avg_exec = sum(r["exec_acc"] for r in eval_rows) / total if total else 0.0 | |
| avg_lat = sum(r["latency_ms"] for r in eval_rows) / total if total else 0.0 | |
| p50 = _percentile([r["latency_ms"] for r in eval_rows], 50.0) | |
| p95 = _percentile([r["latency_ms"] for r in eval_rows], 95.0) | |
| stage_means = _collect_stage_means(eval_rows) | |
| summary = { | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| "split": args.split, | |
| "config": args.config, | |
| "total": total, | |
| "success": success, | |
| "success_rate": round(success / total, 3) if total else 0.0, | |
| "avg_latency_ms": round(avg_lat, 1), | |
| "p50_latency_ms": round(p50, 1), | |
| "p95_latency_ms": round(p95, 1), | |
| "EM": round(avg_em, 3), | |
| "SM": round(avg_sm, 3), | |
| "ExecAcc": round(avg_exec, 3), | |
| **{f"{s}_avg_ms": stage_means[s] for s in STAGES}, | |
| } | |
| (RESULT_DIR / "summary.json").write_text( | |
| json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| # CSV | |
| with (RESULT_DIR / "results.csv").open("w", encoding="utf-8") as f: | |
| f.write("db_id,query,ok,em,sm,exec_acc,latency_ms\n") | |
| for r in eval_rows: | |
| f.write( | |
| f"{r['db_id']},{json.dumps(r['query'])},{'✅' if r['ok'] else '❌'}," | |
| f"{r['em']},{r['sm']},{r['exec_acc']},{r['latency_ms']}\n" | |
| ) | |
| print("\n================== Evaluation Summary ==================") | |
| print(f"Total samples: {total}") | |
| print(f"Successful runs: {success} ({summary['success_rate'] * 100:.1f}%)") | |
| print(f"Avg EM: {summary['EM']}") | |
| print(f"Avg SM: {summary['SM']}") | |
| print(f"Avg ExecAcc: {summary['ExecAcc']}") | |
| print( | |
| f"Avg Latency: {summary['avg_latency_ms']} ms | p50={summary['p50_latency_ms']} ms | p95={summary['p95_latency_ms']} ms" | |
| ) | |
| print(f"Results saved to {RESULT_DIR}") | |
| print("========================================================") | |
| if __name__ == "__main__": | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| main() | |