nl2sql-copilot / benchmarks /evaluate_spider.py
Melika Kheirieh
fix(types): resolve mypy errors and make pytest pass
eee3f75
raw
history blame
15.5 kB
from __future__ import annotations
import json
import subprocess
import time
from pathlib import Path
from typing import Any, Iterable, Optional, Tuple, cast
from tqdm import tqdm
from langchain_community.utilities import SQLDatabase
from sqlglot import parse_one, exp
from sqlglot.errors import ParseError
from sqlalchemy import create_engine, inspect
from spider_loader import load_spider_sqlite
def _try_import_pipeline():
"""
Try multiple plausible entrypoints from nl2sql.
Returns a tuple of callables or None:
(make_pipeline | None, run_function | None, PipelineClass | None)
"""
make_pipeline = None
run_fn = None
PipelineCls = None
try:
from nl2sql.pipeline import make_pipeline as _mk # type: ignore
make_pipeline = _mk
except Exception:
pass
try:
from nl2sql.pipeline import run_nl2sql as _run # type: ignore
run_fn = _run
except Exception:
pass
try:
from nl2sql.pipeline import Pipeline as _P # type: ignore
PipelineCls = _P
except Exception:
pass
return make_pipeline, run_fn, PipelineCls
LOG_DIR = Path("logs/spider_eval")
LOG_DIR.mkdir(parents=True, exist_ok=True)
FORBIDDEN_NODES: Tuple[type, ...] = (
exp.Insert,
exp.Delete,
exp.Update,
exp.Drop,
exp.Alter,
exp.Attach,
exp.Pragma,
exp.Create,
)
def normalize_sql(sql: str) -> str:
return " ".join(sql.lower().strip().split())
def compare_results(
pred_rows: Optional[Iterable[Any]], gold_rows: Optional[Iterable[Any]]
) -> bool:
if pred_rows is None or gold_rows is None:
return False
return set(pred_rows) == set(gold_rows)
def try_execute_sql(
sql_db: SQLDatabase,
sql: str,
timeout: Optional[float] = None, # kept for API compatibility
) -> tuple[Optional[list[tuple[Any, ...]]], float, Optional[str]]:
start = time.time()
try:
raw_rows = sql_db.run(sql)
# Normalize result shape for MyPy and downstream code
if isinstance(raw_rows, list):
rows = [tuple(r) for r in raw_rows]
elif isinstance(raw_rows, tuple):
rows = [tuple(raw_rows)]
else:
# Fallback cast — if library returns ResultSet or something similar
rows = cast(list[tuple[Any, ...]], raw_rows)
return rows, time.time() - start, None
except Exception as e:
return None, time.time() - start, str(e)
def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
try:
ast_pred = parse_one(sql_pred)
ast_gold = parse_one(sql_gold)
except Exception:
return False
def normalize_ast(node: exp.Expression) -> exp.Expression:
for name, arg in node.args.items():
if isinstance(arg, list):
arg.sort(key=lambda x: str(x))
for child in arg:
normalize_ast(child)
elif isinstance(arg, exp.Expression):
normalize_ast(arg)
if isinstance(node, exp.Alias):
return normalize_ast(node.this)
return node
norm_prd = normalize_ast(ast_pred)
norm_gold = normalize_ast(ast_gold)
return norm_prd == norm_gold
def get_git_commit_hash() -> str:
try:
out = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.strip()
.decode("ascii")
)
return out
except Exception:
return "UNKNOWN"
def is_safe_sql(sql: str, dialect: Optional[str] = None) -> bool:
try:
ast = parse_one(sql, read=dialect)
except ParseError:
return False
if not isinstance(ast, exp.Select):
return False
for node in ast.walk():
if isinstance(node, FORBIDDEN_NODES):
return False
return True
# --- جایگزین get_schema_preview از app.routers ---
def get_schema_preview_sqlalchemy(db_path: str, max_cols: int = 0) -> str:
"""
Lightweight schema preview using SQLAlchemy inspector.
max_cols=0 => unlimited
"""
engine = create_engine(f"sqlite:///{db_path}")
insp = inspect(engine)
lines: list[str] = []
for tbl in sorted(insp.get_table_names()):
cols = insp.get_columns(tbl)
if max_cols > 0:
cols = cols[:max_cols]
col_str = ", ".join(f"{c['name']}:{c.get('type')}" for c in cols)
pks = insp.get_pk_constraint(tbl).get("constrained_columns") or []
pk_str = f" | PK: {', '.join(pks)}" if pks else ""
fks = insp.get_foreign_keys(tbl)
fk_str = ""
if fks:
fks_desc = []
for fk in fks:
ref = fk.get("referred_table")
cols_fk = ", ".join(fk.get("constrained_columns") or [])
ref_cols = ", ".join(fk.get("referred_columns") or [])
fks_desc.append(f"{cols_fk} -> {ref}({ref_cols})")
fk_str = " | FK: " + " ; ".join(fks_desc)
lines.append(f"{tbl}({col_str}){pk_str}{fk_str}")
engine.dispose()
return "\n".join(lines)
def _generate_sql(
question: str, sql_db: SQLDatabase, schema_text: str, max_output_tokens: int = 1000
) -> tuple[str, str, dict[str, Any]]:
"""
Returns: (status_msg, sql_text, extra_output)
Strategy:
1) If nl2sql.pipeline.run_nl2sql exists: call it.
2) Else if nl2sql.pipeline.make_pipeline exists: build and run.
3) Else if nl2sql.pipeline.Pipeline exists: instantiate minimal pipeline and run.
4) Else: raise NotImplementedError.
"""
make_pipeline, run_fn, PipelineCls = _try_import_pipeline()
# Case 1: direct run function
if run_fn is not None:
res = run_fn(
question=question,
schema_text=schema_text,
sql_db=sql_db,
max_output_tokens=max_output_tokens,
)
# Expecting a dict-like or object with attributes; normalize:
if isinstance(res, dict):
msg = res.get("status", "ok")
sql = res.get("sql", "")
return msg, sql, res
# fallback generic
msg = getattr(res, "status", "ok")
sql = getattr(res, "sql", "")
return msg, sql, {"result": res}
# Case 2: factory + run
if make_pipeline is not None:
pipe = make_pipeline(sql_db=sql_db, schema_text=schema_text) # type: ignore[arg-type]
# Common conventions:
if hasattr(pipe, "run"):
out = pipe.run(question) # type: ignore[call-arg]
elif hasattr(pipe, "execute"):
out = pipe.execute(question) # type: ignore[call-arg]
else:
raise RuntimeError("Pipeline object has no run/execute()")
msg = getattr(out, "status", "ok")
sql = getattr(out, "sql", "")
return msg, sql, {"result": out}
# Case 3: class-based pipeline
if PipelineCls is not None:
# Try minimal constructor names; adjust to your class signature if needed
# We pass what we have; extra kwargs should be ignored or have defaults.
pipe = PipelineCls(sql_db=sql_db, schema_text=schema_text)
if hasattr(pipe, "run"):
out = pipe.run(question) # type: ignore[call-arg]
else:
raise RuntimeError("Pipeline class has no run()")
msg = getattr(out, "status", "ok")
sql = getattr(out, "sql", "")
return msg, sql, {"result": out}
raise NotImplementedError(
"Cannot locate a public NL2SQL entrypoint in nl2sql.pipeline. "
"Expose one of: run_nl2sql(), make_pipeline(), or Pipeline.run()."
)
def run_eval(
split: str = "dev", limit: int = 100, resume: bool = True, sleep_time: float = 0.01
) -> None:
data = load_spider_sqlite(split)
if len(data) < limit:
limit = len(data)
data = data[:limit]
print(f"Running eval on {len(data)} examples in split={split}...")
commit_hash = get_git_commit_hash()
start_ts = int(time.time())
pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
done: set[tuple[str, str]] = set()
if resume and results_fn.exists():
with results_fn.open("r", encoding="utf-8") as f:
for line in f:
if line.startswith("#"):
continue
try:
r = json.loads(line)
done.add((r.get("db_id"), r.get("question")))
except Exception:
pass
write_header = not results_fn.exists()
agg: list[dict[str, Any]] = []
with (
results_fn.open("a", encoding="utf-8") as fout,
pred_txt.open("a", encoding="utf-8") as fpred,
gold_txt.open("a", encoding="utf-8") as fgold,
):
if write_header:
header = {
"commit_hash": commit_hash,
"split": split,
"limit": limit,
"start_time": start_ts,
}
fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
fout.flush()
for ex in tqdm(data):
key = (ex.db_id, ex.question)
if resume and key in done:
continue
db_path = str(ex.db_path)
schema = get_schema_preview_sqlalchemy(db_path, max_cols=0)
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
t0 = time.time()
try:
msg, sql, output = _generate_sql(
ex.question, sql_db, schema, max_output_tokens=1000
)
except NotImplementedError as e:
rec = {
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": "",
"status": "no_entrypoint",
"output": {"error": str(e)},
"gen_time": time.time() - t0,
"exec_time": None,
"error": "no_entrypoint",
"gold_error": None,
"pred_rows": None,
"gold_rows": None,
"exact_match": False,
"exact_match_structural": False,
"execution_accuracy": False,
"safe_check_failed": True,
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
fgold.flush()
agg.append(rec)
if sleep_time > 0:
time.sleep(sleep_time)
continue
gen_time = time.time() - t0
safe_flag = is_safe_sql(sql)
if not safe_flag:
rec = {
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": sql,
"status": "rejected_safe_check",
"output": output,
"gen_time": gen_time,
"exec_time": None,
"error": "unsafe_sql",
"gold_error": None,
"pred_rows": None,
"gold_rows": None,
"exact_match": False,
"exact_match_structural": False,
"execution_accuracy": False,
"safe_check_failed": True,
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
fpred.write(f"{sql}\t{ex.db_id}\n")
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
fpred.flush()
fgold.flush()
agg.append(rec)
if sleep_time > 0:
time.sleep(sleep_time)
continue
pred_rows, exec_time, error = try_execute_sql(sql_db, sql)
gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
skip = gold_error is not None
em = normalize_sql(sql) == normalize_sql(ex.gold_sql) if not skip else False
em_struct = exact_match_structural(sql, ex.gold_sql) if not skip else False
exec_acc = compare_results(pred_rows, gold_rows) if not skip else False
rec = {
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": sql,
"status": msg,
"output": output,
"gen_time": gen_time,
"exec_time": exec_time,
"error": error,
"gold_error": gold_error,
"pred_rows": pred_rows,
"gold_rows": gold_rows,
"exact_match": em,
"exact_match_structural": em_struct,
"execution_accuracy": exec_acc,
"safe_check_failed": False,
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
fpred.write(f"{sql}\t{ex.db_id}\n")
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
fpred.flush()
fgold.flush()
agg.append(rec)
if sleep_time > 0:
time.sleep(sleep_time)
valid = [
r
for r in agg
if (not r.get("safe_check_failed", False)) and (r.get("gold_error") is None)
]
total_valid = len(valid)
total_all = len(agg)
if total_valid == 0:
print("No valid examples to compute metrics")
return
em_count = sum(1 for r in valid if r["exact_match"])
em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
error_count = sum(
1
for r in agg
if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
)
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
avg_gen_time = sum(float(r["gen_time"]) for r in valid) / total_valid
avg_exec_time = sum(float(r["exec_time"]) for r in valid) / total_valid
metrics = {
"commit_hash": commit_hash,
"split": split,
"limit": limit,
"total_examples": total_all,
"valid_examples": total_valid,
"exact_match_rate": em_count / total_valid,
"exact_match_structural_rate": em_struct_count / total_valid,
"execution_accuracy_rate": exec_acc_count / total_valid,
"error_rate": error_count / total_valid,
"safe_check_fail_rate": safe_fail_count / total_all,
"avg_gen_time": avg_gen_time,
"avg_exec_time": avg_exec_time,
"run_id": start_ts,
}
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
with metrics_fn.open("w", encoding="utf-8") as fm:
json.dump(metrics, fm, ensure_ascii=False, indent=2)
print("Metrics:", metrics)
print(f"Wrote results → {results_fn}")
print(f"Wrote pred file → {pred_txt}")
print(f"Wrote gold file → {gold_txt}")
print(f"Wrote metrics → {metrics_fn}")
if __name__ == "__main__":
run_eval("dev", limit=10, resume=True, sleep_time=0.05)