File size: 1,541 Bytes
e207f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations
import time, json, csv
from pathlib import Path
from tqdm import tqdm

from  app import get_schema_preview, on_generate_query, make_sql_chain
from langchain_community.utilities import SQLDatabase
from benchmarks import load_spider_sqlite


LOG_DIR = Path("logs/spider_eval")
LOG_DIR.mkdir(parents=True, exist_ok=True)

def run_eval(split="dev", limit=20):
    data = load_spider_sqlite(split)
    data = data[:limit]
    print(f"Running eval on {len(data)} examples...")

    results = []
    for ex in tqdm(data):
        db_path = str(ex.db_path)

        schema = get_schema_preview(str(ex.db_path), 0)

        sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
        chain = make_sql_chain(sql_db)

        state = {
            "db_path": db_path,
            "sql_db": sql_db,
            "schema_text": schema,
            "chain": chain,
        }

        msg, sql, output = on_generate_query(ex.question, 1000, state)

        results.append({
            "db_id": ex.db_id,
            "question": ex.question,
            "gold_sql": ex.gold_sql,
            "pred_sql": sql,
            "status": msg,
            "output": output,
        })

        time.sleep(0.3)

    ts = int(time.time())
    out_path = LOG_DIR / f"{split}_results_{ts}.jsonl"
    with out_path.open("w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print(f"Wrote results → {out_path}")

if __name__ == "__main__":
    run_eval("train", 20)