# import json # import sqlite3 # import argparse # from pathlib import Path # import torch # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # from peft import PeftModel # # ---------------- PROMPT (IDENTICAL TO TRAINING) ---------------- # def build_prompt(question, schema): # return f""" # Database Schema: # {schema} # Translate English to SQL: # {question} # SQL: # """ # # ---------------- LOAD SCHEMA ---------------- # def load_schema(db_path): # conn = sqlite3.connect(db_path) # cursor = conn.cursor() # tables = cursor.execute( # "SELECT name FROM sqlite_master WHERE type='table';" # ).fetchall() # schema = "" # for (table,) in tables: # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() # col_names = [c[1] for c in cols] # schema += f"{table}({', '.join(col_names)})\n" # conn.close() # return schema # # ---------------- EXECUTION CHECK ---------------- # def execution_match(pred_sql, gold_sql, db_path): # try: # conn = sqlite3.connect(db_path) # cur = conn.cursor() # cur.execute(pred_sql) # pred = cur.fetchall() # cur.execute(gold_sql) # gold = cur.fetchall() # conn.close() # return pred == gold # except Exception: # return False # # ---------------- MAIN ---------------- # def main(): # parser = argparse.ArgumentParser() # parser.add_argument("--adapter", type=str, required=True) # parser.add_argument("--num_samples", type=int, default=1034) # args = parser.parse_args() # project_root = Path(__file__).resolve().parents[1] # dev_json = project_root / "data" / "dev.json" # db_root = project_root / "data" / "database" # device = "mps" if torch.backends.mps.is_available() else "cpu" # # load model # base_model = "Salesforce/codet5-base" # tokenizer = AutoTokenizer.from_pretrained(args.adapter) # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) # model = PeftModel.from_pretrained(base, args.adapter).to(device) # model = model.merge_and_unload() # with open(dev_json) as f: # dev = json.load(f)[: args.num_samples] # correct = 0 # print(f"Evaluating {len(dev)} examples...\n") # for i, ex in enumerate(dev, 1): # question = ex["question"] # db_id = ex["db_id"] # gold_sql = ex["query"] # db_path = db_root / db_id / f"{db_id}.sqlite" # schema = load_schema(db_path) # prompt = build_prompt(question, schema) # inputs = tokenizer(prompt, return_tensors="pt").to(device) # with torch.no_grad(): # outputs = model.generate( # **inputs, # max_new_tokens=80, # do_sample=False, # num_beams=4, # ) # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) # if "SQL:" in pred_sql: # pred_sql = pred_sql.split("SQL:")[-1].strip() # match = execution_match(pred_sql, gold_sql, db_path) # if match: # correct += 1 # if i % 10 == 0: # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}") # print("\n=============================") # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%") # print("=============================") # if __name__ == "__main__": # main() # import json # import sqlite3 # import argparse # import time # from pathlib import Path # import torch # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # from peft import PeftModel # # ---------------- PROMPT (IDENTICAL TO TRAINING) ---------------- # def build_prompt(question, schema): # return f""" # Database Schema: # {schema} # Translate English to SQL: # {question} # SQL: # """ # # ---------------- LOAD SCHEMA ---------------- # def load_schema(db_path): # conn = sqlite3.connect(db_path) # cursor = conn.cursor() # tables = cursor.execute( # "SELECT name FROM sqlite_master WHERE type='table';" # ).fetchall() # schema = "" # for (table,) in tables: # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() # col_names = [c[1] for c in cols] # schema += f"{table}({', '.join(col_names)})\n" # conn.close() # return schema # # ---------------- EXECUTION CHECK WITH TIMEOUT ---------------- # def execution_match(pred_sql, gold_sql, db_path): # try: # conn = sqlite3.connect(db_path) # # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE --- # start_time = time.monotonic() # def timeout_handler(): # return 1 if (time.monotonic() - start_time) > 5.0 else 0 # conn.set_progress_handler(timeout_handler, 10000) # cur = conn.cursor() # cur.execute(pred_sql) # pred = cur.fetchall() # cur.execute(gold_sql) # gold = cur.fetchall() # conn.close() # return pred == gold # except Exception: # return False # # ---------------- MAIN ---------------- # def main(): # parser = argparse.ArgumentParser() # parser.add_argument("--adapter", type=str, required=True) # parser.add_argument("--num_samples", type=int, default=1034) # args = parser.parse_args() # project_root = Path(__file__).resolve().parents[1] # dev_json = project_root / "data" / "dev.json" # db_root = project_root / "data" / "database" # # šŸŽÆ Added CUDA support for Nvidia GPUs # device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") # # load model # base_model = "Salesforce/codet5-base" # print(f"Loading Base: {base_model}") # print(f"Loading Adapter: {args.adapter}") # tokenizer = AutoTokenizer.from_pretrained(args.adapter) # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) # model = PeftModel.from_pretrained(base, args.adapter).to(device) # model = model.merge_and_unload() # with open(dev_json) as f: # dev = json.load(f)[: args.num_samples] # correct = 0 # print(f"Evaluating {len(dev)} examples...\n") # for i, ex in enumerate(dev, 1): # question = ex["question"] # db_id = ex["db_id"] # gold_sql = ex["query"] # db_path = db_root / db_id / f"{db_id}.sqlite" # schema = load_schema(db_path) # prompt = build_prompt(question, schema) # inputs = tokenizer(prompt, return_tensors="pt").to(device) # with torch.no_grad(): # outputs = model.generate( # **inputs, # max_new_tokens=80, # do_sample=False, # num_beams=4, # ) # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) # if "SQL:" in pred_sql: # pred_sql = pred_sql.split("SQL:")[-1].strip() # match = execution_match(pred_sql, gold_sql, db_path) # if match: # correct += 1 # if i % 10 == 0: # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}") # print("\n=============================") # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%") # print("=============================") # if __name__ == "__main__": # main() import json import subprocess import sys import argparse import random import sqlite3 import time import re from pathlib import Path import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel # Assuming you have a prompting.py that has encode_prompt from prompting import encode_prompt # ------------------------------- # LIVE CHECK HELPERS # ------------------------------- def normalize_sql(sql): """Basic normalization for the live progress bar.""" sql = sql.replace('"', "'") sql = re.sub(r"\s+", " ", sql) return sql.strip().lower().rstrip(";") def check_execution(pred_sql, gold_sql, db_path): """Basic execution check for the live progress bar.""" try: conn = sqlite3.connect(db_path) conn.text_factory = lambda b: b.decode(errors='ignore') # 2-second timeout so the live tracker doesn't freeze forever start_time = time.monotonic() def timeout_handler(): return 1 if (time.monotonic() - start_time) > 2.0 else 0 conn.set_progress_handler(timeout_handler, 10000) cursor = conn.cursor() cursor.execute(pred_sql) pred_res = cursor.fetchall() cursor.execute(gold_sql) gold_res = cursor.fetchall() conn.close() # Simple sorted check for the live tracker return sorted(pred_res) == sorted(gold_res) except Exception: return False # ------------------------------- # SPIDER PARSER # ------------------------------- def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None: for line in stdout.splitlines(): if metric_type == "exec" and line.strip().startswith("execution"): try: return float(line.split()[-1]) except: pass elif metric_type == "match" and line.strip().startswith("exact"): try: return float(line.split()[-1]) except: pass return None # ------------------------------- # MAIN # ------------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint") parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate") parser.add_argument("--shuffle_dev", action="store_true") parser.add_argument("--shuffle_seed", type=int, default=42) args = parser.parse_args() project_root = Path(__file__).resolve().parents[1] adapter_dir = project_root / args.adapter db_root = project_root / "data" / "database" table_json = project_root / "data" / "tables.json" dev_json = project_root / "data" / "dev.json" pred_path = project_root / "temp_predictions.txt" temp_gold_path = project_root / "temp_gold.sql" if not adapter_dir.exists(): raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}") device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") BASE_MODEL = "Salesforce/codet5-base" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading Model: {args.adapter}...") base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device) model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device) model = model.merge_and_unload() model.eval() with dev_json.open() as f: dev = json.load(f) if args.shuffle_dev: rng = random.Random(args.shuffle_seed) rng.shuffle(dev) dev = dev[: args.num_samples] total = len(dev) gen_kwargs = dict( max_new_tokens=160, num_beams=4, do_sample=False, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) print(f"\nšŸš€ Generating and live-tracking {total} samples...\n") em_correct = 0 ex_correct = 0 with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad(): for i, ex in enumerate(dev, start=1): db_id = ex["db_id"] question = ex["question"] gold_query = ex["query"] db_path = db_root / db_id / f"{db_id}.sqlite" # Generate input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512) input_ids = input_ids.unsqueeze(0).to(device) attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device) outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # Write to files for official spider eval later out_pred.write(f"{pred_sql}\n") out_gold.write(f"{gold_query}\t{db_id}\n") # --- LIVE TRACKING CHECKS --- if normalize_sql(pred_sql) == normalize_sql(gold_query): em_correct += 1 if check_execution(pred_sql, gold_query, db_path): ex_correct += 1 # Print progress every 50 loops if i % 10 == 0 or i == total: print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%") print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n") eval_script = project_root / "spider_eval" / "evaluation.py" # 1. RUN EXACT MATCH EVAL cmd_match = [ sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "match", "--db", str(db_root), "--table", str(table_json), ] proc_match = subprocess.run(cmd_match, capture_output=True, text=True) exact_acc = _parse_spider_accuracy(proc_match.stdout, "match") # 2. RUN EXECUTION EVAL cmd_exec = [ sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "exec", "--db", str(db_root), "--table", str(table_json), ] proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True) exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec") print("==========================================") print(f"šŸŽÆ OFFICIAL SPIDER RESULTS FOR: {args.adapter}") print("==========================================") if exact_acc is not None: print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%") else: print("Exact Set Match Accuracy : Could not parse output") if exec_acc is not None: print(f"Execution Accuracy : {exec_acc*100:.2f}%") else: print("Execution Accuracy : Could not parse output") print("==========================================\n") if __name__ == "__main__": main()