File size: 5,097 Bytes
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import sqlite3
import torch
import re
import time
import argparse
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

PROJECT_ROOT = Path(__file__).resolve().parents[1]
DB_ROOT = PROJECT_ROOT / "data" / "database"

# -------------------------------
# 1. NORMALIZATION FOR EXACT MATCH
# -------------------------------
def normalize_sql(sql):
    """Cleans SQL to make Exact Match grading fair (ignores spacing/cases)."""
    sql = sql.replace('"', "'")        # Standardize quotes
    sql = re.sub(r"\s+", " ", sql)     # Remove extra spaces/newlines
    sql = sql.strip().lower()          # Lowercase everything
    sql = sql.rstrip(";")              # Remove trailing semicolons
    return sql

# -------------------------------
# 2. EXECUTION ACCURACY CHECK
# -------------------------------
def check_execution(pred_sql, gold_sql, db_path):
    """Runs both queries and checks if the output rows/columns match."""
    try:
        conn = sqlite3.connect(db_path)
        # Handle bad characters in Spider DBs
        conn.text_factory = lambda b: b.decode(errors='ignore')
        
        # 5-second timeout
        start_time = time.monotonic()
        def timeout_handler():
            return 1 if (time.monotonic() - start_time) > 5.0 else 0
        conn.set_progress_handler(timeout_handler, 10000)

        cursor = conn.cursor()

        # Get Predicted Result
        cursor.execute(pred_sql)
        pred_res = cursor.fetchall()

        # Get Gold Result
        cursor.execute(gold_sql)
        gold_res = cursor.fetchall()

        conn.close()
        return pred_res == gold_res
    except Exception:
        return False

# -------------------------------
# 3. LOAD SCHEMA
# -------------------------------
def load_schema(db_path):
    conn = sqlite3.connect(db_path)
    conn.text_factory = lambda b: b.decode(errors='ignore')
    cursor = conn.cursor()
    tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
    schema = ""
    for (table,) in tables:
        cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
        col_names = [c[1] for c in cols]
        schema += f"{table}({', '.join(col_names)})\n"
    conn.close()
    return schema

# -------------------------------
# 4. MAIN PIPELINE
# -------------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
    parser.add_argument("--num_samples", type=int, default=1034, help="How many samples to evaluate")
    args = parser.parse_args()

    device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
    base_model = "Salesforce/codet5-base"

    print(f"\n🚀 Loading Model from: {args.adapter}")
    tokenizer = AutoTokenizer.from_pretrained(args.adapter)
    base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
    model = PeftModel.from_pretrained(base, args.adapter).to(device)
    model = model.merge_and_unload()
    model.eval()

    dev_json = PROJECT_ROOT / "data" / "dev.json"
    with open(dev_json) as f:
        dev = json.load(f)[:args.num_samples]

    em_correct = 0
    ex_correct = 0
    total = len(dev)

    print(f"\n📊 Evaluating {total} queries for BOTH Exact Match and Execution Accuracy...\n")

    for i, ex in enumerate(dev, 1):
        question = ex["question"]
        gold_sql = ex["query"]
        db_id = ex["db_id"]
        db_path = DB_ROOT / db_id / f"{db_id}.sqlite"

        # Generate SQL
        schema = load_schema(db_path)
        prompt = f"Database Schema:\n{schema}\nTranslate English to SQL:\n{question}\nSQL:\n"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=100, num_beams=4, do_sample=False)
        
        pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "SQL:" in pred_sql:
            pred_sql = pred_sql.split("SQL:")[-1].strip()

        # --- METRIC 1: EXACT MATCH ---
        is_em = (normalize_sql(pred_sql) == normalize_sql(gold_sql))
        if is_em:
            em_correct += 1

        # --- METRIC 2: EXECUTION ACCURACY ---
        is_ex = check_execution(pred_sql, gold_sql, db_path)
        if is_ex:
            ex_correct += 1

        if i % 50 == 0 or i == total:
            print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")

    # Final Results
    final_em = (em_correct / total) * 100
    final_ex = (ex_correct / total) * 100

    print("\n==========================================")
    print(f"🎯 FINAL RESULTS FOR: {args.adapter}")
    print("==========================================")
    print(f"Exact Match (EM) Accuracy      : {final_em:.2f}%")
    print(f"Execution (EX) Accuracy        : {final_ex:.2f}%")
    print("==========================================\n")

if __name__ == "__main__":
    main()