Melika Kheirieh
init: NL2SQL Copilot base with API and Dockerfile
570f7bd
raw
history blame
4.64 kB
# benchmarks/run.py
from __future__ import annotations
import argparse
import os
import json
import time
from pathlib import Path
# ---- app imports
from nl2sql.pipeline import Pipeline
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.safety import Safety
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
# ---- adapters
from adapters.db.sqlite_adapter import SQLiteAdapter
from adapters.llm.openai_provider import OpenAIProvider
# ---- fallbacks: Dummy LLM (so it runs without API keys)
class DummyLLM:
provider_id = "dummy-llm"
def plan(self, *, user_query: str, schema_preview: str):
text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
return text, 0, 0, 0.0
def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str, clarify_answers=None):
# naive demo SQL (so pipeline flows end-to-end)
sql = "SELECT 1 AS one;"
rationale = "Demo SQL from DummyLLM"
return sql, rationale, 0, 0, 0.0
def repair(self, *, sql: str, error_msg: str, schema_preview: str):
return sql, 0, 0, 0.0
def ensure_demo_db(path: Path) -> None:
"""Create a tiny SQLite db if missing, so executor has something to run."""
if path.exists():
return
import sqlite3
path.parent.mkdir(parents=True, exist_ok=True)
con = sqlite3.connect(path)
cur = con.cursor()
cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
cur.executemany("INSERT INTO users(id,name,spend) VALUES(?,?,?)",
[(1,"Alice",120.5),(2,"Bob",80.0),(3,"Carol",155.0)])
con.commit()
con.close()
def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
# DB adapter
db = SQLiteAdapter(str(db_path))
executor = Executor(db)
# LLM provider
if use_openai and os.getenv("OPENAI_API_KEY"):
llm = OpenAIProvider()
else:
llm = DummyLLM()
# stages
detector = AmbiguityDetector()
planner = Planner(llm)
generator = Generator(llm)
safety = Safety()
verifier = Verifier()
repair = Repair(llm)
# pipeline
return Pipeline(
detector=detector,
planner=planner,
generator=generator,
safety=safety,
executor=executor,
verifier=verifier,
repair=repair,
)
def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
results = []
for q in queries:
t0 = time.perf_counter()
r = pipeline.run(user_query=q, schema_preview=schema_preview)
latency_ms = (time.perf_counter()-t0)*1000
ok = (not r.get("ambiguous")) and ("error" not in r)
traces = r.get("traces", [])
cost_sum = 0.0
for t in traces:
try:
cost_sum += float(t.get("cost_usd", 0.0))
except Exception:
pass
results.append({
"query": q,
"exec_acc": 1.0 if ok else 0.0,
"safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
"latency_ms": latency_ms,
"cost_usd": cost_sum,
"repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
"provider": pipeline.generator.llm.provider_id if hasattr(pipeline.generator, "llm") else "unknown",
})
outfile.parent.mkdir(parents=True, exist_ok=True)
with open(outfile, "w") as f:
for row in results:
f.write(json.dumps(row) + "\n")
print(f"[OK] wrote {len(results)} rows → {outfile}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
parser.add_argument("--db", default="data/bench_demo.db")
parser.add_argument("--use-openai", action="store_true", help="Use OpenAI provider if API key present")
args = parser.parse_args()
ROOT = Path(__file__).resolve().parents[1] # project root
outfile = (ROOT / args.outfile).resolve()
db_path = (ROOT / args.db).resolve()
ensure_demo_db(db_path)
pipe = build_pipeline(db_path, use_openai=args.use_openai)
# a small demo set; replace with Spider when ready
queries = [
"show all users",
"top spenders",
"sum of spend",
]
schema_preview = "CREATE TABLE users(id INT, name TEXT, spend REAL);"
run_benchmark(queries, schema_preview, pipe, outfile)
if __name__ == "__main__":
main()