Melika Kheirieh commited on
Commit
e207f41
·
1 Parent(s): c50d16d

Add first benchmark

Browse files
.gitignore CHANGED
@@ -3,4 +3,4 @@ __pycache__/
3
  .venv/
4
  .idea/
5
  .DS_Store
6
- db/
 
3
  .venv/
4
  .idea/
5
  .DS_Store
6
+ data/
app.py CHANGED
@@ -119,6 +119,7 @@ Respond in this exact JSON format:
119
 
120
 
121
  def make_sql_chain(sql_db: SQLDatabase):
 
122
  llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
123
  chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
124
  return chain
 
119
 
120
 
121
  def make_sql_chain(sql_db: SQLDatabase):
122
+ assert hasattr(sql_db, "get_table_info"), "Expected LangChain SQLDatabase"
123
  llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
124
  chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
125
  return chain
benchmarks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .spider_loader import load_spider_sqlite, open_readonly_connection
2
+
3
+ __all__ = ["load_spider_sqlite"]
benchmarks/evaluate_spider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import time, json, csv
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+
6
+ from app import get_schema_preview, on_generate_query, make_sql_chain
7
+ from langchain_community.utilities import SQLDatabase
8
+ from benchmarks import load_spider_sqlite
9
+
10
+
11
+ LOG_DIR = Path("logs/spider_eval")
12
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
13
+
14
+ def run_eval(split="dev", limit=20):
15
+ data = load_spider_sqlite(split)
16
+ data = data[:limit]
17
+ print(f"Running eval on {len(data)} examples...")
18
+
19
+ results = []
20
+ for ex in tqdm(data):
21
+ db_path = str(ex.db_path)
22
+
23
+ schema = get_schema_preview(str(ex.db_path), 0)
24
+
25
+ sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
26
+ chain = make_sql_chain(sql_db)
27
+
28
+ state = {
29
+ "db_path": db_path,
30
+ "sql_db": sql_db,
31
+ "schema_text": schema,
32
+ "chain": chain,
33
+ }
34
+
35
+ msg, sql, output = on_generate_query(ex.question, 1000, state)
36
+
37
+ results.append({
38
+ "db_id": ex.db_id,
39
+ "question": ex.question,
40
+ "gold_sql": ex.gold_sql,
41
+ "pred_sql": sql,
42
+ "status": msg,
43
+ "output": output,
44
+ })
45
+
46
+ time.sleep(0.3)
47
+
48
+ ts = int(time.time())
49
+ out_path = LOG_DIR / f"{split}_results_{ts}.jsonl"
50
+ with out_path.open("w", encoding="utf-8") as f:
51
+ for r in results:
52
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
53
+
54
+ print(f"Wrote results → {out_path}")
55
+
56
+ if __name__ == "__main__":
57
+ run_eval("train", 20)
benchmarks/spider_loader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json, pathlib, sqlite3
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+ import os
6
+
7
+ SPIDER_ROOT = pathlib.Path(
8
+ os.getenv("SPIDER_ROOT", "data/spider")
9
+ )
10
+
11
+ @dataclass
12
+ class SpiderItem:
13
+ db_id: str
14
+ question: str
15
+ gold_sql: str
16
+ db_path: pathlib.Path
17
+
18
+ def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[SpiderItem]:
19
+ fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
20
+ json_path = SPIDER_ROOT / fn
21
+ try:
22
+ items = json.loads(json_path.read_text(encoding="utf-8"))
23
+ except Exception as e:
24
+ raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
25
+
26
+
27
+ out: list[SpiderItem] = []
28
+ for ex in items[: (limit or len(items))]:
29
+ db_id = ex["db_id"]
30
+ db_path = SPIDER_ROOT / "database" / db_id / f"{db_id}.sqlite"
31
+ if not db_path.exists():
32
+ raise FileNotFoundError(f"Missing SQLite DB for {db_id}: {db_path}")
33
+ out.append(
34
+ SpiderItem(
35
+ db_id=db_id,
36
+ question=ex["question"],
37
+ gold_sql=ex["query"],
38
+ db_path=db_path
39
+ )
40
+ )
41
+ return out
42
+
43
+
44
+ def open_readonly_connection(db_path: pathlib.Path, timeout: float=5.0) -> sqlite3.Connection:
45
+ uri = f"file:{db_path}?mode=ro&uri=true"
46
+ conn = sqlite3.connect(uri, uri=True, timeout=timeout)
47
+ conn.row_factory = sqlite3.Row
48
+ return conn
logs/spider_eval/train_results_1760375755.jsonl ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"db_id": "department_management", "question": "How many heads of the departments are older than 56 ?", "gold_sql": "SELECT count(*) FROM head WHERE age > 56", "pred_sql": "SELECT COUNT(*) FROM head WHERE age > 56", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 5\n }\n]"}
2
+ {"db_id": "department_management", "question": "List the name, born state and age of the heads of departments ordered by age.", "gold_sql": "SELECT name , born_state , age FROM head ORDER BY age", "pred_sql": "SELECT h.name, h.born_state, h.age FROM head h JOIN management m ON h.head_ID = m.head_ID ORDER BY h.age", "status": "Returned 5 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Dudley Hart\",\n \"born_state\": \"California\",\n \"age\": 52.0\n },\n {\n \"name\": \"Jeff Maggert\",\n \"born_state\": \"Delaware\",\n \"age\": 53.0\n },\n {\n \"name\": \"Franklin Langham\",\n \"born_state\": \"Connecticut\",\n \"age\": 67.0\n },\n {\n \"name\": \"Billy Mayfair\",\n \"born_state\": \"California\",\n \"age\": 69.0\n },\n {\n \"name\": \"K. J. Choi\",\n \"born_state\": \"Alabama\",\n \"age\": 69.0\n }\n]"}
3
+ {"db_id": "department_management", "question": "List the creation year, name and budget of each department.", "gold_sql": "SELECT creation , name , budget_in_billions FROM department", "pred_sql": "SELECT Creation, Name, Budget_in_Billions FROM department;", "status": "Returned 15 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1789\",\n \"Name\": \"State\",\n \"Budget_in_Billions\": 9.96\n },\n {\n \"Creation\": \"1789\",\n \"Name\": \"Treasury\",\n \"Budget_in_Billions\": 11.1\n },\n {\n \"Creation\": \"1947\",\n \"Name\": \"Defense\",\n \"Budget_in_Billions\": 439.3\n },\n {\n \"Creation\": \"1870\",\n \"Name\": \"Justice\",\n \"Budget_in_Billions\": 23.4\n },\n {\n \"Creation\": \"1849\",\n \"Name\": \"Interior\",\n \"Budget_in_Billions\": 10.7\n },\n {\n \"Creation\": \"1889\",\n \"Name\": \"Agriculture\",\n \"Budget_in_Billions\": 77.6\n },\n {\n \"Creation\": \"1903\",\n \"Name\": \"Commerce\",\n \"Budget_in_Billions\": 6.2\n },\n {\n \"Creation\": \"1913\",\n \"Name\": \"Labor\",\n \"Budget_in_Billions\": 59.7\n },\n {\n \"Creation\": \"1953\",\n \"Name\": \"Health and Human Services\",\n \"Budget_in_Billions\": 543.2\n },\n {\n \"Creation\": \"1965\",\n \"Name\": \"Housing and Urban Development\",\n \"Budget_in_Billions\": 46.2\n },\n {\n \"Creation\": \"1966\",\n \"Name\": \"Transportation\",\n \"Budget_in_Billions\": 58.0\n },\n {\n \"Creation\": \"1977\",\n \"Name\": \"Energy\",\n \"Budget_in_Billions\": 21.5\n },\n {\n \"Creation\": \"1979\",\n \"Name\": \"Education\",\n \"Budget_in_Billions\": 62.8\n },\n {\n \"Creation\": \"1989\",\n \"Name\": \"Veterans Affairs\",\n \"Budget_in_Billions\": 73.2\n },\n {\n \"Creation\": \"2002\",\n \"Name\": \"Homeland Security\",\n \"Budget_in_Billions\": 44.6\n }\n]"}
4
+ {"db_id": "department_management", "question": "What are the maximum and minimum budget of the departments?", "gold_sql": "SELECT max(budget_in_billions) , min(budget_in_billions) FROM department", "pred_sql": "SELECT MAX(Budget_in_Billions) AS Max_Budget, MIN(Budget_in_Billions) AS Min_Budget FROM department", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Max_Budget\": 543.2,\n \"Min_Budget\": 6.2\n }\n]"}
5
+ {"db_id": "department_management", "question": "What is the average number of employees of the departments whose rank is between 10 and 15?", "gold_sql": "SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15", "pred_sql": "SELECT AVG(Num_Employees) AS average_num_employees FROM department WHERE Ranking BETWEEN 10 AND 15;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_num_employees\": 105468.16666666667\n }\n]"}
6
+ {"db_id": "department_management", "question": "What are the names of the heads who are born outside the California state?", "gold_sql": "SELECT name FROM head WHERE born_state != 'California'", "pred_sql": "SELECT name FROM head WHERE born_state <> 'California';", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Tiger Woods\"\n },\n {\n \"name\": \"K. J. Choi\"\n },\n {\n \"name\": \"Jeff Maggert\"\n },\n {\n \"name\": \"Stewart Cink\"\n },\n {\n \"name\": \"P\\u00e1draig Harrington\"\n },\n {\n \"name\": \"Franklin Langham\"\n }\n]"}
7
+ {"db_id": "department_management", "question": "What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?", "gold_sql": "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T3.born_state = 'Alabama'", "pred_sql": "SELECT DISTINCT d.Creation FROM department d JOIN management m ON d.Department_ID = m.department_ID JOIN head h ON m.head_ID = h.head_ID WHERE h.born_state = 'Alabama'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1903\"\n }\n]"}
8
+ {"db_id": "department_management", "question": "What are the names of the states where at least 3 heads were born?", "gold_sql": "SELECT born_state FROM head GROUP BY born_state HAVING count(*) >= 3", "pred_sql": "SELECT born_state FROM head GROUP BY born_state HAVING COUNT(head_ID) >= 3", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
9
+ {"db_id": "department_management", "question": "In which year were most departments established?", "gold_sql": "SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1", "pred_sql": "SELECT strftime('%Y', Creation) AS Year, COUNT(*) AS DepartmentCount FROM department GROUP BY Year ORDER BY DepartmentCount DESC LIMIT 1;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Year\": \"-4707\",\n \"DepartmentCount\": 12\n }\n]"}
10
+ {"db_id": "department_management", "question": "Show the name and number of employees for the departments managed by heads whose temporary acting value is 'Yes'?", "gold_sql": "SELECT T1.name , T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT d.Name, d.Num_Employees FROM department d JOIN management m ON d.Department_ID = m.department_ID WHERE m.temporary_acting = 'Yes'", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n },\n {\n \"Name\": \"Homeland Security\",\n \"Num_Employees\": 208000.0\n },\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n }\n]"}
11
+ {"db_id": "department_management", "question": "How many acting statuses are there?", "gold_sql": "SELECT count(DISTINCT temporary_acting) FROM management", "pred_sql": "SELECT COUNT(DISTINCT temporary_acting) AS acting_status_count FROM management;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"acting_status_count\": 2\n }\n]"}
12
+ {"db_id": "department_management", "question": "How many departments are led by heads who are not mentioned?", "gold_sql": "SELECT count(*) FROM department WHERE department_id NOT IN (SELECT department_id FROM management);", "pred_sql": "SELECT COUNT(*) FROM department d LEFT JOIN management m ON d.Department_ID = m.department_ID WHERE m.head_ID IS NULL;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 11\n }\n]"}
13
+ {"db_id": "department_management", "question": "What are the distinct ages of the heads who are acting?", "gold_sql": "SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT DISTINCT h.age FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE m.temporary_acting = 'Yes';", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"age\": 53.0\n },\n {\n \"age\": 52.0\n },\n {\n \"age\": 69.0\n }\n]"}
14
+ {"db_id": "department_management", "question": "List the states where both the secretary of 'Treasury' department and the secretary of 'Homeland Security' were born.", "gold_sql": "SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Treasury' INTERSECT SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Homeland Security'", "pred_sql": "SELECT DISTINCT h1.born_state FROM management m1 JOIN head h1 ON m1.head_ID = h1.head_ID JOIN department d1 ON m1.department_ID = d1.Department_ID WHERE d1.Name = 'Treasury' AND h1.born_state IN (SELECT h2.born_state FROM management m2 JOIN head h2 ON m2.head_ID = h2.head_ID JOIN department d2 ON m2.department_ID = d2.Department_ID WHERE d2.Name = 'Homeland Security')", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
15
+ {"db_id": "department_management", "question": "Which department has more than 1 head at a time? List the id, name and the number of heads.", "gold_sql": "SELECT T1.department_id , T1.name , count(*) FROM management AS T2 JOIN department AS T1 ON T1.department_id = T2.department_id GROUP BY T1.department_id HAVING count(*) > 1", "pred_sql": "SELECT d.Department_ID, d.Name, COUNT(m.head_ID) AS Number_of_Heads FROM department d JOIN management m ON d.Department_ID = m.department_ID GROUP BY d.Department_ID, d.Name HAVING COUNT(m.head_ID) > 1", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Department_ID\": 2,\n \"Name\": \"Treasury\",\n \"Number_of_Heads\": 2\n }\n]"}
16
+ {"db_id": "department_management", "question": "Which head's name has the substring 'Ha'? List the id and name.", "gold_sql": "SELECT head_id , name FROM head WHERE name LIKE '%Ha%'", "pred_sql": "SELECT h.head_ID, h.name FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE h.name LIKE '%Ha%'", "status": "Returned 2 row(s). Showing up to 50.", "output": "[\n {\n \"head_ID\": 10,\n \"name\": \"Franklin Langham\"\n },\n {\n \"head_ID\": 4,\n \"name\": \"Dudley Hart\"\n }\n]"}
17
+ {"db_id": "farm", "question": "How many farms are there?", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204552914852769LyjH3RSk)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
18
+ {"db_id": "farm", "question": "Count the number of farms.", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204553544953937LKDeBltm)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
19
+ {"db_id": "farm", "question": "List the total number of horses on farms in ascending order.", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204554156810277QpyTydIp)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
20
+ {"db_id": "farm", "question": "What is the total horses record for each farm, sorted ascending?", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000116 (request id: 202510132045548459300909IPAkckU)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}