nl-sql / scripts /archive_sweep.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
4b4ff9e verified
Raw
History Blame Contribute Delete
6.72 kB
"""Archive sweep: scan eval/reports/**/*.json for stale pred_sql that match
gold under the *current* corrected runner, for any qid currently missing in a
baseline report.
Use this after a runner-level fix (e.g. the day-5 bind-bug fix in
`db/connection.py`) or a scoring-methodology change (e.g. Counter → set in
`compare_results`): pred_sqls that were written long ago may have become
correct because the gold side stopped silently dropping rows or because the
matcher is no longer over-strict. Each rescue is a *re-verification*, not a
fresh model call — strictly $0 budget and offline.
Audit discipline: every candidate is re-executed live; the script never trusts
a stored `match` flag from the source report. Audit it afterwards via
`scripts/audit_rescore.py`.
Example:
uv run python scripts/archive_sweep.py \
--baseline eval/reports/2026-05-23/v24-v23-plus-archive-rescore-959-merged.json \
--out eval/reports/2026-05-23/archive-sweep-v24-candidates.json
"""
from __future__ import annotations
import argparse
import glob
import json
from pathlib import Path
from typing import Any
from nl_sql.db import DatabaseSpec
from nl_sql.db.connection import execute_readonly, sqlite_url_readonly
from nl_sql.eval.metrics.execution_accuracy import compare_results
from nl_sql.eval.runner import _execute_gold
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--baseline", type=Path, required=True)
p.add_argument("--reports-root", type=Path, default=Path("eval/reports"))
p.add_argument("--out", type=Path, required=True)
p.add_argument(
"--data-root",
type=Path,
default=Path("data/bird_mini_dev/MINIDEV/dev_databases"),
)
p.add_argument(
"--only-qids",
type=str,
default=None,
help="Optional comma-separated qids to restrict sweep to.",
)
args = p.parse_args()
baseline = json.loads(args.baseline.read_text(encoding="utf-8"))
misses = [r for r in baseline["records"] if not r.get("match")]
if args.only_qids:
wanted = {int(x.strip()) for x in args.only_qids.split(",") if x.strip()}
misses = [r for r in misses if r.get("question_id") in wanted]
miss_index = {r["question_id"]: r for r in misses}
print(f"baseline: {args.baseline}")
print(f" misses: {len(misses)} (qids: {sorted(miss_index)})")
candidates: dict[int, set[str]] = {q: set() for q in miss_index}
for rp in glob.glob(str(args.reports_root / "**" / "*.json"), recursive=True):
rp_path = Path(rp)
if rp_path.resolve() == args.baseline.resolve():
continue
try:
d = json.loads(rp_path.read_text(encoding="utf-8"))
except Exception:
continue
recs = d.get("records", []) if isinstance(d, dict) else []
for r in recs:
if not isinstance(r, dict):
continue
qid_raw = r.get("question_id") or r.get("qid")
if not isinstance(qid_raw, int) or qid_raw not in miss_index:
continue
qid = qid_raw
for key in ("pred_sql", "alt_pred"):
pred = r.get(key) or ""
if isinstance(pred, str) and pred.strip():
candidates[qid].add(pred.strip())
total_cands = sum(len(v) for v in candidates.values())
print(f" unique candidate pred_sqls: {total_cands}")
rescues: list[dict[str, Any]] = []
examined: list[dict[str, Any]] = []
for qid in sorted(miss_index):
miss = miss_index[qid]
db_id = miss["db_id"]
gold_sql = miss["gold_sql"]
db_path = args.data_root / db_id / f"{db_id}.sqlite"
spec = DatabaseSpec(id=db_id, dialect="sqlite", url=sqlite_url_readonly(db_path))
engine = spec.make_engine()
try:
try:
gold_rows, _ = _execute_gold(
engine, gold_sql, statement_timeout_ms=30_000, row_cap=10_000
)
except Exception as exc:
print(f" qid={qid}: gold failed: {exc!r}")
gold_rows = []
found = False
for pred in sorted(candidates[qid]):
try:
with execute_readonly(
engine, pred, statement_timeout_ms=30_000, row_cap=10_000
) as result:
pred_rows = list(result.rows)
except Exception:
pred_rows = []
continue
cmp = compare_results(gold_rows, pred_rows, gold_sql=gold_sql)
if cmp.match:
rescues.append(
{
"question_id": qid,
"difficulty": miss.get("difficulty"),
"db_id": db_id,
"alt_pred": pred,
"alt_match": True,
"alt_rows": len(pred_rows),
"gold_rows": len(gold_rows),
"baseline_match": False,
}
)
print(
f" qid={qid} {miss.get('difficulty'):>11} db={db_id}: RESCUE "
f"(alt_rows={len(pred_rows)}, gold_rows={len(gold_rows)})"
)
found = True
break
examined.append(
{
"question_id": qid,
"difficulty": miss.get("difficulty"),
"db_id": db_id,
"candidates": len(candidates[qid]),
"rescued": found,
}
)
if not found:
print(
f" qid={qid} {miss.get('difficulty'):>11} db={db_id}: no archive rescue "
f"({len(candidates[qid])} cand)"
)
finally:
engine.dispose()
out = {
"alt_model": "archive-sweep",
"baseline": str(args.baseline).replace("\\", "/"),
"summary": {
"voted_better": len(rescues),
"voted_worse": 0,
"voted_same": 0,
"examined_qids": len(miss_index),
"total_candidates": total_cands,
},
"examined": examined,
"records": rescues,
}
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(out, indent=2), encoding="utf-8")
print(f"wrote: {args.out}")
print(f" rescues: {len(rescues)} / {len(miss_index)} misses")
return 0
if __name__ == "__main__":
raise SystemExit(main())