import json import sqlite3 import torch import re import matplotlib.pyplot as plt import numpy as np from pathlib import Path from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel PROJECT_ROOT = Path(__file__).resolve().parents[1] DB_ROOT = PROJECT_ROOT / "data" / "database" # ------------------------------- # Extract SQL components # ------------------------------- def extract_components(sql): sql = sql.lower() return { "select": "select" in sql, "where": "where" in sql, "group": "group by" in sql, "order": "order by" in sql, "and_or": (" and " in sql) or (" or " in sql), "join": "join" in sql } # ------------------------------- # Fallback Difficulty Estimator # ------------------------------- def estimate_difficulty(sql): """Fallback if 'difficulty' is missing from the JSON.""" sql = sql.lower() joins = sql.count("join") conditions = sql.count("and") + sql.count("or") if "intersect" in sql or "except" in sql or "union" in sql or joins > 2: return "extra" elif joins == 2 or ("group by" in sql and conditions > 0): return "hard" elif joins == 1 or "group by" in sql or "order by" in sql: return "medium" else: return "easy" # ------------------------------- # Load schema # ------------------------------- def load_schema(db_path): conn = sqlite3.connect(db_path) conn.text_factory = lambda b: b.decode(errors='ignore') cursor = conn.cursor() tables = cursor.execute( "SELECT name FROM sqlite_master WHERE type='table';" ).fetchall() schema = "" for (table,) in tables: cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() col_names = [c[1] for c in cols] schema += f"{table}({', '.join(col_names)})\n" conn.close() return schema # ------------------------------- # Prompt # ------------------------------- def build_prompt(question, schema): return f"""Database Schema: {schema} Translate English to SQL: {question} SQL: """ # ------------------------------- # Main # ------------------------------- def main(): adapter = "checkpoints/rl_step_1800" base_model = "Salesforce/codet5-base" device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") print("Loading tokenizer and models...") tokenizer = AutoTokenizer.from_pretrained(adapter) base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) model = PeftModel.from_pretrained(base, adapter).to(device) model = model.merge_and_unload() model.eval() dev_json = PROJECT_ROOT / "data" / "dev.json" with open(dev_json) as f: dev = json.load(f)[:1000] # Adjust number to test more/less components_list = ["select", "where", "group", "order", "and_or", "join"] difficulties_list = ["easy", "medium", "hard", "extra"] # Nested dictionary for components stats = { comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list} for comp in components_list } # šŸš€ NEW: Trackers for OVERALL accuracy by difficulty overall_correct = {diff: 0 for diff in difficulties_list} overall_total = {diff: 0 for diff in difficulties_list} print(f"\nRunning grouped evaluation on {len(dev)} examples...\n") for i, ex in enumerate(dev, 1): question = ex["question"] gold_sql = ex["query"] db_id = ex["db_id"] # Determine difficulty difficulty = ex.get("difficulty", estimate_difficulty(gold_sql)) if difficulty not in difficulties_list: difficulty = "medium" db_path = DB_ROOT / db_id / f"{db_id}.sqlite" schema = load_schema(db_path) prompt = build_prompt(question, schema) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1000, num_beams=4, do_sample=False ) pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) if "SQL:" in pred_sql: pred_sql = pred_sql.split("SQL:")[-1] # --- 1. Update Overall Accuracy Trackers --- overall_total[difficulty] += 1 # Simple string match for quick overall accuracy if pred_sql.strip().lower() == gold_sql.strip().lower(): overall_correct[difficulty] += 1 # --- 2. Update Component Stats --- pred_comp = extract_components(pred_sql) gold_comp = extract_components(gold_sql) for comp in components_list: if gold_comp[comp]: # If the gold SQL required this component stats[comp][difficulty]["total"] += 1 if pred_comp[comp]: # If the model successfully generated it stats[comp][difficulty]["correct"] += 1 if i % 20 == 0: print(f"Processed {i}/{len(dev)}") # ------------------------------- # Plotting (Grouped Bar Chart) # ------------------------------- x = np.arange(len(components_list)) width = 0.2 def get_acc(diff): return [ (stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0 for comp in components_list ] acc_easy = get_acc("easy") acc_medium = get_acc("medium") acc_hard = get_acc("hard") acc_extra = get_acc("extra") fig, ax = plt.subplots(figsize=(14, 7)) bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71') bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f') bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22') bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c') ax.set_ylabel('Accuracy (%)', fontsize=12) ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold') ax.set_xticks(x) ax.set_xticklabels([c.upper() for c in components_list], fontsize=11) ax.legend(title="Query Difficulty") ax.set_ylim(0, 115) def autolabel(rects): for rect in rects: height = rect.get_height() if height > 0: ax.annotate(f'{int(height)}%', xy=(rect.get_x() + rect.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8, rotation=90) autolabel(bars1) autolabel(bars2) autolabel(bars3) autolabel(bars4) ax.yaxis.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() plt.savefig("component_by_difficulty_plot.png", dpi=300) # ------------------------------- # šŸš€ Terminal Printout # ------------------------------- print("\nāœ… Saved merged plot -> component_by_difficulty_plot.png") print("\n========================================") print("šŸ† OVERALL AVERAGE ACCURACY BY DIFFICULTY") print("========================================") for diff in difficulties_list: if overall_total[diff] > 0: avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2) print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)") else: print(f"{diff.capitalize():<8}: N/A (0 queries)") print("========================================\n") if __name__ == "__main__": main()