nl2sql-copilot / benchmarks /evaluate_spider.py
Melika Kheirieh
Add first benchmark
e207f41
raw
history blame
1.54 kB
from __future__ import annotations
import time, json, csv
from pathlib import Path
from tqdm import tqdm
from app import get_schema_preview, on_generate_query, make_sql_chain
from langchain_community.utilities import SQLDatabase
from benchmarks import load_spider_sqlite
LOG_DIR = Path("logs/spider_eval")
LOG_DIR.mkdir(parents=True, exist_ok=True)
def run_eval(split="dev", limit=20):
data = load_spider_sqlite(split)
data = data[:limit]
print(f"Running eval on {len(data)} examples...")
results = []
for ex in tqdm(data):
db_path = str(ex.db_path)
schema = get_schema_preview(str(ex.db_path), 0)
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
chain = make_sql_chain(sql_db)
state = {
"db_path": db_path,
"sql_db": sql_db,
"schema_text": schema,
"chain": chain,
}
msg, sql, output = on_generate_query(ex.question, 1000, state)
results.append({
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": sql,
"status": msg,
"output": output,
})
time.sleep(0.3)
ts = int(time.time())
out_path = LOG_DIR / f"{split}_results_{ts}.jsonl"
with out_path.open("w", encoding="utf-8") as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"Wrote results → {out_path}")
if __name__ == "__main__":
run_eval("train", 20)