Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import json | |
| import time | |
| from pathlib import Path | |
| from typing import Iterable, List, Dict, Any, Protocol, Tuple, Optional | |
| # ---- app imports | |
| from nl2sql.pipeline import Pipeline, FinalResult | |
| from nl2sql.ambiguity_detector import AmbiguityDetector | |
| from nl2sql.planner import Planner | |
| from nl2sql.generator import Generator | |
| from nl2sql.safety import Safety | |
| from nl2sql.executor import Executor | |
| from nl2sql.verifier import Verifier | |
| from nl2sql.repair import Repair | |
| # ---- adapters | |
| from adapters.db.sqlite_adapter import SQLiteAdapter | |
| from adapters.llm.openai_provider import OpenAIProvider | |
| # ---- LLM protocol (unifies OpenAIProvider and DummyLLM for mypy) | |
| class LLMProvider(Protocol): | |
| """Minimal interface required by Planner/Generator/Repair stages.""" | |
| provider_id: str | |
| def plan( | |
| self, *, user_query: str, schema_preview: str | |
| ) -> Tuple[str, int, int, float]: ... | |
| def generate_sql( | |
| self, | |
| *, | |
| user_query: str, | |
| schema_preview: str, | |
| plan_text: str, | |
| clarify_answers: Optional[Any] = None, | |
| ) -> Tuple[str, str, int, int, float]: ... | |
| def repair( | |
| self, *, sql: str, error_msg: str, schema_preview: str | |
| ) -> Tuple[str, int, int, float]: ... | |
| # ---- fallback: Dummy LLM (so it runs without API keys) | |
| class DummyLLM: | |
| provider_id = "dummy-llm" | |
| def plan( | |
| self, *, user_query: str, schema_preview: str | |
| ) -> Tuple[str, int, int, float]: | |
| text = ( | |
| f"- understand question: {user_query}\n" | |
| "- identify tables\n- join if needed\n- filter\n- order/limit" | |
| ) | |
| return text, 0, 0, 0.0 | |
| def generate_sql( | |
| self, | |
| *, | |
| user_query: str, | |
| schema_preview: str, | |
| plan_text: str, | |
| clarify_answers: Optional[Any] = None, | |
| ) -> Tuple[str, str, int, int, float]: | |
| # naive demo SQL (so pipeline flows end-to-end) | |
| sql = "SELECT 1 AS one;" | |
| rationale = "Demo SQL from DummyLLM" | |
| return sql, rationale, 0, 0, 0.0 | |
| def repair( | |
| self, *, sql: str, error_msg: str, schema_preview: str | |
| ) -> Tuple[str, int, int, float]: | |
| return sql, 0, 0, 0.0 | |
| def ensure_demo_db(path: Path) -> None: | |
| """Create a tiny SQLite db if missing, so executor has something to run.""" | |
| if path.exists(): | |
| return | |
| import sqlite3 | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| con = sqlite3.connect(path) | |
| cur = con.cursor() | |
| cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);") | |
| cur.executemany( | |
| "INSERT INTO users(id,name,spend) VALUES(?,?,?)", | |
| [(1, "Alice", 120.5), (2, "Bob", 80.0), (3, "Carol", 155.0)], | |
| ) | |
| con.commit() | |
| con.close() | |
| def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline: | |
| # DB adapter | |
| db = SQLiteAdapter(str(db_path)) | |
| executor = Executor(db) | |
| # LLM provider (typed to the Protocol so mypy accepts either provider) | |
| llm: LLMProvider | |
| if use_openai and os.getenv("OPENAI_API_KEY"): | |
| llm = OpenAIProvider() # conforms to LLMProvider | |
| else: | |
| llm = DummyLLM() # conforms to LLMProvider | |
| # stages | |
| detector = AmbiguityDetector() | |
| planner = Planner(llm) | |
| generator = Generator(llm) | |
| safety = Safety() | |
| verifier = Verifier() | |
| repair = Repair(llm) | |
| # pipeline | |
| return Pipeline( | |
| detector=detector, | |
| planner=planner, | |
| generator=generator, | |
| safety=safety, | |
| executor=executor, | |
| verifier=verifier, | |
| repair=repair, | |
| ) | |
| def _sum_cost(traces: Iterable[Dict[str, Any]]) -> float: | |
| total = 0.0 | |
| for tr in traces: | |
| try: | |
| total += float(tr.get("cost_usd", 0.0)) | |
| except Exception: | |
| # ignore bad values | |
| pass | |
| return total | |
| def _is_safe_fail(ok: bool, details: List[str] | None) -> float: | |
| """Return 1.0 when pipeline failed due to unsafe SQL (heuristic).""" | |
| if ok: | |
| return 0.0 | |
| txt = " ".join(details or []).lower() | |
| return 1.0 if "unsafe" in txt else 0.0 | |
| def run_benchmark( | |
| queries: List[str], schema_preview: str, pipeline: Pipeline, outfile: Path | |
| ) -> None: | |
| results: List[Dict[str, Any]] = [] | |
| for q in queries: | |
| t0 = time.perf_counter() | |
| res: FinalResult = pipeline.run(user_query=q, schema_preview=schema_preview) | |
| latency_ms = (time.perf_counter() - t0) * 1000.0 | |
| ok = (not res.ambiguous) and (not res.error) and bool(res.ok) | |
| traces = res.traces or [] | |
| cost_sum = _sum_cost(traces) | |
| results.append( | |
| { | |
| "query": q, | |
| "exec_acc": 1.0 if ok else 0.0, | |
| "safe_fail": _is_safe_fail(ok, res.details), | |
| "latency_ms": latency_ms, | |
| "cost_usd": cost_sum, | |
| "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"), | |
| "provider": getattr( | |
| getattr(pipeline.generator, "llm", None), "provider_id", "unknown" | |
| ), | |
| } | |
| ) | |
| outfile.parent.mkdir(parents=True, exist_ok=True) | |
| with open(outfile, "w") as f: | |
| for row in results: | |
| f.write(json.dumps(row) + "\n") | |
| print(f"[OK] wrote {len(results)} rows → {outfile}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl") | |
| parser.add_argument("--db", default="data/bench_demo.db") | |
| parser.add_argument( | |
| "--use-openai", | |
| action="store_true", | |
| help="Use OpenAI provider if API key present", | |
| ) | |
| args = parser.parse_args() | |
| root = Path(__file__).resolve().parents[1] # project root | |
| outfile = (root / args.outfile).resolve() | |
| db_path = (root / args.db).resolve() | |
| ensure_demo_db(db_path) | |
| pipe = build_pipeline(db_path, use_openai=args.use_openai) | |
| # a small demo set; replace with Spider when ready | |
| queries = [ | |
| "show all users", | |
| "top spenders", | |
| "sum of spend", | |
| ] | |
| schema_preview = "CREATE TABLE users(id INT, name TEXT, spend REAL);" | |
| run_benchmark(queries, schema_preview, pipe, outfile) | |
| if __name__ == "__main__": | |
| main() | |