text2sql-demo / src /evaluate_rl_bart.py
tjhalanigrid's picture
Add src folder
dc59b01
import json
import sqlite3
import argparse
import time
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
# ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
def build_prompt(question, schema):
return f"""
Database Schema:
{schema}
Translate English to SQL:
{question}
SQL:
"""
# ---------------- LOAD SCHEMA ----------------
def load_schema(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema = ""
for (table,) in tables:
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
col_names = [c[1] for c in cols]
schema += f"{table}({', '.join(col_names)})\n"
conn.close()
return schema
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
def execution_match(pred_sql, gold_sql, db_path):
try:
conn = sqlite3.connect(db_path)
# --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
start_time = time.monotonic()
def timeout_handler():
return 1 if (time.monotonic() - start_time) > 5.0 else 0
conn.set_progress_handler(timeout_handler, 10000)
cur = conn.cursor()
cur.execute(pred_sql)
pred = cur.fetchall()
cur.execute(gold_sql)
gold = cur.fetchall()
conn.close()
return pred == gold
except Exception:
return False
# ---------------- MAIN ----------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--adapter", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=1034)
args = parser.parse_args()
project_root = Path(__file__).resolve().parents[1]
dev_json = project_root / "data" / "dev.json"
db_root = project_root / "data" / "database"
# 🎯 Added CUDA support for Nvidia GPUs
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
# load model
base_model = "facebook/bart-base"
print(f"Loading Base: {base_model}")
print(f"Loading Adapter: {args.adapter}")
tokenizer = AutoTokenizer.from_pretrained(args.adapter)
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
model = PeftModel.from_pretrained(base, args.adapter).to(device)
model = model.merge_and_unload()
with open(dev_json) as f:
dev = json.load(f)[: args.num_samples]
correct = 0
print(f"Evaluating {len(dev)} examples...\n")
for i, ex in enumerate(dev, 1):
question = ex["question"]
db_id = ex["db_id"]
gold_sql = ex["query"]
db_path = db_root / db_id / f"{db_id}.sqlite"
schema = load_schema(db_path)
prompt = build_prompt(question, schema)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=80,
do_sample=False,
num_beams=4,
)
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SQL:" in pred_sql:
pred_sql = pred_sql.split("SQL:")[-1].strip()
match = execution_match(pred_sql, gold_sql, db_path)
if match:
correct += 1
if i % 10 == 0:
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
print("\n=============================")
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
print("=============================")
if __name__ == "__main__":
main()