nl2sql-copilot / benchmarks /evaluate_spider.py
Melika Kheirieh
Add more advanced metrics
5eeca35
raw
history blame
9.35 kB
from __future__ import annotations
import time, json, subprocess
from pathlib import Path
from tqdm import tqdm
from app import get_schema_preview, on_generate_query, make_sql_chain
from langchain_community.utilities import SQLDatabase
from benchmarks import load_spider_sqlite
from sqlglot import parse_one, exp
from sqlglot.errors import ParseError
LOG_DIR = Path("logs/spider_eval")
LOG_DIR.mkdir(parents=True, exist_ok=True)
def normalize_sql(sql: str) -> str:
# نسخه ساده؛ می‌تونی قوی‌ترش کنی با پارس + بازسازی
return " ".join(sql.lower().strip().split())
def compare_results(pred_rows, gold_rows):
if pred_rows is None or gold_rows is None:
return False
# اگر ترتیب مهم نیست
return set(pred_rows) == set(gold_rows)
def try_execute_sql(sql_db, sql, timeout: float = None):
start = time.time()
try:
rows = sql_db.run(sql)
return rows, time.time() - start, None
except Exception as e:
return None, time.time() - start, str(e)
def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
try:
ast_pred = parse_one(sql_pred)
ast_gold = parse_one(sql_gold)
except Exception:
return False
def normalize_ast(node: exp.Expression):
for name, arg in node.args.items():
if isinstance(arg, list):
arg.sort(key=lambda x: str(x))
for child in arg:
normalize_ast(child)
elif isinstance(arg, exp.Expression):
normalize_ast(arg)
if isinstance(node, exp.Alias):
return normalize_ast(node.this)
return node
norm_prd = normalize_ast(ast_pred)
norm_gold = normalize_ast(ast_gold)
return norm_prd == norm_gold
def get_git_commit_hash() -> str:
try:
out = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
return out
except Exception:
return "UNKNOWN"
FORBIDDEN_NODES = (
exp.Insert,
exp.Delete,
exp.Update,
exp.Drop,
exp.Alter,
exp.Attach,
exp.Pragma,
exp.Create,
)
def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
try:
ast = parse_one(sql, read=dialect)
except ParseError:
return False
if not isinstance(ast, exp.Select):
return False
for node in ast.walk():
if isinstance(node, FORBIDDEN_NODES):
return False
return True
def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
data = load_spider_sqlite(split)
if len(data) < limit:
limit = len(data)
data = data[:limit]
print(f"Running eval on {len(data)} examples in split={split}...")
commit_hash = get_git_commit_hash()
start_ts = int(time.time())
pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
done = set()
if resume and results_fn.exists():
with results_fn.open("r", encoding="utf-8") as f:
for line in f:
if line.startswith("#"):
continue
try:
r = json.loads(line)
done.add((r.get("db_id"), r.get("question")))
except Exception:
pass
write_header = not results_fn.exists()
with results_fn.open("a", encoding="utf-8") as fout, \
pred_txt.open("a", encoding="utf-8") as fpred, \
gold_txt.open("a", encoding="utf-8") as fgold:
if write_header:
header = {
"commit_hash": commit_hash,
"split": split,
"limit": limit,
"start_time": start_ts,
}
fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
fout.flush()
agg = []
for ex in tqdm(data):
key = (ex.db_id, ex.question)
if resume and key in done:
continue
db_path = str(ex.db_path)
schema = get_schema_preview(db_path, 0)
sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
chain = make_sql_chain(sql_db)
state = {
"db_path": db_path,
"sql_db": sql_db,
"schema_text": schema,
"chain": chain,
}
t0 = time.time()
msg, sql, output = on_generate_query(ex.question, 1000, state)
gen_time = time.time() - t0
safe_flag = is_safe_sql(sql)
if not safe_flag:
rec = {
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": sql,
"status": "rejected_safe_check",
"output": output,
"gen_time": gen_time,
"exec_time": None,
"error": "unsafe_sql",
"gold_error": None,
"pred_rows": None,
"gold_rows": None,
"exact_match": False,
"exact_match_structural": False,
"execution_accuracy": False,
"safe_check_failed": True,
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
fpred.write(f"{sql}\t{ex.db_id}\n")
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
fpred.flush()
fgold.flush()
agg.append(rec)
if sleep_time > 0:
time.sleep(sleep_time)
continue
pred_rows, exec_time, error = try_execute_sql(sql_db, sql)
gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
skip = gold_error is not None
em = False
if not skip:
try:
em = normalize_sql(sql) == normalize_sql(ex.gold_sql)
except Exception:
pass
em_struct = False
if not skip:
em_struct = exact_match_structural(sql, ex.gold_sql)
exec_acc = False
if not skip:
exec_acc = compare_results(pred_rows, gold_rows)
rec = {
"db_id": ex.db_id,
"question": ex.question,
"gold_sql": ex.gold_sql,
"pred_sql": sql,
"status": msg,
"output": output,
"gen_time": gen_time,
"exec_time": exec_time,
"error": error,
"gold_error": gold_error,
"pred_rows": pred_rows,
"gold_rows": gold_rows,
"exact_match": em,
"exact_match_structural": em_struct,
"execution_accuracy": exec_acc,
"safe_check_failed": False,
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
fpred.write(f"{sql}\t{ex.db_id}\n")
fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
fpred.flush()
fgold.flush()
agg.append(rec)
if sleep_time > 0:
time.sleep(sleep_time)
valid = [r for r in agg if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None]
total_valid = len(valid)
total_all = len(agg)
if total_valid == 0:
print("No valid examples to compute metrics")
return
em_count = sum(1 for r in valid if r["exact_match"])
em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
error_count = sum(1 for r in agg if (r.get("error") is not None) and (not r.get("safe_check_failed", False)))
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
metrics = {
"commit_hash": commit_hash,
"split": split,
"limit": limit,
"total_examples": total_all,
"valid_examples": total_valid,
"exact_match_rate": em_count / total_valid,
"exact_match_structural_rate": em_struct_count / total_valid,
"execution_accuracy_rate": exec_acc_count / total_valid,
"error_rate": error_count / total_valid,
"safe_check_fail_rate": safe_fail_count / total_all,
"avg_gen_time": avg_gen_time,
"avg_exec_time": avg_exec_time,
"run_id": start_ts,
}
with metrics_fn.open("w", encoding="utf-8") as fm:
json.dump(metrics, fm, ensure_ascii=False, indent=2)
print("Metrics:", metrics)
print(f"Wrote results → {results_fn}")
print(f"Wrote pred file → {pred_txt}")
print(f"Wrote gold file → {gold_txt}")
print(f"Wrote metrics → {metrics_fn}")
if __name__ == "__main__":
run_eval("dev", limit=10, resume=True, sleep_time=0.05)