Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
eee3f75
1
Parent(s):
a337fad
fix(types): resolve mypy errors and make pytest pass
Browse files- app/schemas.py +3 -3
- benchmarks/evaluate_spider.py +211 -55
- benchmarks/run.py +34 -10
app/schemas.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
-
from typing import List, Optional, Any, Dict,
|
| 3 |
|
| 4 |
|
| 5 |
class NL2SQLRequest(BaseModel):
|
|
@@ -21,7 +21,7 @@ class NL2SQLResponse(BaseModel):
|
|
| 21 |
ambiguous: bool = False
|
| 22 |
sql: Optional[str] = None
|
| 23 |
rationale: Optional[str] = None
|
| 24 |
-
traces:
|
| 25 |
|
| 26 |
|
| 27 |
class ClarifyResponse(BaseModel):
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional, Any, Dict, Mapping, Sequence
|
| 3 |
|
| 4 |
|
| 5 |
class NL2SQLRequest(BaseModel):
|
|
|
|
| 21 |
ambiguous: bool = False
|
| 22 |
sql: Optional[str] = None
|
| 23 |
rationale: Optional[str] = None
|
| 24 |
+
traces: Sequence[TraceModel | Mapping[str, Any]] = Field(default_factory=list)
|
| 25 |
|
| 26 |
|
| 27 |
class ClarifyResponse(BaseModel):
|
benchmarks/evaluate_spider.py
CHANGED
|
@@ -1,38 +1,96 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
import json
|
| 4 |
import subprocess
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
-
from
|
| 7 |
|
| 8 |
-
from
|
| 9 |
from langchain_community.utilities import SQLDatabase
|
| 10 |
-
from benchmarks import load_spider_sqlite
|
| 11 |
-
|
| 12 |
from sqlglot import parse_one, exp
|
| 13 |
from sqlglot.errors import ParseError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
LOG_DIR = Path("logs/spider_eval")
|
| 16 |
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def normalize_sql(sql: str) -> str:
|
| 20 |
-
# نسخه ساده؛ میتونی قویترش کنی با پارس + بازسازی
|
| 21 |
return " ".join(sql.lower().strip().split())
|
| 22 |
|
| 23 |
|
| 24 |
-
def compare_results(
|
|
|
|
|
|
|
| 25 |
if pred_rows is None or gold_rows is None:
|
| 26 |
return False
|
| 27 |
-
# اگر ترتیب مهم نیست
|
| 28 |
return set(pred_rows) == set(gold_rows)
|
| 29 |
|
| 30 |
|
| 31 |
-
def try_execute_sql(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
start = time.time()
|
| 33 |
try:
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
return rows, time.time() - start, None
|
|
|
|
| 36 |
except Exception as e:
|
| 37 |
return None, time.time() - start, str(e)
|
| 38 |
|
|
@@ -44,7 +102,7 @@ def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
|
|
| 44 |
except Exception:
|
| 45 |
return False
|
| 46 |
|
| 47 |
-
def normalize_ast(node: exp.Expression):
|
| 48 |
for name, arg in node.args.items():
|
| 49 |
if isinstance(arg, list):
|
| 50 |
arg.sort(key=lambda x: str(x))
|
|
@@ -73,19 +131,7 @@ def get_git_commit_hash() -> str:
|
|
| 73 |
return "UNKNOWN"
|
| 74 |
|
| 75 |
|
| 76 |
-
|
| 77 |
-
exp.Insert,
|
| 78 |
-
exp.Delete,
|
| 79 |
-
exp.Update,
|
| 80 |
-
exp.Drop,
|
| 81 |
-
exp.Alter,
|
| 82 |
-
exp.Attach,
|
| 83 |
-
exp.Pragma,
|
| 84 |
-
exp.Create,
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
|
| 89 |
try:
|
| 90 |
ast = parse_one(sql, read=dialect)
|
| 91 |
except ParseError:
|
|
@@ -98,7 +144,104 @@ def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
|
|
| 98 |
return True
|
| 99 |
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
data = load_spider_sqlite(split)
|
| 103 |
if len(data) < limit:
|
| 104 |
limit = len(data)
|
|
@@ -113,7 +256,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 113 |
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
|
| 114 |
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
|
| 115 |
|
| 116 |
-
done = set()
|
| 117 |
if resume and results_fn.exists():
|
| 118 |
with results_fn.open("r", encoding="utf-8") as f:
|
| 119 |
for line in f:
|
|
@@ -126,6 +269,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 126 |
pass
|
| 127 |
|
| 128 |
write_header = not results_fn.exists()
|
|
|
|
|
|
|
| 129 |
with (
|
| 130 |
results_fn.open("a", encoding="utf-8") as fout,
|
| 131 |
pred_txt.open("a", encoding="utf-8") as fpred,
|
|
@@ -141,25 +286,48 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 141 |
fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
|
| 142 |
fout.flush()
|
| 143 |
|
| 144 |
-
agg = []
|
| 145 |
for ex in tqdm(data):
|
| 146 |
key = (ex.db_id, ex.question)
|
| 147 |
if resume and key in done:
|
| 148 |
continue
|
| 149 |
|
| 150 |
db_path = str(ex.db_path)
|
| 151 |
-
schema =
|
| 152 |
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
| 153 |
-
chain = make_sql_chain(sql_db)
|
| 154 |
-
state = {
|
| 155 |
-
"db_path": db_path,
|
| 156 |
-
"sql_db": sql_db,
|
| 157 |
-
"schema_text": schema,
|
| 158 |
-
"chain": chain,
|
| 159 |
-
}
|
| 160 |
|
| 161 |
t0 = time.time()
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
gen_time = time.time() - t0
|
| 164 |
|
| 165 |
safe_flag = is_safe_sql(sql)
|
|
@@ -197,21 +365,9 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 197 |
gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
|
| 198 |
|
| 199 |
skip = gold_error is not None
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
if not skip
|
| 203 |
-
try:
|
| 204 |
-
em = normalize_sql(sql) == normalize_sql(ex.gold_sql)
|
| 205 |
-
except Exception:
|
| 206 |
-
pass
|
| 207 |
-
|
| 208 |
-
em_struct = False
|
| 209 |
-
if not skip:
|
| 210 |
-
em_struct = exact_match_structural(sql, ex.gold_sql)
|
| 211 |
-
|
| 212 |
-
exec_acc = False
|
| 213 |
-
if not skip:
|
| 214 |
-
exec_acc = compare_results(pred_rows, gold_rows)
|
| 215 |
|
| 216 |
rec = {
|
| 217 |
"db_id": ex.db_id,
|
|
@@ -231,7 +387,6 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 231 |
"execution_accuracy": exec_acc,
|
| 232 |
"safe_check_failed": False,
|
| 233 |
}
|
| 234 |
-
|
| 235 |
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 236 |
fout.flush()
|
| 237 |
fpred.write(f"{sql}\t{ex.db_id}\n")
|
|
@@ -246,7 +401,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 246 |
valid = [
|
| 247 |
r
|
| 248 |
for r in agg
|
| 249 |
-
if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None
|
| 250 |
]
|
| 251 |
total_valid = len(valid)
|
| 252 |
total_all = len(agg)
|
|
@@ -263,8 +418,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 263 |
if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
|
| 264 |
)
|
| 265 |
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
|
| 266 |
-
avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
|
| 267 |
-
avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
|
| 268 |
|
| 269 |
metrics = {
|
| 270 |
"commit_hash": commit_hash,
|
|
@@ -282,6 +437,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 282 |
"run_id": start_ts,
|
| 283 |
}
|
| 284 |
|
|
|
|
| 285 |
with metrics_fn.open("w", encoding="utf-8") as fm:
|
| 286 |
json.dump(metrics, fm, ensure_ascii=False, indent=2)
|
| 287 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import json
|
| 4 |
import subprocess
|
| 5 |
+
import time
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Any, Iterable, Optional, Tuple, cast
|
| 8 |
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
from langchain_community.utilities import SQLDatabase
|
|
|
|
|
|
|
| 11 |
from sqlglot import parse_one, exp
|
| 12 |
from sqlglot.errors import ParseError
|
| 13 |
+
from sqlalchemy import create_engine, inspect
|
| 14 |
+
from spider_loader import load_spider_sqlite
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _try_import_pipeline():
|
| 18 |
+
"""
|
| 19 |
+
Try multiple plausible entrypoints from nl2sql.
|
| 20 |
+
Returns a tuple of callables or None:
|
| 21 |
+
(make_pipeline | None, run_function | None, PipelineClass | None)
|
| 22 |
+
"""
|
| 23 |
+
make_pipeline = None
|
| 24 |
+
run_fn = None
|
| 25 |
+
PipelineCls = None
|
| 26 |
+
try:
|
| 27 |
+
from nl2sql.pipeline import make_pipeline as _mk # type: ignore
|
| 28 |
+
|
| 29 |
+
make_pipeline = _mk
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
try:
|
| 33 |
+
from nl2sql.pipeline import run_nl2sql as _run # type: ignore
|
| 34 |
+
|
| 35 |
+
run_fn = _run
|
| 36 |
+
except Exception:
|
| 37 |
+
pass
|
| 38 |
+
try:
|
| 39 |
+
from nl2sql.pipeline import Pipeline as _P # type: ignore
|
| 40 |
+
|
| 41 |
+
PipelineCls = _P
|
| 42 |
+
except Exception:
|
| 43 |
+
pass
|
| 44 |
+
return make_pipeline, run_fn, PipelineCls
|
| 45 |
+
|
| 46 |
|
| 47 |
LOG_DIR = Path("logs/spider_eval")
|
| 48 |
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 49 |
|
| 50 |
+
FORBIDDEN_NODES: Tuple[type, ...] = (
|
| 51 |
+
exp.Insert,
|
| 52 |
+
exp.Delete,
|
| 53 |
+
exp.Update,
|
| 54 |
+
exp.Drop,
|
| 55 |
+
exp.Alter,
|
| 56 |
+
exp.Attach,
|
| 57 |
+
exp.Pragma,
|
| 58 |
+
exp.Create,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
|
| 62 |
def normalize_sql(sql: str) -> str:
|
|
|
|
| 63 |
return " ".join(sql.lower().strip().split())
|
| 64 |
|
| 65 |
|
| 66 |
+
def compare_results(
|
| 67 |
+
pred_rows: Optional[Iterable[Any]], gold_rows: Optional[Iterable[Any]]
|
| 68 |
+
) -> bool:
|
| 69 |
if pred_rows is None or gold_rows is None:
|
| 70 |
return False
|
|
|
|
| 71 |
return set(pred_rows) == set(gold_rows)
|
| 72 |
|
| 73 |
|
| 74 |
+
def try_execute_sql(
|
| 75 |
+
sql_db: SQLDatabase,
|
| 76 |
+
sql: str,
|
| 77 |
+
timeout: Optional[float] = None, # kept for API compatibility
|
| 78 |
+
) -> tuple[Optional[list[tuple[Any, ...]]], float, Optional[str]]:
|
| 79 |
start = time.time()
|
| 80 |
try:
|
| 81 |
+
raw_rows = sql_db.run(sql)
|
| 82 |
+
|
| 83 |
+
# Normalize result shape for MyPy and downstream code
|
| 84 |
+
if isinstance(raw_rows, list):
|
| 85 |
+
rows = [tuple(r) for r in raw_rows]
|
| 86 |
+
elif isinstance(raw_rows, tuple):
|
| 87 |
+
rows = [tuple(raw_rows)]
|
| 88 |
+
else:
|
| 89 |
+
# Fallback cast — if library returns ResultSet or something similar
|
| 90 |
+
rows = cast(list[tuple[Any, ...]], raw_rows)
|
| 91 |
+
|
| 92 |
return rows, time.time() - start, None
|
| 93 |
+
|
| 94 |
except Exception as e:
|
| 95 |
return None, time.time() - start, str(e)
|
| 96 |
|
|
|
|
| 102 |
except Exception:
|
| 103 |
return False
|
| 104 |
|
| 105 |
+
def normalize_ast(node: exp.Expression) -> exp.Expression:
|
| 106 |
for name, arg in node.args.items():
|
| 107 |
if isinstance(arg, list):
|
| 108 |
arg.sort(key=lambda x: str(x))
|
|
|
|
| 131 |
return "UNKNOWN"
|
| 132 |
|
| 133 |
|
| 134 |
+
def is_safe_sql(sql: str, dialect: Optional[str] = None) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
try:
|
| 136 |
ast = parse_one(sql, read=dialect)
|
| 137 |
except ParseError:
|
|
|
|
| 144 |
return True
|
| 145 |
|
| 146 |
|
| 147 |
+
# --- جایگزین get_schema_preview از app.routers ---
|
| 148 |
+
def get_schema_preview_sqlalchemy(db_path: str, max_cols: int = 0) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Lightweight schema preview using SQLAlchemy inspector.
|
| 151 |
+
max_cols=0 => unlimited
|
| 152 |
+
"""
|
| 153 |
+
engine = create_engine(f"sqlite:///{db_path}")
|
| 154 |
+
insp = inspect(engine)
|
| 155 |
+
lines: list[str] = []
|
| 156 |
+
for tbl in sorted(insp.get_table_names()):
|
| 157 |
+
cols = insp.get_columns(tbl)
|
| 158 |
+
if max_cols > 0:
|
| 159 |
+
cols = cols[:max_cols]
|
| 160 |
+
col_str = ", ".join(f"{c['name']}:{c.get('type')}" for c in cols)
|
| 161 |
+
pks = insp.get_pk_constraint(tbl).get("constrained_columns") or []
|
| 162 |
+
pk_str = f" | PK: {', '.join(pks)}" if pks else ""
|
| 163 |
+
fks = insp.get_foreign_keys(tbl)
|
| 164 |
+
fk_str = ""
|
| 165 |
+
if fks:
|
| 166 |
+
fks_desc = []
|
| 167 |
+
for fk in fks:
|
| 168 |
+
ref = fk.get("referred_table")
|
| 169 |
+
cols_fk = ", ".join(fk.get("constrained_columns") or [])
|
| 170 |
+
ref_cols = ", ".join(fk.get("referred_columns") or [])
|
| 171 |
+
fks_desc.append(f"{cols_fk} -> {ref}({ref_cols})")
|
| 172 |
+
fk_str = " | FK: " + " ; ".join(fks_desc)
|
| 173 |
+
lines.append(f"{tbl}({col_str}){pk_str}{fk_str}")
|
| 174 |
+
engine.dispose()
|
| 175 |
+
return "\n".join(lines)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _generate_sql(
|
| 179 |
+
question: str, sql_db: SQLDatabase, schema_text: str, max_output_tokens: int = 1000
|
| 180 |
+
) -> tuple[str, str, dict[str, Any]]:
|
| 181 |
+
"""
|
| 182 |
+
Returns: (status_msg, sql_text, extra_output)
|
| 183 |
+
Strategy:
|
| 184 |
+
1) If nl2sql.pipeline.run_nl2sql exists: call it.
|
| 185 |
+
2) Else if nl2sql.pipeline.make_pipeline exists: build and run.
|
| 186 |
+
3) Else if nl2sql.pipeline.Pipeline exists: instantiate minimal pipeline and run.
|
| 187 |
+
4) Else: raise NotImplementedError.
|
| 188 |
+
"""
|
| 189 |
+
make_pipeline, run_fn, PipelineCls = _try_import_pipeline()
|
| 190 |
+
|
| 191 |
+
# Case 1: direct run function
|
| 192 |
+
if run_fn is not None:
|
| 193 |
+
res = run_fn(
|
| 194 |
+
question=question,
|
| 195 |
+
schema_text=schema_text,
|
| 196 |
+
sql_db=sql_db,
|
| 197 |
+
max_output_tokens=max_output_tokens,
|
| 198 |
+
)
|
| 199 |
+
# Expecting a dict-like or object with attributes; normalize:
|
| 200 |
+
if isinstance(res, dict):
|
| 201 |
+
msg = res.get("status", "ok")
|
| 202 |
+
sql = res.get("sql", "")
|
| 203 |
+
return msg, sql, res
|
| 204 |
+
# fallback generic
|
| 205 |
+
msg = getattr(res, "status", "ok")
|
| 206 |
+
sql = getattr(res, "sql", "")
|
| 207 |
+
return msg, sql, {"result": res}
|
| 208 |
+
|
| 209 |
+
# Case 2: factory + run
|
| 210 |
+
if make_pipeline is not None:
|
| 211 |
+
pipe = make_pipeline(sql_db=sql_db, schema_text=schema_text) # type: ignore[arg-type]
|
| 212 |
+
# Common conventions:
|
| 213 |
+
if hasattr(pipe, "run"):
|
| 214 |
+
out = pipe.run(question) # type: ignore[call-arg]
|
| 215 |
+
elif hasattr(pipe, "execute"):
|
| 216 |
+
out = pipe.execute(question) # type: ignore[call-arg]
|
| 217 |
+
else:
|
| 218 |
+
raise RuntimeError("Pipeline object has no run/execute()")
|
| 219 |
+
msg = getattr(out, "status", "ok")
|
| 220 |
+
sql = getattr(out, "sql", "")
|
| 221 |
+
return msg, sql, {"result": out}
|
| 222 |
+
|
| 223 |
+
# Case 3: class-based pipeline
|
| 224 |
+
if PipelineCls is not None:
|
| 225 |
+
# Try minimal constructor names; adjust to your class signature if needed
|
| 226 |
+
# We pass what we have; extra kwargs should be ignored or have defaults.
|
| 227 |
+
pipe = PipelineCls(sql_db=sql_db, schema_text=schema_text)
|
| 228 |
+
if hasattr(pipe, "run"):
|
| 229 |
+
out = pipe.run(question) # type: ignore[call-arg]
|
| 230 |
+
else:
|
| 231 |
+
raise RuntimeError("Pipeline class has no run()")
|
| 232 |
+
msg = getattr(out, "status", "ok")
|
| 233 |
+
sql = getattr(out, "sql", "")
|
| 234 |
+
return msg, sql, {"result": out}
|
| 235 |
+
|
| 236 |
+
raise NotImplementedError(
|
| 237 |
+
"Cannot locate a public NL2SQL entrypoint in nl2sql.pipeline. "
|
| 238 |
+
"Expose one of: run_nl2sql(), make_pipeline(), or Pipeline.run()."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def run_eval(
|
| 243 |
+
split: str = "dev", limit: int = 100, resume: bool = True, sleep_time: float = 0.01
|
| 244 |
+
) -> None:
|
| 245 |
data = load_spider_sqlite(split)
|
| 246 |
if len(data) < limit:
|
| 247 |
limit = len(data)
|
|
|
|
| 256 |
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
|
| 257 |
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
|
| 258 |
|
| 259 |
+
done: set[tuple[str, str]] = set()
|
| 260 |
if resume and results_fn.exists():
|
| 261 |
with results_fn.open("r", encoding="utf-8") as f:
|
| 262 |
for line in f:
|
|
|
|
| 269 |
pass
|
| 270 |
|
| 271 |
write_header = not results_fn.exists()
|
| 272 |
+
agg: list[dict[str, Any]] = []
|
| 273 |
+
|
| 274 |
with (
|
| 275 |
results_fn.open("a", encoding="utf-8") as fout,
|
| 276 |
pred_txt.open("a", encoding="utf-8") as fpred,
|
|
|
|
| 286 |
fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
|
| 287 |
fout.flush()
|
| 288 |
|
|
|
|
| 289 |
for ex in tqdm(data):
|
| 290 |
key = (ex.db_id, ex.question)
|
| 291 |
if resume and key in done:
|
| 292 |
continue
|
| 293 |
|
| 294 |
db_path = str(ex.db_path)
|
| 295 |
+
schema = get_schema_preview_sqlalchemy(db_path, max_cols=0)
|
| 296 |
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
t0 = time.time()
|
| 299 |
+
try:
|
| 300 |
+
msg, sql, output = _generate_sql(
|
| 301 |
+
ex.question, sql_db, schema, max_output_tokens=1000
|
| 302 |
+
)
|
| 303 |
+
except NotImplementedError as e:
|
| 304 |
+
rec = {
|
| 305 |
+
"db_id": ex.db_id,
|
| 306 |
+
"question": ex.question,
|
| 307 |
+
"gold_sql": ex.gold_sql,
|
| 308 |
+
"pred_sql": "",
|
| 309 |
+
"status": "no_entrypoint",
|
| 310 |
+
"output": {"error": str(e)},
|
| 311 |
+
"gen_time": time.time() - t0,
|
| 312 |
+
"exec_time": None,
|
| 313 |
+
"error": "no_entrypoint",
|
| 314 |
+
"gold_error": None,
|
| 315 |
+
"pred_rows": None,
|
| 316 |
+
"gold_rows": None,
|
| 317 |
+
"exact_match": False,
|
| 318 |
+
"exact_match_structural": False,
|
| 319 |
+
"execution_accuracy": False,
|
| 320 |
+
"safe_check_failed": True,
|
| 321 |
+
}
|
| 322 |
+
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 323 |
+
fout.flush()
|
| 324 |
+
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
|
| 325 |
+
fgold.flush()
|
| 326 |
+
agg.append(rec)
|
| 327 |
+
if sleep_time > 0:
|
| 328 |
+
time.sleep(sleep_time)
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
gen_time = time.time() - t0
|
| 332 |
|
| 333 |
safe_flag = is_safe_sql(sql)
|
|
|
|
| 365 |
gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
|
| 366 |
|
| 367 |
skip = gold_error is not None
|
| 368 |
+
em = normalize_sql(sql) == normalize_sql(ex.gold_sql) if not skip else False
|
| 369 |
+
em_struct = exact_match_structural(sql, ex.gold_sql) if not skip else False
|
| 370 |
+
exec_acc = compare_results(pred_rows, gold_rows) if not skip else False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
rec = {
|
| 373 |
"db_id": ex.db_id,
|
|
|
|
| 387 |
"execution_accuracy": exec_acc,
|
| 388 |
"safe_check_failed": False,
|
| 389 |
}
|
|
|
|
| 390 |
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 391 |
fout.flush()
|
| 392 |
fpred.write(f"{sql}\t{ex.db_id}\n")
|
|
|
|
| 401 |
valid = [
|
| 402 |
r
|
| 403 |
for r in agg
|
| 404 |
+
if (not r.get("safe_check_failed", False)) and (r.get("gold_error") is None)
|
| 405 |
]
|
| 406 |
total_valid = len(valid)
|
| 407 |
total_all = len(agg)
|
|
|
|
| 418 |
if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
|
| 419 |
)
|
| 420 |
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
|
| 421 |
+
avg_gen_time = sum(float(r["gen_time"]) for r in valid) / total_valid
|
| 422 |
+
avg_exec_time = sum(float(r["exec_time"]) for r in valid) / total_valid
|
| 423 |
|
| 424 |
metrics = {
|
| 425 |
"commit_hash": commit_hash,
|
|
|
|
| 437 |
"run_id": start_ts,
|
| 438 |
}
|
| 439 |
|
| 440 |
+
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
|
| 441 |
with metrics_fn.open("w", encoding="utf-8") as fm:
|
| 442 |
json.dump(metrics, fm, ensure_ascii=False, indent=2)
|
| 443 |
|
benchmarks/run.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
# benchmarks/run.py
|
| 2 |
from __future__ import annotations
|
|
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
import time
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Iterable, List, Dict, Any
|
| 9 |
|
| 10 |
# ---- app imports
|
| 11 |
from nl2sql.pipeline import Pipeline, FinalResult
|
|
@@ -22,11 +22,34 @@ from adapters.db.sqlite_adapter import SQLiteAdapter
|
|
| 22 |
from adapters.llm.openai_provider import OpenAIProvider
|
| 23 |
|
| 24 |
|
| 25 |
-
# ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class DummyLLM:
|
| 27 |
provider_id = "dummy-llm"
|
| 28 |
|
| 29 |
-
def plan(self, *, user_query: str, schema_preview: str):
|
| 30 |
text = (
|
| 31 |
f"- understand question: {user_query}\n"
|
| 32 |
"- identify tables\n- join if needed\n- filter\n- order/limit"
|
|
@@ -39,14 +62,14 @@ class DummyLLM:
|
|
| 39 |
user_query: str,
|
| 40 |
schema_preview: str,
|
| 41 |
plan_text: str,
|
| 42 |
-
clarify_answers=None,
|
| 43 |
-
):
|
| 44 |
# naive demo SQL (so pipeline flows end-to-end)
|
| 45 |
sql = "SELECT 1 AS one;"
|
| 46 |
rationale = "Demo SQL from DummyLLM"
|
| 47 |
return sql, rationale, 0, 0, 0.0
|
| 48 |
|
| 49 |
-
def repair(self, *, sql: str, error_msg: str, schema_preview: str):
|
| 50 |
return sql, 0, 0, 0.0
|
| 51 |
|
| 52 |
|
|
@@ -73,11 +96,12 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
|
|
| 73 |
db = SQLiteAdapter(str(db_path))
|
| 74 |
executor = Executor(db)
|
| 75 |
|
| 76 |
-
# LLM provider
|
|
|
|
| 77 |
if use_openai and os.getenv("OPENAI_API_KEY"):
|
| 78 |
-
llm = OpenAIProvider()
|
| 79 |
else:
|
| 80 |
-
llm = DummyLLM()
|
| 81 |
|
| 82 |
# stages
|
| 83 |
detector = AmbiguityDetector()
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
import time
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Iterable, List, Dict, Any, Protocol, Tuple, Optional
|
| 9 |
|
| 10 |
# ---- app imports
|
| 11 |
from nl2sql.pipeline import Pipeline, FinalResult
|
|
|
|
| 22 |
from adapters.llm.openai_provider import OpenAIProvider
|
| 23 |
|
| 24 |
|
| 25 |
+
# ---- LLM protocol (unifies OpenAIProvider and DummyLLM for mypy)
|
| 26 |
+
class LLMProvider(Protocol):
|
| 27 |
+
"""Minimal interface required by Planner/Generator/Repair stages."""
|
| 28 |
+
|
| 29 |
+
provider_id: str
|
| 30 |
+
|
| 31 |
+
def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
def generate_sql(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
user_query: str,
|
| 38 |
+
schema_preview: str,
|
| 39 |
+
plan_text: str,
|
| 40 |
+
clarify_answers: Optional[Any] = None,
|
| 41 |
+
) -> Tuple[str, str, int, int, float]:
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 45 |
+
...
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---- fallback: Dummy LLM (so it runs without API keys)
|
| 49 |
class DummyLLM:
|
| 50 |
provider_id = "dummy-llm"
|
| 51 |
|
| 52 |
+
def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 53 |
text = (
|
| 54 |
f"- understand question: {user_query}\n"
|
| 55 |
"- identify tables\n- join if needed\n- filter\n- order/limit"
|
|
|
|
| 62 |
user_query: str,
|
| 63 |
schema_preview: str,
|
| 64 |
plan_text: str,
|
| 65 |
+
clarify_answers: Optional[Any] = None,
|
| 66 |
+
) -> Tuple[str, str, int, int, float]:
|
| 67 |
# naive demo SQL (so pipeline flows end-to-end)
|
| 68 |
sql = "SELECT 1 AS one;"
|
| 69 |
rationale = "Demo SQL from DummyLLM"
|
| 70 |
return sql, rationale, 0, 0, 0.0
|
| 71 |
|
| 72 |
+
def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 73 |
return sql, 0, 0, 0.0
|
| 74 |
|
| 75 |
|
|
|
|
| 96 |
db = SQLiteAdapter(str(db_path))
|
| 97 |
executor = Executor(db)
|
| 98 |
|
| 99 |
+
# LLM provider (typed to the Protocol so mypy accepts either provider)
|
| 100 |
+
llm: LLMProvider
|
| 101 |
if use_openai and os.getenv("OPENAI_API_KEY"):
|
| 102 |
+
llm = OpenAIProvider() # conforms to LLMProvider
|
| 103 |
else:
|
| 104 |
+
llm = DummyLLM() # conforms to LLMProvider
|
| 105 |
|
| 106 |
# stages
|
| 107 |
detector = AmbiguityDetector()
|