text2sql-demo / src /evaluate_sft_bart.py
tjhalanigrid's picture
Add src folder
dc59b01
from __future__ import annotations
import json
import subprocess
import sys
import argparse
import re
import sqlite3
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from prompting import encode_prompt
# ---------------- SQL CLEAN ----------------
def extract_sql(text: str) -> str:
text = text.strip()
if "SQL:" in text:
text = text.split("SQL:")[-1]
match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL)
if match:
text = match.group(1)
text = text.replace('"', "'")
text = re.sub(r"\s+", " ", text).strip()
if not text.endswith(";"):
text += ";"
return text
# ---------------- ROBUST ACC PARSER ----------------
def parse_exec_accuracy(stdout: str):
for line in stdout.splitlines():
if "execution" in line.lower():
numbers = re.findall(r"\d+\.\d+", line)
if numbers:
return float(numbers[-1])
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2")
parser.add_argument("--num_samples", type=int, default=1000)
args = parser.parse_args()
project_root = Path(__file__).resolve().parents[1]
adapter_dir = project_root / args.adapter
if not adapter_dir.exists():
raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
db_root = project_root / "data/database"
table_json = project_root / "data/tables.json"
dev_json = project_root / "data/dev.json"
gold_sql_file = project_root / "data/dev_gold.sql"
pred_sql_file = project_root / "pred.sql"
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
print("Using device:", device)
# -------- LOAD MODEL --------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
BASE_MODEL = "facebook/bart-base"
print(f"Loading base model {BASE_MODEL}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, adapter_dir).to(device)
model = model.merge_and_unload()
model.eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
# -------- LOAD DATA --------
with open(dev_json) as f:
dev = json.load(f)[: args.num_samples]
print("Generating SQL predictions...\n")
correct = 0
total = len(dev)
with open(pred_sql_file, "w") as f, torch.no_grad():
for i, ex in enumerate(dev, 1):
question = ex["question"]
db_id = ex["db_id"]
gold_query = ex["query"]
prompt_ids = encode_prompt(
tokenizer,
question,
db_id,
device=device,
max_input_tokens=512,
)
input_ids = prompt_ids.unsqueeze(0).to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=160,
num_beams=4,
do_sample=False,
)
pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
pred_sql = extract_sql(pred)
f.write(f"{pred_sql}\t{db_id}\n")
# -------- LIVE EXECUTION CHECK --------
try:
db_path = db_root / db_id / f"{db_id}.sqlite"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(pred_sql)
pred_rows = cursor.fetchall()
cursor.execute(gold_query)
gold_rows = cursor.fetchall()
conn.close()
# order insensitive comparison
if sorted(pred_rows) == sorted(gold_rows):
correct += 1
except Exception:
pass # execution failed
if i % 10 == 0 or i == total:
current_acc = correct / i
print(f"{i}/{total} | Acc: {current_acc:.3f}")
print("\nGeneration finished.\n")
# -------- RUN OFFICIAL SPIDER EVAL --------
eval_script = project_root / "spider_eval/evaluation.py"
if (project_root / "spider_eval/evaluation_bart.py").exists():
eval_script = project_root / "spider_eval/evaluation_bart.py"
cmd = [
sys.executable,
str(eval_script),
"--gold", str(gold_sql_file),
"--pred", str(pred_sql_file),
"--etype", "exec",
"--db", str(db_root),
"--table", str(table_json),
]
print(f"\nRunning Spider evaluation using {eval_script.name}...")
proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore")
if proc.returncode != 0:
print("\nSpider evaluation crashed.")
print(proc.stderr)
return
print("\n--- Spider Eval Output ---")
print("\n".join(proc.stdout.splitlines()[-20:]))
acc = parse_exec_accuracy(proc.stdout)
if acc is not None:
print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%")
else:
print("\nCould not parse official accuracy.")
if __name__ == "__main__":
main()