Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
e207f41
1
Parent(s):
c50d16d
Add first benchmark
Browse files- .gitignore +1 -1
- app.py +1 -0
- benchmarks/__init__.py +3 -0
- benchmarks/evaluate_spider.py +57 -0
- benchmarks/spider_loader.py +48 -0
- logs/spider_eval/train_results_1760375755.jsonl +20 -0
.gitignore
CHANGED
|
@@ -3,4 +3,4 @@ __pycache__/
|
|
| 3 |
.venv/
|
| 4 |
.idea/
|
| 5 |
.DS_Store
|
| 6 |
-
|
|
|
|
| 3 |
.venv/
|
| 4 |
.idea/
|
| 5 |
.DS_Store
|
| 6 |
+
data/
|
app.py
CHANGED
|
@@ -119,6 +119,7 @@ Respond in this exact JSON format:
|
|
| 119 |
|
| 120 |
|
| 121 |
def make_sql_chain(sql_db: SQLDatabase):
|
|
|
|
| 122 |
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
|
| 123 |
chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
|
| 124 |
return chain
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def make_sql_chain(sql_db: SQLDatabase):
|
| 122 |
+
assert hasattr(sql_db, "get_table_info"), "Expected LangChain SQLDatabase"
|
| 123 |
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
|
| 124 |
chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
|
| 125 |
return chain
|
benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .spider_loader import load_spider_sqlite, open_readonly_connection
|
| 2 |
+
|
| 3 |
+
__all__ = ["load_spider_sqlite"]
|
benchmarks/evaluate_spider.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import time, json, csv
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from app import get_schema_preview, on_generate_query, make_sql_chain
|
| 7 |
+
from langchain_community.utilities import SQLDatabase
|
| 8 |
+
from benchmarks import load_spider_sqlite
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
LOG_DIR = Path("logs/spider_eval")
|
| 12 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
def run_eval(split="dev", limit=20):
|
| 15 |
+
data = load_spider_sqlite(split)
|
| 16 |
+
data = data[:limit]
|
| 17 |
+
print(f"Running eval on {len(data)} examples...")
|
| 18 |
+
|
| 19 |
+
results = []
|
| 20 |
+
for ex in tqdm(data):
|
| 21 |
+
db_path = str(ex.db_path)
|
| 22 |
+
|
| 23 |
+
schema = get_schema_preview(str(ex.db_path), 0)
|
| 24 |
+
|
| 25 |
+
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
| 26 |
+
chain = make_sql_chain(sql_db)
|
| 27 |
+
|
| 28 |
+
state = {
|
| 29 |
+
"db_path": db_path,
|
| 30 |
+
"sql_db": sql_db,
|
| 31 |
+
"schema_text": schema,
|
| 32 |
+
"chain": chain,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
msg, sql, output = on_generate_query(ex.question, 1000, state)
|
| 36 |
+
|
| 37 |
+
results.append({
|
| 38 |
+
"db_id": ex.db_id,
|
| 39 |
+
"question": ex.question,
|
| 40 |
+
"gold_sql": ex.gold_sql,
|
| 41 |
+
"pred_sql": sql,
|
| 42 |
+
"status": msg,
|
| 43 |
+
"output": output,
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
time.sleep(0.3)
|
| 47 |
+
|
| 48 |
+
ts = int(time.time())
|
| 49 |
+
out_path = LOG_DIR / f"{split}_results_{ts}.jsonl"
|
| 50 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 51 |
+
for r in results:
|
| 52 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 53 |
+
|
| 54 |
+
print(f"Wrote results → {out_path}")
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
run_eval("train", 20)
|
benchmarks/spider_loader.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import json, pathlib, sqlite3
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
SPIDER_ROOT = pathlib.Path(
|
| 8 |
+
os.getenv("SPIDER_ROOT", "data/spider")
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SpiderItem:
|
| 13 |
+
db_id: str
|
| 14 |
+
question: str
|
| 15 |
+
gold_sql: str
|
| 16 |
+
db_path: pathlib.Path
|
| 17 |
+
|
| 18 |
+
def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[SpiderItem]:
|
| 19 |
+
fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
|
| 20 |
+
json_path = SPIDER_ROOT / fn
|
| 21 |
+
try:
|
| 22 |
+
items = json.loads(json_path.read_text(encoding="utf-8"))
|
| 23 |
+
except Exception as e:
|
| 24 |
+
raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
out: list[SpiderItem] = []
|
| 28 |
+
for ex in items[: (limit or len(items))]:
|
| 29 |
+
db_id = ex["db_id"]
|
| 30 |
+
db_path = SPIDER_ROOT / "database" / db_id / f"{db_id}.sqlite"
|
| 31 |
+
if not db_path.exists():
|
| 32 |
+
raise FileNotFoundError(f"Missing SQLite DB for {db_id}: {db_path}")
|
| 33 |
+
out.append(
|
| 34 |
+
SpiderItem(
|
| 35 |
+
db_id=db_id,
|
| 36 |
+
question=ex["question"],
|
| 37 |
+
gold_sql=ex["query"],
|
| 38 |
+
db_path=db_path
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def open_readonly_connection(db_path: pathlib.Path, timeout: float=5.0) -> sqlite3.Connection:
|
| 45 |
+
uri = f"file:{db_path}?mode=ro&uri=true"
|
| 46 |
+
conn = sqlite3.connect(uri, uri=True, timeout=timeout)
|
| 47 |
+
conn.row_factory = sqlite3.Row
|
| 48 |
+
return conn
|
logs/spider_eval/train_results_1760375755.jsonl
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"db_id": "department_management", "question": "How many heads of the departments are older than 56 ?", "gold_sql": "SELECT count(*) FROM head WHERE age > 56", "pred_sql": "SELECT COUNT(*) FROM head WHERE age > 56", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 5\n }\n]"}
|
| 2 |
+
{"db_id": "department_management", "question": "List the name, born state and age of the heads of departments ordered by age.", "gold_sql": "SELECT name , born_state , age FROM head ORDER BY age", "pred_sql": "SELECT h.name, h.born_state, h.age FROM head h JOIN management m ON h.head_ID = m.head_ID ORDER BY h.age", "status": "Returned 5 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Dudley Hart\",\n \"born_state\": \"California\",\n \"age\": 52.0\n },\n {\n \"name\": \"Jeff Maggert\",\n \"born_state\": \"Delaware\",\n \"age\": 53.0\n },\n {\n \"name\": \"Franklin Langham\",\n \"born_state\": \"Connecticut\",\n \"age\": 67.0\n },\n {\n \"name\": \"Billy Mayfair\",\n \"born_state\": \"California\",\n \"age\": 69.0\n },\n {\n \"name\": \"K. J. Choi\",\n \"born_state\": \"Alabama\",\n \"age\": 69.0\n }\n]"}
|
| 3 |
+
{"db_id": "department_management", "question": "List the creation year, name and budget of each department.", "gold_sql": "SELECT creation , name , budget_in_billions FROM department", "pred_sql": "SELECT Creation, Name, Budget_in_Billions FROM department;", "status": "Returned 15 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1789\",\n \"Name\": \"State\",\n \"Budget_in_Billions\": 9.96\n },\n {\n \"Creation\": \"1789\",\n \"Name\": \"Treasury\",\n \"Budget_in_Billions\": 11.1\n },\n {\n \"Creation\": \"1947\",\n \"Name\": \"Defense\",\n \"Budget_in_Billions\": 439.3\n },\n {\n \"Creation\": \"1870\",\n \"Name\": \"Justice\",\n \"Budget_in_Billions\": 23.4\n },\n {\n \"Creation\": \"1849\",\n \"Name\": \"Interior\",\n \"Budget_in_Billions\": 10.7\n },\n {\n \"Creation\": \"1889\",\n \"Name\": \"Agriculture\",\n \"Budget_in_Billions\": 77.6\n },\n {\n \"Creation\": \"1903\",\n \"Name\": \"Commerce\",\n \"Budget_in_Billions\": 6.2\n },\n {\n \"Creation\": \"1913\",\n \"Name\": \"Labor\",\n \"Budget_in_Billions\": 59.7\n },\n {\n \"Creation\": \"1953\",\n \"Name\": \"Health and Human Services\",\n \"Budget_in_Billions\": 543.2\n },\n {\n \"Creation\": \"1965\",\n \"Name\": \"Housing and Urban Development\",\n \"Budget_in_Billions\": 46.2\n },\n {\n \"Creation\": \"1966\",\n \"Name\": \"Transportation\",\n \"Budget_in_Billions\": 58.0\n },\n {\n \"Creation\": \"1977\",\n \"Name\": \"Energy\",\n \"Budget_in_Billions\": 21.5\n },\n {\n \"Creation\": \"1979\",\n \"Name\": \"Education\",\n \"Budget_in_Billions\": 62.8\n },\n {\n \"Creation\": \"1989\",\n \"Name\": \"Veterans Affairs\",\n \"Budget_in_Billions\": 73.2\n },\n {\n \"Creation\": \"2002\",\n \"Name\": \"Homeland Security\",\n \"Budget_in_Billions\": 44.6\n }\n]"}
|
| 4 |
+
{"db_id": "department_management", "question": "What are the maximum and minimum budget of the departments?", "gold_sql": "SELECT max(budget_in_billions) , min(budget_in_billions) FROM department", "pred_sql": "SELECT MAX(Budget_in_Billions) AS Max_Budget, MIN(Budget_in_Billions) AS Min_Budget FROM department", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Max_Budget\": 543.2,\n \"Min_Budget\": 6.2\n }\n]"}
|
| 5 |
+
{"db_id": "department_management", "question": "What is the average number of employees of the departments whose rank is between 10 and 15?", "gold_sql": "SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15", "pred_sql": "SELECT AVG(Num_Employees) AS average_num_employees FROM department WHERE Ranking BETWEEN 10 AND 15;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_num_employees\": 105468.16666666667\n }\n]"}
|
| 6 |
+
{"db_id": "department_management", "question": "What are the names of the heads who are born outside the California state?", "gold_sql": "SELECT name FROM head WHERE born_state != 'California'", "pred_sql": "SELECT name FROM head WHERE born_state <> 'California';", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"name\": \"Tiger Woods\"\n },\n {\n \"name\": \"K. J. Choi\"\n },\n {\n \"name\": \"Jeff Maggert\"\n },\n {\n \"name\": \"Stewart Cink\"\n },\n {\n \"name\": \"P\\u00e1draig Harrington\"\n },\n {\n \"name\": \"Franklin Langham\"\n }\n]"}
|
| 7 |
+
{"db_id": "department_management", "question": "What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?", "gold_sql": "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T3.born_state = 'Alabama'", "pred_sql": "SELECT DISTINCT d.Creation FROM department d JOIN management m ON d.Department_ID = m.department_ID JOIN head h ON m.head_ID = h.head_ID WHERE h.born_state = 'Alabama'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Creation\": \"1903\"\n }\n]"}
|
| 8 |
+
{"db_id": "department_management", "question": "What are the names of the states where at least 3 heads were born?", "gold_sql": "SELECT born_state FROM head GROUP BY born_state HAVING count(*) >= 3", "pred_sql": "SELECT born_state FROM head GROUP BY born_state HAVING COUNT(head_ID) >= 3", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
|
| 9 |
+
{"db_id": "department_management", "question": "In which year were most departments established?", "gold_sql": "SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1", "pred_sql": "SELECT strftime('%Y', Creation) AS Year, COUNT(*) AS DepartmentCount FROM department GROUP BY Year ORDER BY DepartmentCount DESC LIMIT 1;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Year\": \"-4707\",\n \"DepartmentCount\": 12\n }\n]"}
|
| 10 |
+
{"db_id": "department_management", "question": "Show the name and number of employees for the departments managed by heads whose temporary acting value is 'Yes'?", "gold_sql": "SELECT T1.name , T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT d.Name, d.Num_Employees FROM department d JOIN management m ON d.Department_ID = m.department_ID WHERE m.temporary_acting = 'Yes'", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n },\n {\n \"Name\": \"Homeland Security\",\n \"Num_Employees\": 208000.0\n },\n {\n \"Name\": \"Treasury\",\n \"Num_Employees\": 115897.0\n }\n]"}
|
| 11 |
+
{"db_id": "department_management", "question": "How many acting statuses are there?", "gold_sql": "SELECT count(DISTINCT temporary_acting) FROM management", "pred_sql": "SELECT COUNT(DISTINCT temporary_acting) AS acting_status_count FROM management;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"acting_status_count\": 2\n }\n]"}
|
| 12 |
+
{"db_id": "department_management", "question": "How many departments are led by heads who are not mentioned?", "gold_sql": "SELECT count(*) FROM department WHERE department_id NOT IN (SELECT department_id FROM management);", "pred_sql": "SELECT COUNT(*) FROM department d LEFT JOIN management m ON d.Department_ID = m.department_ID WHERE m.head_ID IS NULL;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"COUNT(*)\": 11\n }\n]"}
|
| 13 |
+
{"db_id": "department_management", "question": "What are the distinct ages of the heads who are acting?", "gold_sql": "SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'", "pred_sql": "SELECT DISTINCT h.age FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE m.temporary_acting = 'Yes';", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"age\": 53.0\n },\n {\n \"age\": 52.0\n },\n {\n \"age\": 69.0\n }\n]"}
|
| 14 |
+
{"db_id": "department_management", "question": "List the states where both the secretary of 'Treasury' department and the secretary of 'Homeland Security' were born.", "gold_sql": "SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Treasury' INTERSECT SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T1.name = 'Homeland Security'", "pred_sql": "SELECT DISTINCT h1.born_state FROM management m1 JOIN head h1 ON m1.head_ID = h1.head_ID JOIN department d1 ON m1.department_ID = d1.Department_ID WHERE d1.Name = 'Treasury' AND h1.born_state IN (SELECT h2.born_state FROM management m2 JOIN head h2 ON m2.head_ID = h2.head_ID JOIN department d2 ON m2.department_ID = d2.Department_ID WHERE d2.Name = 'Homeland Security')", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"born_state\": \"California\"\n }\n]"}
|
| 15 |
+
{"db_id": "department_management", "question": "Which department has more than 1 head at a time? List the id, name and the number of heads.", "gold_sql": "SELECT T1.department_id , T1.name , count(*) FROM management AS T2 JOIN department AS T1 ON T1.department_id = T2.department_id GROUP BY T1.department_id HAVING count(*) > 1", "pred_sql": "SELECT d.Department_ID, d.Name, COUNT(m.head_ID) AS Number_of_Heads FROM department d JOIN management m ON d.Department_ID = m.department_ID GROUP BY d.Department_ID, d.Name HAVING COUNT(m.head_ID) > 1", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Department_ID\": 2,\n \"Name\": \"Treasury\",\n \"Number_of_Heads\": 2\n }\n]"}
|
| 16 |
+
{"db_id": "department_management", "question": "Which head's name has the substring 'Ha'? List the id and name.", "gold_sql": "SELECT head_id , name FROM head WHERE name LIKE '%Ha%'", "pred_sql": "SELECT h.head_ID, h.name FROM head h JOIN management m ON h.head_ID = m.head_ID WHERE h.name LIKE '%Ha%'", "status": "Returned 2 row(s). Showing up to 50.", "output": "[\n {\n \"head_ID\": 10,\n \"name\": \"Franklin Langham\"\n },\n {\n \"head_ID\": 4,\n \"name\": \"Dudley Hart\"\n }\n]"}
|
| 17 |
+
{"db_id": "farm", "question": "How many farms are there?", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204552914852769LyjH3RSk)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
|
| 18 |
+
{"db_id": "farm", "question": "Count the number of farms.", "gold_sql": "SELECT count(*) FROM farm", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204553544953937LKDeBltm)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
|
| 19 |
+
{"db_id": "farm", "question": "List the total number of horses on farms in ascending order.", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000114 (request id: 20251013204554156810277QpyTydIp)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
|
| 20 |
+
{"db_id": "farm", "question": "What is the total horses record for each farm, sorted ascending?", "gold_sql": "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC", "pred_sql": "", "status": "Error: Error code: 403 - {'error': {'message': 'token quota is not enough, token remain quota: $0.000008, need quota: $0.000116 (request id: 202510132045548459300909IPAkckU)', 'type': 'new_api_error', 'param': '', 'code': 'pre_consume_token_quota_failed'}}", "output": ""}
|