Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import subprocess | |
| import sys | |
| import argparse | |
| import re | |
| import sqlite3 | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| from prompting import encode_prompt | |
| # ---------------- SQL CLEAN ---------------- | |
| def extract_sql(text: str) -> str: | |
| text = text.strip() | |
| if "SQL:" in text: | |
| text = text.split("SQL:")[-1] | |
| match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| text = match.group(1) | |
| text = text.replace('"', "'") | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if not text.endswith(";"): | |
| text += ";" | |
| return text | |
| # ---------------- ROBUST ACC PARSER ---------------- | |
| def parse_exec_accuracy(stdout: str): | |
| for line in stdout.splitlines(): | |
| if "execution" in line.lower(): | |
| numbers = re.findall(r"\d+\.\d+", line) | |
| if numbers: | |
| return float(numbers[-1]) | |
| return None | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2") | |
| parser.add_argument("--num_samples", type=int, default=1000) | |
| args = parser.parse_args() | |
| project_root = Path(__file__).resolve().parents[1] | |
| adapter_dir = project_root / args.adapter | |
| if not adapter_dir.exists(): | |
| raise FileNotFoundError(f"Adapter not found: {adapter_dir}") | |
| db_root = project_root / "data/database" | |
| table_json = project_root / "data/tables.json" | |
| dev_json = project_root / "data/dev.json" | |
| gold_sql_file = project_root / "data/dev_gold.sql" | |
| pred_sql_file = project_root / "pred.sql" | |
| device = "mps" if torch.backends.mps.is_available() else ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("Using device:", device) | |
| # -------- LOAD MODEL -------- | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_dir) | |
| BASE_MODEL = "facebook/bart-base" | |
| print(f"Loading base model {BASE_MODEL}...") | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device) | |
| print("Loading LoRA adapter...") | |
| model = PeftModel.from_pretrained(base_model, adapter_dir).to(device) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # -------- LOAD DATA -------- | |
| with open(dev_json) as f: | |
| dev = json.load(f)[: args.num_samples] | |
| print("Generating SQL predictions...\n") | |
| correct = 0 | |
| total = len(dev) | |
| with open(pred_sql_file, "w") as f, torch.no_grad(): | |
| for i, ex in enumerate(dev, 1): | |
| question = ex["question"] | |
| db_id = ex["db_id"] | |
| gold_query = ex["query"] | |
| prompt_ids = encode_prompt( | |
| tokenizer, | |
| question, | |
| db_id, | |
| device=device, | |
| max_input_tokens=512, | |
| ) | |
| input_ids = prompt_ids.unsqueeze(0).to(device) | |
| attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device) | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=160, | |
| num_beams=4, | |
| do_sample=False, | |
| ) | |
| pred = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| pred_sql = extract_sql(pred) | |
| f.write(f"{pred_sql}\t{db_id}\n") | |
| # -------- LIVE EXECUTION CHECK -------- | |
| try: | |
| db_path = db_root / db_id / f"{db_id}.sqlite" | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(pred_sql) | |
| pred_rows = cursor.fetchall() | |
| cursor.execute(gold_query) | |
| gold_rows = cursor.fetchall() | |
| conn.close() | |
| # order insensitive comparison | |
| if sorted(pred_rows) == sorted(gold_rows): | |
| correct += 1 | |
| except Exception: | |
| pass # execution failed | |
| if i % 10 == 0 or i == total: | |
| current_acc = correct / i | |
| print(f"{i}/{total} | Acc: {current_acc:.3f}") | |
| print("\nGeneration finished.\n") | |
| # -------- RUN OFFICIAL SPIDER EVAL -------- | |
| eval_script = project_root / "spider_eval/evaluation.py" | |
| if (project_root / "spider_eval/evaluation_bart.py").exists(): | |
| eval_script = project_root / "spider_eval/evaluation_bart.py" | |
| cmd = [ | |
| sys.executable, | |
| str(eval_script), | |
| "--gold", str(gold_sql_file), | |
| "--pred", str(pred_sql_file), | |
| "--etype", "exec", | |
| "--db", str(db_root), | |
| "--table", str(table_json), | |
| ] | |
| print(f"\nRunning Spider evaluation using {eval_script.name}...") | |
| proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore") | |
| if proc.returncode != 0: | |
| print("\nSpider evaluation crashed.") | |
| print(proc.stderr) | |
| return | |
| print("\n--- Spider Eval Output ---") | |
| print("\n".join(proc.stdout.splitlines()[-20:])) | |
| acc = parse_exec_accuracy(proc.stdout) | |
| if acc is not None: | |
| print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%") | |
| else: | |
| print("\nCould not parse official accuracy.") | |
| if __name__ == "__main__": | |
| main() |