Spaces:
Sleeping
Sleeping
Commit ·
dc59b01
1
Parent(s): 16941e7
Add src folder
Browse files- src/__pycache__/schema_encoder.cpython-310.pyc +0 -0
- src/__pycache__/sql_validator.cpython-310.pyc +0 -0
- src/__pycache__/text2sql_engine.cpython-310.pyc +0 -0
- src/ask.py +93 -0
- src/component_analysis.py +229 -0
- src/convert_to_hf_dataset.py +8 -0
- src/eval_baseline_codet5.py +112 -0
- src/eval_both_metrics.py +144 -0
- src/eval_rl_fixed.py +466 -0
- src/eval_rl_t5.py +279 -0
- src/eval_single_model.py +218 -0
- src/evaluate_model_codet5.py +392 -0
- src/evaluate_model_t5_small_sft.py +179 -0
- src/evaluate_rl_bart.py +138 -0
- src/evaluate_sft_bart.py +190 -0
- src/execution_reward.py +409 -0
- src/generate_sql.py +68 -0
- src/human_eval_runner.py +152 -0
- src/load_lora_model.py +30 -0
- src/make_rl_dataset.py +20 -0
- src/manual_check.py +44 -0
- src/predict.py +112 -0
- src/prepare_dataset.py +143 -0
- src/prompting.py +151 -0
- src/run_sql.py +39 -0
- src/schema_encoder.py +51 -0
- src/schema_linker.py +215 -0
- src/sql_validator.py +133 -0
- src/text2sql_engine.py +286 -0
- src/tokenize_dataset.py +57 -0
- src/train_model.py +105 -0
- src/train_rl.py +816 -0
- src/train_rl_bart.py +370 -0
- src/train_rl_codet5.py +409 -0
- src/train_rl_lora.py +151 -0
- src/train_sft.py +192 -0
- src/train_sft_bart.py +202 -0
- src/train_sft_codet5.py +195 -0
src/__pycache__/schema_encoder.cpython-310.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
src/__pycache__/sql_validator.cpython-310.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
src/__pycache__/text2sql_engine.cpython-310.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
src/ask.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TERMINAL CHAT WITH DATABASE
|
| 3 |
+
Run:
|
| 4 |
+
python src/ask.py chinook_1
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from text2sql_engine import get_engine
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# -------------------------------
|
| 12 |
+
# Pretty table printer
|
| 13 |
+
# -------------------------------
|
| 14 |
+
def print_table(cols, rows, limit=20):
|
| 15 |
+
if not rows or not cols:
|
| 16 |
+
print("No results\n")
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
cols = [str(c) for c in cols]
|
| 20 |
+
|
| 21 |
+
widths = [max(len(c), 12) for c in cols]
|
| 22 |
+
|
| 23 |
+
for r in rows[:limit]:
|
| 24 |
+
for i, val in enumerate(r):
|
| 25 |
+
widths[i] = max(widths[i], len(str(val)))
|
| 26 |
+
|
| 27 |
+
header = " | ".join(cols[i].ljust(widths[i]) for i in range(len(cols)))
|
| 28 |
+
print("\n" + header)
|
| 29 |
+
print("-" * len(header))
|
| 30 |
+
|
| 31 |
+
for r in rows[:limit]:
|
| 32 |
+
print(" | ".join(str(r[i]).ljust(widths[i]) for i in range(len(cols))))
|
| 33 |
+
|
| 34 |
+
if len(rows) > limit:
|
| 35 |
+
print(f"\n... showing first {limit} rows of {len(rows)}")
|
| 36 |
+
|
| 37 |
+
print()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# -------------------------------
|
| 41 |
+
# Main loop
|
| 42 |
+
# -------------------------------
|
| 43 |
+
def main():
|
| 44 |
+
if len(sys.argv) < 2:
|
| 45 |
+
print("Usage: python src/ask.py <db_id>")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
db_id = sys.argv[1].strip()
|
| 49 |
+
|
| 50 |
+
print("Loading model... (first time takes 20-40s)")
|
| 51 |
+
engine = get_engine()
|
| 52 |
+
|
| 53 |
+
print(f"\nConnected to database: {db_id}")
|
| 54 |
+
print("Type 'exit' to quit\n")
|
| 55 |
+
|
| 56 |
+
while True:
|
| 57 |
+
try:
|
| 58 |
+
q = input("Ask> ").strip()
|
| 59 |
+
|
| 60 |
+
if not q:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if q.lower() in ["exit", "quit"]:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
result = engine.ask(q, db_id)
|
| 67 |
+
|
| 68 |
+
if result is None:
|
| 69 |
+
print("Model returned no output\n")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
print("\nGenerated SQL:")
|
| 73 |
+
print(result.get("sql", "<no sql>"))
|
| 74 |
+
|
| 75 |
+
if result.get("error"):
|
| 76 |
+
print("\nSQL Error:")
|
| 77 |
+
print(result["error"])
|
| 78 |
+
else:
|
| 79 |
+
print_table(
|
| 80 |
+
result.get("columns", []),
|
| 81 |
+
result.get("rows", []),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
except KeyboardInterrupt:
|
| 85 |
+
break
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print("\nRuntime error:", e, "\n")
|
| 88 |
+
|
| 89 |
+
print("\nBye!")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
src/component_analysis.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 13 |
+
|
| 14 |
+
# -------------------------------
|
| 15 |
+
# Extract SQL components
|
| 16 |
+
# -------------------------------
|
| 17 |
+
def extract_components(sql):
|
| 18 |
+
sql = sql.lower()
|
| 19 |
+
return {
|
| 20 |
+
"select": "select" in sql,
|
| 21 |
+
"where": "where" in sql,
|
| 22 |
+
"group": "group by" in sql,
|
| 23 |
+
"order": "order by" in sql,
|
| 24 |
+
"and_or": (" and " in sql) or (" or " in sql),
|
| 25 |
+
"join": "join" in sql
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# -------------------------------
|
| 29 |
+
# Fallback Difficulty Estimator
|
| 30 |
+
# -------------------------------
|
| 31 |
+
def estimate_difficulty(sql):
|
| 32 |
+
"""Fallback if 'difficulty' is missing from the JSON."""
|
| 33 |
+
sql = sql.lower()
|
| 34 |
+
joins = sql.count("join")
|
| 35 |
+
conditions = sql.count("and") + sql.count("or")
|
| 36 |
+
|
| 37 |
+
if "intersect" in sql or "except" in sql or "union" in sql or joins > 2:
|
| 38 |
+
return "extra"
|
| 39 |
+
elif joins == 2 or ("group by" in sql and conditions > 0):
|
| 40 |
+
return "hard"
|
| 41 |
+
elif joins == 1 or "group by" in sql or "order by" in sql:
|
| 42 |
+
return "medium"
|
| 43 |
+
else:
|
| 44 |
+
return "easy"
|
| 45 |
+
|
| 46 |
+
# -------------------------------
|
| 47 |
+
# Load schema
|
| 48 |
+
# -------------------------------
|
| 49 |
+
def load_schema(db_path):
|
| 50 |
+
conn = sqlite3.connect(db_path)
|
| 51 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 52 |
+
cursor = conn.cursor()
|
| 53 |
+
|
| 54 |
+
tables = cursor.execute(
|
| 55 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 56 |
+
).fetchall()
|
| 57 |
+
|
| 58 |
+
schema = ""
|
| 59 |
+
for (table,) in tables:
|
| 60 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 61 |
+
col_names = [c[1] for c in cols]
|
| 62 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 63 |
+
|
| 64 |
+
conn.close()
|
| 65 |
+
return schema
|
| 66 |
+
|
| 67 |
+
# -------------------------------
|
| 68 |
+
# Prompt
|
| 69 |
+
# -------------------------------
|
| 70 |
+
def build_prompt(question, schema):
|
| 71 |
+
return f"""Database Schema:
|
| 72 |
+
{schema}
|
| 73 |
+
|
| 74 |
+
Translate English to SQL:
|
| 75 |
+
{question}
|
| 76 |
+
SQL:
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# -------------------------------
|
| 80 |
+
# Main
|
| 81 |
+
# -------------------------------
|
| 82 |
+
def main():
|
| 83 |
+
adapter = "checkpoints/rl_step_1800"
|
| 84 |
+
base_model = "Salesforce/codet5-base"
|
| 85 |
+
|
| 86 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
|
| 88 |
+
print("Loading tokenizer and models...")
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter)
|
| 90 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 91 |
+
model = PeftModel.from_pretrained(base, adapter).to(device)
|
| 92 |
+
model = model.merge_and_unload()
|
| 93 |
+
model.eval()
|
| 94 |
+
|
| 95 |
+
dev_json = PROJECT_ROOT / "data" / "dev.json"
|
| 96 |
+
|
| 97 |
+
with open(dev_json) as f:
|
| 98 |
+
dev = json.load(f)[:1000] # Adjust number to test more/less
|
| 99 |
+
|
| 100 |
+
components_list = ["select", "where", "group", "order", "and_or", "join"]
|
| 101 |
+
difficulties_list = ["easy", "medium", "hard", "extra"]
|
| 102 |
+
|
| 103 |
+
# Nested dictionary for components
|
| 104 |
+
stats = {
|
| 105 |
+
comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list}
|
| 106 |
+
for comp in components_list
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# 🚀 NEW: Trackers for OVERALL accuracy by difficulty
|
| 110 |
+
overall_correct = {diff: 0 for diff in difficulties_list}
|
| 111 |
+
overall_total = {diff: 0 for diff in difficulties_list}
|
| 112 |
+
|
| 113 |
+
print(f"\nRunning grouped evaluation on {len(dev)} examples...\n")
|
| 114 |
+
|
| 115 |
+
for i, ex in enumerate(dev, 1):
|
| 116 |
+
question = ex["question"]
|
| 117 |
+
gold_sql = ex["query"]
|
| 118 |
+
db_id = ex["db_id"]
|
| 119 |
+
|
| 120 |
+
# Determine difficulty
|
| 121 |
+
difficulty = ex.get("difficulty", estimate_difficulty(gold_sql))
|
| 122 |
+
if difficulty not in difficulties_list:
|
| 123 |
+
difficulty = "medium"
|
| 124 |
+
|
| 125 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 126 |
+
schema = load_schema(db_path)
|
| 127 |
+
prompt = build_prompt(question, schema)
|
| 128 |
+
|
| 129 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
outputs = model.generate(
|
| 133 |
+
**inputs,
|
| 134 |
+
max_new_tokens=1000,
|
| 135 |
+
num_beams=4,
|
| 136 |
+
do_sample=False
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 140 |
+
if "SQL:" in pred_sql:
|
| 141 |
+
pred_sql = pred_sql.split("SQL:")[-1]
|
| 142 |
+
|
| 143 |
+
# --- 1. Update Overall Accuracy Trackers ---
|
| 144 |
+
overall_total[difficulty] += 1
|
| 145 |
+
# Simple string match for quick overall accuracy
|
| 146 |
+
if pred_sql.strip().lower() == gold_sql.strip().lower():
|
| 147 |
+
overall_correct[difficulty] += 1
|
| 148 |
+
|
| 149 |
+
# --- 2. Update Component Stats ---
|
| 150 |
+
pred_comp = extract_components(pred_sql)
|
| 151 |
+
gold_comp = extract_components(gold_sql)
|
| 152 |
+
|
| 153 |
+
for comp in components_list:
|
| 154 |
+
if gold_comp[comp]: # If the gold SQL required this component
|
| 155 |
+
stats[comp][difficulty]["total"] += 1
|
| 156 |
+
if pred_comp[comp]: # If the model successfully generated it
|
| 157 |
+
stats[comp][difficulty]["correct"] += 1
|
| 158 |
+
|
| 159 |
+
if i % 20 == 0:
|
| 160 |
+
print(f"Processed {i}/{len(dev)}")
|
| 161 |
+
|
| 162 |
+
# -------------------------------
|
| 163 |
+
# Plotting (Grouped Bar Chart)
|
| 164 |
+
# -------------------------------
|
| 165 |
+
x = np.arange(len(components_list))
|
| 166 |
+
width = 0.2
|
| 167 |
+
|
| 168 |
+
def get_acc(diff):
|
| 169 |
+
return [
|
| 170 |
+
(stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0
|
| 171 |
+
for comp in components_list
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
acc_easy = get_acc("easy")
|
| 175 |
+
acc_medium = get_acc("medium")
|
| 176 |
+
acc_hard = get_acc("hard")
|
| 177 |
+
acc_extra = get_acc("extra")
|
| 178 |
+
|
| 179 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
| 180 |
+
|
| 181 |
+
bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71')
|
| 182 |
+
bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f')
|
| 183 |
+
bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22')
|
| 184 |
+
bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c')
|
| 185 |
+
|
| 186 |
+
ax.set_ylabel('Accuracy (%)', fontsize=12)
|
| 187 |
+
ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold')
|
| 188 |
+
ax.set_xticks(x)
|
| 189 |
+
ax.set_xticklabels([c.upper() for c in components_list], fontsize=11)
|
| 190 |
+
ax.legend(title="Query Difficulty")
|
| 191 |
+
ax.set_ylim(0, 115)
|
| 192 |
+
|
| 193 |
+
def autolabel(rects):
|
| 194 |
+
for rect in rects:
|
| 195 |
+
height = rect.get_height()
|
| 196 |
+
if height > 0:
|
| 197 |
+
ax.annotate(f'{int(height)}%',
|
| 198 |
+
xy=(rect.get_x() + rect.get_width() / 2, height),
|
| 199 |
+
xytext=(0, 3),
|
| 200 |
+
textcoords="offset points",
|
| 201 |
+
ha='center', va='bottom', fontsize=8, rotation=90)
|
| 202 |
+
|
| 203 |
+
autolabel(bars1)
|
| 204 |
+
autolabel(bars2)
|
| 205 |
+
autolabel(bars3)
|
| 206 |
+
autolabel(bars4)
|
| 207 |
+
|
| 208 |
+
ax.yaxis.grid(True, linestyle='--', alpha=0.7)
|
| 209 |
+
plt.tight_layout()
|
| 210 |
+
plt.savefig("component_by_difficulty_plot.png", dpi=300)
|
| 211 |
+
|
| 212 |
+
# -------------------------------
|
| 213 |
+
# 🚀 Terminal Printout
|
| 214 |
+
# -------------------------------
|
| 215 |
+
print("\n✅ Saved merged plot -> component_by_difficulty_plot.png")
|
| 216 |
+
|
| 217 |
+
print("\n========================================")
|
| 218 |
+
print("🏆 OVERALL AVERAGE ACCURACY BY DIFFICULTY")
|
| 219 |
+
print("========================================")
|
| 220 |
+
for diff in difficulties_list:
|
| 221 |
+
if overall_total[diff] > 0:
|
| 222 |
+
avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2)
|
| 223 |
+
print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)")
|
| 224 |
+
else:
|
| 225 |
+
print(f"{diff.capitalize():<8}: N/A (0 queries)")
|
| 226 |
+
print("========================================\n")
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
main()
|
src/convert_to_hf_dataset.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
df = pd.read_csv("../data/processed/train.csv")
|
| 5 |
+
ds = Dataset.from_pandas(df)
|
| 6 |
+
ds.save_to_disk("../data/processed/train")
|
| 7 |
+
print("DONE")
|
| 8 |
+
|
src/eval_baseline_codet5.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 6 |
+
|
| 7 |
+
# ---------------- PROMPT (same style as training) ----------------
|
| 8 |
+
def build_prompt(question, schema):
|
| 9 |
+
return f"""translate English to SQL:
|
| 10 |
+
|
| 11 |
+
Schema:
|
| 12 |
+
{schema}
|
| 13 |
+
|
| 14 |
+
Question:
|
| 15 |
+
{question}
|
| 16 |
+
|
| 17 |
+
SQL:"""
|
| 18 |
+
|
| 19 |
+
# ---------------- LOAD SCHEMA ----------------
|
| 20 |
+
def load_schema(db_path):
|
| 21 |
+
conn = sqlite3.connect(db_path)
|
| 22 |
+
cursor = conn.cursor()
|
| 23 |
+
|
| 24 |
+
tables = cursor.execute(
|
| 25 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 26 |
+
).fetchall()
|
| 27 |
+
|
| 28 |
+
schema = ""
|
| 29 |
+
for (table,) in tables:
|
| 30 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 31 |
+
col_names = [c[1] for c in cols]
|
| 32 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 33 |
+
|
| 34 |
+
conn.close()
|
| 35 |
+
return schema
|
| 36 |
+
|
| 37 |
+
# ---------------- EXECUTION MATCH ----------------
|
| 38 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 39 |
+
try:
|
| 40 |
+
conn = sqlite3.connect(db_path)
|
| 41 |
+
cur = conn.cursor()
|
| 42 |
+
|
| 43 |
+
cur.execute(pred_sql)
|
| 44 |
+
pred = cur.fetchall()
|
| 45 |
+
|
| 46 |
+
cur.execute(gold_sql)
|
| 47 |
+
gold = cur.fetchall()
|
| 48 |
+
|
| 49 |
+
conn.close()
|
| 50 |
+
return pred == gold
|
| 51 |
+
|
| 52 |
+
except Exception:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
# ---------------- MAIN ----------------
|
| 56 |
+
def main():
|
| 57 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 58 |
+
|
| 59 |
+
dev_json = project_root / "data" / "dev.json"
|
| 60 |
+
db_root = project_root / "data" / "database"
|
| 61 |
+
|
| 62 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 63 |
+
|
| 64 |
+
print("Loading BASE CodeT5...")
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
|
| 66 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
|
| 67 |
+
model.eval()
|
| 68 |
+
|
| 69 |
+
with open(dev_json) as f:
|
| 70 |
+
dev = json.load(f)[:100]
|
| 71 |
+
|
| 72 |
+
correct = 0
|
| 73 |
+
|
| 74 |
+
print(f"\nEvaluating {len(dev)} samples...\n")
|
| 75 |
+
|
| 76 |
+
for i, ex in enumerate(dev, 1):
|
| 77 |
+
question = ex["question"]
|
| 78 |
+
db_id = ex["db_id"]
|
| 79 |
+
gold_sql = ex["query"]
|
| 80 |
+
|
| 81 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 82 |
+
schema = load_schema(db_path)
|
| 83 |
+
|
| 84 |
+
prompt = build_prompt(question, schema)
|
| 85 |
+
|
| 86 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model.generate(
|
| 90 |
+
**inputs,
|
| 91 |
+
max_new_tokens=80,
|
| 92 |
+
num_beams=4,
|
| 93 |
+
do_sample=False
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 97 |
+
|
| 98 |
+
if "SQL:" in pred_sql:
|
| 99 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 100 |
+
|
| 101 |
+
if execution_match(pred_sql, gold_sql, db_path):
|
| 102 |
+
correct += 1
|
| 103 |
+
|
| 104 |
+
if i % 10 == 0:
|
| 105 |
+
print(f"{i}/100 | Accuracy: {correct/i:.3f}")
|
| 106 |
+
|
| 107 |
+
print("\n=============================")
|
| 108 |
+
print(f"BASE MODEL ACCURACY: {correct}% / 100 = {correct}%")
|
| 109 |
+
print("=============================")
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
src/eval_both_metrics.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 13 |
+
|
| 14 |
+
# -------------------------------
|
| 15 |
+
# 1. NORMALIZATION FOR EXACT MATCH
|
| 16 |
+
# -------------------------------
|
| 17 |
+
def normalize_sql(sql):
|
| 18 |
+
"""Cleans SQL to make Exact Match grading fair (ignores spacing/cases)."""
|
| 19 |
+
sql = sql.replace('"', "'") # Standardize quotes
|
| 20 |
+
sql = re.sub(r"\s+", " ", sql) # Remove extra spaces/newlines
|
| 21 |
+
sql = sql.strip().lower() # Lowercase everything
|
| 22 |
+
sql = sql.rstrip(";") # Remove trailing semicolons
|
| 23 |
+
return sql
|
| 24 |
+
|
| 25 |
+
# -------------------------------
|
| 26 |
+
# 2. EXECUTION ACCURACY CHECK
|
| 27 |
+
# -------------------------------
|
| 28 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 29 |
+
"""Runs both queries and checks if the output rows/columns match."""
|
| 30 |
+
try:
|
| 31 |
+
conn = sqlite3.connect(db_path)
|
| 32 |
+
# Handle bad characters in Spider DBs
|
| 33 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 34 |
+
|
| 35 |
+
# 5-second timeout
|
| 36 |
+
start_time = time.monotonic()
|
| 37 |
+
def timeout_handler():
|
| 38 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 39 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 40 |
+
|
| 41 |
+
cursor = conn.cursor()
|
| 42 |
+
|
| 43 |
+
# Get Predicted Result
|
| 44 |
+
cursor.execute(pred_sql)
|
| 45 |
+
pred_res = cursor.fetchall()
|
| 46 |
+
|
| 47 |
+
# Get Gold Result
|
| 48 |
+
cursor.execute(gold_sql)
|
| 49 |
+
gold_res = cursor.fetchall()
|
| 50 |
+
|
| 51 |
+
conn.close()
|
| 52 |
+
return pred_res == gold_res
|
| 53 |
+
except Exception:
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
# -------------------------------
|
| 57 |
+
# 3. LOAD SCHEMA
|
| 58 |
+
# -------------------------------
|
| 59 |
+
def load_schema(db_path):
|
| 60 |
+
conn = sqlite3.connect(db_path)
|
| 61 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 62 |
+
cursor = conn.cursor()
|
| 63 |
+
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 64 |
+
schema = ""
|
| 65 |
+
for (table,) in tables:
|
| 66 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 67 |
+
col_names = [c[1] for c in cols]
|
| 68 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 69 |
+
conn.close()
|
| 70 |
+
return schema
|
| 71 |
+
|
| 72 |
+
# -------------------------------
|
| 73 |
+
# 4. MAIN PIPELINE
|
| 74 |
+
# -------------------------------
|
| 75 |
+
def main():
|
| 76 |
+
parser = argparse.ArgumentParser()
|
| 77 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 78 |
+
parser.add_argument("--num_samples", type=int, default=1034, help="How many samples to evaluate")
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
+
base_model = "Salesforce/codet5-base"
|
| 83 |
+
|
| 84 |
+
print(f"\n🚀 Loading Model from: {args.adapter}")
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 86 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 87 |
+
model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 88 |
+
model = model.merge_and_unload()
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
dev_json = PROJECT_ROOT / "data" / "dev.json"
|
| 92 |
+
with open(dev_json) as f:
|
| 93 |
+
dev = json.load(f)[:args.num_samples]
|
| 94 |
+
|
| 95 |
+
em_correct = 0
|
| 96 |
+
ex_correct = 0
|
| 97 |
+
total = len(dev)
|
| 98 |
+
|
| 99 |
+
print(f"\n📊 Evaluating {total} queries for BOTH Exact Match and Execution Accuracy...\n")
|
| 100 |
+
|
| 101 |
+
for i, ex in enumerate(dev, 1):
|
| 102 |
+
question = ex["question"]
|
| 103 |
+
gold_sql = ex["query"]
|
| 104 |
+
db_id = ex["db_id"]
|
| 105 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 106 |
+
|
| 107 |
+
# Generate SQL
|
| 108 |
+
schema = load_schema(db_path)
|
| 109 |
+
prompt = f"Database Schema:\n{schema}\nTranslate English to SQL:\n{question}\nSQL:\n"
|
| 110 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
outputs = model.generate(**inputs, max_new_tokens=100, num_beams=4, do_sample=False)
|
| 114 |
+
|
| 115 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 116 |
+
if "SQL:" in pred_sql:
|
| 117 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 118 |
+
|
| 119 |
+
# --- METRIC 1: EXACT MATCH ---
|
| 120 |
+
is_em = (normalize_sql(pred_sql) == normalize_sql(gold_sql))
|
| 121 |
+
if is_em:
|
| 122 |
+
em_correct += 1
|
| 123 |
+
|
| 124 |
+
# --- METRIC 2: EXECUTION ACCURACY ---
|
| 125 |
+
is_ex = check_execution(pred_sql, gold_sql, db_path)
|
| 126 |
+
if is_ex:
|
| 127 |
+
ex_correct += 1
|
| 128 |
+
|
| 129 |
+
if i % 50 == 0 or i == total:
|
| 130 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 131 |
+
|
| 132 |
+
# Final Results
|
| 133 |
+
final_em = (em_correct / total) * 100
|
| 134 |
+
final_ex = (ex_correct / total) * 100
|
| 135 |
+
|
| 136 |
+
print("\n==========================================")
|
| 137 |
+
print(f"🎯 FINAL RESULTS FOR: {args.adapter}")
|
| 138 |
+
print("==========================================")
|
| 139 |
+
print(f"Exact Match (EM) Accuracy : {final_em:.2f}%")
|
| 140 |
+
print(f"Execution (EX) Accuracy : {final_ex:.2f}%")
|
| 141 |
+
print("==========================================\n")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
src/eval_rl_fixed.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import json
|
| 2 |
+
# import sqlite3
|
| 3 |
+
# import argparse
|
| 4 |
+
# from pathlib import Path
|
| 5 |
+
# import torch
|
| 6 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
+
# from peft import PeftModel
|
| 8 |
+
|
| 9 |
+
# # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 10 |
+
# def build_prompt(question, schema):
|
| 11 |
+
# return f"""
|
| 12 |
+
# Database Schema:
|
| 13 |
+
# {schema}
|
| 14 |
+
|
| 15 |
+
# Translate English to SQL:
|
| 16 |
+
# {question}
|
| 17 |
+
# SQL:
|
| 18 |
+
# """
|
| 19 |
+
|
| 20 |
+
# # ---------------- LOAD SCHEMA ----------------
|
| 21 |
+
# def load_schema(db_path):
|
| 22 |
+
# conn = sqlite3.connect(db_path)
|
| 23 |
+
# cursor = conn.cursor()
|
| 24 |
+
|
| 25 |
+
# tables = cursor.execute(
|
| 26 |
+
# "SELECT name FROM sqlite_master WHERE type='table';"
|
| 27 |
+
# ).fetchall()
|
| 28 |
+
|
| 29 |
+
# schema = ""
|
| 30 |
+
# for (table,) in tables:
|
| 31 |
+
# cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 32 |
+
# col_names = [c[1] for c in cols]
|
| 33 |
+
# schema += f"{table}({', '.join(col_names)})\n"
|
| 34 |
+
|
| 35 |
+
# conn.close()
|
| 36 |
+
# return schema
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# # ---------------- EXECUTION CHECK ----------------
|
| 40 |
+
# def execution_match(pred_sql, gold_sql, db_path):
|
| 41 |
+
# try:
|
| 42 |
+
# conn = sqlite3.connect(db_path)
|
| 43 |
+
# cur = conn.cursor()
|
| 44 |
+
|
| 45 |
+
# cur.execute(pred_sql)
|
| 46 |
+
# pred = cur.fetchall()
|
| 47 |
+
|
| 48 |
+
# cur.execute(gold_sql)
|
| 49 |
+
# gold = cur.fetchall()
|
| 50 |
+
|
| 51 |
+
# conn.close()
|
| 52 |
+
# return pred == gold
|
| 53 |
+
|
| 54 |
+
# except Exception:
|
| 55 |
+
# return False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# # ---------------- MAIN ----------------
|
| 59 |
+
# def main():
|
| 60 |
+
# parser = argparse.ArgumentParser()
|
| 61 |
+
# parser.add_argument("--adapter", type=str, required=True)
|
| 62 |
+
# parser.add_argument("--num_samples", type=int, default=1034)
|
| 63 |
+
# args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 66 |
+
|
| 67 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 68 |
+
# db_root = project_root / "data" / "database"
|
| 69 |
+
|
| 70 |
+
# device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 71 |
+
|
| 72 |
+
# # load model
|
| 73 |
+
# base_model = "Salesforce/codet5-base"
|
| 74 |
+
# tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 75 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 76 |
+
# model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 77 |
+
# model = model.merge_and_unload()
|
| 78 |
+
|
| 79 |
+
# with open(dev_json) as f:
|
| 80 |
+
# dev = json.load(f)[: args.num_samples]
|
| 81 |
+
|
| 82 |
+
# correct = 0
|
| 83 |
+
|
| 84 |
+
# print(f"Evaluating {len(dev)} examples...\n")
|
| 85 |
+
|
| 86 |
+
# for i, ex in enumerate(dev, 1):
|
| 87 |
+
# question = ex["question"]
|
| 88 |
+
# db_id = ex["db_id"]
|
| 89 |
+
# gold_sql = ex["query"]
|
| 90 |
+
|
| 91 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 92 |
+
# schema = load_schema(db_path)
|
| 93 |
+
|
| 94 |
+
# prompt = build_prompt(question, schema)
|
| 95 |
+
|
| 96 |
+
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 97 |
+
|
| 98 |
+
# with torch.no_grad():
|
| 99 |
+
# outputs = model.generate(
|
| 100 |
+
# **inputs,
|
| 101 |
+
# max_new_tokens=80,
|
| 102 |
+
# do_sample=False,
|
| 103 |
+
# num_beams=4,
|
| 104 |
+
# )
|
| 105 |
+
|
| 106 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 107 |
+
|
| 108 |
+
# if "SQL:" in pred_sql:
|
| 109 |
+
# pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 110 |
+
|
| 111 |
+
# match = execution_match(pred_sql, gold_sql, db_path)
|
| 112 |
+
|
| 113 |
+
# if match:
|
| 114 |
+
# correct += 1
|
| 115 |
+
|
| 116 |
+
# if i % 10 == 0:
|
| 117 |
+
# print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 118 |
+
|
| 119 |
+
# print("\n=============================")
|
| 120 |
+
# print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 121 |
+
# print("=============================")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# if __name__ == "__main__":
|
| 125 |
+
# main()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# import json
|
| 129 |
+
# import sqlite3
|
| 130 |
+
# import argparse
|
| 131 |
+
# import time
|
| 132 |
+
# from pathlib import Path
|
| 133 |
+
# import torch
|
| 134 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 135 |
+
# from peft import PeftModel
|
| 136 |
+
|
| 137 |
+
# # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 138 |
+
# def build_prompt(question, schema):
|
| 139 |
+
# return f"""
|
| 140 |
+
# Database Schema:
|
| 141 |
+
# {schema}
|
| 142 |
+
|
| 143 |
+
# Translate English to SQL:
|
| 144 |
+
# {question}
|
| 145 |
+
# SQL:
|
| 146 |
+
# """
|
| 147 |
+
|
| 148 |
+
# # ---------------- LOAD SCHEMA ----------------
|
| 149 |
+
# def load_schema(db_path):
|
| 150 |
+
# conn = sqlite3.connect(db_path)
|
| 151 |
+
# cursor = conn.cursor()
|
| 152 |
+
|
| 153 |
+
# tables = cursor.execute(
|
| 154 |
+
# "SELECT name FROM sqlite_master WHERE type='table';"
|
| 155 |
+
# ).fetchall()
|
| 156 |
+
|
| 157 |
+
# schema = ""
|
| 158 |
+
# for (table,) in tables:
|
| 159 |
+
# cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 160 |
+
# col_names = [c[1] for c in cols]
|
| 161 |
+
# schema += f"{table}({', '.join(col_names)})\n"
|
| 162 |
+
|
| 163 |
+
# conn.close()
|
| 164 |
+
# return schema
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 168 |
+
# def execution_match(pred_sql, gold_sql, db_path):
|
| 169 |
+
# try:
|
| 170 |
+
# conn = sqlite3.connect(db_path)
|
| 171 |
+
|
| 172 |
+
# # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
|
| 173 |
+
# start_time = time.monotonic()
|
| 174 |
+
# def timeout_handler():
|
| 175 |
+
# return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 176 |
+
# conn.set_progress_handler(timeout_handler, 10000)
|
| 177 |
+
|
| 178 |
+
# cur = conn.cursor()
|
| 179 |
+
|
| 180 |
+
# cur.execute(pred_sql)
|
| 181 |
+
# pred = cur.fetchall()
|
| 182 |
+
|
| 183 |
+
# cur.execute(gold_sql)
|
| 184 |
+
# gold = cur.fetchall()
|
| 185 |
+
|
| 186 |
+
# conn.close()
|
| 187 |
+
# return pred == gold
|
| 188 |
+
|
| 189 |
+
# except Exception:
|
| 190 |
+
# return False
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# # ---------------- MAIN ----------------
|
| 194 |
+
# def main():
|
| 195 |
+
# parser = argparse.ArgumentParser()
|
| 196 |
+
# parser.add_argument("--adapter", type=str, required=True)
|
| 197 |
+
# parser.add_argument("--num_samples", type=int, default=1034)
|
| 198 |
+
# args = parser.parse_args()
|
| 199 |
+
|
| 200 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 201 |
+
|
| 202 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 203 |
+
# db_root = project_root / "data" / "database"
|
| 204 |
+
|
| 205 |
+
# # 🎯 Added CUDA support for Nvidia GPUs
|
| 206 |
+
# device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 207 |
+
|
| 208 |
+
# # load model
|
| 209 |
+
# base_model = "Salesforce/codet5-base"
|
| 210 |
+
# print(f"Loading Base: {base_model}")
|
| 211 |
+
# print(f"Loading Adapter: {args.adapter}")
|
| 212 |
+
|
| 213 |
+
# tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 214 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 215 |
+
# model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 216 |
+
# model = model.merge_and_unload()
|
| 217 |
+
|
| 218 |
+
# with open(dev_json) as f:
|
| 219 |
+
# dev = json.load(f)[: args.num_samples]
|
| 220 |
+
|
| 221 |
+
# correct = 0
|
| 222 |
+
|
| 223 |
+
# print(f"Evaluating {len(dev)} examples...\n")
|
| 224 |
+
|
| 225 |
+
# for i, ex in enumerate(dev, 1):
|
| 226 |
+
# question = ex["question"]
|
| 227 |
+
# db_id = ex["db_id"]
|
| 228 |
+
# gold_sql = ex["query"]
|
| 229 |
+
|
| 230 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 231 |
+
# schema = load_schema(db_path)
|
| 232 |
+
|
| 233 |
+
# prompt = build_prompt(question, schema)
|
| 234 |
+
|
| 235 |
+
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 236 |
+
|
| 237 |
+
# with torch.no_grad():
|
| 238 |
+
# outputs = model.generate(
|
| 239 |
+
# **inputs,
|
| 240 |
+
# max_new_tokens=80,
|
| 241 |
+
# do_sample=False,
|
| 242 |
+
# num_beams=4,
|
| 243 |
+
# )
|
| 244 |
+
|
| 245 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 246 |
+
|
| 247 |
+
# if "SQL:" in pred_sql:
|
| 248 |
+
# pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 249 |
+
|
| 250 |
+
# match = execution_match(pred_sql, gold_sql, db_path)
|
| 251 |
+
|
| 252 |
+
# if match:
|
| 253 |
+
# correct += 1
|
| 254 |
+
|
| 255 |
+
# if i % 10 == 0:
|
| 256 |
+
# print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 257 |
+
|
| 258 |
+
# print("\n=============================")
|
| 259 |
+
# print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 260 |
+
# print("=============================")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# if __name__ == "__main__":
|
| 264 |
+
# main()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
import json
|
| 268 |
+
import subprocess
|
| 269 |
+
import sys
|
| 270 |
+
import argparse
|
| 271 |
+
import random
|
| 272 |
+
import sqlite3
|
| 273 |
+
import time
|
| 274 |
+
import re
|
| 275 |
+
from pathlib import Path
|
| 276 |
+
|
| 277 |
+
import torch
|
| 278 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 279 |
+
from peft import PeftModel
|
| 280 |
+
|
| 281 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 282 |
+
from prompting import encode_prompt
|
| 283 |
+
|
| 284 |
+
# -------------------------------
|
| 285 |
+
# LIVE CHECK HELPERS
|
| 286 |
+
# -------------------------------
|
| 287 |
+
def normalize_sql(sql):
|
| 288 |
+
"""Basic normalization for the live progress bar."""
|
| 289 |
+
sql = sql.replace('"', "'")
|
| 290 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 291 |
+
return sql.strip().lower().rstrip(";")
|
| 292 |
+
|
| 293 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 294 |
+
"""Basic execution check for the live progress bar."""
|
| 295 |
+
try:
|
| 296 |
+
conn = sqlite3.connect(db_path)
|
| 297 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 298 |
+
|
| 299 |
+
# 2-second timeout so the live tracker doesn't freeze forever
|
| 300 |
+
start_time = time.monotonic()
|
| 301 |
+
def timeout_handler():
|
| 302 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 303 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 304 |
+
|
| 305 |
+
cursor = conn.cursor()
|
| 306 |
+
cursor.execute(pred_sql)
|
| 307 |
+
pred_res = cursor.fetchall()
|
| 308 |
+
|
| 309 |
+
cursor.execute(gold_sql)
|
| 310 |
+
gold_res = cursor.fetchall()
|
| 311 |
+
conn.close()
|
| 312 |
+
|
| 313 |
+
# Simple sorted check for the live tracker
|
| 314 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 315 |
+
except Exception:
|
| 316 |
+
return False
|
| 317 |
+
|
| 318 |
+
# -------------------------------
|
| 319 |
+
# SPIDER PARSER
|
| 320 |
+
# -------------------------------
|
| 321 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 322 |
+
for line in stdout.splitlines():
|
| 323 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 324 |
+
try: return float(line.split()[-1])
|
| 325 |
+
except: pass
|
| 326 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 327 |
+
try: return float(line.split()[-1])
|
| 328 |
+
except: pass
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
# -------------------------------
|
| 332 |
+
# MAIN
|
| 333 |
+
# -------------------------------
|
| 334 |
+
def main():
|
| 335 |
+
parser = argparse.ArgumentParser()
|
| 336 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 337 |
+
parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate")
|
| 338 |
+
parser.add_argument("--shuffle_dev", action="store_true")
|
| 339 |
+
parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 340 |
+
args = parser.parse_args()
|
| 341 |
+
|
| 342 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 343 |
+
adapter_dir = project_root / args.adapter
|
| 344 |
+
|
| 345 |
+
db_root = project_root / "data" / "database"
|
| 346 |
+
table_json = project_root / "data" / "tables.json"
|
| 347 |
+
dev_json = project_root / "data" / "dev.json"
|
| 348 |
+
|
| 349 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 350 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 351 |
+
|
| 352 |
+
if not adapter_dir.exists():
|
| 353 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 354 |
+
|
| 355 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 356 |
+
print(f"Using device: {device}")
|
| 357 |
+
|
| 358 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 359 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 360 |
+
if tokenizer.pad_token is None:
|
| 361 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 362 |
+
|
| 363 |
+
print(f"Loading Model: {args.adapter}...")
|
| 364 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 365 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 366 |
+
model = model.merge_and_unload()
|
| 367 |
+
model.eval()
|
| 368 |
+
|
| 369 |
+
with dev_json.open() as f:
|
| 370 |
+
dev = json.load(f)
|
| 371 |
+
|
| 372 |
+
if args.shuffle_dev:
|
| 373 |
+
rng = random.Random(args.shuffle_seed)
|
| 374 |
+
rng.shuffle(dev)
|
| 375 |
+
|
| 376 |
+
dev = dev[: args.num_samples]
|
| 377 |
+
total = len(dev)
|
| 378 |
+
|
| 379 |
+
gen_kwargs = dict(
|
| 380 |
+
max_new_tokens=160,
|
| 381 |
+
num_beams=4,
|
| 382 |
+
do_sample=False,
|
| 383 |
+
early_stopping=True,
|
| 384 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 385 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
| 389 |
+
|
| 390 |
+
em_correct = 0
|
| 391 |
+
ex_correct = 0
|
| 392 |
+
|
| 393 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 394 |
+
for i, ex in enumerate(dev, start=1):
|
| 395 |
+
db_id = ex["db_id"]
|
| 396 |
+
question = ex["question"]
|
| 397 |
+
gold_query = ex["query"]
|
| 398 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 399 |
+
|
| 400 |
+
# Generate
|
| 401 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 402 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 403 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 404 |
+
|
| 405 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 406 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 407 |
+
|
| 408 |
+
# Write to files for official spider eval later
|
| 409 |
+
out_pred.write(f"{pred_sql}\n")
|
| 410 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 411 |
+
|
| 412 |
+
# --- LIVE TRACKING CHECKS ---
|
| 413 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 414 |
+
em_correct += 1
|
| 415 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 416 |
+
ex_correct += 1
|
| 417 |
+
|
| 418 |
+
# Print progress every 50 loops
|
| 419 |
+
if i % 10 == 0 or i == total:
|
| 420 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 421 |
+
|
| 422 |
+
print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
|
| 423 |
+
|
| 424 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 425 |
+
|
| 426 |
+
# 1. RUN EXACT MATCH EVAL
|
| 427 |
+
cmd_match = [
|
| 428 |
+
sys.executable, str(eval_script),
|
| 429 |
+
"--gold", str(temp_gold_path),
|
| 430 |
+
"--pred", str(pred_path),
|
| 431 |
+
"--etype", "match",
|
| 432 |
+
"--db", str(db_root),
|
| 433 |
+
"--table", str(table_json),
|
| 434 |
+
]
|
| 435 |
+
proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 436 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 437 |
+
|
| 438 |
+
# 2. RUN EXECUTION EVAL
|
| 439 |
+
cmd_exec = [
|
| 440 |
+
sys.executable, str(eval_script),
|
| 441 |
+
"--gold", str(temp_gold_path),
|
| 442 |
+
"--pred", str(pred_path),
|
| 443 |
+
"--etype", "exec",
|
| 444 |
+
"--db", str(db_root),
|
| 445 |
+
"--table", str(table_json),
|
| 446 |
+
]
|
| 447 |
+
proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 448 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 449 |
+
|
| 450 |
+
print("==========================================")
|
| 451 |
+
print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 452 |
+
print("==========================================")
|
| 453 |
+
|
| 454 |
+
if exact_acc is not None:
|
| 455 |
+
print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
|
| 456 |
+
else:
|
| 457 |
+
print("Exact Set Match Accuracy : Could not parse output")
|
| 458 |
+
|
| 459 |
+
if exec_acc is not None:
|
| 460 |
+
print(f"Execution Accuracy : {exec_acc*100:.2f}%")
|
| 461 |
+
else:
|
| 462 |
+
print("Execution Accuracy : Could not parse output")
|
| 463 |
+
print("==========================================\n")
|
| 464 |
+
|
| 465 |
+
if __name__ == "__main__":
|
| 466 |
+
main()
|
src/eval_rl_t5.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import sys
|
| 2 |
+
# import os
|
| 3 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 4 |
+
# import json
|
| 5 |
+
|
| 6 |
+
# import subprocess
|
| 7 |
+
|
| 8 |
+
# import argparse
|
| 9 |
+
# from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# import torch
|
| 12 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
# from peft import PeftModel
|
| 14 |
+
|
| 15 |
+
# # IMPORTANT: must match training prompt format
|
| 16 |
+
# from prompting import build_prompt
|
| 17 |
+
# from schema_utils import get_schema as get_db_schema
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# def _parse_exec_accuracy(stdout: str):
|
| 21 |
+
# for line in stdout.splitlines():
|
| 22 |
+
# if line.strip().startswith("execution"):
|
| 23 |
+
# parts = line.split()
|
| 24 |
+
# try:
|
| 25 |
+
# return float(parts[-1])
|
| 26 |
+
# except Exception:
|
| 27 |
+
# return None
|
| 28 |
+
# return None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# def main():
|
| 32 |
+
# parser = argparse.ArgumentParser()
|
| 33 |
+
# parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
|
| 34 |
+
# parser.add_argument("--num_samples", type=int, default=200)
|
| 35 |
+
# args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 38 |
+
# adapter_dir = project_root / args.adapter
|
| 39 |
+
|
| 40 |
+
# if not adapter_dir.exists():
|
| 41 |
+
# raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
|
| 42 |
+
|
| 43 |
+
# db_root = project_root / "data" / "database"
|
| 44 |
+
# table_json = project_root / "data" / "tables.json"
|
| 45 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 46 |
+
# gold_sql = project_root / "data" / "dev_gold.sql"
|
| 47 |
+
# pred_path = project_root / "predictions_rl.txt"
|
| 48 |
+
|
| 49 |
+
# device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 50 |
+
|
| 51 |
+
# # ---- LOAD MODEL (CodeT5 + LoRA) ----
|
| 52 |
+
# base_model = "Salesforce/codet5-base"
|
| 53 |
+
|
| 54 |
+
# tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
|
| 55 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 56 |
+
# model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 57 |
+
|
| 58 |
+
# # merge LoRA for faster inference
|
| 59 |
+
# model = model.merge_and_unload()
|
| 60 |
+
# model.eval()
|
| 61 |
+
# model.config.use_cache = True
|
| 62 |
+
|
| 63 |
+
# if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 64 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 65 |
+
|
| 66 |
+
# # ---- LOAD DATA ----
|
| 67 |
+
# with dev_json.open() as f:
|
| 68 |
+
# dev = json.load(f)
|
| 69 |
+
|
| 70 |
+
# dev = dev[: args.num_samples]
|
| 71 |
+
|
| 72 |
+
# gen_kwargs = dict(
|
| 73 |
+
# max_new_tokens=120,
|
| 74 |
+
# do_sample=False,
|
| 75 |
+
# num_beams=1,
|
| 76 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 77 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 78 |
+
# )
|
| 79 |
+
|
| 80 |
+
# print(f"Generating {len(dev)} predictions...")
|
| 81 |
+
|
| 82 |
+
# with pred_path.open("w") as out_f, torch.no_grad():
|
| 83 |
+
# for i, ex in enumerate(dev, start=1):
|
| 84 |
+
# db_id = ex["db_id"]
|
| 85 |
+
# question = ex["question"]
|
| 86 |
+
|
| 87 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 88 |
+
# schema = get_db_schema(str(db_path))
|
| 89 |
+
# prompt = build_prompt(question, schema, use_schema=True)
|
| 90 |
+
|
| 91 |
+
# inputs = tokenizer(
|
| 92 |
+
# prompt,
|
| 93 |
+
# return_tensors="pt",
|
| 94 |
+
# truncation=True,
|
| 95 |
+
# max_length=512
|
| 96 |
+
# ).to(device)
|
| 97 |
+
|
| 98 |
+
# out = model.generate(**inputs, **gen_kwargs)
|
| 99 |
+
# pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 100 |
+
|
| 101 |
+
# out_f.write(f"{pred_sql}\t{db_id}\n")
|
| 102 |
+
|
| 103 |
+
# if i % 20 == 0 or i == len(dev):
|
| 104 |
+
# print(f"{i}/{len(dev)} done")
|
| 105 |
+
|
| 106 |
+
# # ---- SPIDER OFFICIAL EVAL ----
|
| 107 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 108 |
+
|
| 109 |
+
# cmd = [
|
| 110 |
+
# sys.executable,
|
| 111 |
+
# str(eval_script),
|
| 112 |
+
# "--gold",
|
| 113 |
+
# str(gold_sql),
|
| 114 |
+
# "--pred",
|
| 115 |
+
# str(pred_path),
|
| 116 |
+
# "--etype",
|
| 117 |
+
# "exec",
|
| 118 |
+
# "--db",
|
| 119 |
+
# str(db_root),
|
| 120 |
+
# "--table",
|
| 121 |
+
# str(table_json),
|
| 122 |
+
# ]
|
| 123 |
+
|
| 124 |
+
# print("\nRunning Spider execution evaluation...\n")
|
| 125 |
+
# proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 126 |
+
|
| 127 |
+
# if proc.returncode != 0:
|
| 128 |
+
# print(proc.stdout)
|
| 129 |
+
# print(proc.stderr)
|
| 130 |
+
# sys.exit(proc.returncode)
|
| 131 |
+
|
| 132 |
+
# print(proc.stdout)
|
| 133 |
+
|
| 134 |
+
# acc = _parse_exec_accuracy(proc.stdout)
|
| 135 |
+
# if acc is not None:
|
| 136 |
+
# print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%")
|
| 137 |
+
# else:
|
| 138 |
+
# print("Could not parse execution accuracy")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# if __name__ == "__main__":
|
| 142 |
+
# main()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
import json
|
| 146 |
+
import sqlite3
|
| 147 |
+
import argparse
|
| 148 |
+
import time
|
| 149 |
+
from pathlib import Path
|
| 150 |
+
import torch
|
| 151 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 152 |
+
from peft import PeftModel
|
| 153 |
+
|
| 154 |
+
# ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ----------------
|
| 155 |
+
def build_prompt(question, schema):
|
| 156 |
+
return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
|
| 157 |
+
|
| 158 |
+
# ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ----------------
|
| 159 |
+
def load_schema(db_path):
|
| 160 |
+
conn = sqlite3.connect(db_path)
|
| 161 |
+
cursor = conn.cursor()
|
| 162 |
+
|
| 163 |
+
tables = cursor.execute(
|
| 164 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 165 |
+
).fetchall()
|
| 166 |
+
|
| 167 |
+
schema = ""
|
| 168 |
+
for (table,) in tables:
|
| 169 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 170 |
+
col_names = [c[1] for c in cols]
|
| 171 |
+
# Space-separated, not newline-separated, just like the RLHF script
|
| 172 |
+
schema += f"{table}({', '.join(col_names)}) "
|
| 173 |
+
|
| 174 |
+
conn.close()
|
| 175 |
+
return schema.strip()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 179 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 180 |
+
try:
|
| 181 |
+
conn = sqlite3.connect(db_path)
|
| 182 |
+
|
| 183 |
+
# --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG ---
|
| 184 |
+
start_time = time.monotonic()
|
| 185 |
+
def timeout_handler():
|
| 186 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 187 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 188 |
+
|
| 189 |
+
cur = conn.cursor()
|
| 190 |
+
|
| 191 |
+
cur.execute(pred_sql)
|
| 192 |
+
pred = cur.fetchall()
|
| 193 |
+
|
| 194 |
+
cur.execute(gold_sql)
|
| 195 |
+
gold = cur.fetchall()
|
| 196 |
+
|
| 197 |
+
conn.close()
|
| 198 |
+
return pred == gold
|
| 199 |
+
|
| 200 |
+
except Exception:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------- MAIN ----------------
|
| 205 |
+
def main():
|
| 206 |
+
parser = argparse.ArgumentParser()
|
| 207 |
+
# 🎯 Set the default directly to your best RLHF model!
|
| 208 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best")
|
| 209 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 213 |
+
|
| 214 |
+
# Resolve adapter path safely
|
| 215 |
+
adapter_path = project_root / args.adapter
|
| 216 |
+
|
| 217 |
+
dev_json = project_root / "data" / "dev.json"
|
| 218 |
+
db_root = project_root / "data" / "database"
|
| 219 |
+
|
| 220 |
+
# 🎯 Added CUDA support
|
| 221 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 222 |
+
|
| 223 |
+
# load model
|
| 224 |
+
base_model = "t5-small"
|
| 225 |
+
print(f"Loading Base: {base_model}")
|
| 226 |
+
print(f"Loading Adapter: {adapter_path}")
|
| 227 |
+
|
| 228 |
+
tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
|
| 229 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 230 |
+
model = PeftModel.from_pretrained(base, str(adapter_path)).to(device)
|
| 231 |
+
model = model.merge_and_unload()
|
| 232 |
+
|
| 233 |
+
with open(dev_json) as f:
|
| 234 |
+
dev = json.load(f)[: args.num_samples]
|
| 235 |
+
|
| 236 |
+
correct = 0
|
| 237 |
+
|
| 238 |
+
print(f"Evaluating {len(dev)} examples...\n")
|
| 239 |
+
|
| 240 |
+
for i, ex in enumerate(dev, 1):
|
| 241 |
+
question = ex["question"]
|
| 242 |
+
db_id = ex["db_id"]
|
| 243 |
+
gold_sql = ex["query"]
|
| 244 |
+
|
| 245 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 246 |
+
schema = load_schema(db_path)
|
| 247 |
+
|
| 248 |
+
prompt = build_prompt(question, schema)
|
| 249 |
+
|
| 250 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 251 |
+
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
outputs = model.generate(
|
| 254 |
+
**inputs,
|
| 255 |
+
max_new_tokens=80,
|
| 256 |
+
do_sample=False,
|
| 257 |
+
num_beams=4,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 261 |
+
|
| 262 |
+
if "SQL:" in pred_sql:
|
| 263 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 264 |
+
|
| 265 |
+
match = execution_match(pred_sql, gold_sql, db_path)
|
| 266 |
+
|
| 267 |
+
if match:
|
| 268 |
+
correct += 1
|
| 269 |
+
|
| 270 |
+
if i % 10 == 0:
|
| 271 |
+
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 272 |
+
|
| 273 |
+
print("\n=============================")
|
| 274 |
+
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 275 |
+
print("=============================")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
src/eval_single_model.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
import sqlite3
|
| 7 |
+
import time
|
| 8 |
+
import re
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 15 |
+
from peft import PeftModel
|
| 16 |
+
|
| 17 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 18 |
+
from prompting import encode_prompt
|
| 19 |
+
|
| 20 |
+
# -------------------------------
|
| 21 |
+
# LIVE CHECK HELPERS
|
| 22 |
+
# -------------------------------
|
| 23 |
+
def normalize_sql(sql):
|
| 24 |
+
sql = sql.replace('"', "'")
|
| 25 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 26 |
+
return sql.strip().lower().rstrip(";")
|
| 27 |
+
|
| 28 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 29 |
+
try:
|
| 30 |
+
conn = sqlite3.connect(db_path)
|
| 31 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 32 |
+
|
| 33 |
+
start_time = time.monotonic()
|
| 34 |
+
def timeout_handler():
|
| 35 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 36 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 37 |
+
|
| 38 |
+
cursor = conn.cursor()
|
| 39 |
+
cursor.execute(pred_sql)
|
| 40 |
+
pred_res = cursor.fetchall()
|
| 41 |
+
|
| 42 |
+
cursor.execute(gold_sql)
|
| 43 |
+
gold_res = cursor.fetchall()
|
| 44 |
+
conn.close()
|
| 45 |
+
|
| 46 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 47 |
+
except Exception:
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
# -------------------------------
|
| 51 |
+
# SPIDER PARSER
|
| 52 |
+
# -------------------------------
|
| 53 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 54 |
+
for line in stdout.splitlines():
|
| 55 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 56 |
+
try: return float(line.split()[-1])
|
| 57 |
+
except: pass
|
| 58 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 59 |
+
try: return float(line.split()[-1])
|
| 60 |
+
except: pass
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
# -------------------------------
|
| 64 |
+
# MAIN
|
| 65 |
+
# -------------------------------
|
| 66 |
+
def main():
|
| 67 |
+
parser = argparse.ArgumentParser()
|
| 68 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your checkpoint")
|
| 69 |
+
parser.add_argument("--base_model", type=str, required=True, help="E.g., facebook/bart-base, t5-small")
|
| 70 |
+
parser.add_argument("--model_name", type=str, required=True, help="Name for the plot label (e.g., 'BART RLHF')")
|
| 71 |
+
parser.add_argument("--num_samples", type=int, default=700)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 75 |
+
adapter_dir = project_root / args.adapter
|
| 76 |
+
|
| 77 |
+
db_root = project_root / "data" / "database"
|
| 78 |
+
table_json = project_root / "data" / "tables.json"
|
| 79 |
+
dev_json = project_root / "data" / "dev.json"
|
| 80 |
+
|
| 81 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 82 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 83 |
+
|
| 84 |
+
# NEW: Plot directory setup
|
| 85 |
+
plot_dir = project_root / "comparison_plots"
|
| 86 |
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
results_json_path = plot_dir / "all_metrics.json"
|
| 88 |
+
|
| 89 |
+
if not adapter_dir.exists():
|
| 90 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 91 |
+
|
| 92 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
+
print(f"Loading Base Model: {args.base_model} on {device}...")
|
| 94 |
+
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
| 96 |
+
if tokenizer.pad_token is None:
|
| 97 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 98 |
+
|
| 99 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(args.base_model).to(device)
|
| 100 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 101 |
+
model = model.merge_and_unload()
|
| 102 |
+
model.eval()
|
| 103 |
+
|
| 104 |
+
with dev_json.open() as f:
|
| 105 |
+
dev = json.load(f)[: args.num_samples]
|
| 106 |
+
total = len(dev)
|
| 107 |
+
|
| 108 |
+
gen_kwargs = dict(
|
| 109 |
+
max_new_tokens=160,
|
| 110 |
+
num_beams=4,
|
| 111 |
+
do_sample=False,
|
| 112 |
+
early_stopping=True,
|
| 113 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 114 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
| 118 |
+
|
| 119 |
+
em_correct = 0
|
| 120 |
+
ex_correct = 0
|
| 121 |
+
|
| 122 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 123 |
+
for i, ex in enumerate(dev, start=1):
|
| 124 |
+
db_id = ex["db_id"]
|
| 125 |
+
question = ex["question"]
|
| 126 |
+
gold_query = ex["query"]
|
| 127 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 128 |
+
|
| 129 |
+
# Generate
|
| 130 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 131 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 132 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 133 |
+
|
| 134 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 135 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 136 |
+
|
| 137 |
+
out_pred.write(f"{pred_sql}\n")
|
| 138 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 139 |
+
|
| 140 |
+
# --- PRINT FIRST 3 EXAMPLES ---
|
| 141 |
+
if i <= 3:
|
| 142 |
+
print(f"--- 🔍 Example {i} ---")
|
| 143 |
+
print(f"Q : {question}")
|
| 144 |
+
print(f"Gold: {gold_query}")
|
| 145 |
+
print(f"Pred: {pred_sql}")
|
| 146 |
+
print("-" * 25)
|
| 147 |
+
|
| 148 |
+
# --- LIVE TRACKING CHECKS ---
|
| 149 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 150 |
+
em_correct += 1
|
| 151 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 152 |
+
ex_correct += 1
|
| 153 |
+
|
| 154 |
+
if i % 50 == 0 or i == total:
|
| 155 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 156 |
+
|
| 157 |
+
print("\nRunning Official Spider Evaluations...")
|
| 158 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 159 |
+
|
| 160 |
+
proc_match = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "match", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
|
| 161 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 162 |
+
|
| 163 |
+
proc_exec = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "exec", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
|
| 164 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 165 |
+
|
| 166 |
+
print("\n==========================================")
|
| 167 |
+
print(f"🎯 RESULTS FOR: {args.model_name}")
|
| 168 |
+
print("==========================================")
|
| 169 |
+
exact_val = exact_acc * 100 if exact_acc else 0
|
| 170 |
+
exec_val = exec_acc * 100 if exec_acc else 0
|
| 171 |
+
print(f"Exact Match : {exact_val:.2f}%")
|
| 172 |
+
print(f"Execution : {exec_val:.2f}%")
|
| 173 |
+
print("==========================================\n")
|
| 174 |
+
|
| 175 |
+
# -------------------------------
|
| 176 |
+
# SAVE JSON & GENERATE PLOT
|
| 177 |
+
# -------------------------------
|
| 178 |
+
if results_json_path.exists():
|
| 179 |
+
with open(results_json_path, 'r') as f:
|
| 180 |
+
all_results = json.load(f)
|
| 181 |
+
else:
|
| 182 |
+
all_results = {}
|
| 183 |
+
|
| 184 |
+
all_results[args.model_name] = {"EM": exact_val, "EX": exec_val}
|
| 185 |
+
|
| 186 |
+
with open(results_json_path, 'w') as f:
|
| 187 |
+
json.dump(all_results, f, indent=4)
|
| 188 |
+
|
| 189 |
+
labels = list(all_results.keys())
|
| 190 |
+
em_vals = [all_results[k]["EM"] for k in labels]
|
| 191 |
+
ex_vals = [all_results[k]["EX"] for k in labels]
|
| 192 |
+
|
| 193 |
+
x = np.arange(len(labels))
|
| 194 |
+
width = 0.35
|
| 195 |
+
|
| 196 |
+
plt.figure(figsize=(max(8, len(labels) * 1.5), 6))
|
| 197 |
+
plt.bar(x - width/2, em_vals, width, label='Exact Match', color='#3498db')
|
| 198 |
+
plt.bar(x + width/2, ex_vals, width, label='Execution', color='#2ecc71')
|
| 199 |
+
|
| 200 |
+
plt.ylabel('Accuracy (%)', fontweight='bold')
|
| 201 |
+
plt.title('Model Comparison: Exact Match vs Execution Accuracy', fontweight='bold', fontsize=14)
|
| 202 |
+
plt.xticks(x, labels, rotation=45, ha="right")
|
| 203 |
+
plt.legend()
|
| 204 |
+
plt.ylim(0, max(max(em_vals, default=0), max(ex_vals, default=0)) + 15)
|
| 205 |
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
| 206 |
+
|
| 207 |
+
# Attach labels to bars
|
| 208 |
+
for i in range(len(labels)):
|
| 209 |
+
plt.text(x[i] - width/2, em_vals[i] + 1, f"{em_vals[i]:.1f}%", ha='center', fontsize=9)
|
| 210 |
+
plt.text(x[i] + width/2, ex_vals[i] + 1, f"{ex_vals[i]:.1f}%", ha='center', fontsize=9)
|
| 211 |
+
|
| 212 |
+
plt.tight_layout()
|
| 213 |
+
plot_path = plot_dir / "accuracy_comparison.png"
|
| 214 |
+
plt.savefig(plot_path, dpi=300)
|
| 215 |
+
print(f"📈 Updated comparison plot saved to: {plot_path}")
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
main()
|
src/evaluate_model_codet5.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# import json
|
| 4 |
+
# import subprocess
|
| 5 |
+
# import sys
|
| 6 |
+
# import argparse
|
| 7 |
+
# import sqlite3
|
| 8 |
+
# import random
|
| 9 |
+
# from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# import torch
|
| 12 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
# from peft import PeftModel
|
| 14 |
+
|
| 15 |
+
# from prompting import encode_prompt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# def _parse_exec_accuracy(stdout: str) -> float | None:
|
| 19 |
+
# for line in stdout.splitlines():
|
| 20 |
+
# if line.strip().startswith("execution"):
|
| 21 |
+
# try:
|
| 22 |
+
# return float(line.split()[-1])
|
| 23 |
+
# except:
|
| 24 |
+
# return None
|
| 25 |
+
# return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# def main():
|
| 29 |
+
|
| 30 |
+
# # ---------------- ARGUMENTS ----------------
|
| 31 |
+
# parser = argparse.ArgumentParser()
|
| 32 |
+
# parser.add_argument("--adapter", type=str, default="checkpoints/sft_adapter_codet5")
|
| 33 |
+
# parser.add_argument("--num_samples", type=int, default=1000)
|
| 34 |
+
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 35 |
+
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 36 |
+
# parser.add_argument("--accuracy_log", type=str, default="")
|
| 37 |
+
# args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 40 |
+
# adapter_dir = project_root / args.adapter
|
| 41 |
+
|
| 42 |
+
# db_root = project_root / "data" / "database"
|
| 43 |
+
# table_json = project_root / "data" / "tables.json"
|
| 44 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 45 |
+
# gold_sql = project_root / "data" / "dev_gold.sql"
|
| 46 |
+
# pred_path = project_root / "predictions.txt"
|
| 47 |
+
|
| 48 |
+
# if not adapter_dir.exists():
|
| 49 |
+
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 50 |
+
|
| 51 |
+
# # ---------------- DEVICE ----------------
|
| 52 |
+
# device = "mps" if torch.backends.mps.is_available() else (
|
| 53 |
+
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
# )
|
| 55 |
+
# print("Using device:", device)
|
| 56 |
+
|
| 57 |
+
# # ---------------- LOAD MODEL ----------------
|
| 58 |
+
# BASE_MODEL = "Salesforce/codet5-base"
|
| 59 |
+
|
| 60 |
+
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 61 |
+
|
| 62 |
+
# if tokenizer.pad_token is None:
|
| 63 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 64 |
+
|
| 65 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 66 |
+
# model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 67 |
+
|
| 68 |
+
# model = model.merge_and_unload()
|
| 69 |
+
# model.eval()
|
| 70 |
+
|
| 71 |
+
# # ---------------- LOAD DATA ----------------
|
| 72 |
+
# with dev_json.open() as f:
|
| 73 |
+
# dev = json.load(f)
|
| 74 |
+
|
| 75 |
+
# if args.shuffle_dev:
|
| 76 |
+
# rng = random.Random(args.shuffle_seed)
|
| 77 |
+
# rng.shuffle(dev)
|
| 78 |
+
|
| 79 |
+
# dev = dev[: args.num_samples]
|
| 80 |
+
|
| 81 |
+
# # ---------------- GENERATION CONFIG ----------------
|
| 82 |
+
# gen_kwargs = dict(
|
| 83 |
+
# max_new_tokens=160,
|
| 84 |
+
# num_beams=4,
|
| 85 |
+
# do_sample=False,
|
| 86 |
+
# early_stopping=True,
|
| 87 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 88 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 89 |
+
# )
|
| 90 |
+
|
| 91 |
+
# print("Generating predictions...\n")
|
| 92 |
+
|
| 93 |
+
# correct = 0
|
| 94 |
+
# total = len(dev)
|
| 95 |
+
# accuracy_log_fh = None
|
| 96 |
+
|
| 97 |
+
# if args.accuracy_log:
|
| 98 |
+
# accuracy_log_path = (project_root / args.accuracy_log).resolve()
|
| 99 |
+
# accuracy_log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
# accuracy_log_fh = accuracy_log_path.open("w")
|
| 101 |
+
# print(f"Writing running accuracy log to: {accuracy_log_path}")
|
| 102 |
+
|
| 103 |
+
# with pred_path.open("w") as out_f, torch.no_grad():
|
| 104 |
+
|
| 105 |
+
# for i, ex in enumerate(dev, start=1):
|
| 106 |
+
|
| 107 |
+
# db_id = ex["db_id"]
|
| 108 |
+
# question = ex["question"]
|
| 109 |
+
# gold_query = ex["query"]
|
| 110 |
+
|
| 111 |
+
# input_ids = encode_prompt(
|
| 112 |
+
# tokenizer,
|
| 113 |
+
# question,
|
| 114 |
+
# db_id,
|
| 115 |
+
# device=device,
|
| 116 |
+
# max_input_tokens=512,
|
| 117 |
+
# )
|
| 118 |
+
|
| 119 |
+
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 120 |
+
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 121 |
+
|
| 122 |
+
# outputs = model.generate(
|
| 123 |
+
# input_ids=input_ids,
|
| 124 |
+
# attention_mask=attention_mask,
|
| 125 |
+
# **gen_kwargs
|
| 126 |
+
# )
|
| 127 |
+
|
| 128 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 129 |
+
# out_f.write(f"{pred_sql}\t{db_id}\n")
|
| 130 |
+
|
| 131 |
+
# # ---------------- LIVE EXECUTION CHECK ----------------
|
| 132 |
+
# try:
|
| 133 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 134 |
+
|
| 135 |
+
# conn = sqlite3.connect(db_path)
|
| 136 |
+
# cursor = conn.cursor()
|
| 137 |
+
|
| 138 |
+
# cursor.execute(pred_sql)
|
| 139 |
+
# pred_rows = cursor.fetchall()
|
| 140 |
+
|
| 141 |
+
# cursor.execute(gold_query)
|
| 142 |
+
# gold_rows = cursor.fetchall()
|
| 143 |
+
|
| 144 |
+
# conn.close()
|
| 145 |
+
|
| 146 |
+
# if sorted(pred_rows) == sorted(gold_rows):
|
| 147 |
+
# correct += 1
|
| 148 |
+
|
| 149 |
+
# except Exception:
|
| 150 |
+
# pass # execution failed
|
| 151 |
+
|
| 152 |
+
# # 🔥 PRINT EVERY 10
|
| 153 |
+
# if i % 10 == 0 or i == total:
|
| 154 |
+
# current_acc = correct / i
|
| 155 |
+
# line = f"{i}/{total} | Acc: {current_acc:.3f}"
|
| 156 |
+
# print(line)
|
| 157 |
+
# if accuracy_log_fh is not None:
|
| 158 |
+
# accuracy_log_fh.write(line + "\n")
|
| 159 |
+
|
| 160 |
+
# if accuracy_log_fh is not None:
|
| 161 |
+
# accuracy_log_fh.close()
|
| 162 |
+
|
| 163 |
+
# print("\nGeneration finished.\n")
|
| 164 |
+
|
| 165 |
+
# # ---------------- OFFICIAL SPIDER EVAL ----------------
|
| 166 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 167 |
+
|
| 168 |
+
# cmd = [
|
| 169 |
+
# sys.executable,
|
| 170 |
+
# str(eval_script),
|
| 171 |
+
# "--gold", str(gold_sql),
|
| 172 |
+
# "--pred", str(pred_path),
|
| 173 |
+
# "--etype", "exec",
|
| 174 |
+
# "--db", str(db_root),
|
| 175 |
+
# "--table", str(table_json),
|
| 176 |
+
# ]
|
| 177 |
+
|
| 178 |
+
# print("Running Spider evaluation...")
|
| 179 |
+
# proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 180 |
+
|
| 181 |
+
# print(proc.stdout)
|
| 182 |
+
|
| 183 |
+
# exec_acc = _parse_exec_accuracy(proc.stdout)
|
| 184 |
+
# if exec_acc is not None:
|
| 185 |
+
# print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
|
| 186 |
+
# else:
|
| 187 |
+
# print("Could not parse accuracy.")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# if __name__ == "__main__":
|
| 191 |
+
# main()
|
| 192 |
+
|
| 193 |
+
import json
|
| 194 |
+
import subprocess
|
| 195 |
+
import sys
|
| 196 |
+
import argparse
|
| 197 |
+
import random
|
| 198 |
+
import sqlite3
|
| 199 |
+
import time
|
| 200 |
+
import re
|
| 201 |
+
from pathlib import Path
|
| 202 |
+
|
| 203 |
+
import torch
|
| 204 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 205 |
+
from peft import PeftModel
|
| 206 |
+
|
| 207 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 208 |
+
from prompting import encode_prompt
|
| 209 |
+
|
| 210 |
+
# -------------------------------
|
| 211 |
+
# LIVE CHECK HELPERS
|
| 212 |
+
# -------------------------------
|
| 213 |
+
def normalize_sql(sql):
|
| 214 |
+
"""Basic normalization for the live progress bar."""
|
| 215 |
+
sql = sql.replace('"', "'")
|
| 216 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 217 |
+
return sql.strip().lower().rstrip(";")
|
| 218 |
+
|
| 219 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 220 |
+
"""Basic execution check for the live progress bar."""
|
| 221 |
+
try:
|
| 222 |
+
conn = sqlite3.connect(db_path)
|
| 223 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 224 |
+
|
| 225 |
+
# 2-second timeout so the live tracker doesn't freeze forever
|
| 226 |
+
start_time = time.monotonic()
|
| 227 |
+
def timeout_handler():
|
| 228 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 229 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 230 |
+
|
| 231 |
+
cursor = conn.cursor()
|
| 232 |
+
cursor.execute(pred_sql)
|
| 233 |
+
pred_res = cursor.fetchall()
|
| 234 |
+
|
| 235 |
+
cursor.execute(gold_sql)
|
| 236 |
+
gold_res = cursor.fetchall()
|
| 237 |
+
conn.close()
|
| 238 |
+
|
| 239 |
+
# Simple sorted check for the live tracker
|
| 240 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 241 |
+
except Exception:
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
# -------------------------------
|
| 245 |
+
# SPIDER PARSER
|
| 246 |
+
# -------------------------------
|
| 247 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 248 |
+
for line in stdout.splitlines():
|
| 249 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 250 |
+
try: return float(line.split()[-1])
|
| 251 |
+
except: pass
|
| 252 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 253 |
+
try: return float(line.split()[-1])
|
| 254 |
+
except: pass
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
# -------------------------------
|
| 258 |
+
# MAIN
|
| 259 |
+
# -------------------------------
|
| 260 |
+
def main():
|
| 261 |
+
parser = argparse.ArgumentParser()
|
| 262 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 263 |
+
parser.add_argument("--num_samples", type=int, default=1034, help="Number of samples to evaluate")
|
| 264 |
+
parser.add_argument("--shuffle_dev", action="store_true")
|
| 265 |
+
parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 269 |
+
adapter_dir = project_root / args.adapter
|
| 270 |
+
|
| 271 |
+
db_root = project_root / "data" / "database"
|
| 272 |
+
table_json = project_root / "data" / "tables.json"
|
| 273 |
+
dev_json = project_root / "data" / "dev.json"
|
| 274 |
+
|
| 275 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 276 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 277 |
+
|
| 278 |
+
if not adapter_dir.exists():
|
| 279 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 280 |
+
|
| 281 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 282 |
+
print(f"Using device: {device}")
|
| 283 |
+
|
| 284 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 285 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 286 |
+
if tokenizer.pad_token is None:
|
| 287 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 288 |
+
|
| 289 |
+
print(f"Loading Model: {args.adapter}...")
|
| 290 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 291 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 292 |
+
model = model.merge_and_unload()
|
| 293 |
+
model.eval()
|
| 294 |
+
|
| 295 |
+
with dev_json.open() as f:
|
| 296 |
+
dev = json.load(f)
|
| 297 |
+
|
| 298 |
+
if args.shuffle_dev:
|
| 299 |
+
rng = random.Random(args.shuffle_seed)
|
| 300 |
+
rng.shuffle(dev)
|
| 301 |
+
|
| 302 |
+
dev = dev[: args.num_samples]
|
| 303 |
+
total = len(dev)
|
| 304 |
+
|
| 305 |
+
gen_kwargs = dict(
|
| 306 |
+
max_new_tokens=160,
|
| 307 |
+
num_beams=4,
|
| 308 |
+
do_sample=False,
|
| 309 |
+
early_stopping=True,
|
| 310 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 311 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
| 315 |
+
|
| 316 |
+
em_correct = 0
|
| 317 |
+
ex_correct = 0
|
| 318 |
+
|
| 319 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 320 |
+
for i, ex in enumerate(dev, start=1):
|
| 321 |
+
db_id = ex["db_id"]
|
| 322 |
+
question = ex["question"]
|
| 323 |
+
gold_query = ex["query"]
|
| 324 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 325 |
+
|
| 326 |
+
# Generate
|
| 327 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 328 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 329 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 330 |
+
|
| 331 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 332 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 333 |
+
|
| 334 |
+
# Write to files for official spider eval later
|
| 335 |
+
out_pred.write(f"{pred_sql}\n")
|
| 336 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 337 |
+
|
| 338 |
+
# --- LIVE TRACKING CHECKS ---
|
| 339 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 340 |
+
em_correct += 1
|
| 341 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 342 |
+
ex_correct += 1
|
| 343 |
+
|
| 344 |
+
# Print progress every 50 loops
|
| 345 |
+
if i % 50 == 0 or i == total:
|
| 346 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 347 |
+
|
| 348 |
+
print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
|
| 349 |
+
|
| 350 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 351 |
+
|
| 352 |
+
# 1. RUN EXACT MATCH EVAL
|
| 353 |
+
cmd_match = [
|
| 354 |
+
sys.executable, str(eval_script),
|
| 355 |
+
"--gold", str(temp_gold_path),
|
| 356 |
+
"--pred", str(pred_path),
|
| 357 |
+
"--etype", "match",
|
| 358 |
+
"--db", str(db_root),
|
| 359 |
+
"--table", str(table_json),
|
| 360 |
+
]
|
| 361 |
+
proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 362 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 363 |
+
|
| 364 |
+
# 2. RUN EXECUTION EVAL
|
| 365 |
+
cmd_exec = [
|
| 366 |
+
sys.executable, str(eval_script),
|
| 367 |
+
"--gold", str(temp_gold_path),
|
| 368 |
+
"--pred", str(pred_path),
|
| 369 |
+
"--etype", "exec",
|
| 370 |
+
"--db", str(db_root),
|
| 371 |
+
"--table", str(table_json),
|
| 372 |
+
]
|
| 373 |
+
proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 374 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 375 |
+
|
| 376 |
+
print("==========================================")
|
| 377 |
+
print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 378 |
+
print("==========================================")
|
| 379 |
+
|
| 380 |
+
if exact_acc is not None:
|
| 381 |
+
print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
|
| 382 |
+
else:
|
| 383 |
+
print("Exact Set Match Accuracy : Could not parse output")
|
| 384 |
+
|
| 385 |
+
if exec_acc is not None:
|
| 386 |
+
print(f"Execution Accuracy : {exec_acc*100:.2f}%")
|
| 387 |
+
else:
|
| 388 |
+
print("Execution Accuracy : Could not parse output")
|
| 389 |
+
print("==========================================\n")
|
| 390 |
+
|
| 391 |
+
if __name__ == "__main__":
|
| 392 |
+
main()
|
src/evaluate_model_t5_small_sft.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import argparse
|
| 7 |
+
import re
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
from peft import PeftModel
|
| 14 |
+
from prompting import encode_prompt
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------- PARSE ACC ----------------
|
| 18 |
+
def _parse_exec_accuracy(stdout: str) -> float | None:
|
| 19 |
+
for line in stdout.splitlines():
|
| 20 |
+
if line.strip().startswith("execution"):
|
| 21 |
+
try:
|
| 22 |
+
return float(line.split()[-1])
|
| 23 |
+
except:
|
| 24 |
+
return None
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------- CLEAN SQL ----------------
|
| 29 |
+
def clean_prediction(pred_sql: str) -> str:
|
| 30 |
+
pred_sql = pred_sql.strip()
|
| 31 |
+
|
| 32 |
+
if "SQL:" in pred_sql:
|
| 33 |
+
pred_sql = pred_sql.split("SQL:")[-1]
|
| 34 |
+
|
| 35 |
+
pred_sql = pred_sql.replace('"', "'")
|
| 36 |
+
pred_sql = re.sub(r"\s+", " ", pred_sql).strip()
|
| 37 |
+
|
| 38 |
+
if not pred_sql.endswith(";"):
|
| 39 |
+
pred_sql += ";"
|
| 40 |
+
|
| 41 |
+
return pred_sql
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/sft_t5")
|
| 48 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
|
| 51 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 52 |
+
adapter_dir = project_root / args.adapter
|
| 53 |
+
|
| 54 |
+
db_root = project_root / "data/database"
|
| 55 |
+
table_json = project_root / "data/tables.json"
|
| 56 |
+
dev_json = project_root / "data/dev.json"
|
| 57 |
+
gold_sql = project_root / "data/dev_gold.sql"
|
| 58 |
+
pred_path = project_root / "pred.sql"
|
| 59 |
+
|
| 60 |
+
if not adapter_dir.exists():
|
| 61 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 62 |
+
|
| 63 |
+
# ---------------- DEVICE ----------------
|
| 64 |
+
device = "mps" if torch.backends.mps.is_available() else (
|
| 65 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
+
)
|
| 67 |
+
print("Using device:", device)
|
| 68 |
+
|
| 69 |
+
# ---------------- LOAD MODEL ----------------
|
| 70 |
+
BASE_MODEL = "t5-small"
|
| 71 |
+
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 73 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 74 |
+
|
| 75 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 76 |
+
model = model.merge_and_unload()
|
| 77 |
+
model.eval()
|
| 78 |
+
|
| 79 |
+
if tokenizer.pad_token_id is None:
|
| 80 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 81 |
+
|
| 82 |
+
# ---------------- LOAD DATA ----------------
|
| 83 |
+
with dev_json.open() as f:
|
| 84 |
+
dev = json.load(f)[: args.num_samples]
|
| 85 |
+
|
| 86 |
+
print("Generating predictions...\n")
|
| 87 |
+
|
| 88 |
+
correct = 0
|
| 89 |
+
total = len(dev)
|
| 90 |
+
|
| 91 |
+
# ---------------- GENERATE + LIVE EXEC ----------------
|
| 92 |
+
with pred_path.open("w") as out_f, torch.no_grad():
|
| 93 |
+
|
| 94 |
+
for i, ex in enumerate(dev, start=1):
|
| 95 |
+
|
| 96 |
+
db_id = ex["db_id"]
|
| 97 |
+
question = ex["question"]
|
| 98 |
+
gold_query = ex["query"]
|
| 99 |
+
|
| 100 |
+
prompt_ids = encode_prompt(
|
| 101 |
+
tokenizer,
|
| 102 |
+
question,
|
| 103 |
+
db_id,
|
| 104 |
+
device=device,
|
| 105 |
+
max_input_tokens=512,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
input_ids = prompt_ids.unsqueeze(0).to(device)
|
| 109 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 110 |
+
|
| 111 |
+
outputs = model.generate(
|
| 112 |
+
input_ids=input_ids,
|
| 113 |
+
attention_mask=attention_mask,
|
| 114 |
+
max_new_tokens=160,
|
| 115 |
+
num_beams=4,
|
| 116 |
+
do_sample=False,
|
| 117 |
+
early_stopping=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 121 |
+
pred_sql = clean_prediction(pred_sql)
|
| 122 |
+
|
| 123 |
+
out_f.write(pred_sql + "\n")
|
| 124 |
+
|
| 125 |
+
# -------- LIVE EXECUTION CHECK --------
|
| 126 |
+
try:
|
| 127 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 128 |
+
|
| 129 |
+
conn = sqlite3.connect(db_path)
|
| 130 |
+
cursor = conn.cursor()
|
| 131 |
+
|
| 132 |
+
cursor.execute(pred_sql)
|
| 133 |
+
pred_rows = cursor.fetchall()
|
| 134 |
+
|
| 135 |
+
cursor.execute(gold_query)
|
| 136 |
+
gold_rows = cursor.fetchall()
|
| 137 |
+
|
| 138 |
+
conn.close()
|
| 139 |
+
|
| 140 |
+
if sorted(pred_rows) == sorted(gold_rows):
|
| 141 |
+
correct += 1
|
| 142 |
+
|
| 143 |
+
except Exception:
|
| 144 |
+
pass # execution failed
|
| 145 |
+
|
| 146 |
+
# 🔥 PRINT EVERY 10
|
| 147 |
+
if i % 10 == 0 or i == total:
|
| 148 |
+
current_acc = correct / i
|
| 149 |
+
print(f"{i}/{total} | Acc: {current_acc:.3f}")
|
| 150 |
+
|
| 151 |
+
print("\nGeneration finished.\n")
|
| 152 |
+
|
| 153 |
+
# ---------------- SPIDER EVAL ----------------
|
| 154 |
+
eval_script = project_root / "spider_eval/evaluation.py"
|
| 155 |
+
|
| 156 |
+
cmd = [
|
| 157 |
+
sys.executable,
|
| 158 |
+
str(eval_script),
|
| 159 |
+
"--gold", str(gold_sql),
|
| 160 |
+
"--pred", str(pred_path),
|
| 161 |
+
"--etype", "exec",
|
| 162 |
+
"--db", str(db_root),
|
| 163 |
+
"--table", str(table_json),
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
print("Running Spider evaluation...")
|
| 167 |
+
proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 168 |
+
|
| 169 |
+
print(proc.stdout)
|
| 170 |
+
|
| 171 |
+
exec_acc = _parse_exec_accuracy(proc.stdout)
|
| 172 |
+
if exec_acc is not None:
|
| 173 |
+
print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
|
| 174 |
+
else:
|
| 175 |
+
print("Could not parse accuracy.")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
main()
|
src/evaluate_rl_bart.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
import sqlite3
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
# ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 12 |
+
def build_prompt(question, schema):
|
| 13 |
+
return f"""
|
| 14 |
+
Database Schema:
|
| 15 |
+
{schema}
|
| 16 |
+
|
| 17 |
+
Translate English to SQL:
|
| 18 |
+
{question}
|
| 19 |
+
SQL:
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# ---------------- LOAD SCHEMA ----------------
|
| 23 |
+
def load_schema(db_path):
|
| 24 |
+
conn = sqlite3.connect(db_path)
|
| 25 |
+
cursor = conn.cursor()
|
| 26 |
+
|
| 27 |
+
tables = cursor.execute(
|
| 28 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 29 |
+
).fetchall()
|
| 30 |
+
|
| 31 |
+
schema = ""
|
| 32 |
+
for (table,) in tables:
|
| 33 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 34 |
+
col_names = [c[1] for c in cols]
|
| 35 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 36 |
+
|
| 37 |
+
conn.close()
|
| 38 |
+
return schema
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 42 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 43 |
+
try:
|
| 44 |
+
conn = sqlite3.connect(db_path)
|
| 45 |
+
|
| 46 |
+
# --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
|
| 47 |
+
start_time = time.monotonic()
|
| 48 |
+
def timeout_handler():
|
| 49 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 50 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 51 |
+
|
| 52 |
+
cur = conn.cursor()
|
| 53 |
+
|
| 54 |
+
cur.execute(pred_sql)
|
| 55 |
+
pred = cur.fetchall()
|
| 56 |
+
|
| 57 |
+
cur.execute(gold_sql)
|
| 58 |
+
gold = cur.fetchall()
|
| 59 |
+
|
| 60 |
+
conn.close()
|
| 61 |
+
return pred == gold
|
| 62 |
+
|
| 63 |
+
except Exception:
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------- MAIN ----------------
|
| 68 |
+
def main():
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
parser.add_argument("--adapter", type=str, required=True)
|
| 71 |
+
parser.add_argument("--num_samples", type=int, default=1034)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 75 |
+
|
| 76 |
+
dev_json = project_root / "data" / "dev.json"
|
| 77 |
+
db_root = project_root / "data" / "database"
|
| 78 |
+
|
| 79 |
+
# 🎯 Added CUDA support for Nvidia GPUs
|
| 80 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
|
| 82 |
+
# load model
|
| 83 |
+
base_model = "facebook/bart-base"
|
| 84 |
+
print(f"Loading Base: {base_model}")
|
| 85 |
+
print(f"Loading Adapter: {args.adapter}")
|
| 86 |
+
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 88 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 89 |
+
model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 90 |
+
model = model.merge_and_unload()
|
| 91 |
+
|
| 92 |
+
with open(dev_json) as f:
|
| 93 |
+
dev = json.load(f)[: args.num_samples]
|
| 94 |
+
|
| 95 |
+
correct = 0
|
| 96 |
+
|
| 97 |
+
print(f"Evaluating {len(dev)} examples...\n")
|
| 98 |
+
|
| 99 |
+
for i, ex in enumerate(dev, 1):
|
| 100 |
+
question = ex["question"]
|
| 101 |
+
db_id = ex["db_id"]
|
| 102 |
+
gold_sql = ex["query"]
|
| 103 |
+
|
| 104 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 105 |
+
schema = load_schema(db_path)
|
| 106 |
+
|
| 107 |
+
prompt = build_prompt(question, schema)
|
| 108 |
+
|
| 109 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = model.generate(
|
| 113 |
+
**inputs,
|
| 114 |
+
max_new_tokens=80,
|
| 115 |
+
do_sample=False,
|
| 116 |
+
num_beams=4,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 120 |
+
|
| 121 |
+
if "SQL:" in pred_sql:
|
| 122 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 123 |
+
|
| 124 |
+
match = execution_match(pred_sql, gold_sql, db_path)
|
| 125 |
+
|
| 126 |
+
if match:
|
| 127 |
+
correct += 1
|
| 128 |
+
|
| 129 |
+
if i % 10 == 0:
|
| 130 |
+
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 131 |
+
|
| 132 |
+
print("\n=============================")
|
| 133 |
+
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 134 |
+
print("=============================")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|
src/evaluate_sft_bart.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import argparse
|
| 7 |
+
import re
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
from peft import PeftModel
|
| 14 |
+
from prompting import encode_prompt
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------- SQL CLEAN ----------------
|
| 18 |
+
def extract_sql(text: str) -> str:
|
| 19 |
+
text = text.strip()
|
| 20 |
+
|
| 21 |
+
if "SQL:" in text:
|
| 22 |
+
text = text.split("SQL:")[-1]
|
| 23 |
+
|
| 24 |
+
match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL)
|
| 25 |
+
if match:
|
| 26 |
+
text = match.group(1)
|
| 27 |
+
|
| 28 |
+
text = text.replace('"', "'")
|
| 29 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 30 |
+
|
| 31 |
+
if not text.endswith(";"):
|
| 32 |
+
text += ";"
|
| 33 |
+
|
| 34 |
+
return text
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------- ROBUST ACC PARSER ----------------
|
| 38 |
+
def parse_exec_accuracy(stdout: str):
|
| 39 |
+
for line in stdout.splitlines():
|
| 40 |
+
if "execution" in line.lower():
|
| 41 |
+
numbers = re.findall(r"\d+\.\d+", line)
|
| 42 |
+
if numbers:
|
| 43 |
+
return float(numbers[-1])
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2")
|
| 51 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 55 |
+
adapter_dir = project_root / args.adapter
|
| 56 |
+
|
| 57 |
+
if not adapter_dir.exists():
|
| 58 |
+
raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
|
| 59 |
+
|
| 60 |
+
db_root = project_root / "data/database"
|
| 61 |
+
table_json = project_root / "data/tables.json"
|
| 62 |
+
dev_json = project_root / "data/dev.json"
|
| 63 |
+
gold_sql_file = project_root / "data/dev_gold.sql"
|
| 64 |
+
pred_sql_file = project_root / "pred.sql"
|
| 65 |
+
|
| 66 |
+
device = "mps" if torch.backends.mps.is_available() else (
|
| 67 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
)
|
| 69 |
+
print("Using device:", device)
|
| 70 |
+
|
| 71 |
+
# -------- LOAD MODEL --------
|
| 72 |
+
print("Loading tokenizer...")
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
|
| 74 |
+
|
| 75 |
+
BASE_MODEL = "facebook/bart-base"
|
| 76 |
+
print(f"Loading base model {BASE_MODEL}...")
|
| 77 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 78 |
+
|
| 79 |
+
print("Loading LoRA adapter...")
|
| 80 |
+
model = PeftModel.from_pretrained(base_model, adapter_dir).to(device)
|
| 81 |
+
model = model.merge_and_unload()
|
| 82 |
+
model.eval()
|
| 83 |
+
|
| 84 |
+
if tokenizer.pad_token_id is None:
|
| 85 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
# -------- LOAD DATA --------
|
| 88 |
+
with open(dev_json) as f:
|
| 89 |
+
dev = json.load(f)[: args.num_samples]
|
| 90 |
+
|
| 91 |
+
print("Generating SQL predictions...\n")
|
| 92 |
+
|
| 93 |
+
correct = 0
|
| 94 |
+
total = len(dev)
|
| 95 |
+
|
| 96 |
+
with open(pred_sql_file, "w") as f, torch.no_grad():
|
| 97 |
+
|
| 98 |
+
for i, ex in enumerate(dev, 1):
|
| 99 |
+
|
| 100 |
+
question = ex["question"]
|
| 101 |
+
db_id = ex["db_id"]
|
| 102 |
+
gold_query = ex["query"]
|
| 103 |
+
|
| 104 |
+
prompt_ids = encode_prompt(
|
| 105 |
+
tokenizer,
|
| 106 |
+
question,
|
| 107 |
+
db_id,
|
| 108 |
+
device=device,
|
| 109 |
+
max_input_tokens=512,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
input_ids = prompt_ids.unsqueeze(0).to(device)
|
| 113 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 114 |
+
|
| 115 |
+
outputs = model.generate(
|
| 116 |
+
input_ids=input_ids,
|
| 117 |
+
attention_mask=attention_mask,
|
| 118 |
+
max_new_tokens=160,
|
| 119 |
+
num_beams=4,
|
| 120 |
+
do_sample=False,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 124 |
+
pred_sql = extract_sql(pred)
|
| 125 |
+
|
| 126 |
+
f.write(f"{pred_sql}\t{db_id}\n")
|
| 127 |
+
|
| 128 |
+
# -------- LIVE EXECUTION CHECK --------
|
| 129 |
+
try:
|
| 130 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 131 |
+
|
| 132 |
+
conn = sqlite3.connect(db_path)
|
| 133 |
+
cursor = conn.cursor()
|
| 134 |
+
|
| 135 |
+
cursor.execute(pred_sql)
|
| 136 |
+
pred_rows = cursor.fetchall()
|
| 137 |
+
|
| 138 |
+
cursor.execute(gold_query)
|
| 139 |
+
gold_rows = cursor.fetchall()
|
| 140 |
+
|
| 141 |
+
conn.close()
|
| 142 |
+
|
| 143 |
+
# order insensitive comparison
|
| 144 |
+
if sorted(pred_rows) == sorted(gold_rows):
|
| 145 |
+
correct += 1
|
| 146 |
+
|
| 147 |
+
except Exception:
|
| 148 |
+
pass # execution failed
|
| 149 |
+
|
| 150 |
+
if i % 10 == 0 or i == total:
|
| 151 |
+
current_acc = correct / i
|
| 152 |
+
print(f"{i}/{total} | Acc: {current_acc:.3f}")
|
| 153 |
+
|
| 154 |
+
print("\nGeneration finished.\n")
|
| 155 |
+
|
| 156 |
+
# -------- RUN OFFICIAL SPIDER EVAL --------
|
| 157 |
+
eval_script = project_root / "spider_eval/evaluation.py"
|
| 158 |
+
if (project_root / "spider_eval/evaluation_bart.py").exists():
|
| 159 |
+
eval_script = project_root / "spider_eval/evaluation_bart.py"
|
| 160 |
+
|
| 161 |
+
cmd = [
|
| 162 |
+
sys.executable,
|
| 163 |
+
str(eval_script),
|
| 164 |
+
"--gold", str(gold_sql_file),
|
| 165 |
+
"--pred", str(pred_sql_file),
|
| 166 |
+
"--etype", "exec",
|
| 167 |
+
"--db", str(db_root),
|
| 168 |
+
"--table", str(table_json),
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
print(f"\nRunning Spider evaluation using {eval_script.name}...")
|
| 172 |
+
proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore")
|
| 173 |
+
|
| 174 |
+
if proc.returncode != 0:
|
| 175 |
+
print("\nSpider evaluation crashed.")
|
| 176 |
+
print(proc.stderr)
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
print("\n--- Spider Eval Output ---")
|
| 180 |
+
print("\n".join(proc.stdout.splitlines()[-20:]))
|
| 181 |
+
|
| 182 |
+
acc = parse_exec_accuracy(proc.stdout)
|
| 183 |
+
if acc is not None:
|
| 184 |
+
print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%")
|
| 185 |
+
else:
|
| 186 |
+
print("\nCould not parse official accuracy.")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
main()
|
src/execution_reward.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sqlite3
|
| 6 |
+
import time
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Optional, Sequence, Set, Tuple, Union
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import sqlparse
|
| 12 |
+
from sqlparse.sql import Function, Identifier, IdentifierList, Statement, Token, Where
|
| 13 |
+
from sqlparse.tokens import DML, Keyword, Name, Number, Punctuation, String, Whitespace
|
| 14 |
+
except Exception: # pragma: no cover
|
| 15 |
+
sqlparse = None # type: ignore[assignment]
|
| 16 |
+
Statement = object # type: ignore[misc,assignment]
|
| 17 |
+
Token = object # type: ignore[misc,assignment]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _normalize_sql(sql: str) -> str:
|
| 21 |
+
if not isinstance(sql, str):
|
| 22 |
+
return ""
|
| 23 |
+
s = sql.strip()
|
| 24 |
+
if s.startswith("```"):
|
| 25 |
+
# Strip markdown fences if present.
|
| 26 |
+
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 27 |
+
s = re.sub(r"\n?```$", "", s).strip()
|
| 28 |
+
if s.lower().startswith("sql:"):
|
| 29 |
+
s = s[4:].strip()
|
| 30 |
+
# Keep only the first statement to avoid accidental multi-statement execution.
|
| 31 |
+
if ";" in s:
|
| 32 |
+
s = s.split(";", 1)[0].strip()
|
| 33 |
+
return s
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 37 |
+
# Read-only prevents any accidental mutation during reward computation.
|
| 38 |
+
# Note: requires SQLite URI support (built-in).
|
| 39 |
+
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 40 |
+
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 41 |
+
conn.execute("PRAGMA query_only = ON;")
|
| 42 |
+
conn.execute("PRAGMA foreign_keys = ON;")
|
| 43 |
+
return conn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _with_timeout(conn: sqlite3.Connection, timeout_s: float = 1.0) -> None:
|
| 47 |
+
start = time.monotonic()
|
| 48 |
+
|
| 49 |
+
def _handler() -> int:
|
| 50 |
+
return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 51 |
+
|
| 52 |
+
# Call handler every N VM opcodes.
|
| 53 |
+
conn.set_progress_handler(_handler, 10_000)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _list_tables(conn: sqlite3.Connection) -> List[str]:
|
| 57 |
+
try:
|
| 58 |
+
cur = conn.execute(
|
| 59 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 60 |
+
)
|
| 61 |
+
return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 62 |
+
except sqlite3.Error:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
|
| 67 |
+
s = sql.lower()
|
| 68 |
+
for t in table_names:
|
| 69 |
+
tl = t.lower()
|
| 70 |
+
if not tl:
|
| 71 |
+
continue
|
| 72 |
+
if re.search(rf"\b{re.escape(tl)}\b", s):
|
| 73 |
+
return True
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 78 |
+
try:
|
| 79 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 80 |
+
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 81 |
+
return True
|
| 82 |
+
except sqlite3.Error:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
|
| 87 |
+
try:
|
| 88 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 89 |
+
cur = conn.execute(sql)
|
| 90 |
+
rows = cur.fetchmany(max_rows)
|
| 91 |
+
# Normalize to plain tuples for deterministic comparison.
|
| 92 |
+
norm_rows = [tuple(r) for r in rows]
|
| 93 |
+
return True, norm_rows, None
|
| 94 |
+
except sqlite3.Error as e:
|
| 95 |
+
return False, [], str(e)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
_SQL_KEYWORDS_TO_IGNORE = {
|
| 99 |
+
"select",
|
| 100 |
+
"from",
|
| 101 |
+
"where",
|
| 102 |
+
"join",
|
| 103 |
+
"inner",
|
| 104 |
+
"left",
|
| 105 |
+
"right",
|
| 106 |
+
"full",
|
| 107 |
+
"outer",
|
| 108 |
+
"on",
|
| 109 |
+
"group",
|
| 110 |
+
"by",
|
| 111 |
+
"order",
|
| 112 |
+
"limit",
|
| 113 |
+
"having",
|
| 114 |
+
"distinct",
|
| 115 |
+
"union",
|
| 116 |
+
"intersect",
|
| 117 |
+
"except",
|
| 118 |
+
"as",
|
| 119 |
+
"and",
|
| 120 |
+
"or",
|
| 121 |
+
"not",
|
| 122 |
+
"in",
|
| 123 |
+
"is",
|
| 124 |
+
"null",
|
| 125 |
+
"like",
|
| 126 |
+
"between",
|
| 127 |
+
"case",
|
| 128 |
+
"when",
|
| 129 |
+
"then",
|
| 130 |
+
"else",
|
| 131 |
+
"end",
|
| 132 |
+
"asc",
|
| 133 |
+
"desc",
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
_SQL_FUNCTIONS_TO_IGNORE = {
|
| 137 |
+
"count",
|
| 138 |
+
"avg",
|
| 139 |
+
"min",
|
| 140 |
+
"max",
|
| 141 |
+
"sum",
|
| 142 |
+
"lower",
|
| 143 |
+
"upper",
|
| 144 |
+
"substr",
|
| 145 |
+
"coalesce",
|
| 146 |
+
"round",
|
| 147 |
+
"date",
|
| 148 |
+
"datetime",
|
| 149 |
+
"strftime",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def extract_tables(sql: str) -> Set[str]:
|
| 154 |
+
"""
|
| 155 |
+
Best-effort table extraction from SQL using sqlparse.
|
| 156 |
+
Returns lowercase table names (unqualified).
|
| 157 |
+
"""
|
| 158 |
+
sql = _normalize_sql(sql)
|
| 159 |
+
if not sql:
|
| 160 |
+
return set()
|
| 161 |
+
if sqlparse is None:
|
| 162 |
+
# Fallback: naive regex for FROM/JOIN.
|
| 163 |
+
found = set()
|
| 164 |
+
for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I):
|
| 165 |
+
found.add(m.group(2).lower())
|
| 166 |
+
return found
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
statements = sqlparse.parse(sql)
|
| 170 |
+
except Exception:
|
| 171 |
+
return set()
|
| 172 |
+
|
| 173 |
+
tables: Set[str] = set()
|
| 174 |
+
|
| 175 |
+
def _add_identifier_as_table(ident: Identifier) -> None:
|
| 176 |
+
# Prefer real name over alias; strip any schema prefix.
|
| 177 |
+
name = ident.get_real_name() or ident.get_name()
|
| 178 |
+
if not name:
|
| 179 |
+
return
|
| 180 |
+
tables.add(name.lower())
|
| 181 |
+
|
| 182 |
+
for st in statements:
|
| 183 |
+
if not isinstance(st, Statement):
|
| 184 |
+
continue
|
| 185 |
+
seen_from = False
|
| 186 |
+
for tok in st.flatten():
|
| 187 |
+
if tok.ttype in Whitespace:
|
| 188 |
+
continue
|
| 189 |
+
if tok.ttype is Keyword and tok.value.upper() in {"FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"}:
|
| 190 |
+
seen_from = True
|
| 191 |
+
continue
|
| 192 |
+
if not seen_from:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
if isinstance(tok, Identifier):
|
| 196 |
+
_add_identifier_as_table(tok)
|
| 197 |
+
seen_from = False
|
| 198 |
+
elif tok.ttype is Name:
|
| 199 |
+
tables.add(tok.value.lower())
|
| 200 |
+
seen_from = False
|
| 201 |
+
elif tok.ttype is Keyword and tok.value.upper() in {"WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"}:
|
| 202 |
+
seen_from = False
|
| 203 |
+
|
| 204 |
+
return tables
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def extract_columns(sql: str) -> Set[str]:
|
| 208 |
+
"""
|
| 209 |
+
Best-effort column extraction from SQL using sqlparse.
|
| 210 |
+
Returns lowercase column names (unqualified).
|
| 211 |
+
"""
|
| 212 |
+
sql = _normalize_sql(sql)
|
| 213 |
+
if not sql:
|
| 214 |
+
return set()
|
| 215 |
+
if sqlparse is None:
|
| 216 |
+
# Fallback: naive dotted identifiers and bare names after SELECT/WHERE/etc.
|
| 217 |
+
cols = set()
|
| 218 |
+
for m in re.finditer(r"\b([a-zA-Z_][\w$]*)\b", sql):
|
| 219 |
+
w = m.group(1).lower()
|
| 220 |
+
if w in _SQL_KEYWORDS_TO_IGNORE or w in _SQL_FUNCTIONS_TO_IGNORE:
|
| 221 |
+
continue
|
| 222 |
+
cols.add(w)
|
| 223 |
+
return cols
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
statements = sqlparse.parse(sql)
|
| 227 |
+
except Exception:
|
| 228 |
+
return set()
|
| 229 |
+
|
| 230 |
+
cols: Set[str] = set()
|
| 231 |
+
|
| 232 |
+
def _maybe_add_col(name: Optional[str]) -> None:
|
| 233 |
+
if not name:
|
| 234 |
+
return
|
| 235 |
+
n = name.strip().strip('"').strip("'").lower()
|
| 236 |
+
if not n or n == "*":
|
| 237 |
+
return
|
| 238 |
+
if n in _SQL_KEYWORDS_TO_IGNORE or n in _SQL_FUNCTIONS_TO_IGNORE:
|
| 239 |
+
return
|
| 240 |
+
cols.add(n)
|
| 241 |
+
|
| 242 |
+
def _handle_identifier(ident: Identifier) -> None:
|
| 243 |
+
# If qualified (t.col), keep only col for overlap/hallucination checks.
|
| 244 |
+
_maybe_add_col(ident.get_real_name() or ident.get_name())
|
| 245 |
+
|
| 246 |
+
for st in statements:
|
| 247 |
+
if not isinstance(st, Statement):
|
| 248 |
+
continue
|
| 249 |
+
for tok in st.flatten():
|
| 250 |
+
# Skip whitespace/punctuation/string literals/numbers.
|
| 251 |
+
if getattr(tok, "ttype", None) in (Whitespace, Punctuation, String, Number):
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
if isinstance(tok, Function):
|
| 255 |
+
fname = tok.get_name()
|
| 256 |
+
if fname:
|
| 257 |
+
# Don't treat function name as a column.
|
| 258 |
+
pass
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
if isinstance(tok, IdentifierList):
|
| 262 |
+
for ident in tok.get_identifiers():
|
| 263 |
+
if isinstance(ident, Identifier):
|
| 264 |
+
_handle_identifier(ident)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
if isinstance(tok, Identifier):
|
| 268 |
+
_handle_identifier(tok)
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
if getattr(tok, "ttype", None) is Name:
|
| 272 |
+
_maybe_add_col(tok.value)
|
| 273 |
+
|
| 274 |
+
return cols
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
|
| 278 |
+
"""
|
| 279 |
+
Return (tables, columns) sets from SQLite schema; all lowercased.
|
| 280 |
+
Columns are returned as a global set (unqualified).
|
| 281 |
+
"""
|
| 282 |
+
tables = set()
|
| 283 |
+
columns = set()
|
| 284 |
+
for t in _list_tables(conn):
|
| 285 |
+
tl = t.lower()
|
| 286 |
+
if not tl:
|
| 287 |
+
continue
|
| 288 |
+
tables.add(tl)
|
| 289 |
+
try:
|
| 290 |
+
cur = conn.execute(f'PRAGMA table_info("{t}")')
|
| 291 |
+
for row in cur.fetchall():
|
| 292 |
+
if row and isinstance(row[1], str):
|
| 293 |
+
columns.add(row[1].lower())
|
| 294 |
+
except sqlite3.Error:
|
| 295 |
+
continue
|
| 296 |
+
return tables, columns
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
|
| 300 |
+
# Deterministic comparison: compare exact row tuples in order.
|
| 301 |
+
return a == b
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dataclass
|
| 305 |
+
class RewardDebugStats:
|
| 306 |
+
total: int = 0
|
| 307 |
+
parsed_ok: int = 0
|
| 308 |
+
table_match: int = 0
|
| 309 |
+
column_match: int = 0
|
| 310 |
+
executed_ok: int = 0
|
| 311 |
+
exact_match: int = 0
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
_DEBUG = RewardDebugStats()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def reset_debug_metrics() -> None:
|
| 318 |
+
global _DEBUG
|
| 319 |
+
_DEBUG = RewardDebugStats()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_debug_metrics() -> dict:
|
| 323 |
+
denom = max(_DEBUG.total, 1)
|
| 324 |
+
return {
|
| 325 |
+
"valid_sql_rate": _DEBUG.parsed_ok / denom,
|
| 326 |
+
"table_match_rate": _DEBUG.table_match / denom,
|
| 327 |
+
"column_match_rate": _DEBUG.column_match / denom,
|
| 328 |
+
"execution_accuracy": _DEBUG.exact_match / denom,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 335 |
+
"""
|
| 336 |
+
Execute SQL safely.
|
| 337 |
+
|
| 338 |
+
If sqlite raises ANY exception, return EXECUTION_ERROR (NOT empty list).
|
| 339 |
+
"""
|
| 340 |
+
try:
|
| 341 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 342 |
+
cur = conn.execute(sql)
|
| 343 |
+
rows = cur.fetchmany(max_rows)
|
| 344 |
+
return [tuple(r) for r in rows]
|
| 345 |
+
except Exception:
|
| 346 |
+
return EXECUTION_ERROR
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _sqlparse_valid_select(sql: str) -> bool:
|
| 350 |
+
"""
|
| 351 |
+
Parse validation using sqlparse:
|
| 352 |
+
- parse() non-empty
|
| 353 |
+
- contains a SELECT statement
|
| 354 |
+
"""
|
| 355 |
+
if sqlparse is None:
|
| 356 |
+
return False
|
| 357 |
+
try:
|
| 358 |
+
stmts = sqlparse.parse(sql)
|
| 359 |
+
if not stmts:
|
| 360 |
+
return False
|
| 361 |
+
for st in stmts:
|
| 362 |
+
try:
|
| 363 |
+
if hasattr(st, "get_type") and st.get_type() == "SELECT":
|
| 364 |
+
return True
|
| 365 |
+
except Exception:
|
| 366 |
+
continue
|
| 367 |
+
return False
|
| 368 |
+
except Exception:
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 372 |
+
try:
|
| 373 |
+
sql = _normalize_sql(pred_sql)
|
| 374 |
+
gold = _normalize_sql(gold_sql)
|
| 375 |
+
|
| 376 |
+
if not sql or "SELECT" not in sql.upper():
|
| 377 |
+
return -1.0
|
| 378 |
+
|
| 379 |
+
if not _sqlparse_valid_select(sql):
|
| 380 |
+
return -1.0
|
| 381 |
+
|
| 382 |
+
reward = -0.2 # valid SQL baseline
|
| 383 |
+
|
| 384 |
+
pred_tables = extract_tables(sql)
|
| 385 |
+
gold_tables = extract_tables(gold)
|
| 386 |
+
|
| 387 |
+
if pred_tables == gold_tables and len(gold_tables) > 0:
|
| 388 |
+
reward += 0.3
|
| 389 |
+
|
| 390 |
+
pred_cols = extract_columns(sql)
|
| 391 |
+
gold_cols = extract_columns(gold)
|
| 392 |
+
|
| 393 |
+
if gold_cols:
|
| 394 |
+
overlap = len(pred_cols & gold_cols) / len(gold_cols)
|
| 395 |
+
reward += 0.3 * overlap
|
| 396 |
+
|
| 397 |
+
with _connect_readonly(db_path) as conn:
|
| 398 |
+
pred_res = execute_sql(conn, sql)
|
| 399 |
+
if pred_res != EXECUTION_ERROR:
|
| 400 |
+
reward += 0.2
|
| 401 |
+
|
| 402 |
+
gold_res = execute_sql(conn, gold)
|
| 403 |
+
if pred_res != EXECUTION_ERROR and _safe_results_equal(pred_res, gold_res):
|
| 404 |
+
return 1.0
|
| 405 |
+
|
| 406 |
+
return max(-1.0, min(1.0, reward))
|
| 407 |
+
|
| 408 |
+
except Exception:
|
| 409 |
+
return -1.0
|
src/generate_sql.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
|
| 9 |
+
from prompting import encode_prompt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser(description="Generate SQL from a question + db_id using the RLHF model.")
|
| 14 |
+
parser.add_argument("--question", type=str, required=True)
|
| 15 |
+
parser.add_argument("--db_id", type=str, required=True)
|
| 16 |
+
parser.add_argument("--model_dir", type=str, default=None, help="Defaults to outputs/rlhf_text2sql")
|
| 17 |
+
parser.add_argument("--use_schema", action="store_true", help="Include schema in the prompt (must match training).")
|
| 18 |
+
parser.add_argument("--max_schema_chars", type=int, default=1500)
|
| 19 |
+
parser.add_argument("--max_new_tokens", type=int, default=80)
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 23 |
+
|
| 24 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 25 |
+
adapter_dir = args.model_dir or os.path.join(project_root, "outputs", "rlhf_text2sql")
|
| 26 |
+
base_model = os.environ.get("BASE_MODEL", "t5-small")
|
| 27 |
+
fallback_base_model = os.path.join(project_root, "models", "t5_spider_sft")
|
| 28 |
+
if not os.path.isdir(base_model) and os.path.isdir(fallback_base_model):
|
| 29 |
+
base_model = fallback_base_model
|
| 30 |
+
|
| 31 |
+
local_only = not os.path.isdir(base_model)
|
| 32 |
+
tokenizer_source = adapter_dir if os.path.isdir(adapter_dir) else base_model
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, local_files_only=not os.path.isdir(tokenizer_source))
|
| 34 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model, local_files_only=local_only).to(device)
|
| 35 |
+
model = PeftModel.from_pretrained(base, adapter_dir).to(device)
|
| 36 |
+
# Merge adapters for faster/stabler generation.
|
| 37 |
+
model = model.merge_and_unload()
|
| 38 |
+
model.config.use_cache = False
|
| 39 |
+
|
| 40 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 41 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 42 |
+
|
| 43 |
+
input_ids = encode_prompt(
|
| 44 |
+
tokenizer,
|
| 45 |
+
args.question,
|
| 46 |
+
args.db_id,
|
| 47 |
+
device=device,
|
| 48 |
+
max_input_tokens=512,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
gen_kwargs = dict(
|
| 52 |
+
max_new_tokens=args.max_new_tokens,
|
| 53 |
+
do_sample=False,
|
| 54 |
+
num_beams=1,
|
| 55 |
+
early_stopping=True,
|
| 56 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 57 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
out = model.generate(input_ids=input_ids.unsqueeze(0), **gen_kwargs)
|
| 62 |
+
|
| 63 |
+
sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 64 |
+
print(sql)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|
src/human_eval_runner.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 6 |
+
from peft import PeftModel
|
| 7 |
+
|
| 8 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 10 |
+
|
| 11 |
+
# Added CUDA fallback for consistency
|
| 12 |
+
DEVICE = "mps" if torch.backends.mps.is_available() else (
|
| 13 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# ================= LOAD MODEL =================
|
| 17 |
+
def load_model(adapter_path):
|
| 18 |
+
base_name = "Salesforce/codet5-base"
|
| 19 |
+
|
| 20 |
+
# 🐛 FIXED: Convert relative path to absolute path to prevent Hugging Face 404 errors
|
| 21 |
+
abs_path = (PROJECT_ROOT / adapter_path).resolve()
|
| 22 |
+
if not abs_path.exists():
|
| 23 |
+
raise FileNotFoundError(f"Adapter not found at: {abs_path}")
|
| 24 |
+
|
| 25 |
+
print(f"\nLoading model from: {abs_path}")
|
| 26 |
+
|
| 27 |
+
# 🐛 FIXED: Added fallback in case tokenizer isn't saved in the adapter folder
|
| 28 |
+
try:
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(str(abs_path), local_files_only=True)
|
| 30 |
+
except Exception:
|
| 31 |
+
print("Adapter tokenizer missing — using base tokenizer")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(base_name)
|
| 33 |
+
|
| 34 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_name).to(DEVICE)
|
| 35 |
+
model = PeftModel.from_pretrained(base, str(abs_path)).to(DEVICE)
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
return tokenizer, model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ================= SCHEMA =================
|
| 42 |
+
def load_schema(db_id):
|
| 43 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 44 |
+
conn = sqlite3.connect(db_path)
|
| 45 |
+
cursor = conn.cursor()
|
| 46 |
+
|
| 47 |
+
tables = cursor.execute(
|
| 48 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 49 |
+
).fetchall()
|
| 50 |
+
|
| 51 |
+
schema = ""
|
| 52 |
+
for (table,) in tables:
|
| 53 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 54 |
+
col_names = [c[1] for c in cols]
|
| 55 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 56 |
+
|
| 57 |
+
conn.close()
|
| 58 |
+
return schema
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ================= GENERATE =================
|
| 62 |
+
def generate_sql(tokenizer, model, question, db_id):
|
| 63 |
+
schema = load_schema(db_id)
|
| 64 |
+
|
| 65 |
+
prompt = f"""
|
| 66 |
+
Database Schema:
|
| 67 |
+
{schema}
|
| 68 |
+
|
| 69 |
+
Translate English to SQL:
|
| 70 |
+
{question}
|
| 71 |
+
SQL:
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
outputs = model.generate(
|
| 78 |
+
**inputs,
|
| 79 |
+
max_new_tokens=120,
|
| 80 |
+
num_beams=4,
|
| 81 |
+
do_sample=False
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 85 |
+
|
| 86 |
+
if "SQL:" in sql:
|
| 87 |
+
sql = sql.split("SQL:")[-1]
|
| 88 |
+
|
| 89 |
+
return sql.strip()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ================= EXECUTE =================
|
| 93 |
+
def try_execute(sql, db_id):
|
| 94 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 95 |
+
try:
|
| 96 |
+
conn = sqlite3.connect(db_path)
|
| 97 |
+
cur = conn.cursor()
|
| 98 |
+
cur.execute(sql)
|
| 99 |
+
cur.fetchall()
|
| 100 |
+
conn.close()
|
| 101 |
+
return True
|
| 102 |
+
except:
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ================= MAIN =================
|
| 107 |
+
def main():
|
| 108 |
+
# paths (change if needed)
|
| 109 |
+
SFT_MODEL = "checkpoints/sft_adapter_codet5" # Ensure this matches your actual SFT folder name!
|
| 110 |
+
RLHF_MODEL = "checkpoints/best_rlhf_model"
|
| 111 |
+
|
| 112 |
+
tokenizer_sft, model_sft = load_model(SFT_MODEL)
|
| 113 |
+
tokenizer_rl, model_rl = load_model(RLHF_MODEL)
|
| 114 |
+
|
| 115 |
+
human_eval_path = PROJECT_ROOT / "data/human_eval.json"
|
| 116 |
+
with open(human_eval_path) as f:
|
| 117 |
+
questions = json.load(f)
|
| 118 |
+
|
| 119 |
+
sft_success = 0
|
| 120 |
+
rl_success = 0
|
| 121 |
+
|
| 122 |
+
print("\nRunning Human Evaluation...\n")
|
| 123 |
+
|
| 124 |
+
for i, q in enumerate(questions, 1):
|
| 125 |
+
db = q["db_id"]
|
| 126 |
+
question = q["question"]
|
| 127 |
+
|
| 128 |
+
sql_sft = generate_sql(tokenizer_sft, model_sft, question, db)
|
| 129 |
+
sql_rl = generate_sql(tokenizer_rl, model_rl, question, db)
|
| 130 |
+
|
| 131 |
+
ok_sft = try_execute(sql_sft, db)
|
| 132 |
+
ok_rl = try_execute(sql_rl, db)
|
| 133 |
+
|
| 134 |
+
if ok_sft:
|
| 135 |
+
sft_success += 1
|
| 136 |
+
if ok_rl:
|
| 137 |
+
rl_success += 1
|
| 138 |
+
|
| 139 |
+
print(f"\nQ{i}: {question}")
|
| 140 |
+
print(f"SFT : {'OK' if ok_sft else 'FAIL'}")
|
| 141 |
+
print(f"RLHF: {'OK' if ok_rl else 'FAIL'}")
|
| 142 |
+
|
| 143 |
+
print("\n=============================")
|
| 144 |
+
print("HUMAN EVALUATION RESULT")
|
| 145 |
+
print("=============================")
|
| 146 |
+
print(f"SFT Success: {sft_success}/{len(questions)} = {sft_success/len(questions)*100:.2f}%")
|
| 147 |
+
print(f"RLHF Success: {rl_success}/{len(questions)} = {rl_success/len(questions)*100:.2f}%")
|
| 148 |
+
print("=============================\n")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main()
|
src/load_lora_model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 3 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 4 |
+
|
| 5 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 6 |
+
|
| 7 |
+
MODEL_PATH = "../outputs/model" # your supervised trained model
|
| 8 |
+
|
| 9 |
+
print("Loading base model...")
|
| 10 |
+
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
|
| 11 |
+
|
| 12 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 13 |
+
|
| 14 |
+
# ---------------- LoRA CONFIG ----------------
|
| 15 |
+
lora_config = LoraConfig(
|
| 16 |
+
r=8, # rank (small brain attachment)
|
| 17 |
+
lora_alpha=16,
|
| 18 |
+
target_modules=["q", "v"], # attention matrices only
|
| 19 |
+
lora_dropout=0.05,
|
| 20 |
+
bias="none",
|
| 21 |
+
task_type=TaskType.SEQ_2_SEQ_LM
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
print("Attaching LoRA adapters...")
|
| 25 |
+
model = get_peft_model(model, lora_config)
|
| 26 |
+
|
| 27 |
+
model.print_trainable_parameters()
|
| 28 |
+
|
| 29 |
+
print("READY ✔ LoRA model loaded")
|
| 30 |
+
|
src/make_rl_dataset.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
|
| 4 |
+
print("Loading Spider dataset...")
|
| 5 |
+
dataset = load_dataset("spider", split="train")
|
| 6 |
+
|
| 7 |
+
data = []
|
| 8 |
+
|
| 9 |
+
for ex in dataset:
|
| 10 |
+
data.append({
|
| 11 |
+
"question": ex["question"],
|
| 12 |
+
"query": ex["query"],
|
| 13 |
+
"db_id": ex["db_id"] # ⭐ CRITICAL FIELD
|
| 14 |
+
})
|
| 15 |
+
|
| 16 |
+
print("Saving JSON...")
|
| 17 |
+
with open("data/train_spider.json", "w") as f:
|
| 18 |
+
json.dump(data, f, indent=2)
|
| 19 |
+
|
| 20 |
+
print("Done! File saved at data/train_spider.json")
|
src/manual_check.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
|
| 5 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 6 |
+
ADAPTER = "checkpoints/sft_adapter" # change if needed
|
| 7 |
+
|
| 8 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 9 |
+
|
| 10 |
+
print("Loading model...")
|
| 11 |
+
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
|
| 14 |
+
model = PeftModel.from_pretrained(model, ADAPTER)
|
| 15 |
+
|
| 16 |
+
model = model.to(device)
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
# 5 random Spider style questions
|
| 20 |
+
questions = [
|
| 21 |
+
"List all employee names",
|
| 22 |
+
"Find the number of students in each department",
|
| 23 |
+
"Show the average salary of employees",
|
| 24 |
+
"Which flights depart from LA?",
|
| 25 |
+
"Find customers who bought more than 5 items"
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
for q in questions:
|
| 29 |
+
prompt = f"Translate to SQL: {q}"
|
| 30 |
+
|
| 31 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
outputs = model.generate(
|
| 35 |
+
**inputs,
|
| 36 |
+
max_new_tokens=128,
|
| 37 |
+
temperature=0.0, # deterministic
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 41 |
+
|
| 42 |
+
print("\nQUESTION:", q)
|
| 43 |
+
print("SQL:", sql)
|
| 44 |
+
print("-"*60)
|
src/predict.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sqlite3
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------
|
| 6 |
+
# PATH
|
| 7 |
+
# --------------------------------------------------
|
| 8 |
+
MODEL_PATH = "outputs/model"
|
| 9 |
+
|
| 10 |
+
print("Loading tokenizer...")
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 12 |
+
|
| 13 |
+
print("Loading fine-tuned model...")
|
| 14 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
|
| 15 |
+
model.eval()
|
| 16 |
+
|
| 17 |
+
# --------------------------------------------------
|
| 18 |
+
# CONNECT DATABASE
|
| 19 |
+
# --------------------------------------------------
|
| 20 |
+
print("Connecting to database...")
|
| 21 |
+
# conn = sqlite3.connect("../data/database/department_management/department_management.sqlite")
|
| 22 |
+
conn = sqlite3.connect("data/database/department_management/department_management.sqlite")
|
| 23 |
+
cursor = conn.cursor()
|
| 24 |
+
print("Database connected ✔")
|
| 25 |
+
|
| 26 |
+
# --------------------------------------------------
|
| 27 |
+
# BUILD PROMPT
|
| 28 |
+
# --------------------------------------------------
|
| 29 |
+
def build_prompt(question):
|
| 30 |
+
schema = """
|
| 31 |
+
Table department columns = Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees.
|
| 32 |
+
Table head columns = head_ID, name, born_state, age.
|
| 33 |
+
Table management columns = department_ID, head_ID, temporary_acting.
|
| 34 |
+
"""
|
| 35 |
+
return f"translate English to SQL: {schema} question: {question}"
|
| 36 |
+
|
| 37 |
+
# --------------------------------------------------
|
| 38 |
+
# GENERATE SQL
|
| 39 |
+
# --------------------------------------------------
|
| 40 |
+
def generate_sql(question):
|
| 41 |
+
|
| 42 |
+
prompt = build_prompt(question)
|
| 43 |
+
|
| 44 |
+
encoding = tokenizer(
|
| 45 |
+
prompt,
|
| 46 |
+
return_tensors="pt",
|
| 47 |
+
truncation=True,
|
| 48 |
+
padding=True,
|
| 49 |
+
max_length=256
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = model.generate(
|
| 54 |
+
input_ids=encoding["input_ids"],
|
| 55 |
+
attention_mask=encoding["attention_mask"],
|
| 56 |
+
max_length=256,
|
| 57 |
+
num_beams=5,
|
| 58 |
+
early_stopping=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 62 |
+
return sql.strip()
|
| 63 |
+
|
| 64 |
+
# --------------------------------------------------
|
| 65 |
+
# EVALUATE SQL (REWARD FUNCTION)
|
| 66 |
+
# --------------------------------------------------
|
| 67 |
+
def evaluate_sql(sql):
|
| 68 |
+
try:
|
| 69 |
+
cursor.execute(sql)
|
| 70 |
+
rows = cursor.fetchall()
|
| 71 |
+
|
| 72 |
+
# executed but no useful result
|
| 73 |
+
if len(rows) == 0:
|
| 74 |
+
return -0.2, rows
|
| 75 |
+
|
| 76 |
+
# good query
|
| 77 |
+
else:
|
| 78 |
+
return 1.0, rows
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
# invalid SQL
|
| 82 |
+
return -1.0, str(e)
|
| 83 |
+
|
| 84 |
+
# --------------------------------------------------
|
| 85 |
+
# INTERACTIVE LOOP
|
| 86 |
+
# --------------------------------------------------
|
| 87 |
+
while True:
|
| 88 |
+
q = input("\nAsk question (type exit to quit): ")
|
| 89 |
+
|
| 90 |
+
if q.lower() == "exit":
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
sql = generate_sql(q)
|
| 94 |
+
|
| 95 |
+
print("\nPredicted SQL:")
|
| 96 |
+
print(sql)
|
| 97 |
+
|
| 98 |
+
# ---------------- RUN SQL + REWARD ----------------
|
| 99 |
+
reward, output = evaluate_sql(sql)
|
| 100 |
+
|
| 101 |
+
print("\nReward:", reward)
|
| 102 |
+
|
| 103 |
+
if reward == -1.0:
|
| 104 |
+
print("SQL Error:", output)
|
| 105 |
+
|
| 106 |
+
elif reward == -0.2:
|
| 107 |
+
print("No results found")
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
print("\nAnswer:")
|
| 111 |
+
for r in output:
|
| 112 |
+
print(r)
|
src/prepare_dataset.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sqlite3
|
| 4 |
+
from datasets import Dataset
|
| 5 |
+
from transformers import T5Tokenizer
|
| 6 |
+
|
| 7 |
+
# =========================================================
|
| 8 |
+
# PROJECT ROOT (VERY IMPORTANT — fixes path issues)
|
| 9 |
+
# =========================================================
|
| 10 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
|
| 12 |
+
TRAIN_JSON = os.path.join(BASE_DIR, "data", "train_spider.json")
|
| 13 |
+
DEV_JSON = os.path.join(BASE_DIR, "data", "dev.json")
|
| 14 |
+
DB_FOLDER = os.path.join(BASE_DIR, "data", "database")
|
| 15 |
+
|
| 16 |
+
SAVE_TRAIN = os.path.join(BASE_DIR, "data", "tokenized", "train")
|
| 17 |
+
SAVE_DEV = os.path.join(BASE_DIR, "data", "tokenized", "validation")
|
| 18 |
+
|
| 19 |
+
os.makedirs(os.path.dirname(SAVE_TRAIN), exist_ok=True)
|
| 20 |
+
|
| 21 |
+
print("Project root:", BASE_DIR)
|
| 22 |
+
print("Train file:", TRAIN_JSON)
|
| 23 |
+
print("Database folder:", DB_FOLDER)
|
| 24 |
+
|
| 25 |
+
# =========================================================
|
| 26 |
+
# TOKENIZER
|
| 27 |
+
# =========================================================
|
| 28 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 29 |
+
|
| 30 |
+
# =========================================================
|
| 31 |
+
# READ DATABASE SCHEMA
|
| 32 |
+
# =========================================================
|
| 33 |
+
def get_schema(db_path):
|
| 34 |
+
conn = sqlite3.connect(db_path)
|
| 35 |
+
cursor = conn.cursor()
|
| 36 |
+
|
| 37 |
+
tables = cursor.execute(
|
| 38 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 39 |
+
).fetchall()
|
| 40 |
+
|
| 41 |
+
schema_text = []
|
| 42 |
+
|
| 43 |
+
for table in tables:
|
| 44 |
+
table = table[0]
|
| 45 |
+
|
| 46 |
+
columns = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 47 |
+
col_names = [c[1] for c in columns]
|
| 48 |
+
|
| 49 |
+
schema_text.append(f"{table}({', '.join(col_names)})")
|
| 50 |
+
|
| 51 |
+
conn.close()
|
| 52 |
+
return "\n".join(schema_text)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# =========================================================
|
| 56 |
+
# BUILD TRAINING EXAMPLES
|
| 57 |
+
# =========================================================
|
| 58 |
+
def build_examples(spider_json):
|
| 59 |
+
|
| 60 |
+
print(f"\nBuilding dataset from: {spider_json}")
|
| 61 |
+
|
| 62 |
+
data = json.load(open(spider_json))
|
| 63 |
+
|
| 64 |
+
inputs = []
|
| 65 |
+
outputs = []
|
| 66 |
+
|
| 67 |
+
for ex in data:
|
| 68 |
+
|
| 69 |
+
question = ex["question"]
|
| 70 |
+
sql = ex["query"]
|
| 71 |
+
db_id = ex["db_id"]
|
| 72 |
+
|
| 73 |
+
db_path = os.path.join(DB_FOLDER, db_id, f"{db_id}.sqlite")
|
| 74 |
+
|
| 75 |
+
# skip if db missing (safety)
|
| 76 |
+
if not os.path.exists(db_path):
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
schema = get_schema(db_path)
|
| 80 |
+
|
| 81 |
+
# ⭐ SCHEMA-AWARE PROMPT (VERY IMPORTANT)
|
| 82 |
+
input_text = f"""Database Schema:
|
| 83 |
+
{schema}
|
| 84 |
+
|
| 85 |
+
Translate English to SQL:
|
| 86 |
+
{question}
|
| 87 |
+
SQL:
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
inputs.append(input_text)
|
| 91 |
+
outputs.append(sql)
|
| 92 |
+
|
| 93 |
+
return Dataset.from_dict({"input": inputs, "output": outputs})
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# =========================================================
|
| 97 |
+
# TOKENIZE
|
| 98 |
+
# =========================================================
|
| 99 |
+
def tokenize(example):
|
| 100 |
+
|
| 101 |
+
model_input = tokenizer(
|
| 102 |
+
example["input"],
|
| 103 |
+
max_length=512,
|
| 104 |
+
padding="max_length",
|
| 105 |
+
truncation=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
label = tokenizer(
|
| 109 |
+
example["output"],
|
| 110 |
+
max_length=256,
|
| 111 |
+
padding="max_length",
|
| 112 |
+
truncation=True
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
model_input["labels"] = label["input_ids"]
|
| 116 |
+
return model_input
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# =========================================================
|
| 120 |
+
# RUN PIPELINE
|
| 121 |
+
# =========================================================
|
| 122 |
+
print("\nBuilding TRAIN dataset...")
|
| 123 |
+
train_dataset = build_examples(TRAIN_JSON)
|
| 124 |
+
|
| 125 |
+
print("Tokenizing TRAIN dataset...")
|
| 126 |
+
tokenized_train = train_dataset.map(tokenize, batched=False)
|
| 127 |
+
|
| 128 |
+
print("Saving TRAIN dataset...")
|
| 129 |
+
tokenized_train.save_to_disk(SAVE_TRAIN)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
print("\nBuilding VALIDATION dataset...")
|
| 133 |
+
val_dataset = build_examples(DEV_JSON)
|
| 134 |
+
|
| 135 |
+
print("Tokenizing VALIDATION dataset...")
|
| 136 |
+
tokenized_val = val_dataset.map(tokenize, batched=False)
|
| 137 |
+
|
| 138 |
+
print("Saving VALIDATION dataset...")
|
| 139 |
+
tokenized_val.save_to_disk(SAVE_DEV)
|
| 140 |
+
|
| 141 |
+
print("\nDONE ✔ Dataset prepared successfully!")
|
| 142 |
+
print("Train saved at:", SAVE_TRAIN)
|
| 143 |
+
print("Validation saved at:", SAVE_DEV)
|
src/prompting.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sqlite3
|
| 6 |
+
from contextlib import closing
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# Keep for compatibility with existing imports. Schema linking is disabled for
|
| 12 |
+
# SFT/RL alignment in this project version (full schema, deterministic order).
|
| 13 |
+
USE_SCHEMA_LINKING = False
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
|
| 17 |
+
|
| 18 |
+
SCHEMA_CACHE: Dict[str, str] = {}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_schema_text(db_id: str) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Deterministic schema string:
|
| 24 |
+
table(col1, col2, ...)
|
| 25 |
+
Tables ordered alphabetically. Columns kept in PRAGMA order.
|
| 26 |
+
"""
|
| 27 |
+
if db_id in SCHEMA_CACHE:
|
| 28 |
+
return SCHEMA_CACHE[db_id]
|
| 29 |
+
|
| 30 |
+
db_path = os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 31 |
+
schema_lines = []
|
| 32 |
+
try:
|
| 33 |
+
with closing(sqlite3.connect(db_path)) as conn:
|
| 34 |
+
cur = conn.cursor()
|
| 35 |
+
tables = cur.execute(
|
| 36 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 37 |
+
).fetchall()
|
| 38 |
+
table_names = sorted([t[0] for t in tables if t and isinstance(t[0], str)])
|
| 39 |
+
for tname in table_names:
|
| 40 |
+
cols = cur.execute(f'PRAGMA table_info("{tname}")').fetchall()
|
| 41 |
+
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
|
| 42 |
+
schema_lines.append(f"{tname}({', '.join(col_names)})")
|
| 43 |
+
except Exception:
|
| 44 |
+
schema_lines = []
|
| 45 |
+
|
| 46 |
+
schema_text = "\n".join(schema_lines).strip()
|
| 47 |
+
SCHEMA_CACHE[db_id] = schema_text
|
| 48 |
+
return schema_text
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def clean_gold_sql(sql: str) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Lowercase SQL + strip common Spider aliases safely.
|
| 54 |
+
If alias removal is ambiguous (same table used multiple times), keep SQL as-is.
|
| 55 |
+
"""
|
| 56 |
+
if not isinstance(sql, str):
|
| 57 |
+
return ""
|
| 58 |
+
s = sql.strip().rstrip(";").strip()
|
| 59 |
+
if not s:
|
| 60 |
+
return ""
|
| 61 |
+
|
| 62 |
+
# Attempt to resolve T1/T2 aliases to table names for simple cases.
|
| 63 |
+
# Build alias -> table map from FROM/JOIN clauses.
|
| 64 |
+
alias_map: Dict[str, str] = {}
|
| 65 |
+
table_counts: Dict[str, int] = {}
|
| 66 |
+
|
| 67 |
+
for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)\s+(?:as\s+)?(t\d+)\b", s, flags=re.I):
|
| 68 |
+
table = m.group(2)
|
| 69 |
+
alias = m.group(3)
|
| 70 |
+
table_counts[table.lower()] = table_counts.get(table.lower(), 0) + 1
|
| 71 |
+
alias_map[alias.lower()] = table
|
| 72 |
+
|
| 73 |
+
# If any table appears multiple times, alias removal can be ambiguous → skip.
|
| 74 |
+
if any(c > 1 for c in table_counts.values()):
|
| 75 |
+
return s.lower()
|
| 76 |
+
|
| 77 |
+
# Replace alias-qualified refs alias.col -> table.col
|
| 78 |
+
out = s
|
| 79 |
+
for alias, table in alias_map.items():
|
| 80 |
+
out = re.sub(rf"\b{re.escape(alias)}\.", f"{table}.", out, flags=re.I)
|
| 81 |
+
|
| 82 |
+
# Remove alias declarations: "table AS t1" or "table t1"
|
| 83 |
+
for alias, table in alias_map.items():
|
| 84 |
+
out = re.sub(rf"\b{re.escape(table)}\s+as\s+{re.escape(alias)}\b", table, out, flags=re.I)
|
| 85 |
+
out = re.sub(rf"\b{re.escape(table)}\s+{re.escape(alias)}\b", table, out, flags=re.I)
|
| 86 |
+
|
| 87 |
+
return out.lower().strip()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_prompt(
|
| 91 |
+
question: str,
|
| 92 |
+
db_id: str,
|
| 93 |
+
*,
|
| 94 |
+
schema_text: str,
|
| 95 |
+
training_sql: Optional[str] = None,
|
| 96 |
+
) -> str:
|
| 97 |
+
"""
|
| 98 |
+
Required prompt format:
|
| 99 |
+
|
| 100 |
+
You are a SQLite expert.
|
| 101 |
+
|
| 102 |
+
Database: <db_id>
|
| 103 |
+
|
| 104 |
+
Schema:
|
| 105 |
+
<table>(col1, col2, ...)
|
| 106 |
+
...
|
| 107 |
+
|
| 108 |
+
Question:
|
| 109 |
+
<question>
|
| 110 |
+
|
| 111 |
+
SQL:
|
| 112 |
+
<gold sql> (training only)
|
| 113 |
+
"""
|
| 114 |
+
base = (
|
| 115 |
+
"You are a SQLite expert.\n\n"
|
| 116 |
+
f"Database: {db_id}\n\n"
|
| 117 |
+
"Schema:\n"
|
| 118 |
+
f"{schema_text}\n\n"
|
| 119 |
+
"Question:\n"
|
| 120 |
+
f"{question}\n\n"
|
| 121 |
+
"SQL:"
|
| 122 |
+
)
|
| 123 |
+
if training_sql is None:
|
| 124 |
+
return base
|
| 125 |
+
return base + "\n" + training_sql
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def encode_prompt(
|
| 129 |
+
tokenizer,
|
| 130 |
+
question: str,
|
| 131 |
+
db_id: str,
|
| 132 |
+
*,
|
| 133 |
+
device: str,
|
| 134 |
+
max_input_tokens: int = 512,
|
| 135 |
+
training_sql: Optional[str] = None,
|
| 136 |
+
) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Inference mode: stops at "SQL:"
|
| 139 |
+
Training mode: can include SQL target (optional; we still recommend decoder labels).
|
| 140 |
+
Truncation happens only on schema portion by character trimming (deterministic).
|
| 141 |
+
"""
|
| 142 |
+
schema_text = get_schema_text(db_id)
|
| 143 |
+
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=training_sql)
|
| 144 |
+
enc = tokenizer(
|
| 145 |
+
prompt,
|
| 146 |
+
truncation=True,
|
| 147 |
+
max_length=max_input_tokens,
|
| 148 |
+
padding=False,
|
| 149 |
+
return_tensors="pt",
|
| 150 |
+
)
|
| 151 |
+
return enc.input_ids[0].to(device)
|
src/run_sql.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sqlite3
|
| 5 |
+
from contextlib import closing
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def execute_sql(db_path: str, sql: str):
|
| 9 |
+
try:
|
| 10 |
+
with closing(sqlite3.connect(db_path)) as conn:
|
| 11 |
+
cursor = conn.cursor()
|
| 12 |
+
cursor.execute(sql)
|
| 13 |
+
rows = cursor.fetchall()
|
| 14 |
+
return {"ok": True, "rows": rows, "error": None}
|
| 15 |
+
except Exception as e:
|
| 16 |
+
return {"ok": False, "rows": [], "error": str(e)}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
parser = argparse.ArgumentParser(description="Safely execute SQL against a Spider SQLite DB.")
|
| 21 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 22 |
+
group.add_argument("--db_path", type=str, help="Path to SQLite database file")
|
| 23 |
+
group.add_argument("--db_id", type=str, help="Spider database id (uses data/database/<db_id>/<db_id>.sqlite)")
|
| 24 |
+
parser.add_argument("--sql", type=str, required=True, help="SQL to execute")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 28 |
+
if args.db_path:
|
| 29 |
+
db_path = args.db_path
|
| 30 |
+
else:
|
| 31 |
+
db_path = os.path.join(project_root, "data", "database", args.db_id, f"{args.db_id}.sqlite")
|
| 32 |
+
|
| 33 |
+
result = execute_sql(db_path, args.sql)
|
| 34 |
+
print(json.dumps(result, ensure_ascii=False, default=str))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
| 39 |
+
|
src/schema_encoder.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SchemaEncoder:
|
| 5 |
+
|
| 6 |
+
def __init__(self, db_root):
|
| 7 |
+
self.db_root = db_root
|
| 8 |
+
|
| 9 |
+
def get_tables_and_columns(self, db_id):
|
| 10 |
+
db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 11 |
+
conn = sqlite3.connect(db_path)
|
| 12 |
+
cursor = conn.cursor()
|
| 13 |
+
|
| 14 |
+
tables = cursor.execute(
|
| 15 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 16 |
+
).fetchall()
|
| 17 |
+
|
| 18 |
+
schema = {}
|
| 19 |
+
|
| 20 |
+
for (table,) in tables:
|
| 21 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 22 |
+
col_names = [c[1] for c in cols]
|
| 23 |
+
schema[table] = col_names
|
| 24 |
+
|
| 25 |
+
conn.close()
|
| 26 |
+
return schema
|
| 27 |
+
|
| 28 |
+
# -----------------------------------
|
| 29 |
+
# Strategy 1: Structured (current)
|
| 30 |
+
# -----------------------------------
|
| 31 |
+
def structured_schema(self, db_id):
|
| 32 |
+
schema = self.get_tables_and_columns(db_id)
|
| 33 |
+
|
| 34 |
+
lines = []
|
| 35 |
+
for table, cols in schema.items():
|
| 36 |
+
lines.append(f"{table}({', '.join(cols)})")
|
| 37 |
+
|
| 38 |
+
return "\n".join(lines)
|
| 39 |
+
|
| 40 |
+
# -----------------------------------
|
| 41 |
+
# Strategy 2: Natural Language
|
| 42 |
+
# -----------------------------------
|
| 43 |
+
def natural_language_schema(self, db_id):
|
| 44 |
+
schema = self.get_tables_and_columns(db_id)
|
| 45 |
+
|
| 46 |
+
lines = []
|
| 47 |
+
for table, cols in schema.items():
|
| 48 |
+
col_text = ", ".join(cols)
|
| 49 |
+
lines.append(f"The table '{table}' contains the columns: {col_text}.")
|
| 50 |
+
|
| 51 |
+
return "\n".join(lines)
|
src/schema_linker.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple schema linking for Spider-style Text-to-SQL.
|
| 3 |
+
|
| 4 |
+
Goal:
|
| 5 |
+
- Given (question, db_id), select a small set of relevant tables/columns
|
| 6 |
+
to include in the prompt (RAG-style schema retrieval).
|
| 7 |
+
|
| 8 |
+
Design constraints:
|
| 9 |
+
- Pure Python (no heavy external deps).
|
| 10 |
+
- Robust to missing/odd schemas: never crash.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import sqlite3
|
| 19 |
+
from contextlib import closing
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_ALNUM_RE = re.compile(r"[A-Za-z0-9]+")
|
| 25 |
+
_CAMEL_RE = re.compile(r"([a-z])([A-Z])")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _normalize_identifier(text: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Normalize a schema identifier:
|
| 31 |
+
- split underscores
|
| 32 |
+
- split camelCase / PascalCase boundaries
|
| 33 |
+
- lowercase
|
| 34 |
+
"""
|
| 35 |
+
text = str(text or "")
|
| 36 |
+
text = text.replace("_", " ")
|
| 37 |
+
text = _CAMEL_RE.sub(r"\1 \2", text)
|
| 38 |
+
return text.lower()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _tokenize(text: str) -> List[str]:
|
| 42 |
+
text = _normalize_identifier(text)
|
| 43 |
+
return _ALNUM_RE.findall(text)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True)
|
| 47 |
+
class TableSchema:
|
| 48 |
+
table_name: str
|
| 49 |
+
columns: Tuple[str, ...]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SchemaLinker:
|
| 53 |
+
"""
|
| 54 |
+
Loads Spider `tables.json` and (optionally) SQLite schemas from disk.
|
| 55 |
+
Provides a lightweight table scoring function based on token overlap.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, tables_json_path: str, db_root: Optional[str] = None):
|
| 59 |
+
self.tables_json_path = tables_json_path
|
| 60 |
+
self.db_root = db_root
|
| 61 |
+
self._tables_by_db: Dict[str, List[TableSchema]] = {}
|
| 62 |
+
self._sqlite_schema_cache: Dict[str, Dict[str, List[str]]] = {}
|
| 63 |
+
self._load_tables_json()
|
| 64 |
+
|
| 65 |
+
def _load_tables_json(self) -> None:
|
| 66 |
+
with open(self.tables_json_path) as f:
|
| 67 |
+
entries = json.load(f)
|
| 68 |
+
|
| 69 |
+
tables_by_db: Dict[str, List[TableSchema]] = {}
|
| 70 |
+
for entry in entries:
|
| 71 |
+
db_id = entry["db_id"]
|
| 72 |
+
table_names: List[str] = entry.get("table_names_original") or entry.get("table_names") or []
|
| 73 |
+
col_names: List[Sequence] = entry.get("column_names_original") or entry.get("column_names") or []
|
| 74 |
+
|
| 75 |
+
columns_by_table_idx: Dict[int, List[str]] = {i: [] for i in range(len(table_names))}
|
| 76 |
+
for col in col_names:
|
| 77 |
+
# Spider format: [table_idx, col_name]
|
| 78 |
+
if not col or len(col) < 2:
|
| 79 |
+
continue
|
| 80 |
+
table_idx, col_name = col[0], col[1]
|
| 81 |
+
if table_idx is None or table_idx < 0:
|
| 82 |
+
continue # skip "*"
|
| 83 |
+
if table_idx not in columns_by_table_idx:
|
| 84 |
+
continue
|
| 85 |
+
columns_by_table_idx[table_idx].append(str(col_name))
|
| 86 |
+
|
| 87 |
+
tables: List[TableSchema] = []
|
| 88 |
+
for i, tname in enumerate(table_names):
|
| 89 |
+
cols = tuple(columns_by_table_idx.get(i, []))
|
| 90 |
+
tables.append(TableSchema(table_name=str(tname), columns=cols))
|
| 91 |
+
|
| 92 |
+
tables_by_db[db_id] = tables
|
| 93 |
+
|
| 94 |
+
self._tables_by_db = tables_by_db
|
| 95 |
+
|
| 96 |
+
def _db_path(self, db_id: str) -> Optional[str]:
|
| 97 |
+
if not self.db_root:
|
| 98 |
+
return None
|
| 99 |
+
path = os.path.join(self.db_root, db_id, f"{db_id}.sqlite")
|
| 100 |
+
return path if os.path.exists(path) else None
|
| 101 |
+
|
| 102 |
+
def _load_sqlite_schema(self, db_id: str) -> Dict[str, List[str]]:
|
| 103 |
+
"""
|
| 104 |
+
Load actual SQLite schema (table -> columns). Cached per db_id.
|
| 105 |
+
"""
|
| 106 |
+
if db_id in self._sqlite_schema_cache:
|
| 107 |
+
return self._sqlite_schema_cache[db_id]
|
| 108 |
+
|
| 109 |
+
schema: Dict[str, List[str]] = {}
|
| 110 |
+
db_path = self._db_path(db_id)
|
| 111 |
+
if not db_path:
|
| 112 |
+
self._sqlite_schema_cache[db_id] = schema
|
| 113 |
+
return schema
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
with closing(sqlite3.connect(db_path)) as conn:
|
| 117 |
+
cursor = conn.cursor()
|
| 118 |
+
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 119 |
+
for (table_name,) in tables:
|
| 120 |
+
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
|
| 121 |
+
schema[str(table_name)] = [str(col[1]) for col in columns]
|
| 122 |
+
except Exception:
|
| 123 |
+
schema = {}
|
| 124 |
+
|
| 125 |
+
self._sqlite_schema_cache[db_id] = schema
|
| 126 |
+
return schema
|
| 127 |
+
|
| 128 |
+
def get_schema(self, db_id: str) -> List[TableSchema]:
|
| 129 |
+
"""
|
| 130 |
+
Returns a list of table schemas for this db.
|
| 131 |
+
Prefers `tables.json` (Spider canonical), but can fallback to SQLite if needed.
|
| 132 |
+
"""
|
| 133 |
+
tables = self._tables_by_db.get(db_id, [])
|
| 134 |
+
if tables:
|
| 135 |
+
return tables
|
| 136 |
+
|
| 137 |
+
sqlite_schema = self._load_sqlite_schema(db_id)
|
| 138 |
+
return [TableSchema(table_name=t, columns=tuple(cols)) for t, cols in sqlite_schema.items()]
|
| 139 |
+
|
| 140 |
+
def score_tables(self, question: str, db_id: str) -> List[Tuple[float, TableSchema]]:
|
| 141 |
+
"""
|
| 142 |
+
Score each table using token overlap:
|
| 143 |
+
- table token overlap (higher weight)
|
| 144 |
+
- column token overlap (lower weight)
|
| 145 |
+
"""
|
| 146 |
+
q_tokens = set(_tokenize(question))
|
| 147 |
+
tables = self.get_schema(db_id)
|
| 148 |
+
|
| 149 |
+
scored: List[Tuple[float, TableSchema]] = []
|
| 150 |
+
for t in tables:
|
| 151 |
+
table_tokens = set(_tokenize(t.table_name))
|
| 152 |
+
col_tokens: set[str] = set()
|
| 153 |
+
for c in t.columns:
|
| 154 |
+
col_tokens.update(_tokenize(c))
|
| 155 |
+
|
| 156 |
+
table_overlap = len(q_tokens & table_tokens)
|
| 157 |
+
col_overlap = len(q_tokens & col_tokens)
|
| 158 |
+
|
| 159 |
+
# Simple weighted overlap (tuned to bias table matches).
|
| 160 |
+
score = 3.0 * table_overlap + 1.0 * col_overlap
|
| 161 |
+
|
| 162 |
+
# Small boost for substring mentions (helps e.g. "album" vs "albums").
|
| 163 |
+
q_text = _normalize_identifier(question)
|
| 164 |
+
if t.table_name and _normalize_identifier(t.table_name) in q_text:
|
| 165 |
+
score += 0.5
|
| 166 |
+
|
| 167 |
+
scored.append((score, t))
|
| 168 |
+
|
| 169 |
+
scored.sort(key=lambda x: (x[0], x[1].table_name), reverse=True)
|
| 170 |
+
return scored
|
| 171 |
+
|
| 172 |
+
def select_top_tables(self, question: str, db_id: str, top_k: int = 4) -> List[TableSchema]:
|
| 173 |
+
scored = self.score_tables(question, db_id)
|
| 174 |
+
if not scored:
|
| 175 |
+
return []
|
| 176 |
+
top_k = max(1, int(top_k))
|
| 177 |
+
selected = [t for _, t in scored[:top_k]]
|
| 178 |
+
|
| 179 |
+
# If everything scores 0, still return a stable selection.
|
| 180 |
+
if scored[0][0] <= 0:
|
| 181 |
+
tables = self.get_schema(db_id)
|
| 182 |
+
return tables[:top_k]
|
| 183 |
+
|
| 184 |
+
return selected
|
| 185 |
+
|
| 186 |
+
def columns_for_selected_tables(self, db_id: str, selected_tables: Iterable[TableSchema]) -> Dict[str, List[str]]:
|
| 187 |
+
"""
|
| 188 |
+
Returns only columns belonging to selected tables.
|
| 189 |
+
Prefer SQLite columns (actual DB) if available; fallback to tables.json.
|
| 190 |
+
"""
|
| 191 |
+
sqlite_schema = self._load_sqlite_schema(db_id)
|
| 192 |
+
out: Dict[str, List[str]] = {}
|
| 193 |
+
for t in selected_tables:
|
| 194 |
+
if t.table_name in sqlite_schema and sqlite_schema[t.table_name]:
|
| 195 |
+
out[t.table_name] = sqlite_schema[t.table_name]
|
| 196 |
+
else:
|
| 197 |
+
out[t.table_name] = list(t.columns)
|
| 198 |
+
return out
|
| 199 |
+
|
| 200 |
+
def format_relevant_schema(self, question: str, db_id: str, top_k: int = 4) -> Tuple[List[str], Dict[str, List[str]]]:
|
| 201 |
+
"""
|
| 202 |
+
Returns:
|
| 203 |
+
- lines: ["table(col1, col2)", ...]
|
| 204 |
+
- selected: {table: [cols...], ...}
|
| 205 |
+
"""
|
| 206 |
+
selected_tables = self.select_top_tables(question, db_id, top_k=top_k)
|
| 207 |
+
selected = self.columns_for_selected_tables(db_id, selected_tables)
|
| 208 |
+
|
| 209 |
+
lines: List[str] = []
|
| 210 |
+
for table_name, cols in selected.items():
|
| 211 |
+
cols_str = ", ".join(cols)
|
| 212 |
+
lines.append(f"{table_name}({cols_str})")
|
| 213 |
+
|
| 214 |
+
return lines, selected
|
| 215 |
+
|
src/sql_validator.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
class SQLValidator:
|
| 6 |
+
|
| 7 |
+
def __init__(self, db_root):
|
| 8 |
+
self.db_root = Path(db_root)
|
| 9 |
+
|
| 10 |
+
# ---------------------------
|
| 11 |
+
# Load schema
|
| 12 |
+
# ---------------------------
|
| 13 |
+
def load_schema(self, db_id):
|
| 14 |
+
db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 15 |
+
|
| 16 |
+
conn = sqlite3.connect(db_path)
|
| 17 |
+
cursor = conn.cursor()
|
| 18 |
+
|
| 19 |
+
tables = cursor.execute(
|
| 20 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 21 |
+
).fetchall()
|
| 22 |
+
|
| 23 |
+
schema = {}
|
| 24 |
+
|
| 25 |
+
for (table,) in tables:
|
| 26 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 27 |
+
schema[table.lower()] = [c[1].lower() for c in cols]
|
| 28 |
+
|
| 29 |
+
conn.close()
|
| 30 |
+
return schema
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------
|
| 34 |
+
# Basic syntax check
|
| 35 |
+
# ---------------------------
|
| 36 |
+
def basic_structure_valid(self, sql):
|
| 37 |
+
s = sql.lower()
|
| 38 |
+
|
| 39 |
+
if "select" not in s or "from" not in s:
|
| 40 |
+
return False, "Missing SELECT or FROM"
|
| 41 |
+
|
| 42 |
+
if len(s.split()) < 4:
|
| 43 |
+
return False, "Too short to be SQL"
|
| 44 |
+
|
| 45 |
+
return True, None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------
|
| 49 |
+
# Extract identifiers
|
| 50 |
+
# ---------------------------
|
| 51 |
+
def extract_identifiers(self, sql):
|
| 52 |
+
tokens = re.findall(r"[A-Za-z_]+", sql.lower())
|
| 53 |
+
return set(tokens)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------
|
| 57 |
+
# Table validation
|
| 58 |
+
# ---------------------------
|
| 59 |
+
def validate_tables(self, sql, schema):
|
| 60 |
+
words = self.extract_identifiers(sql)
|
| 61 |
+
tables = set(schema.keys())
|
| 62 |
+
|
| 63 |
+
used_tables = [w for w in words if w in tables]
|
| 64 |
+
|
| 65 |
+
if not used_tables:
|
| 66 |
+
return False, "No valid table used"
|
| 67 |
+
|
| 68 |
+
return True, None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------
|
| 72 |
+
# Column validation
|
| 73 |
+
# ---------------------------
|
| 74 |
+
def validate_columns(self, sql, schema):
|
| 75 |
+
words = self.extract_identifiers(sql)
|
| 76 |
+
|
| 77 |
+
valid_columns = set()
|
| 78 |
+
for cols in schema.values():
|
| 79 |
+
valid_columns.update(cols)
|
| 80 |
+
|
| 81 |
+
# ignore SQL keywords
|
| 82 |
+
keywords = {
|
| 83 |
+
"select","from","where","join","on","group","by",
|
| 84 |
+
"order","limit","count","sum","avg","min","max",
|
| 85 |
+
"and","or","in","like","distinct","asc","desc"
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
invalid = []
|
| 89 |
+
for w in words:
|
| 90 |
+
if w not in valid_columns and w not in schema and w not in keywords:
|
| 91 |
+
if not w.isdigit():
|
| 92 |
+
invalid.append(w)
|
| 93 |
+
|
| 94 |
+
# allow small hallucinations but block many
|
| 95 |
+
if len(invalid) > 3:
|
| 96 |
+
return False, f"Too many unknown identifiers: {invalid[:5]}"
|
| 97 |
+
|
| 98 |
+
return True, None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------
|
| 102 |
+
# Dangerous query protection
|
| 103 |
+
# ---------------------------
|
| 104 |
+
def block_dangerous(self, sql):
|
| 105 |
+
bad = ["drop", "delete", "update", "insert", "alter"]
|
| 106 |
+
|
| 107 |
+
s = sql.lower()
|
| 108 |
+
for b in bad:
|
| 109 |
+
if b in s:
|
| 110 |
+
return False, f"Dangerous keyword detected: {b}"
|
| 111 |
+
|
| 112 |
+
return True, None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------
|
| 116 |
+
# Main validation
|
| 117 |
+
# ---------------------------
|
| 118 |
+
def validate(self, sql, db_id):
|
| 119 |
+
|
| 120 |
+
schema = self.load_schema(db_id)
|
| 121 |
+
|
| 122 |
+
checks = [
|
| 123 |
+
self.block_dangerous(sql),
|
| 124 |
+
self.basic_structure_valid(sql),
|
| 125 |
+
self.validate_tables(sql, schema),
|
| 126 |
+
self.validate_columns(sql, schema),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
for ok, msg in checks:
|
| 130 |
+
if not ok:
|
| 131 |
+
return False, msg
|
| 132 |
+
|
| 133 |
+
return True, None
|
src/text2sql_engine.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
import torch
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
from src.sql_validator import SQLValidator
|
| 11 |
+
from src.schema_encoder import SchemaEncoder
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 15 |
+
|
| 16 |
+
# ==========================================
|
| 17 |
+
# UNIVERSAL STRING NORMALIZERS
|
| 18 |
+
# ==========================================
|
| 19 |
+
def normalize_question(q: str):
|
| 20 |
+
q = q.lower().strip()
|
| 21 |
+
q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
|
| 22 |
+
q = re.sub(r"\s+", " ", q)
|
| 23 |
+
return q
|
| 24 |
+
|
| 25 |
+
def semantic_fix(question, sql):
|
| 26 |
+
"""Universal structural fixes that apply to ALL queries and ALL databases."""
|
| 27 |
+
q = question.lower().strip()
|
| 28 |
+
s = sql.lower()
|
| 29 |
+
|
| 30 |
+
# UNIVERSAL LIMIT CATCHER: Enforce LIMIT if a number is in the question
|
| 31 |
+
# FIXED: Removed the '?'. Now it ONLY catches numbers explicitly preceded by "show", "top", "limit", etc.
|
| 32 |
+
# This stops it from accidentally catching years like "2000".
|
| 33 |
+
num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q)
|
| 34 |
+
if num_match and "limit" not in s and "count(" not in s:
|
| 35 |
+
limit_val = num_match.group(1)
|
| 36 |
+
sql = sql.rstrip(";")
|
| 37 |
+
sql = f"{sql.strip()} LIMIT {limit_val}"
|
| 38 |
+
|
| 39 |
+
return sql
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Text2SQLEngine:
|
| 43 |
+
def __init__(self,
|
| 44 |
+
adapter_path="checkpoints/best_rlhf_model",
|
| 45 |
+
base_model_name="Salesforce/codet5-base",
|
| 46 |
+
use_lora=True):
|
| 47 |
+
|
| 48 |
+
self.device = "mps" if torch.backends.mps.is_available() else (
|
| 49 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.validator = SQLValidator(DB_ROOT)
|
| 53 |
+
self.schema_encoder = SchemaEncoder(DB_ROOT)
|
| 54 |
+
self.schema_mode = "structured"
|
| 55 |
+
|
| 56 |
+
# Security Keywords
|
| 57 |
+
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 58 |
+
|
| 59 |
+
print("Loading base model...")
|
| 60 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
|
| 61 |
+
|
| 62 |
+
if not use_lora:
|
| 63 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 64 |
+
self.model = base.to(self.device)
|
| 65 |
+
self.model.eval()
|
| 66 |
+
print("✅ Base model ready\n")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
adapter_path = (PROJECT_ROOT / adapter_path).resolve()
|
| 70 |
+
|
| 71 |
+
print("Loading tokenizer and LoRA adapter...")
|
| 72 |
+
try:
|
| 73 |
+
self.tokenizer = AutoTokenizer.from_pretrained(str(adapter_path), local_files_only=True)
|
| 74 |
+
except Exception:
|
| 75 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 76 |
+
|
| 77 |
+
self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
|
| 78 |
+
self.model.eval()
|
| 79 |
+
print("✅ RLHF model ready\n")
|
| 80 |
+
|
| 81 |
+
# ==========================================
|
| 82 |
+
# ---------------- PROMPT BUILDERS ---------
|
| 83 |
+
# ==========================================
|
| 84 |
+
def build_prompt(self, question, schema):
|
| 85 |
+
return f"""You are an expert SQL generator.
|
| 86 |
+
Database schema:
|
| 87 |
+
{schema}
|
| 88 |
+
Generate a valid SQLite query for the question.
|
| 89 |
+
Question:
|
| 90 |
+
{question}
|
| 91 |
+
SQL:
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def build_repair_prompt(self, question, schema, bad_sql, error_msg):
|
| 95 |
+
# UNIVERSAL UPGRADE: Extract the hallucinated column and explicitly warn the model
|
| 96 |
+
hallucinated_warning = ""
|
| 97 |
+
col_match = re.search(r"no such column:\s*([^\s]+)", error_msg, re.IGNORECASE)
|
| 98 |
+
if col_match:
|
| 99 |
+
bad_col = col_match.group(1)
|
| 100 |
+
hallucinated_warning = f"\n🚨 CRITICAL ERROR: You hallucinated the column '{bad_col}'. IT DOES NOT EXIST. Look at the schema and find the actual column name (it might be spelled differently or be a synonym like 'details', 'desc', or have a typo)."
|
| 101 |
+
|
| 102 |
+
return f"""You are an expert SQL generator.
|
| 103 |
+
Database schema:
|
| 104 |
+
{schema}
|
| 105 |
+
|
| 106 |
+
You generated this incorrect SQL for the question "{question}":
|
| 107 |
+
{bad_sql}
|
| 108 |
+
|
| 109 |
+
Execution failed with this SQLite error:
|
| 110 |
+
{error_msg}{hallucinated_warning}
|
| 111 |
+
|
| 112 |
+
UNIVERSAL RULES TO FIX THIS:
|
| 113 |
+
1. NEVER invent or guess column names. Use ONLY the exact table and column names listed in the schema above.
|
| 114 |
+
2. Watch out for typos in the database schema! If you need 'assessment', look for 'asessment'. If you need 'name', look for 'details'.
|
| 115 |
+
3. If the error is "no such column", you either hallucinated the name, or you forgot an INNER JOIN. Check the schema and fix it.
|
| 116 |
+
4. If the query requires a COUNT() but also selects names, ensure you added a GROUP BY.
|
| 117 |
+
|
| 118 |
+
Write the corrected SQLite SQL query.
|
| 119 |
+
SQL:
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def get_schema(self, db_id):
|
| 123 |
+
return self.schema_encoder.structured_schema(db_id)
|
| 124 |
+
|
| 125 |
+
# ==========================================
|
| 126 |
+
# ---------------- SQL POSTPROCESS ---------
|
| 127 |
+
# ==========================================
|
| 128 |
+
def extract_sql(self, text: str):
|
| 129 |
+
text = text.strip()
|
| 130 |
+
if "SQL:" in text:
|
| 131 |
+
text = text.split("SQL:")[-1]
|
| 132 |
+
match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
|
| 133 |
+
if match:
|
| 134 |
+
text = match.group(0)
|
| 135 |
+
return text.split(";")[0].strip()
|
| 136 |
+
|
| 137 |
+
def clean_sql(self, sql: str):
|
| 138 |
+
sql = sql.replace('"', "'")
|
| 139 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 140 |
+
return sql.strip()
|
| 141 |
+
|
| 142 |
+
def repair_logic(self, question, sql):
|
| 143 |
+
"""Universal logical repairs (like missing NOT NULL for negation)"""
|
| 144 |
+
q = question.lower()
|
| 145 |
+
s = sql.lower()
|
| 146 |
+
|
| 147 |
+
# Universal Negation Auto-Joiner
|
| 148 |
+
if any(word in q for word in ["never", "no ", "without"]):
|
| 149 |
+
m = re.search(r"from\s+(\w+).*join\s+(\w+)", s)
|
| 150 |
+
if m:
|
| 151 |
+
left, right = m.group(1), m.group(2)
|
| 152 |
+
key = re.search(r"on\s+(\w+\.\w+)\s*=\s*(\w+\.\w+)", s)
|
| 153 |
+
if key:
|
| 154 |
+
sql = f"SELECT {left}.* FROM {left} LEFT JOIN {right} ON {key.group(1)} = {key.group(2)} WHERE {key.group(2)} IS NULL"
|
| 155 |
+
|
| 156 |
+
# Universal LIKE wildcard injection
|
| 157 |
+
if any(w in q for w in ["contain", "with", "include"]):
|
| 158 |
+
sql = re.sub(r"=\s*'([^']+)'", r"LIKE '%\1%'", sql, flags=re.IGNORECASE)
|
| 159 |
+
|
| 160 |
+
return sql
|
| 161 |
+
|
| 162 |
+
# ==========================================
|
| 163 |
+
# ---------------- GENERATE ----------------
|
| 164 |
+
# ==========================================
|
| 165 |
+
def generate_sql(self, prompt, is_repair=False):
|
| 166 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
|
| 167 |
+
|
| 168 |
+
# FIXED: Dynamic Generation Parameters (No more terminal warnings)
|
| 169 |
+
gen_kwargs = {
|
| 170 |
+
"max_new_tokens": 128,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if is_repair:
|
| 174 |
+
# If the model failed, it needs to think differently.
|
| 175 |
+
# We turn off rigid beam search and introduce sampling so it doesn't repeat the exact same broken SQL.
|
| 176 |
+
gen_kwargs["do_sample"] = True
|
| 177 |
+
gen_kwargs["temperature"] = 0.5
|
| 178 |
+
gen_kwargs["top_p"] = 0.9
|
| 179 |
+
else:
|
| 180 |
+
# First attempt is strictly deterministic for maximum benchmark accuracy
|
| 181 |
+
gen_kwargs["num_beams"] = 5
|
| 182 |
+
gen_kwargs["do_sample"] = False
|
| 183 |
+
gen_kwargs["early_stopping"] = True # <--- Moved here so it doesn't clash with sampling!
|
| 184 |
+
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
outputs = self.model.generate(**inputs, **gen_kwargs)
|
| 187 |
+
|
| 188 |
+
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 189 |
+
return self.clean_sql(self.extract_sql(decoded))
|
| 190 |
+
|
| 191 |
+
# ==========================================
|
| 192 |
+
# ---------------- EXECUTE -----------------
|
| 193 |
+
# ==========================================
|
| 194 |
+
def execute_sql(self, question, sql, db_id):
|
| 195 |
+
|
| 196 |
+
# 🛡️ DEFENSE LAYER 2: Block Execution of Malicious SQL
|
| 197 |
+
if re.search(self.dml_keywords, sql, re.IGNORECASE):
|
| 198 |
+
return sql, [], [], "❌ Security Alert: Malicious DML/DDL SQL syntax blocked."
|
| 199 |
+
|
| 200 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 201 |
+
|
| 202 |
+
sql = self.repair_logic(question, sql)
|
| 203 |
+
sql = self.clean_sql(sql)
|
| 204 |
+
sql = semantic_fix(question, sql)
|
| 205 |
+
|
| 206 |
+
is_valid, reason = self.validator.validate(sql, db_id)
|
| 207 |
+
if not is_valid:
|
| 208 |
+
return sql, [], [], f"Blocked unsafe SQL: {reason}"
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
conn = sqlite3.connect(db_path)
|
| 212 |
+
start_time = time.monotonic()
|
| 213 |
+
def timeout_handler():
|
| 214 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 215 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 216 |
+
|
| 217 |
+
cursor = conn.cursor()
|
| 218 |
+
cursor.execute(sql)
|
| 219 |
+
rows = cursor.fetchall()
|
| 220 |
+
columns = [d[0] for d in cursor.description] if cursor.description else []
|
| 221 |
+
conn.close()
|
| 222 |
+
|
| 223 |
+
return sql, columns, rows, None
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
return sql, [], [], str(e)
|
| 227 |
+
|
| 228 |
+
# ==========================================
|
| 229 |
+
# ---------------- PIPELINE ----------------
|
| 230 |
+
# ==========================================
|
| 231 |
+
def ask(self, question, db_id):
|
| 232 |
+
question = normalize_question(question)
|
| 233 |
+
|
| 234 |
+
# 🛡️ DEFENSE LAYER 1: Block Malicious Natural Language Intent Early
|
| 235 |
+
if re.search(self.dml_keywords, question, re.IGNORECASE):
|
| 236 |
+
return {
|
| 237 |
+
"question": question,
|
| 238 |
+
"sql": "-- BLOCKED",
|
| 239 |
+
"columns": [],
|
| 240 |
+
"rows": [],
|
| 241 |
+
"error": "❌ Security Alert: Malicious intent (DELETE/DROP/UPDATE) detected in the prompt."
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
# 1. First Pass Generation
|
| 245 |
+
schema = self.get_schema(db_id)
|
| 246 |
+
prompt = self.build_prompt(question, schema)
|
| 247 |
+
|
| 248 |
+
# is_repair=False -> Uses strict Beam Search
|
| 249 |
+
raw_sql = self.generate_sql(prompt, is_repair=False)
|
| 250 |
+
|
| 251 |
+
# 2. First Execution Attempt
|
| 252 |
+
final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
|
| 253 |
+
|
| 254 |
+
# 🤖 3. UNIVERSAL AGENTIC SELF-CORRECTION LOOP
|
| 255 |
+
if error and "Security Alert" not in error:
|
| 256 |
+
print(f"\n Caught SQLite Error: {error}")
|
| 257 |
+
print(f" Triggering Stochastic LLM Self-Correction...")
|
| 258 |
+
|
| 259 |
+
# Feed the explicit error instructions back to the LLM
|
| 260 |
+
repair_prompt = self.build_repair_prompt(question, schema, final_sql, error)
|
| 261 |
+
|
| 262 |
+
# is_repair=True -> Uses Temperature Sampling to break out of hallucination loops
|
| 263 |
+
repaired_sql = self.generate_sql(repair_prompt, is_repair=True)
|
| 264 |
+
|
| 265 |
+
# Try executing the repaired SQL
|
| 266 |
+
final_sql, cols, rows, error = self.execute_sql(question, repaired_sql, db_id)
|
| 267 |
+
|
| 268 |
+
if not error:
|
| 269 |
+
print("✅ Universal Agent successfully self-corrected the query!")
|
| 270 |
+
else:
|
| 271 |
+
print("❌ Model failed self-correction.")
|
| 272 |
+
|
| 273 |
+
return {
|
| 274 |
+
"question": question,
|
| 275 |
+
"sql": final_sql,
|
| 276 |
+
"columns": cols,
|
| 277 |
+
"rows": rows,
|
| 278 |
+
"error": error
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
_engine = None
|
| 282 |
+
def get_engine():
|
| 283 |
+
global _engine
|
| 284 |
+
if _engine is None:
|
| 285 |
+
_engine = Text2SQLEngine()
|
| 286 |
+
return _engine
|
src/tokenize_dataset.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset
|
| 2 |
+
from transformers import T5Tokenizer
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
print("Loading processed dataset...")
|
| 6 |
+
train = pd.read_csv("../data/processed/train.csv")
|
| 7 |
+
val = pd.read_csv("../data/processed/validation.csv")
|
| 8 |
+
|
| 9 |
+
# remove hidden pandas index column if exists
|
| 10 |
+
train = train.drop(columns=[c for c in train.columns if "index" in c.lower()], errors="ignore")
|
| 11 |
+
val = val.drop(columns=[c for c in val.columns if "index" in c.lower()], errors="ignore")
|
| 12 |
+
|
| 13 |
+
print("Loading tokenizer (t5-small)...")
|
| 14 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 15 |
+
|
| 16 |
+
SQL_PREFIX = "translate English to SQL: "
|
| 17 |
+
|
| 18 |
+
# ------------------------------------------------------
|
| 19 |
+
# TOKENIZATION FUNCTION
|
| 20 |
+
# ------------------------------------------------------
|
| 21 |
+
def tokenize(example):
|
| 22 |
+
|
| 23 |
+
# input = schema + question
|
| 24 |
+
input_text = SQL_PREFIX + example["input"]
|
| 25 |
+
|
| 26 |
+
# target = real SQL
|
| 27 |
+
target_sql = example["sql"]
|
| 28 |
+
|
| 29 |
+
model_inputs = tokenizer(
|
| 30 |
+
input_text,
|
| 31 |
+
text_target=target_sql,
|
| 32 |
+
max_length=256,
|
| 33 |
+
padding="max_length",
|
| 34 |
+
truncation=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return model_inputs
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ------------------------------------------------------
|
| 41 |
+
# DATASET CONVERSION
|
| 42 |
+
# ------------------------------------------------------
|
| 43 |
+
print("Preparing dataset...")
|
| 44 |
+
train_ds = Dataset.from_pandas(train)
|
| 45 |
+
val_ds = Dataset.from_pandas(val)
|
| 46 |
+
|
| 47 |
+
print("Tokenizing train...")
|
| 48 |
+
train_ds = train_ds.map(tokenize, remove_columns=train_ds.column_names)
|
| 49 |
+
|
| 50 |
+
print("Tokenizing validation...")
|
| 51 |
+
val_ds = val_ds.map(tokenize, remove_columns=val_ds.column_names)
|
| 52 |
+
|
| 53 |
+
# save tokenized dataset
|
| 54 |
+
train_ds.save_to_disk("../data/tokenized/train")
|
| 55 |
+
val_ds.save_to_disk("../data/tokenized/validation")
|
| 56 |
+
|
| 57 |
+
print("DONE ✔ Tokenized dataset saved correctly")
|
src/train_model.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from datasets import load_from_disk
|
| 3 |
+
from transformers import (
|
| 4 |
+
T5ForConditionalGeneration,
|
| 5 |
+
T5Tokenizer,
|
| 6 |
+
DataCollatorForSeq2Seq,
|
| 7 |
+
Seq2SeqTrainer,
|
| 8 |
+
Seq2SeqTrainingArguments
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# ======================================================
|
| 12 |
+
# DEVICE (Mac M1/M2/M3 Safe)
|
| 13 |
+
# ======================================================
|
| 14 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 15 |
+
print("Using device:", device)
|
| 16 |
+
|
| 17 |
+
# ======================================================
|
| 18 |
+
# LOAD TOKENIZED DATASET (FIXED PATHS)
|
| 19 |
+
# ======================================================
|
| 20 |
+
print("Loading tokenized dataset...")
|
| 21 |
+
|
| 22 |
+
train_dataset = load_from_disk("data/tokenized/train")
|
| 23 |
+
val_dataset = load_from_disk("data/tokenized/validation")
|
| 24 |
+
|
| 25 |
+
print("Train size:", len(train_dataset))
|
| 26 |
+
print("Validation size:", len(val_dataset))
|
| 27 |
+
|
| 28 |
+
# ======================================================
|
| 29 |
+
# LOAD MODEL
|
| 30 |
+
# ======================================================
|
| 31 |
+
print("Loading model (t5-small)...")
|
| 32 |
+
|
| 33 |
+
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
|
| 34 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 35 |
+
|
| 36 |
+
# Prevent Mac memory crash
|
| 37 |
+
model.config.use_cache = False
|
| 38 |
+
|
| 39 |
+
# Important T5 settings (prevents generation bugs)
|
| 40 |
+
model.config.decoder_start_token_id = tokenizer.pad_token_id
|
| 41 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 42 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 43 |
+
|
| 44 |
+
# ======================================================
|
| 45 |
+
# DATA COLLATOR
|
| 46 |
+
# ======================================================
|
| 47 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 48 |
+
tokenizer=tokenizer,
|
| 49 |
+
model=model
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# ======================================================
|
| 53 |
+
# TRAINING ARGUMENTS (Mac Safe)
|
| 54 |
+
# ======================================================
|
| 55 |
+
print("Setting training config...")
|
| 56 |
+
|
| 57 |
+
training_args = Seq2SeqTrainingArguments(
|
| 58 |
+
output_dir="outputs/model",
|
| 59 |
+
|
| 60 |
+
evaluation_strategy="epoch",
|
| 61 |
+
save_strategy="epoch",
|
| 62 |
+
|
| 63 |
+
learning_rate=3e-4,
|
| 64 |
+
num_train_epochs=5,
|
| 65 |
+
|
| 66 |
+
per_device_train_batch_size=1,
|
| 67 |
+
per_device_eval_batch_size=1,
|
| 68 |
+
gradient_accumulation_steps=8,
|
| 69 |
+
|
| 70 |
+
logging_steps=50,
|
| 71 |
+
|
| 72 |
+
fp16=False,
|
| 73 |
+
bf16=False,
|
| 74 |
+
dataloader_pin_memory=False,
|
| 75 |
+
|
| 76 |
+
predict_with_generate=True,
|
| 77 |
+
report_to="none"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# ======================================================
|
| 81 |
+
# TRAINER
|
| 82 |
+
# ======================================================
|
| 83 |
+
trainer = Seq2SeqTrainer(
|
| 84 |
+
model=model,
|
| 85 |
+
args=training_args,
|
| 86 |
+
train_dataset=train_dataset,
|
| 87 |
+
eval_dataset=val_dataset,
|
| 88 |
+
tokenizer=tokenizer,
|
| 89 |
+
data_collator=data_collator,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ======================================================
|
| 93 |
+
# TRAIN
|
| 94 |
+
# ======================================================
|
| 95 |
+
print("Training started 🚀")
|
| 96 |
+
trainer.train()
|
| 97 |
+
|
| 98 |
+
# ======================================================
|
| 99 |
+
# SAVE MODEL
|
| 100 |
+
# ======================================================
|
| 101 |
+
print("Saving model...")
|
| 102 |
+
trainer.save_model("outputs/model")
|
| 103 |
+
tokenizer.save_pretrained("outputs/model")
|
| 104 |
+
|
| 105 |
+
print("\nDONE ✔ Base model trained successfully")
|
src/train_rl.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================================================
|
| 2 |
+
# RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION)
|
| 3 |
+
# =========================================================
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
|
| 10 |
+
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
|
| 11 |
+
from peft import PeftModel
|
| 12 |
+
import os, sys, sqlite3, re, random
|
| 13 |
+
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
from execution_reward import execution_reward, extract_tables, extract_columns
|
| 16 |
+
try:
|
| 17 |
+
import sqlparse # gate PPO updates on parsable SQL only
|
| 18 |
+
except Exception: # pragma: no cover
|
| 19 |
+
sqlparse = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ======================================================
|
| 23 |
+
# DEVICE
|
| 24 |
+
# ======================================================
|
| 25 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 26 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
print("Using device:", device)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ======================================================
|
| 31 |
+
# TRAINING SETTINGS
|
| 32 |
+
# ======================================================
|
| 33 |
+
NUM_EPOCHS = 5
|
| 34 |
+
LOG_EVERY = 20
|
| 35 |
+
USE_SCHEMA = True
|
| 36 |
+
SCHEMA_WARMUP_EPOCHS = 0
|
| 37 |
+
MAX_SCHEMA_CHARS = 1500
|
| 38 |
+
MAX_OUTPUT_TOKENS = 80
|
| 39 |
+
ROLLOUTS_PER_EPOCH = 2048
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ======================================================
|
| 43 |
+
# PATHS
|
| 44 |
+
# ======================================================
|
| 45 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 46 |
+
|
| 47 |
+
# 🎯 FIXED: Save ONLY the best model to this exact path
|
| 48 |
+
RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "rlhf_t5_best")
|
| 49 |
+
output_dir = RL_MODEL_PATH
|
| 50 |
+
|
| 51 |
+
DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
|
| 52 |
+
|
| 53 |
+
# 🎯 Updated to point to our newly trained t5-small SFT model
|
| 54 |
+
ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_t5")
|
| 55 |
+
|
| 56 |
+
FALLBACK_ADAPTER_PATH = os.path.join(PROJECT_ROOT, "models/t5_spider_sft_lora")
|
| 57 |
+
FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "outputs/sft_text2sql")
|
| 58 |
+
# 🎯 ENSURING t5-small is used
|
| 59 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ======================================================
|
| 63 |
+
# LOAD MODEL (LoRA)
|
| 64 |
+
# ======================================================
|
| 65 |
+
print("Loading base:", BASE_MODEL)
|
| 66 |
+
if not os.path.isdir(ADAPTER_PATH):
|
| 67 |
+
if os.path.isdir(FALLBACK_ADAPTER_PATH):
|
| 68 |
+
ADAPTER_PATH = FALLBACK_ADAPTER_PATH
|
| 69 |
+
elif os.path.isdir(FALLBACK_ADAPTER_PATH_2):
|
| 70 |
+
ADAPTER_PATH = FALLBACK_ADAPTER_PATH_2
|
| 71 |
+
print("Loading adapters:", ADAPTER_PATH)
|
| 72 |
+
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 74 |
+
|
| 75 |
+
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
|
| 76 |
+
model.pretrained_model = PeftModel.from_pretrained(model.pretrained_model, ADAPTER_PATH)
|
| 77 |
+
|
| 78 |
+
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
|
| 79 |
+
ref_model.pretrained_model = PeftModel.from_pretrained(ref_model.pretrained_model, ADAPTER_PATH)
|
| 80 |
+
|
| 81 |
+
ref_model.eval()
|
| 82 |
+
for p in ref_model.parameters():
|
| 83 |
+
p.requires_grad_(False)
|
| 84 |
+
|
| 85 |
+
# Freeze base transformer weights; train LoRA adapters + value head.
|
| 86 |
+
for name, p in model.named_parameters():
|
| 87 |
+
# Train value head
|
| 88 |
+
if name.startswith("v_head"):
|
| 89 |
+
p.requires_grad = True
|
| 90 |
+
# Train LoRA adapters (policy learning!)
|
| 91 |
+
elif "lora_" in name:
|
| 92 |
+
p.requires_grad = True
|
| 93 |
+
# Freeze base model
|
| 94 |
+
else:
|
| 95 |
+
p.requires_grad = False
|
| 96 |
+
|
| 97 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 98 |
+
total = sum(p.numel() for p in model.parameters())
|
| 99 |
+
print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
|
| 100 |
+
|
| 101 |
+
model.config.use_cache = False
|
| 102 |
+
ref_model.config.use_cache = False
|
| 103 |
+
|
| 104 |
+
if tokenizer.pad_token_id is None:
|
| 105 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ======================================================
|
| 109 |
+
# DATASET
|
| 110 |
+
# ======================================================
|
| 111 |
+
print("Loading Spider subset...")
|
| 112 |
+
random.seed(0)
|
| 113 |
+
|
| 114 |
+
# Train on a small, stable curriculum of DBs first.
|
| 115 |
+
TRAIN_DBS = [
|
| 116 |
+
"flight_1",
|
| 117 |
+
"student_assessment",
|
| 118 |
+
"store_1",
|
| 119 |
+
"bike_1",
|
| 120 |
+
"book_2",
|
| 121 |
+
"chinook_1",
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
dataset = load_dataset("spider", split="train")
|
| 125 |
+
_TRAIN_DBS_SET = set(TRAIN_DBS)
|
| 126 |
+
dataset = dataset.filter(lambda x: x["db_id"] in _TRAIN_DBS_SET)
|
| 127 |
+
dataset = dataset.select(range(min(800, len(dataset))))
|
| 128 |
+
|
| 129 |
+
print("Using RLHF DBs:", TRAIN_DBS)
|
| 130 |
+
print("Filtered size:", len(dataset))
|
| 131 |
+
|
| 132 |
+
total_steps = ROLLOUTS_PER_EPOCH
|
| 133 |
+
|
| 134 |
+
# ======================================================
|
| 135 |
+
# DB UTILITIES
|
| 136 |
+
# ======================================================
|
| 137 |
+
def get_db_path(db_id):
|
| 138 |
+
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_db_schema(db_path):
|
| 142 |
+
schema_text = ""
|
| 143 |
+
try:
|
| 144 |
+
conn = sqlite3.connect(db_path)
|
| 145 |
+
cursor = conn.cursor()
|
| 146 |
+
|
| 147 |
+
tables = cursor.execute(
|
| 148 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 149 |
+
).fetchall()
|
| 150 |
+
|
| 151 |
+
for table in tables:
|
| 152 |
+
table_name = table[0]
|
| 153 |
+
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
|
| 154 |
+
col_names = [col[1] for col in columns]
|
| 155 |
+
schema_text += f"{table_name}({', '.join(col_names)}) "
|
| 156 |
+
|
| 157 |
+
conn.close()
|
| 158 |
+
except:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
return schema_text
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ======================================================
|
| 165 |
+
# PROMPT
|
| 166 |
+
# ======================================================
|
| 167 |
+
PREFIX = "translate English to SQL:"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def trim_schema(schema: str, max_chars: int = 1200) -> str:
|
| 171 |
+
if schema is None:
|
| 172 |
+
return ""
|
| 173 |
+
schema = str(schema)
|
| 174 |
+
if len(schema) <= max_chars:
|
| 175 |
+
return schema
|
| 176 |
+
return schema[:max_chars]
|
| 177 |
+
def build_prompt(question: str, schema: str, use_schema: bool) -> str:
|
| 178 |
+
if not use_schema:
|
| 179 |
+
return f"{PREFIX}\n\nQuestion:\n{question}\n\nSQL:"
|
| 180 |
+
|
| 181 |
+
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
|
| 182 |
+
return f"{PREFIX}\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
|
| 183 |
+
|
| 184 |
+
def encode_prompt(question, schema, use_schema):
|
| 185 |
+
# Never truncate the question; only truncate schema tokens if needed.
|
| 186 |
+
if not use_schema:
|
| 187 |
+
prompt = build_prompt(question, schema, use_schema=False)
|
| 188 |
+
return tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0].to(device)
|
| 189 |
+
|
| 190 |
+
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
|
| 191 |
+
prefix_schema = f"{PREFIX}\n\nSchema:\n"
|
| 192 |
+
mid = "\n\nQuestion:\n"
|
| 193 |
+
suffix = f"{question}\n\nSQL:"
|
| 194 |
+
|
| 195 |
+
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
|
| 196 |
+
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
|
| 197 |
+
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
|
| 198 |
+
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
|
| 199 |
+
|
| 200 |
+
max_len = getattr(tokenizer, "model_max_length", 512)
|
| 201 |
+
eos_id = tokenizer.eos_token_id
|
| 202 |
+
max_without_eos = max_len - (1 if eos_id is not None else 0)
|
| 203 |
+
|
| 204 |
+
# Ensure the question+SQL suffix always fits; truncate schema first.
|
| 205 |
+
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 206 |
+
if fixed_len > max_without_eos:
|
| 207 |
+
# Extremely rare; clip the suffix (question) only if unavoidable.
|
| 208 |
+
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
|
| 209 |
+
suffix_ids = suffix_ids[:keep]
|
| 210 |
+
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 211 |
+
|
| 212 |
+
remaining_for_schema = max_without_eos - fixed_len
|
| 213 |
+
if remaining_for_schema < 0:
|
| 214 |
+
remaining_for_schema = 0
|
| 215 |
+
schema_ids = schema_ids[:remaining_for_schema]
|
| 216 |
+
|
| 217 |
+
ids = prefix_ids + schema_ids + mid_ids + suffix_ids
|
| 218 |
+
ids = ids[:max_without_eos]
|
| 219 |
+
if eos_id is not None:
|
| 220 |
+
ids = ids + [eos_id]
|
| 221 |
+
|
| 222 |
+
return torch.tensor(ids, dtype=torch.long).to(device)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ======================================================
|
| 226 |
+
# SQL CONSTRAINED DECODING
|
| 227 |
+
# ======================================================
|
| 228 |
+
SQL_KEYWORDS = [
|
| 229 |
+
"select", "from", "where", "join", "inner", "left", "right",
|
| 230 |
+
"full", "outer", "on", "group", "by", "order", "having",
|
| 231 |
+
"limit", "distinct", "as", "and", "or", "not", "in", "is",
|
| 232 |
+
"null", "like", "between", "asc", "desc", "union",
|
| 233 |
+
"intersect", "except",
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
SQL_OPERATORS = ["*", ",", ".", "(", ")", "=", "<", ">", "!", "+", "-", "/", "%", "_"]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _piece_token_str(tok: str) -> str:
|
| 240 |
+
# T5 SentencePiece: "▁" marks a leading space; strip it for char checks.
|
| 241 |
+
return tok.lstrip("▁")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _precompute_always_allowed_token_ids():
|
| 245 |
+
vocab_size = len(tokenizer)
|
| 246 |
+
allowed = set()
|
| 247 |
+
|
| 248 |
+
# Always allow special tokens.
|
| 249 |
+
for tid in [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.unk_token_id]:
|
| 250 |
+
if tid is not None and tid >= 0:
|
| 251 |
+
allowed.add(int(tid))
|
| 252 |
+
|
| 253 |
+
# Allow whitespace/newlines in case they exist as pieces.
|
| 254 |
+
for s in [" ", "\n", "\t"]:
|
| 255 |
+
allowed.update(tokenizer.encode(s, add_special_tokens=False))
|
| 256 |
+
|
| 257 |
+
# Allow operators/punctuation/numeric pieces broadly.
|
| 258 |
+
op_chars = set("".join(SQL_OPERATORS))
|
| 259 |
+
for tid in range(vocab_size):
|
| 260 |
+
tok = tokenizer.convert_ids_to_tokens(tid)
|
| 261 |
+
if not isinstance(tok, str) or not tok:
|
| 262 |
+
continue
|
| 263 |
+
piece = _piece_token_str(tok)
|
| 264 |
+
if not piece:
|
| 265 |
+
continue
|
| 266 |
+
if all((ch in op_chars) for ch in piece):
|
| 267 |
+
allowed.add(tid)
|
| 268 |
+
continue
|
| 269 |
+
if piece.isdigit():
|
| 270 |
+
allowed.add(tid)
|
| 271 |
+
continue
|
| 272 |
+
# Common numeric fragments like "1", "00", etc.
|
| 273 |
+
if all(ch.isdigit() for ch in piece):
|
| 274 |
+
allowed.add(tid)
|
| 275 |
+
|
| 276 |
+
# Allow keyword pieces.
|
| 277 |
+
for kw in SQL_KEYWORDS:
|
| 278 |
+
for variant in {kw, kw.upper(), kw.capitalize()}:
|
| 279 |
+
allowed.update(tokenizer.encode(" " + variant, add_special_tokens=False))
|
| 280 |
+
allowed.update(tokenizer.encode(variant, add_special_tokens=False))
|
| 281 |
+
|
| 282 |
+
return allowed
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
ALWAYS_ALLOWED_TOKEN_IDS = _precompute_always_allowed_token_ids()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _schema_allowed_token_ids(table_names, column_names):
|
| 289 |
+
allowed = set(ALWAYS_ALLOWED_TOKEN_IDS)
|
| 290 |
+
|
| 291 |
+
def _add_identifier(name: str):
|
| 292 |
+
if not name:
|
| 293 |
+
return
|
| 294 |
+
# Add whole identifier and common splits.
|
| 295 |
+
variants = {name, name.lower(), name.upper()}
|
| 296 |
+
parts = re.split(r"[_\s]+", name)
|
| 297 |
+
variants.update({p for p in parts if p})
|
| 298 |
+
for v in variants:
|
| 299 |
+
allowed.update(tokenizer.encode(" " + v, add_special_tokens=False))
|
| 300 |
+
allowed.update(tokenizer.encode(v, add_special_tokens=False))
|
| 301 |
+
|
| 302 |
+
for t in table_names:
|
| 303 |
+
_add_identifier(t)
|
| 304 |
+
for c in column_names:
|
| 305 |
+
_add_identifier(c)
|
| 306 |
+
|
| 307 |
+
return allowed
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class SQLVocabularyLogitsProcessor(LogitsProcessor):
|
| 311 |
+
def __init__(self, allowed_token_ids):
|
| 312 |
+
self.allowed_token_ids = {int(i) for i in allowed_token_ids if int(i) >= 0}
|
| 313 |
+
self._bias = None
|
| 314 |
+
self._bias_vocab_size = None
|
| 315 |
+
|
| 316 |
+
def _get_bias(self, scores: torch.Tensor) -> torch.Tensor:
|
| 317 |
+
vocab_size = int(scores.shape[-1])
|
| 318 |
+
if (
|
| 319 |
+
self._bias is None
|
| 320 |
+
or self._bias.device != scores.device
|
| 321 |
+
or self._bias.dtype != scores.dtype
|
| 322 |
+
or self._bias_vocab_size != vocab_size
|
| 323 |
+
):
|
| 324 |
+
bias = torch.full((vocab_size,), float("-inf"), device=scores.device, dtype=scores.dtype)
|
| 325 |
+
for tid in self.allowed_token_ids:
|
| 326 |
+
if tid < vocab_size:
|
| 327 |
+
bias[tid] = 0.0
|
| 328 |
+
self._bias = bias
|
| 329 |
+
self._bias_vocab_size = vocab_size
|
| 330 |
+
return self._bias
|
| 331 |
+
|
| 332 |
+
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
| 333 |
+
return scores + self._get_bias(scores)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
_DB_VOCAB_CACHE = {}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def get_db_tables_columns(db_path: str):
|
| 340 |
+
if db_path in _DB_VOCAB_CACHE:
|
| 341 |
+
return _DB_VOCAB_CACHE[db_path]
|
| 342 |
+
tables, cols = [], []
|
| 343 |
+
try:
|
| 344 |
+
conn = sqlite3.connect(db_path)
|
| 345 |
+
cur = conn.cursor()
|
| 346 |
+
for (tname,) in cur.execute(
|
| 347 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 348 |
+
).fetchall():
|
| 349 |
+
if not tname:
|
| 350 |
+
continue
|
| 351 |
+
tables.append(tname)
|
| 352 |
+
try:
|
| 353 |
+
for row in cur.execute(f'PRAGMA table_info("{tname}")').fetchall():
|
| 354 |
+
if row and isinstance(row[1], str):
|
| 355 |
+
cols.append(row[1])
|
| 356 |
+
except Exception:
|
| 357 |
+
continue
|
| 358 |
+
conn.close()
|
| 359 |
+
except Exception:
|
| 360 |
+
pass
|
| 361 |
+
_DB_VOCAB_CACHE[db_path] = (tables, cols)
|
| 362 |
+
return tables, cols
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ======================================================
|
| 366 |
+
# PPO CONFIG (stable learning)
|
| 367 |
+
# ======================================================
|
| 368 |
+
ppo_config = PPOConfig(
|
| 369 |
+
learning_rate=2e-5, # was 1e-6 → model could not move
|
| 370 |
+
batch_size=8, # better gradient estimate
|
| 371 |
+
mini_batch_size=2,
|
| 372 |
+
gradient_accumulation_steps=2, # stable updates on small data
|
| 373 |
+
ppo_epochs=1,
|
| 374 |
+
|
| 375 |
+
# --- KL control (MOST IMPORTANT FIX) ---
|
| 376 |
+
init_kl_coef=0.05, # reduce punishment
|
| 377 |
+
target_kl=0.15, # relax constraint to avoid skipped updates
|
| 378 |
+
adap_kl_ctrl=True,
|
| 379 |
+
|
| 380 |
+
# --- stability ---
|
| 381 |
+
cliprange=0.25,
|
| 382 |
+
cliprange_value=0.25,
|
| 383 |
+
whiten_rewards=True,
|
| 384 |
+
kl_penalty="kl",
|
| 385 |
+
max_grad_norm=1.0,
|
| 386 |
+
)
|
| 387 |
+
trainer = PPOTrainer(
|
| 388 |
+
config=ppo_config,
|
| 389 |
+
model=model,
|
| 390 |
+
ref_model=ref_model,
|
| 391 |
+
tokenizer=tokenizer,
|
| 392 |
+
)
|
| 393 |
+
optimizer = trainer.optimizer
|
| 394 |
+
|
| 395 |
+
# Provide `.device` attribute for the supervised anchor helper.
|
| 396 |
+
try:
|
| 397 |
+
model.device = torch.device(device)
|
| 398 |
+
except Exception:
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# ======================================================
|
| 403 |
+
# GENERATION (schema-constrained decoding)
|
| 404 |
+
# ======================================================
|
| 405 |
+
generation_kwargs = dict(
|
| 406 |
+
max_new_tokens=64, # 128 causes garbage SQL loops
|
| 407 |
+
|
| 408 |
+
do_sample=True,
|
| 409 |
+
temperature=0.9, # encourage exploration
|
| 410 |
+
top_p=0.95,
|
| 411 |
+
top_k=100,
|
| 412 |
+
|
| 413 |
+
repetition_penalty=1.1, # prevents SELECT SELECT SELECT
|
| 414 |
+
no_repeat_ngram_size=3,
|
| 415 |
+
|
| 416 |
+
num_beams=1,
|
| 417 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 418 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 419 |
+
)
|
| 420 |
+
# ======================================================
|
| 421 |
+
# TRAIN LOOP
|
| 422 |
+
# ======================================================
|
| 423 |
+
print("Starting RL training 🚀")
|
| 424 |
+
|
| 425 |
+
query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
|
| 426 |
+
query_text_buffer = []
|
| 427 |
+
|
| 428 |
+
best_reward = -999999
|
| 429 |
+
best_epoch = -1
|
| 430 |
+
|
| 431 |
+
def _is_parsable_sql(sql: str) -> bool:
|
| 432 |
+
s = (sql or "").strip()
|
| 433 |
+
if not s:
|
| 434 |
+
return False
|
| 435 |
+
up = s.upper()
|
| 436 |
+
if "SELECT" not in up or "FROM" not in up:
|
| 437 |
+
return False
|
| 438 |
+
if sqlparse is None:
|
| 439 |
+
return True
|
| 440 |
+
try:
|
| 441 |
+
return bool(sqlparse.parse(s))
|
| 442 |
+
except Exception:
|
| 443 |
+
return False
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _pad_2d(seqs, pad_id: int):
|
| 447 |
+
max_len = max(int(s.numel()) for s in seqs)
|
| 448 |
+
out = torch.full((len(seqs), max_len), int(pad_id), dtype=torch.long, device=device)
|
| 449 |
+
attn = torch.zeros((len(seqs), max_len), dtype=torch.long, device=device)
|
| 450 |
+
for i, s in enumerate(seqs):
|
| 451 |
+
n = int(s.numel())
|
| 452 |
+
out[i, :n] = s.to(device)
|
| 453 |
+
attn[i, :n] = 1
|
| 454 |
+
return out, attn
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _shift_right(labels: torch.Tensor, start_id: int) -> torch.Tensor:
|
| 458 |
+
dec = labels.clone()
|
| 459 |
+
dec[:, 1:] = labels[:, :-1]
|
| 460 |
+
dec[:, 0] = int(start_id)
|
| 461 |
+
return dec
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def safe_get_kl(stats):
|
| 465 |
+
if not isinstance(stats, dict):
|
| 466 |
+
return None
|
| 467 |
+
for k in stats.keys():
|
| 468 |
+
if "kl" in str(k).lower():
|
| 469 |
+
v = stats[k]
|
| 470 |
+
try:
|
| 471 |
+
return float(v.item() if hasattr(v, "item") else v)
|
| 472 |
+
except Exception:
|
| 473 |
+
return None
|
| 474 |
+
return None
|
| 475 |
+
|
| 476 |
+
def supervised_anchor_step(model, tokenizer, queries, gold_sqls, weight=0.05):
|
| 477 |
+
model.train()
|
| 478 |
+
total_loss = 0.0
|
| 479 |
+
|
| 480 |
+
for q, gold in zip(queries, gold_sqls):
|
| 481 |
+
|
| 482 |
+
enc = tokenizer(q, return_tensors="pt", truncation=True).to(model.device)
|
| 483 |
+
dec = tokenizer(text_target=gold, return_tensors="pt", truncation=True)
|
| 484 |
+
|
| 485 |
+
labels = dec.input_ids.to(model.device)
|
| 486 |
+
|
| 487 |
+
# teacher forcing shift
|
| 488 |
+
decoder_input_ids = labels[:, :-1].contiguous()
|
| 489 |
+
target_ids = labels[:, 1:].contiguous()
|
| 490 |
+
|
| 491 |
+
outputs = model(
|
| 492 |
+
input_ids=enc.input_ids,
|
| 493 |
+
attention_mask=enc.attention_mask,
|
| 494 |
+
decoder_input_ids=decoder_input_ids,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
logits = outputs[0]
|
| 498 |
+
|
| 499 |
+
vocab_size = logits.size(-1)
|
| 500 |
+
loss = F.cross_entropy(
|
| 501 |
+
logits.view(-1, vocab_size),
|
| 502 |
+
target_ids.view(-1),
|
| 503 |
+
ignore_index=tokenizer.pad_token_id,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
(loss * weight).backward()
|
| 507 |
+
total_loss += loss.item()
|
| 508 |
+
|
| 509 |
+
return total_loss
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
@torch.no_grad()
|
| 513 |
+
def _estimate_policy_entropy(query_tensors, response_tensors) -> torch.Tensor:
|
| 514 |
+
"""
|
| 515 |
+
Returns per-sample average token entropy of the policy on the sampled response tokens.
|
| 516 |
+
Used as a small bonus to reduce repetition collapse.
|
| 517 |
+
"""
|
| 518 |
+
pad_id = int(tokenizer.pad_token_id)
|
| 519 |
+
enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
|
| 520 |
+
dec_ids, dec_attn = _pad_2d(response_tensors, pad_id)
|
| 521 |
+
|
| 522 |
+
start_id = int(getattr(model.pretrained_model.config, "decoder_start_token_id", pad_id))
|
| 523 |
+
dec_inp = _shift_right(dec_ids, start_id)
|
| 524 |
+
|
| 525 |
+
out = model.pretrained_model(
|
| 526 |
+
input_ids=enc_ids,
|
| 527 |
+
attention_mask=enc_attn,
|
| 528 |
+
decoder_input_ids=dec_inp,
|
| 529 |
+
use_cache=False,
|
| 530 |
+
)
|
| 531 |
+
logp = torch.log_softmax(out.logits, dim=-1)
|
| 532 |
+
p = torch.exp(logp)
|
| 533 |
+
ent = -(p * logp).sum(dim=-1) # [B, T]
|
| 534 |
+
# average only over non-pad positions of the sampled response
|
| 535 |
+
denom = dec_attn.sum(dim=-1).clamp_min(1)
|
| 536 |
+
return (ent * dec_attn).sum(dim=-1) / denom # [B]
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _repeat_penalty(response_tensor: torch.Tensor) -> float:
|
| 540 |
+
"""
|
| 541 |
+
Penalize repetition to avoid 'SELECT SELECT SELECT' collapse.
|
| 542 |
+
Simple heuristic: consecutive duplicate token ratio + low-unique-token ratio.
|
| 543 |
+
"""
|
| 544 |
+
ids = response_tensor.detach().tolist()
|
| 545 |
+
n = len(ids)
|
| 546 |
+
if n <= 1:
|
| 547 |
+
return 0.0
|
| 548 |
+
consec_dup = 0
|
| 549 |
+
for i in range(1, n):
|
| 550 |
+
if ids[i] == ids[i - 1]:
|
| 551 |
+
consec_dup += 1
|
| 552 |
+
unique_ratio = len(set(ids)) / n
|
| 553 |
+
consec_ratio = consec_dup / (n - 1)
|
| 554 |
+
# Higher penalty when low unique + high consecutive duplicates
|
| 555 |
+
return float(0.5 * consec_ratio + 0.5 * (1.0 - unique_ratio))
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _supervised_anchor_step(query_tensors, gold_sql_texts, weight: float = 0.05) -> None:
|
| 559 |
+
"""
|
| 560 |
+
Small teacher-forcing step on gold SQL to anchor grammar during PPO.
|
| 561 |
+
Runs only if PPOTrainer exposes (accelerator, optimizer).
|
| 562 |
+
"""
|
| 563 |
+
if not gold_sql_texts:
|
| 564 |
+
return
|
| 565 |
+
accelerator = getattr(trainer, "accelerator", None)
|
| 566 |
+
optimizer = getattr(trainer, "optimizer", None)
|
| 567 |
+
if accelerator is None or optimizer is None:
|
| 568 |
+
return
|
| 569 |
+
|
| 570 |
+
pad_id = int(tokenizer.pad_token_id)
|
| 571 |
+
enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
|
| 572 |
+
|
| 573 |
+
# Tokenize gold SQL targets (decoder side)
|
| 574 |
+
gold_ids = []
|
| 575 |
+
for s in gold_sql_texts:
|
| 576 |
+
g = (s or "").strip()
|
| 577 |
+
if not g:
|
| 578 |
+
g = "SELECT 1"
|
| 579 |
+
ids = tokenizer.encode(g, add_special_tokens=False)[:256]
|
| 580 |
+
if tokenizer.eos_token_id is not None:
|
| 581 |
+
ids = ids + [int(tokenizer.eos_token_id)]
|
| 582 |
+
gold_ids.append(torch.tensor(ids, dtype=torch.long))
|
| 583 |
+
|
| 584 |
+
dec_ids, dec_attn = _pad_2d(gold_ids, pad_id)
|
| 585 |
+
labels = dec_ids.clone()
|
| 586 |
+
labels[dec_attn == 0] = -100
|
| 587 |
+
|
| 588 |
+
# PEFT model forward supports labels -> returns loss
|
| 589 |
+
out = model.pretrained_model(
|
| 590 |
+
input_ids=enc_ids,
|
| 591 |
+
attention_mask=enc_attn,
|
| 592 |
+
labels=labels,
|
| 593 |
+
use_cache=False,
|
| 594 |
+
)
|
| 595 |
+
loss = out.loss * float(weight)
|
| 596 |
+
|
| 597 |
+
optimizer.zero_grad(set_to_none=True) if hasattr(optimizer, "zero_grad") else None
|
| 598 |
+
accelerator.backward(loss)
|
| 599 |
+
optimizer.step()
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def _curriculum_allows(gold_sql: str, epoch_num: int) -> bool:
|
| 603 |
+
gold_up = (gold_sql or "").upper()
|
| 604 |
+
has_join = "JOIN" in gold_up
|
| 605 |
+
has_set_op = any(op in gold_up for op in ["UNION", "INTERSECT", "EXCEPT"])
|
| 606 |
+
tables = extract_tables(gold_sql)
|
| 607 |
+
single_table = len(tables) <= 1 and (not has_join)
|
| 608 |
+
|
| 609 |
+
# Epoch 1: only single-table, no joins/set-ops.
|
| 610 |
+
if epoch_num == 1:
|
| 611 |
+
return single_table and (not has_set_op)
|
| 612 |
+
# Epoch 2: allow joins, but still avoid set-ops.
|
| 613 |
+
if epoch_num == 2:
|
| 614 |
+
return (single_table or has_join) and (not has_set_op)
|
| 615 |
+
# Epoch 3+: full dataset.
|
| 616 |
+
return True
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
for epoch in range(1, NUM_EPOCHS + 1):
|
| 620 |
+
|
| 621 |
+
use_schema_this_epoch = USE_SCHEMA and (epoch > SCHEMA_WARMUP_EPOCHS)
|
| 622 |
+
|
| 623 |
+
epoch_reward_sum = 0
|
| 624 |
+
negative_rewards = 0
|
| 625 |
+
partial_rewards = 0
|
| 626 |
+
correct_rewards = 0
|
| 627 |
+
|
| 628 |
+
total_considered = 0
|
| 629 |
+
valid_sql_count = 0
|
| 630 |
+
exec_correct_count = 0
|
| 631 |
+
table_overlap_sum = 0.0
|
| 632 |
+
column_overlap_sum = 0.0
|
| 633 |
+
kl_values = []
|
| 634 |
+
|
| 635 |
+
for step in range(1, total_steps + 1):
|
| 636 |
+
|
| 637 |
+
example = dataset[random.randrange(len(dataset))]
|
| 638 |
+
|
| 639 |
+
question = example["question"]
|
| 640 |
+
gold_sql = example["query"]
|
| 641 |
+
db_id = example["db_id"]
|
| 642 |
+
db_path = get_db_path(db_id)
|
| 643 |
+
|
| 644 |
+
# NOTE: sampling-with-replacement provides more rollouts per epoch.
|
| 645 |
+
|
| 646 |
+
schema = get_db_schema(db_path)
|
| 647 |
+
question_text = build_prompt(question, schema, use_schema_this_epoch)
|
| 648 |
+
query_tensor = encode_prompt(question, schema, use_schema_this_epoch)
|
| 649 |
+
|
| 650 |
+
# ----- generate -----
|
| 651 |
+
table_names, column_names = get_db_tables_columns(db_path)
|
| 652 |
+
allowed_ids = _schema_allowed_token_ids(table_names, column_names)
|
| 653 |
+
logits_processor = LogitsProcessorList([SQLVocabularyLogitsProcessor(allowed_ids)])
|
| 654 |
+
|
| 655 |
+
response = trainer.generate([query_tensor], logits_processor=logits_processor, **generation_kwargs)[0]
|
| 656 |
+
response_tensor = response.squeeze(0)[:MAX_OUTPUT_TOKENS]
|
| 657 |
+
|
| 658 |
+
pred_sql = tokenizer.decode(response_tensor.cpu(), skip_special_tokens=True)
|
| 659 |
+
|
| 660 |
+
total_considered += 1
|
| 661 |
+
|
| 662 |
+
# PPO must optimize ONLY when SQL parses successfully.
|
| 663 |
+
if not _is_parsable_sql(pred_sql):
|
| 664 |
+
negative_rewards += 1
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
# Reject generations shorter than 6 tokens.
|
| 668 |
+
if int(response_tensor.numel()) < 6:
|
| 669 |
+
negative_rewards += 1
|
| 670 |
+
continue
|
| 671 |
+
|
| 672 |
+
# ----- reward -----
|
| 673 |
+
reward_value = execution_reward(pred_sql, db_path, gold_sql)
|
| 674 |
+
|
| 675 |
+
# SQL validity gate: if invalid/unparsable -> reward_value is None -> skip PPO entirely.
|
| 676 |
+
if reward_value is None:
|
| 677 |
+
if step % 100 == 0:
|
| 678 |
+
ratio = valid_sql_count / max(total_considered, 1)
|
| 679 |
+
print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
|
| 680 |
+
if ratio < 0.15:
|
| 681 |
+
print("MODEL COLLAPSING")
|
| 682 |
+
continue
|
| 683 |
+
|
| 684 |
+
# Clip rewards to [-1, 1]
|
| 685 |
+
reward_value = float(max(-1.0, min(1.0, reward_value)))
|
| 686 |
+
# Penalize repetition in decoded output (token-level heuristic).
|
| 687 |
+
reward_value = float(max(-1.0, min(1.0, reward_value - 0.2 * _repeat_penalty(response_tensor))))
|
| 688 |
+
# Keep rewards on CPU for normalization; move to device only for trainer.step().
|
| 689 |
+
reward_tensor = torch.tensor(reward_value, dtype=torch.float32)
|
| 690 |
+
|
| 691 |
+
epoch_reward_sum += reward_value
|
| 692 |
+
|
| 693 |
+
# ----- metrics -----
|
| 694 |
+
# "Valid sample" means reward is not None (parsable SQL).
|
| 695 |
+
valid_sql_count += 1
|
| 696 |
+
|
| 697 |
+
pred_tables = extract_tables(pred_sql)
|
| 698 |
+
gold_tables = extract_tables(gold_sql)
|
| 699 |
+
pred_cols = extract_columns(pred_sql)
|
| 700 |
+
gold_cols = extract_columns(gold_sql)
|
| 701 |
+
|
| 702 |
+
if len(gold_tables) > 0:
|
| 703 |
+
table_overlap_sum += len(pred_tables & gold_tables) / max(len(gold_tables), 1)
|
| 704 |
+
if len(gold_cols) > 0:
|
| 705 |
+
column_overlap_sum += len(pred_cols & gold_cols) / max(len(gold_cols), 1)
|
| 706 |
+
|
| 707 |
+
# execution_reward returns 1.0 for correct execution result.
|
| 708 |
+
if reward_value >= 1.0:
|
| 709 |
+
exec_correct_count += 1
|
| 710 |
+
|
| 711 |
+
if reward_value <= -1.0:
|
| 712 |
+
negative_rewards += 1
|
| 713 |
+
elif reward_value >= 1.0:
|
| 714 |
+
correct_rewards += 1
|
| 715 |
+
else:
|
| 716 |
+
partial_rewards += 1
|
| 717 |
+
|
| 718 |
+
# Train only on informative samples:
|
| 719 |
+
# - invalid SQL already skipped (reward is None)
|
| 720 |
+
# - very small magnitude signal skipped
|
| 721 |
+
if abs(reward_value) < 0.1:
|
| 722 |
+
continue
|
| 723 |
+
|
| 724 |
+
query_buffer.append(query_tensor)
|
| 725 |
+
response_buffer.append(response_tensor)
|
| 726 |
+
reward_buffer.append(reward_tensor)
|
| 727 |
+
gold_buffer.append(gold_sql)
|
| 728 |
+
query_text_buffer.append(question_text)
|
| 729 |
+
|
| 730 |
+
# ----- PPO update -----
|
| 731 |
+
if len(query_buffer) == ppo_config.batch_size:
|
| 732 |
+
# move rewards to device
|
| 733 |
+
reward_buffer = [r.to(device) for r in reward_buffer]
|
| 734 |
+
|
| 735 |
+
# run PPO step
|
| 736 |
+
stats = trainer.step(query_buffer, response_buffer, reward_buffer)
|
| 737 |
+
|
| 738 |
+
# log KL safely (no control logic)
|
| 739 |
+
kl = safe_get_kl(stats)
|
| 740 |
+
if kl is not None:
|
| 741 |
+
kl_values.append(kl)
|
| 742 |
+
|
| 743 |
+
# --- supervised anchor to prevent grammar collapse ---
|
| 744 |
+
supervised_anchor_step(model, tokenizer, query_text_buffer, gold_buffer, weight=0.05)
|
| 745 |
+
optimizer.step()
|
| 746 |
+
optimizer.zero_grad()
|
| 747 |
+
|
| 748 |
+
# reset buffers
|
| 749 |
+
query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
|
| 750 |
+
query_text_buffer = []
|
| 751 |
+
|
| 752 |
+
# ----- learning ratio logging -----
|
| 753 |
+
if step % 100 == 0:
|
| 754 |
+
ratio = valid_sql_count / max(total_considered, 1)
|
| 755 |
+
print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
|
| 756 |
+
if ratio < 0.15:
|
| 757 |
+
print("MODEL COLLAPSING")
|
| 758 |
+
# Increase KL coefficient dynamically when valid_sql_rate drops.
|
| 759 |
+
try:
|
| 760 |
+
if hasattr(trainer, "kl_ctl") and hasattr(trainer.kl_ctl, "value"):
|
| 761 |
+
trainer.kl_ctl.value *= 1.5
|
| 762 |
+
print(f"Increasing KL coef -> {trainer.kl_ctl.value:.4f}")
|
| 763 |
+
except Exception:
|
| 764 |
+
pass
|
| 765 |
+
|
| 766 |
+
# ----- logging -----
|
| 767 |
+
if step % LOG_EVERY == 0:
|
| 768 |
+
avg_reward = epoch_reward_sum / step
|
| 769 |
+
print("\n---------------------------")
|
| 770 |
+
print(f"Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{total_steps} | Avg Reward {avg_reward:.3f}")
|
| 771 |
+
print("DB:", db_id)
|
| 772 |
+
print("Q:", question)
|
| 773 |
+
print("SQL:", pred_sql)
|
| 774 |
+
print("Reward:", reward_value)
|
| 775 |
+
|
| 776 |
+
# epoch stats
|
| 777 |
+
print(f"\nEpoch {epoch} stats:")
|
| 778 |
+
print("negative:", negative_rewards)
|
| 779 |
+
print("partial:", partial_rewards)
|
| 780 |
+
print("correct:", correct_rewards)
|
| 781 |
+
|
| 782 |
+
denom = max(total_considered, 1)
|
| 783 |
+
print("\nEpoch metrics:")
|
| 784 |
+
print(f"execution_accuracy: {exec_correct_count/denom:.3f}")
|
| 785 |
+
print(f"valid_sql_rate: {valid_sql_count/denom:.3f}")
|
| 786 |
+
print(f"table_match_rate: {table_overlap_sum/denom:.3f}")
|
| 787 |
+
print(f"column_match_rate: {column_overlap_sum/denom:.3f}")
|
| 788 |
+
print(f"avg_reward: {epoch_reward_sum/max(denom,1):.3f}")
|
| 789 |
+
if kl_values:
|
| 790 |
+
avg_kl = sum(kl_values) / max(len(kl_values), 1)
|
| 791 |
+
print(f"avg_kl: {avg_kl:.3f}")
|
| 792 |
+
if avg_kl < -8:
|
| 793 |
+
print("\nKL collapse guard triggered (avg_kl < -8). Stopping early.")
|
| 794 |
+
break
|
| 795 |
+
|
| 796 |
+
# 🎯 FIXED: Removed the code that saved intermediate checkpoints at the end of each epoch
|
| 797 |
+
|
| 798 |
+
# Only save if this epoch is the best one so far
|
| 799 |
+
epoch_avg_reward = epoch_reward_sum / max(denom, 1)
|
| 800 |
+
if epoch_avg_reward > best_reward:
|
| 801 |
+
best_reward = epoch_avg_reward
|
| 802 |
+
best_epoch = epoch
|
| 803 |
+
|
| 804 |
+
print(f"\nNew best model at epoch {epoch} with reward {best_reward:.4f}")
|
| 805 |
+
|
| 806 |
+
# 🎯 FIXED: Save directly to checkpoints/rlhf_t5_best, overwriting if needed
|
| 807 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 808 |
+
|
| 809 |
+
trainer.model.save_pretrained(output_dir)
|
| 810 |
+
tokenizer.save_pretrained(output_dir)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
print(f"\nTraining finished.")
|
| 814 |
+
print(f"Best epoch: {best_epoch}")
|
| 815 |
+
print(f"Best reward: {best_reward:.4f}")
|
| 816 |
+
print(f"Best model saved at: {output_dir}")
|
src/train_rl_bart.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================================================
|
| 2 |
+
# RLHF TRAINING FOR TEXT2SQL (OPTIMIZED PPO VERSION - BART)
|
| 3 |
+
# =========================================================
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
import os, sys, sqlite3, re, random
|
| 11 |
+
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
from execution_reward import execution_reward, extract_tables, extract_columns
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import sqlparse # gate PPO updates on parsable SQL only
|
| 17 |
+
except Exception: # pragma: no cover
|
| 18 |
+
sqlparse = None
|
| 19 |
+
|
| 20 |
+
# ======================================================
|
| 21 |
+
# DEVICE
|
| 22 |
+
# ======================================================
|
| 23 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 24 |
+
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
print("Using device:", device)
|
| 26 |
+
|
| 27 |
+
# ======================================================
|
| 28 |
+
# TRAINING SETTINGS (🚀 OPTIMIZED FOR SPEED)
|
| 29 |
+
# ======================================================
|
| 30 |
+
NUM_EPOCHS = 10 # Increased to compensate for faster epochs
|
| 31 |
+
LOG_EVERY = 5 # Print logs much more frequently
|
| 32 |
+
MAX_SCHEMA_CHARS = 1500
|
| 33 |
+
MAX_OUTPUT_TOKENS = 48 # 🚀 Down from 64. 95% of Spider SQL is <40 tokens.
|
| 34 |
+
ROLLOUTS_PER_EPOCH = 256 # 🚀 Down from 1024. Epochs will finish 4x faster!
|
| 35 |
+
|
| 36 |
+
# ======================================================
|
| 37 |
+
# PATHS
|
| 38 |
+
# ======================================================
|
| 39 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 40 |
+
DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
|
| 41 |
+
|
| 42 |
+
# 🎯 Strict Input: Load strictly from your SFT BART checkpoint
|
| 43 |
+
ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_best_bart_2")
|
| 44 |
+
|
| 45 |
+
# 🎯 Strict Output: Save strictly to rl_best_bart
|
| 46 |
+
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints/rl_best_bart")
|
| 47 |
+
|
| 48 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base")
|
| 49 |
+
|
| 50 |
+
if not os.path.exists(ADAPTER_PATH):
|
| 51 |
+
raise RuntimeError(f"❌ No valid LoRA adapter found at: {ADAPTER_PATH}")
|
| 52 |
+
|
| 53 |
+
print("Loading base:", BASE_MODEL)
|
| 54 |
+
print("Loading adapter:", ADAPTER_PATH)
|
| 55 |
+
|
| 56 |
+
# ======================================================
|
| 57 |
+
# TOKENIZER
|
| 58 |
+
# ======================================================
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 60 |
+
if tokenizer.pad_token is None:
|
| 61 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 62 |
+
|
| 63 |
+
# ======================================================
|
| 64 |
+
# LOAD PPO MODEL
|
| 65 |
+
# ======================================================
|
| 66 |
+
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 67 |
+
BASE_MODEL,
|
| 68 |
+
torch_dtype=torch.float32
|
| 69 |
+
).to(device)
|
| 70 |
+
|
| 71 |
+
model.pretrained_model = PeftModel.from_pretrained(
|
| 72 |
+
model.pretrained_model,
|
| 73 |
+
ADAPTER_PATH,
|
| 74 |
+
is_trainable=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# ======================================================
|
| 78 |
+
# LOAD REFERENCE MODEL (FROZEN)
|
| 79 |
+
# ======================================================
|
| 80 |
+
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 81 |
+
BASE_MODEL,
|
| 82 |
+
torch_dtype=torch.float32
|
| 83 |
+
).to(device)
|
| 84 |
+
|
| 85 |
+
ref_model.pretrained_model = PeftModel.from_pretrained(
|
| 86 |
+
ref_model.pretrained_model,
|
| 87 |
+
ADAPTER_PATH,
|
| 88 |
+
is_trainable=False
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
ref_model.eval()
|
| 92 |
+
for p in ref_model.parameters():
|
| 93 |
+
p.requires_grad = False
|
| 94 |
+
|
| 95 |
+
# ======================================================
|
| 96 |
+
# TRAINABLE PARAMS — ONLY LoRA + VALUE HEAD
|
| 97 |
+
# ======================================================
|
| 98 |
+
for name, p in model.named_parameters():
|
| 99 |
+
if "lora_" in name or "v_head" in name:
|
| 100 |
+
p.requires_grad = True
|
| 101 |
+
else:
|
| 102 |
+
p.requires_grad = False
|
| 103 |
+
|
| 104 |
+
model.train()
|
| 105 |
+
|
| 106 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 107 |
+
total = sum(p.numel() for p in model.parameters())
|
| 108 |
+
print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
|
| 109 |
+
|
| 110 |
+
model.config.use_cache = False
|
| 111 |
+
ref_model.config.use_cache = False
|
| 112 |
+
|
| 113 |
+
# ======================================================
|
| 114 |
+
# DATASET
|
| 115 |
+
# ======================================================
|
| 116 |
+
print("Loading Spider subset...")
|
| 117 |
+
random.seed(0)
|
| 118 |
+
|
| 119 |
+
TRAIN_DBS = [
|
| 120 |
+
# already trained
|
| 121 |
+
"flight_1","student_assessment","store_1","bike_1","book_2","chinook_1",
|
| 122 |
+
"academic","aircraft","car_1","cinema","club_1","csu_1",
|
| 123 |
+
|
| 124 |
+
# medium difficulty (NEW)
|
| 125 |
+
"college_1","college_2","company_1","company_employee",
|
| 126 |
+
"customer_complaints","department_store","employee_hire_evaluation",
|
| 127 |
+
"museum_visit","products_for_hire","restaurant_1",
|
| 128 |
+
"school_finance","shop_membership","small_bank_1",
|
| 129 |
+
"soccer_1","student_1","tvshow","voter_1","world_1"
|
| 130 |
+
]
|
| 131 |
+
dataset = load_dataset("spider", split="train")
|
| 132 |
+
dataset = dataset.filter(lambda x: x["db_id"] in TRAIN_DBS)
|
| 133 |
+
|
| 134 |
+
def valid_example(x):
|
| 135 |
+
return 5 <= len(x["question"].split()) <= 40
|
| 136 |
+
|
| 137 |
+
dataset = dataset.filter(valid_example)
|
| 138 |
+
print("Filtered dataset size:", len(dataset))
|
| 139 |
+
|
| 140 |
+
def sample_example():
|
| 141 |
+
return dataset[random.randrange(len(dataset))]
|
| 142 |
+
|
| 143 |
+
# ======================================================
|
| 144 |
+
# DB UTILITIES
|
| 145 |
+
# ======================================================
|
| 146 |
+
def get_db_path(db_id):
|
| 147 |
+
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 148 |
+
|
| 149 |
+
_SCHEMA_CACHE = {}
|
| 150 |
+
|
| 151 |
+
def get_db_schema_cached(db_path):
|
| 152 |
+
if db_path in _SCHEMA_CACHE:
|
| 153 |
+
return _SCHEMA_CACHE[db_path]
|
| 154 |
+
|
| 155 |
+
schema_text = ""
|
| 156 |
+
try:
|
| 157 |
+
conn = sqlite3.connect(db_path)
|
| 158 |
+
cursor = conn.cursor()
|
| 159 |
+
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 160 |
+
|
| 161 |
+
for table in tables:
|
| 162 |
+
table_name = table[0]
|
| 163 |
+
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
|
| 164 |
+
col_names = [col[1] for col in columns]
|
| 165 |
+
schema_text += f"{table_name}({', '.join(col_names)})\n"
|
| 166 |
+
conn.close()
|
| 167 |
+
except:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
_SCHEMA_CACHE[db_path] = schema_text.strip()
|
| 171 |
+
return _SCHEMA_CACHE[db_path]
|
| 172 |
+
|
| 173 |
+
# ======================================================
|
| 174 |
+
# PROMPT
|
| 175 |
+
# ======================================================
|
| 176 |
+
def trim_schema(schema: str, max_chars: int = 1200) -> str:
|
| 177 |
+
if schema is None:
|
| 178 |
+
return ""
|
| 179 |
+
schema = str(schema)
|
| 180 |
+
if len(schema) <= max_chars:
|
| 181 |
+
return schema
|
| 182 |
+
return schema[:max_chars]
|
| 183 |
+
|
| 184 |
+
def build_prompt(question: str, schema: str) -> str:
|
| 185 |
+
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
|
| 186 |
+
return f"Database Schema:\n{schema}\n\nTranslate English to SQL:\n{question}\nSQL:\n"
|
| 187 |
+
|
| 188 |
+
# ======================================================
|
| 189 |
+
# PPO CONFIG (STABLE POLICY LEARNING)
|
| 190 |
+
# ======================================================
|
| 191 |
+
ppo_config = PPOConfig(
|
| 192 |
+
learning_rate=3e-6, # slower = prevents policy jump (very important)
|
| 193 |
+
batch_size=8,
|
| 194 |
+
mini_batch_size=4, # good size, keep this
|
| 195 |
+
gradient_accumulation_steps=2,
|
| 196 |
+
|
| 197 |
+
ppo_epochs=2, # smoother policy update (was 1 → unstable)
|
| 198 |
+
|
| 199 |
+
# ---- KL CONTROL (main fix for negative KL) ----
|
| 200 |
+
init_kl_coef=0.1,
|
| 201 |
+
target_kl=0.08, # 0.02 was too strict → caused oscillation
|
| 202 |
+
adap_kl_ctrl=True,
|
| 203 |
+
|
| 204 |
+
# ---- CLIPPING ----
|
| 205 |
+
cliprange=0.15,
|
| 206 |
+
cliprange_value=0.15,
|
| 207 |
+
|
| 208 |
+
# ---- REWARD STABILITY ----
|
| 209 |
+
whiten_rewards=True, # VERY IMPORTANT for binary execution reward
|
| 210 |
+
kl_penalty="kl",
|
| 211 |
+
|
| 212 |
+
# ---- GRADIENT SAFETY ----
|
| 213 |
+
max_grad_norm=0.3,
|
| 214 |
+
)
|
| 215 |
+
trainer = PPOTrainer(
|
| 216 |
+
config=ppo_config,
|
| 217 |
+
model=model,
|
| 218 |
+
ref_model=ref_model,
|
| 219 |
+
tokenizer=tokenizer,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
model.device = torch.device(device)
|
| 224 |
+
except Exception:
|
| 225 |
+
pass
|
| 226 |
+
|
| 227 |
+
# ======================================================
|
| 228 |
+
# GENERATION CONFIG
|
| 229 |
+
# ======================================================
|
| 230 |
+
generation_kwargs = dict(
|
| 231 |
+
max_new_tokens=MAX_OUTPUT_TOKENS,
|
| 232 |
+
do_sample=True,
|
| 233 |
+
temperature=0.7,
|
| 234 |
+
top_p=0.9,
|
| 235 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 236 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 237 |
+
)
|
| 238 |
+
# ======================================================
|
| 239 |
+
# TRAIN LOOP (BATCHED & OPTIMIZED)
|
| 240 |
+
# ======================================================
|
| 241 |
+
print("Starting RL training 🚀 (BART PPO Optimized)")
|
| 242 |
+
|
| 243 |
+
best_reward = -1e9
|
| 244 |
+
global_ppo_step = 0
|
| 245 |
+
model.train()
|
| 246 |
+
|
| 247 |
+
for epoch in range(1, NUM_EPOCHS + 1):
|
| 248 |
+
epoch_reward_sum = 0
|
| 249 |
+
valid_sql_count = 0
|
| 250 |
+
total_seen = 0
|
| 251 |
+
|
| 252 |
+
for step in range(0, ROLLOUTS_PER_EPOCH, ppo_config.batch_size):
|
| 253 |
+
|
| 254 |
+
batch_prompts = []
|
| 255 |
+
batch_meta = []
|
| 256 |
+
|
| 257 |
+
for _ in range(ppo_config.batch_size):
|
| 258 |
+
example = sample_example()
|
| 259 |
+
question = example["question"]
|
| 260 |
+
gold_sql = example["query"]
|
| 261 |
+
db_id = example["db_id"]
|
| 262 |
+
db_path = get_db_path(db_id)
|
| 263 |
+
|
| 264 |
+
schema = get_db_schema_cached(db_path)
|
| 265 |
+
prompt = build_prompt(question, schema)
|
| 266 |
+
|
| 267 |
+
batch_prompts.append(prompt)
|
| 268 |
+
batch_meta.append((question, gold_sql, db_path, db_id))
|
| 269 |
+
|
| 270 |
+
encoded_inputs = tokenizer(
|
| 271 |
+
batch_prompts,
|
| 272 |
+
return_tensors="pt",
|
| 273 |
+
padding=True,
|
| 274 |
+
truncation=True,
|
| 275 |
+
max_length=512,
|
| 276 |
+
pad_to_multiple_of=8
|
| 277 |
+
).to(device)
|
| 278 |
+
|
| 279 |
+
query_tensors = [encoded_inputs.input_ids[i] for i in range(ppo_config.batch_size)]
|
| 280 |
+
|
| 281 |
+
# 🎯 BYPASS: Native model.generate to prevent TRL's truncation crash
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
response_tensors_raw = model.generate(
|
| 284 |
+
input_ids=encoded_inputs.input_ids,
|
| 285 |
+
attention_mask=encoded_inputs.attention_mask,
|
| 286 |
+
**generation_kwargs
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
batch_rewards = []
|
| 290 |
+
batch_responses_text = []
|
| 291 |
+
response_tensors = []
|
| 292 |
+
|
| 293 |
+
for i in range(ppo_config.batch_size):
|
| 294 |
+
resp = response_tensors_raw[i]
|
| 295 |
+
|
| 296 |
+
# 🎯 Strip padding safely so TRL's mask calculation never crashes
|
| 297 |
+
non_pad_mask = resp != tokenizer.pad_token_id
|
| 298 |
+
if non_pad_mask.sum() == 0:
|
| 299 |
+
resp = torch.tensor([tokenizer.eos_token_id], device=device)
|
| 300 |
+
non_pad_mask = resp != tokenizer.pad_token_id
|
| 301 |
+
|
| 302 |
+
valid_len = non_pad_mask.nonzero()[-1].item() + 1
|
| 303 |
+
clean_resp = resp[:valid_len]
|
| 304 |
+
response_tensors.append(clean_resp)
|
| 305 |
+
|
| 306 |
+
response = tokenizer.decode(clean_resp, skip_special_tokens=True)
|
| 307 |
+
batch_responses_text.append(response)
|
| 308 |
+
|
| 309 |
+
question, gold_sql, db_path, db_id = batch_meta[i]
|
| 310 |
+
total_seen += 1
|
| 311 |
+
|
| 312 |
+
if "select" not in response.lower():
|
| 313 |
+
batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
reward = execution_reward(response, db_path, gold_sql)
|
| 317 |
+
if reward is None:
|
| 318 |
+
batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
reward = float(reward)
|
| 322 |
+
|
| 323 |
+
pred_tables = extract_tables(response)
|
| 324 |
+
gold_tables = extract_tables(gold_sql)
|
| 325 |
+
if len(gold_tables) > 0:
|
| 326 |
+
reward += 0.25 * (len(pred_tables & gold_tables) / len(gold_tables))
|
| 327 |
+
|
| 328 |
+
pred_cols = extract_columns(response)
|
| 329 |
+
gold_cols = extract_columns(gold_sql)
|
| 330 |
+
if len(gold_cols) > 0:
|
| 331 |
+
reward += 0.15 * (len(pred_cols & gold_cols) / len(gold_cols))
|
| 332 |
+
|
| 333 |
+
reward = max(-1.0, min(1.0, reward))
|
| 334 |
+
batch_rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))
|
| 335 |
+
|
| 336 |
+
epoch_reward_sum += reward
|
| 337 |
+
valid_sql_count += 1
|
| 338 |
+
|
| 339 |
+
# ---------- PPO UPDATE ----------
|
| 340 |
+
try:
|
| 341 |
+
trainer.step(query_tensors, response_tensors, batch_rewards)
|
| 342 |
+
global_ppo_step += 1
|
| 343 |
+
except Exception as e:
|
| 344 |
+
print("⚠️ PPO skipped:", e)
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
# ---------- LOG ----------
|
| 348 |
+
if step % (LOG_EVERY * ppo_config.batch_size) == 0 and valid_sql_count > 0:
|
| 349 |
+
print("\n---------------------------")
|
| 350 |
+
print(f"Epoch {epoch}/{NUM_EPOCHS} Step {step}/{ROLLOUTS_PER_EPOCH} | Global Update {global_ppo_step}")
|
| 351 |
+
print("Avg Reward:", round(epoch_reward_sum/valid_sql_count,3))
|
| 352 |
+
print("Valid SQL:", valid_sql_count,"/",total_seen)
|
| 353 |
+
|
| 354 |
+
sample_idx = random.randint(0, ppo_config.batch_size - 1)
|
| 355 |
+
print("DB:", batch_meta[sample_idx][3])
|
| 356 |
+
print("Q:", batch_meta[sample_idx][0])
|
| 357 |
+
print("SQL:", batch_responses_text[sample_idx])
|
| 358 |
+
print("Reward:", round(batch_rewards[sample_idx].item(), 3))
|
| 359 |
+
|
| 360 |
+
# ---------- SAVE ONLY THE BEST MODEL ----------
|
| 361 |
+
avg_reward = epoch_reward_sum / max(valid_sql_count, 1)
|
| 362 |
+
|
| 363 |
+
if avg_reward > best_reward:
|
| 364 |
+
best_reward = avg_reward
|
| 365 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 366 |
+
|
| 367 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 368 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 369 |
+
|
| 370 |
+
print(f"\n✅ Saved BEST RLHF model for Epoch {epoch} (reward {best_reward:.3f}) at {OUTPUT_DIR}")
|
src/train_rl_codet5.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================================================
|
| 2 |
+
# RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION)
|
| 3 |
+
# =========================================================
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
|
| 9 |
+
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
import os, sys, sqlite3, re, random
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
from execution_reward import execution_reward, extract_tables, extract_columns
|
| 15 |
+
try:
|
| 16 |
+
import sqlparse # gate PPO updates on parsable SQL only
|
| 17 |
+
except Exception: # pragma: no cover
|
| 18 |
+
sqlparse = None
|
| 19 |
+
|
| 20 |
+
# ======================================================
|
| 21 |
+
# DEVICE
|
| 22 |
+
# ======================================================
|
| 23 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 24 |
+
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
print("Using device:", device)
|
| 26 |
+
|
| 27 |
+
# ======================================================
|
| 28 |
+
# TRAINING SETTINGS
|
| 29 |
+
# ======================================================
|
| 30 |
+
NUM_EPOCHS = 15
|
| 31 |
+
LOG_EVERY = 20
|
| 32 |
+
USE_SCHEMA = True
|
| 33 |
+
SCHEMA_WARMUP_EPOCHS = 2
|
| 34 |
+
MAX_SCHEMA_CHARS = 1500
|
| 35 |
+
MAX_OUTPUT_TOKENS = 64 # 🚀 Speed up: Reduced max tokens
|
| 36 |
+
ROLLOUTS_PER_EPOCH = 1024
|
| 37 |
+
|
| 38 |
+
# ======================================================
|
| 39 |
+
# PATHS
|
| 40 |
+
# ======================================================
|
| 41 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 42 |
+
RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "outputs/rlhf_text2sql")
|
| 43 |
+
output_dir = RL_MODEL_PATH
|
| 44 |
+
DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
|
| 45 |
+
|
| 46 |
+
# Explicit resume checkpoint
|
| 47 |
+
RESUME_CHECKPOINT = os.path.join(PROJECT_ROOT, "checkpoints/milestone_before_more_dbs")
|
| 48 |
+
|
| 49 |
+
ADAPTER_PATH = os.path.abspath(os.path.join(PROJECT_ROOT, "checkpoints/sft_adapter_codet5"))
|
| 50 |
+
FALLBACK_ADAPTER_PATH = ADAPTER_PATH
|
| 51 |
+
FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "checkpoints")
|
| 52 |
+
|
| 53 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "Salesforce/codet5-base")
|
| 54 |
+
|
| 55 |
+
# ======================================================
|
| 56 |
+
# LOAD MODEL (LoRA)
|
| 57 |
+
# ======================================================
|
| 58 |
+
def find_valid_adapter(path_candidates):
|
| 59 |
+
# 🚀 SAFETY & RESUME: Check for existing milestone first
|
| 60 |
+
if os.path.exists(os.path.join(RESUME_CHECKPOINT, "adapter_config.json")):
|
| 61 |
+
print(f"\n✅ Resuming RL training from checkpoint: {RESUME_CHECKPOINT}\n")
|
| 62 |
+
return RESUME_CHECKPOINT
|
| 63 |
+
|
| 64 |
+
for p in path_candidates:
|
| 65 |
+
if p and os.path.exists(os.path.join(p, "adapter_config.json")):
|
| 66 |
+
return os.path.abspath(p)
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
print("Loading base:", BASE_MODEL)
|
| 70 |
+
|
| 71 |
+
ADAPTER_PATH = find_valid_adapter([
|
| 72 |
+
ADAPTER_PATH,
|
| 73 |
+
FALLBACK_ADAPTER_PATH,
|
| 74 |
+
FALLBACK_ADAPTER_PATH_2,
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
if ADAPTER_PATH is None:
|
| 78 |
+
raise RuntimeError("❌ No valid LoRA adapter found!")
|
| 79 |
+
|
| 80 |
+
print("Loading adapter:", ADAPTER_PATH)
|
| 81 |
+
|
| 82 |
+
# ======================================================
|
| 83 |
+
# TOKENIZER
|
| 84 |
+
# ======================================================
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 86 |
+
if tokenizer.pad_token is None:
|
| 87 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 88 |
+
|
| 89 |
+
# ======================================================
|
| 90 |
+
# LOAD PPO MODEL
|
| 91 |
+
# ======================================================
|
| 92 |
+
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 93 |
+
BASE_MODEL,
|
| 94 |
+
torch_dtype=torch.float32
|
| 95 |
+
).to(device)
|
| 96 |
+
|
| 97 |
+
# 🚀 RESUME: Load adapter dynamically and ensure it's trainable
|
| 98 |
+
model.pretrained_model = PeftModel.from_pretrained(
|
| 99 |
+
model.pretrained_model,
|
| 100 |
+
ADAPTER_PATH,
|
| 101 |
+
is_trainable=True
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# ======================================================
|
| 105 |
+
# LOAD REFERENCE MODEL (FROZEN)
|
| 106 |
+
# ======================================================
|
| 107 |
+
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 108 |
+
BASE_MODEL,
|
| 109 |
+
torch_dtype=torch.float32
|
| 110 |
+
).to(device)
|
| 111 |
+
|
| 112 |
+
ref_model.pretrained_model = PeftModel.from_pretrained(
|
| 113 |
+
ref_model.pretrained_model,
|
| 114 |
+
ADAPTER_PATH,
|
| 115 |
+
is_trainable=False
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
ref_model.eval()
|
| 119 |
+
for p in ref_model.parameters():
|
| 120 |
+
p.requires_grad = False
|
| 121 |
+
|
| 122 |
+
# ======================================================
|
| 123 |
+
# TRAINABLE PARAMS — ONLY LoRA + VALUE HEAD
|
| 124 |
+
# ======================================================
|
| 125 |
+
for name, p in model.named_parameters():
|
| 126 |
+
if "lora_" in name or "v_head" in name:
|
| 127 |
+
p.requires_grad = True
|
| 128 |
+
else:
|
| 129 |
+
p.requires_grad = False
|
| 130 |
+
|
| 131 |
+
model.train()
|
| 132 |
+
|
| 133 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 134 |
+
total = sum(p.numel() for p in model.parameters())
|
| 135 |
+
print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
|
| 136 |
+
|
| 137 |
+
model.config.use_cache = False
|
| 138 |
+
ref_model.config.use_cache = False
|
| 139 |
+
|
| 140 |
+
# ======================================================
|
| 141 |
+
# DATASET
|
| 142 |
+
# ======================================================
|
| 143 |
+
print("Loading Spider subset...")
|
| 144 |
+
random.seed(0)
|
| 145 |
+
|
| 146 |
+
TRAIN_DBS = [
|
| 147 |
+
# already trained
|
| 148 |
+
"flight_1","student_assessment","store_1","bike_1","book_2","chinook_1",
|
| 149 |
+
"academic","aircraft","car_1","cinema","club_1","csu_1",
|
| 150 |
+
|
| 151 |
+
# medium difficulty (NEW)
|
| 152 |
+
"college_1","college_2","company_1","company_employee",
|
| 153 |
+
"customer_complaints","department_store","employee_hire_evaluation",
|
| 154 |
+
"museum_visit","products_for_hire","restaurant_1",
|
| 155 |
+
"school_finance","shop_membership","small_bank_1",
|
| 156 |
+
"soccer_1","student_1","tvshow","voter_1","world_1"
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
dataset = load_dataset("spider", split="train")
|
| 160 |
+
dataset = dataset.filter(lambda x: x["db_id"] in TRAIN_DBS)
|
| 161 |
+
|
| 162 |
+
def valid_example(x):
|
| 163 |
+
return 5 <= len(x["question"].split()) <= 40
|
| 164 |
+
|
| 165 |
+
dataset = dataset.filter(valid_example)
|
| 166 |
+
print("Filtered dataset size:", len(dataset))
|
| 167 |
+
|
| 168 |
+
def sample_example():
|
| 169 |
+
return dataset[random.randrange(len(dataset))]
|
| 170 |
+
|
| 171 |
+
# ======================================================
|
| 172 |
+
# DB UTILITIES
|
| 173 |
+
# ======================================================
|
| 174 |
+
def get_db_path(db_id):
|
| 175 |
+
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 176 |
+
|
| 177 |
+
# 🚀 SPEED OPTIMIZATION: Cache schema so we don't spam disk IO
|
| 178 |
+
_SCHEMA_CACHE = {}
|
| 179 |
+
|
| 180 |
+
def get_db_schema_cached(db_path):
|
| 181 |
+
if db_path in _SCHEMA_CACHE:
|
| 182 |
+
return _SCHEMA_CACHE[db_path]
|
| 183 |
+
|
| 184 |
+
schema_text = ""
|
| 185 |
+
try:
|
| 186 |
+
conn = sqlite3.connect(db_path)
|
| 187 |
+
cursor = conn.cursor()
|
| 188 |
+
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 189 |
+
|
| 190 |
+
for table in tables:
|
| 191 |
+
table_name = table[0]
|
| 192 |
+
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
|
| 193 |
+
col_names = [col[1] for col in columns]
|
| 194 |
+
schema_text += f"{table_name}({', '.join(col_names)}) "
|
| 195 |
+
conn.close()
|
| 196 |
+
except:
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
_SCHEMA_CACHE[db_path] = schema_text
|
| 200 |
+
return schema_text
|
| 201 |
+
|
| 202 |
+
# ======================================================
|
| 203 |
+
# PROMPT
|
| 204 |
+
# ======================================================
|
| 205 |
+
def trim_schema(schema: str, max_chars: int = 1200) -> str:
|
| 206 |
+
if schema is None:
|
| 207 |
+
return ""
|
| 208 |
+
schema = str(schema)
|
| 209 |
+
if len(schema) <= max_chars:
|
| 210 |
+
return schema
|
| 211 |
+
return schema[:max_chars]
|
| 212 |
+
|
| 213 |
+
def build_prompt(question: str, schema: str, use_schema: bool) -> str:
|
| 214 |
+
if not use_schema:
|
| 215 |
+
return f"### Question:\n{question}\n### SQL:"
|
| 216 |
+
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
|
| 217 |
+
return f"### Database Schema:\n{schema}\n### Question:\n{question}\n### SQL:"
|
| 218 |
+
|
| 219 |
+
# ======================================================
|
| 220 |
+
# PPO CONFIG (STABLE POLICY LEARNING)
|
| 221 |
+
# ======================================================
|
| 222 |
+
ppo_config = PPOConfig(
|
| 223 |
+
learning_rate=5e-6,
|
| 224 |
+
batch_size=8,
|
| 225 |
+
mini_batch_size=2,
|
| 226 |
+
|
| 227 |
+
gradient_accumulation_steps=2,
|
| 228 |
+
ppo_epochs=1,
|
| 229 |
+
init_kl_coef=0.2,
|
| 230 |
+
target_kl=0.02,
|
| 231 |
+
adap_kl_ctrl=True,
|
| 232 |
+
cliprange=0.1,
|
| 233 |
+
cliprange_value=0.1,
|
| 234 |
+
whiten_rewards=False,
|
| 235 |
+
kl_penalty="kl",
|
| 236 |
+
max_grad_norm=0.5,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
trainer = PPOTrainer(
|
| 240 |
+
config=ppo_config,
|
| 241 |
+
model=model,
|
| 242 |
+
ref_model=ref_model,
|
| 243 |
+
tokenizer=tokenizer,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
model.device = torch.device(device)
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
# ======================================================
|
| 252 |
+
# GENERATION CONFIG
|
| 253 |
+
# ======================================================
|
| 254 |
+
# 🚀 SPEED OPTIMIZATION: generation limits and randomness bypass
|
| 255 |
+
generation_kwargs = dict(
|
| 256 |
+
max_new_tokens=MAX_OUTPUT_TOKENS,
|
| 257 |
+
do_sample=True, # TRL Requires do_sample=True
|
| 258 |
+
temperature=1.0, # Disabled randomness logic
|
| 259 |
+
top_p=1.0, # Disabled randomness logic
|
| 260 |
+
top_k=0, # Disabled randomness logic
|
| 261 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 262 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# ======================================================
|
| 266 |
+
# TRAIN LOOP (BATCHED & OPTIMIZED)
|
| 267 |
+
# ======================================================
|
| 268 |
+
print("Starting RL training 🚀 (CodeT5 PPO Stable)")
|
| 269 |
+
|
| 270 |
+
best_reward = -1e9
|
| 271 |
+
global_ppo_step = 0
|
| 272 |
+
model.train()
|
| 273 |
+
|
| 274 |
+
for epoch in range(1, NUM_EPOCHS + 1):
|
| 275 |
+
epoch_reward_sum = 0
|
| 276 |
+
valid_sql_count = 0
|
| 277 |
+
total_seen = 0
|
| 278 |
+
|
| 279 |
+
# Process in exact chunks matching batch_size to avoid buffer remnants
|
| 280 |
+
for step in range(0, ROLLOUTS_PER_EPOCH, ppo_config.batch_size):
|
| 281 |
+
|
| 282 |
+
batch_prompts = []
|
| 283 |
+
batch_meta = [] # Store tuple of (question, gold_sql, db_path, db_id)
|
| 284 |
+
|
| 285 |
+
# 🚀 BATCH PREPARATION
|
| 286 |
+
for _ in range(ppo_config.batch_size):
|
| 287 |
+
example = sample_example()
|
| 288 |
+
question = example["question"]
|
| 289 |
+
gold_sql = example["query"]
|
| 290 |
+
db_id = example["db_id"]
|
| 291 |
+
db_path = get_db_path(db_id)
|
| 292 |
+
|
| 293 |
+
schema = get_db_schema_cached(db_path)
|
| 294 |
+
prompt = build_prompt(question, schema, use_schema=True)
|
| 295 |
+
|
| 296 |
+
batch_prompts.append(prompt)
|
| 297 |
+
batch_meta.append((question, gold_sql, db_path, db_id))
|
| 298 |
+
|
| 299 |
+
# 🚀 SPEED OPTIMIZATION: Padded Batch Tokenization (Multiple of 8)
|
| 300 |
+
encoded_inputs = tokenizer(
|
| 301 |
+
batch_prompts,
|
| 302 |
+
return_tensors="pt",
|
| 303 |
+
padding=True,
|
| 304 |
+
truncation=True,
|
| 305 |
+
max_length=512,
|
| 306 |
+
pad_to_multiple_of=8
|
| 307 |
+
).to(device)
|
| 308 |
+
|
| 309 |
+
# TRL expects lists of 1D tensors
|
| 310 |
+
query_tensors = [encoded_inputs.input_ids[i] for i in range(ppo_config.batch_size)]
|
| 311 |
+
|
| 312 |
+
# 🚀 SPEED OPTIMIZATION: Disable gradients for generation pass
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
response_tensors = trainer.generate(
|
| 315 |
+
query_tensors,
|
| 316 |
+
**generation_kwargs
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
batch_rewards = []
|
| 320 |
+
batch_responses_text = []
|
| 321 |
+
|
| 322 |
+
# 🚀 BATCH SQL REWARD EXECUTION (Strictly CPU strings)
|
| 323 |
+
for i in range(ppo_config.batch_size):
|
| 324 |
+
response = tokenizer.decode(response_tensors[i], skip_special_tokens=True)
|
| 325 |
+
batch_responses_text.append(response)
|
| 326 |
+
question, gold_sql, db_path, db_id = batch_meta[i]
|
| 327 |
+
|
| 328 |
+
total_seen += 1
|
| 329 |
+
|
| 330 |
+
# ---------- BASIC SQL FILTER ----------
|
| 331 |
+
if "select" not in response.lower():
|
| 332 |
+
batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
# ---------- EXECUTION REWARD ----------
|
| 336 |
+
reward = execution_reward(response, db_path, gold_sql)
|
| 337 |
+
if reward is None:
|
| 338 |
+
batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
|
| 339 |
+
continue
|
| 340 |
+
|
| 341 |
+
reward = float(reward)
|
| 342 |
+
|
| 343 |
+
# ---------- TABLE BONUS ----------
|
| 344 |
+
pred_tables = extract_tables(response)
|
| 345 |
+
gold_tables = extract_tables(gold_sql)
|
| 346 |
+
if len(gold_tables) > 0:
|
| 347 |
+
reward += 0.25 * (len(pred_tables & gold_tables) / len(gold_tables))
|
| 348 |
+
|
| 349 |
+
# ---------- COLUMN BONUS ----------
|
| 350 |
+
pred_cols = extract_columns(response)
|
| 351 |
+
gold_cols = extract_columns(gold_sql)
|
| 352 |
+
if len(gold_cols) > 0:
|
| 353 |
+
reward += 0.15 * (len(pred_cols & gold_cols) / len(gold_cols))
|
| 354 |
+
|
| 355 |
+
# ---------- CLAMP ----------
|
| 356 |
+
reward = max(-1.0, min(1.0, reward))
|
| 357 |
+
batch_rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))
|
| 358 |
+
|
| 359 |
+
epoch_reward_sum += reward
|
| 360 |
+
valid_sql_count += 1
|
| 361 |
+
|
| 362 |
+
# ---------- PPO UPDATE ----------
|
| 363 |
+
try:
|
| 364 |
+
trainer.step(
|
| 365 |
+
query_tensors,
|
| 366 |
+
response_tensors,
|
| 367 |
+
batch_rewards
|
| 368 |
+
)
|
| 369 |
+
global_ppo_step += 1
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print("⚠️ PPO skipped:", e)
|
| 372 |
+
continue
|
| 373 |
+
|
| 374 |
+
# 🚀 AUTO CHECKPOINT SAVING: Every 200 PPO Updates
|
| 375 |
+
if global_ppo_step > 0 and global_ppo_step % 200 == 0:
|
| 376 |
+
step_save_path = os.path.join(PROJECT_ROOT, f"checkpoints/rl_step_{global_ppo_step}")
|
| 377 |
+
os.makedirs(step_save_path, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
# Saves ONLY the adapter, keeping disk usage tiny!
|
| 380 |
+
model.save_pretrained(step_save_path)
|
| 381 |
+
tokenizer.save_pretrained(step_save_path)
|
| 382 |
+
print(f"\n💾 [AUTO-SAVE] Checkpoint saved at PPO step {global_ppo_step} -> {step_save_path}")
|
| 383 |
+
|
| 384 |
+
# ---------- LOG ----------
|
| 385 |
+
if step % (LOG_EVERY * ppo_config.batch_size) == 0 and valid_sql_count > 0:
|
| 386 |
+
print("\n---------------------------")
|
| 387 |
+
print(f"Epoch {epoch}/{NUM_EPOCHS} Step {step}/{ROLLOUTS_PER_EPOCH} | Global Update {global_ppo_step}")
|
| 388 |
+
print("Avg Reward:", round(epoch_reward_sum/valid_sql_count,3))
|
| 389 |
+
print("Valid SQL:", valid_sql_count,"/",total_seen)
|
| 390 |
+
|
| 391 |
+
# Print sample from latest batch
|
| 392 |
+
sample_idx = random.randint(0, ppo_config.batch_size - 1)
|
| 393 |
+
print("DB:", batch_meta[sample_idx][3])
|
| 394 |
+
print("Q:", batch_meta[sample_idx][0])
|
| 395 |
+
print("SQL:", batch_responses_text[sample_idx])
|
| 396 |
+
print("Reward:", round(batch_rewards[sample_idx].item(), 3))
|
| 397 |
+
|
| 398 |
+
# ---------- SAVE BEST MODEL (INSIDE EPOCH) ----------
|
| 399 |
+
avg_reward = epoch_reward_sum / max(valid_sql_count, 1)
|
| 400 |
+
|
| 401 |
+
if avg_reward > best_reward:
|
| 402 |
+
best_reward = avg_reward
|
| 403 |
+
save_path = os.path.join(PROJECT_ROOT, "checkpoints/best_rlhf_model")
|
| 404 |
+
os.makedirs(save_path, exist_ok=True)
|
| 405 |
+
|
| 406 |
+
model.save_pretrained(save_path)
|
| 407 |
+
tokenizer.save_pretrained(save_path)
|
| 408 |
+
|
| 409 |
+
print(f"\n✅ Saved BEST RLHF model for Epoch {epoch} (reward {best_reward:.3f})")
|
src/train_rl_lora.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ======================================
|
| 2 |
+
# RLHF Text2SQL — FINAL WORKING VERSION
|
| 3 |
+
# T5-small + LoRA + PPO + Execution Reward
|
| 4 |
+
# Single-sample stable training (Mac MPS safe)
|
| 5 |
+
# ======================================
|
| 6 |
+
|
| 7 |
+
from execution_reward import execution_reward
|
| 8 |
+
import os, gc, json, random, torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 10 |
+
from trl import PPOTrainer, PPOConfig
|
| 11 |
+
from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead
|
| 12 |
+
from peft import LoraConfig, get_peft_model
|
| 13 |
+
|
| 14 |
+
# ---------------- SETTINGS ----------------
|
| 15 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 17 |
+
|
| 18 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 19 |
+
print("Using device:", device)
|
| 20 |
+
|
| 21 |
+
os.makedirs("rlhf_text2sql_lora", exist_ok=True)
|
| 22 |
+
|
| 23 |
+
# ---------------- MODEL ----------------
|
| 24 |
+
model_name = "google/flan-t5-small"
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 26 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 27 |
+
|
| 28 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 29 |
+
|
| 30 |
+
# LoRA
|
| 31 |
+
lora_config = LoraConfig(
|
| 32 |
+
r=8,
|
| 33 |
+
lora_alpha=16,
|
| 34 |
+
target_modules=["q","v"],
|
| 35 |
+
lora_dropout=0.05,
|
| 36 |
+
bias="none",
|
| 37 |
+
task_type="SEQ_2_SEQ_LM",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
base_model = get_peft_model(base_model, lora_config)
|
| 41 |
+
|
| 42 |
+
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device)
|
| 43 |
+
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device)
|
| 44 |
+
|
| 45 |
+
model.config.use_cache = False
|
| 46 |
+
ref_model.config.use_cache = False
|
| 47 |
+
|
| 48 |
+
# ---------------- DATA ----------------
|
| 49 |
+
with open("data/train_spider.json") as f:
|
| 50 |
+
dataset = json.load(f)
|
| 51 |
+
|
| 52 |
+
def build_prompt(example):
|
| 53 |
+
return f"Translate to SQL: {example['question']}"
|
| 54 |
+
|
| 55 |
+
# ---------------- PPO ----------------
|
| 56 |
+
ppo_config = PPOConfig(
|
| 57 |
+
batch_size=1,
|
| 58 |
+
mini_batch_size=1,
|
| 59 |
+
learning_rate=2e-6,
|
| 60 |
+
target_kl=0.05,
|
| 61 |
+
adap_kl_ctrl=True,
|
| 62 |
+
init_kl_coef=0.2,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
ppo_trainer = PPOTrainer(
|
| 66 |
+
config=ppo_config,
|
| 67 |
+
model=model,
|
| 68 |
+
ref_model=ref_model,
|
| 69 |
+
tokenizer=tokenizer,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# ---------------- GENERATION ----------------
|
| 73 |
+
def generate_sql(query_tensors):
|
| 74 |
+
|
| 75 |
+
# deterministic decoding = prevents NaN explosion
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
response_tensors = ppo_trainer.generate(
|
| 78 |
+
query_tensors,
|
| 79 |
+
max_new_tokens=64,
|
| 80 |
+
|
| 81 |
+
# 🔴 CRITICAL: disable sampling
|
| 82 |
+
do_sample=False,
|
| 83 |
+
|
| 84 |
+
# stable decoding
|
| 85 |
+
num_beams=1,
|
| 86 |
+
early_stopping=True,
|
| 87 |
+
|
| 88 |
+
# prevents invalid tokens
|
| 89 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# extra safety (important on MPS)
|
| 93 |
+
cleaned = []
|
| 94 |
+
for t in response_tensors:
|
| 95 |
+
t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0)
|
| 96 |
+
cleaned.append(t)
|
| 97 |
+
|
| 98 |
+
return cleaned
|
| 99 |
+
|
| 100 |
+
# ---------------- TRAIN ----------------
|
| 101 |
+
MAX_STEPS = 1200
|
| 102 |
+
|
| 103 |
+
for step in range(MAX_STEPS):
|
| 104 |
+
|
| 105 |
+
# pick random Spider example
|
| 106 |
+
example = random.choice(dataset)
|
| 107 |
+
|
| 108 |
+
question = example["question"]
|
| 109 |
+
gold_sql = example["query"]
|
| 110 |
+
db_id = example["db_id"]
|
| 111 |
+
db_path = f"data/database/{db_id}/{db_id}.sqlite"
|
| 112 |
+
|
| 113 |
+
# tokenize
|
| 114 |
+
enc = tokenizer(build_prompt(example), return_tensors="pt")
|
| 115 |
+
query_tensor = enc.input_ids.to(device)
|
| 116 |
+
query_tensors = [query_tensor[0]]
|
| 117 |
+
|
| 118 |
+
# generate SQL
|
| 119 |
+
response_tensors = generate_sql(query_tensors)
|
| 120 |
+
pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True)
|
| 121 |
+
|
| 122 |
+
# -------- EXECUTION REWARD --------
|
| 123 |
+
reward = execution_reward(pred_sql, gold_sql, db_path)
|
| 124 |
+
reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device)
|
| 125 |
+
|
| 126 |
+
# PPO update
|
| 127 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor])
|
| 128 |
+
|
| 129 |
+
# stabilize
|
| 130 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 131 |
+
|
| 132 |
+
# cleanup
|
| 133 |
+
del query_tensor, response_tensors, reward_tensor
|
| 134 |
+
gc.collect()
|
| 135 |
+
if device == "mps":
|
| 136 |
+
torch.mps.empty_cache()
|
| 137 |
+
|
| 138 |
+
# log
|
| 139 |
+
if step % 20 == 0:
|
| 140 |
+
print(f"\nStep {step}/{MAX_STEPS}")
|
| 141 |
+
print("DB:", db_id)
|
| 142 |
+
print("Q:", question)
|
| 143 |
+
print("Pred:", pred_sql)
|
| 144 |
+
print("Gold:", gold_sql)
|
| 145 |
+
print("Reward:", reward)
|
| 146 |
+
|
| 147 |
+
# ---------------- SAVE ----------------
|
| 148 |
+
model.save_pretrained("rlhf_text2sql_lora")
|
| 149 |
+
tokenizer.save_pretrained("rlhf_text2sql_lora")
|
| 150 |
+
|
| 151 |
+
print("\nTraining complete — model saved!")
|
src/train_sft.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from peft import LoraConfig, get_peft_model
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForSeq2SeqLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
DataCollatorForSeq2Seq,
|
| 11 |
+
Seq2SeqTrainer,
|
| 12 |
+
Seq2SeqTrainingArguments,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from prompting import clean_gold_sql, get_schema_text, build_prompt
|
| 16 |
+
|
| 17 |
+
# =====================================================
|
| 18 |
+
# SETTINGS
|
| 19 |
+
# =====================================================
|
| 20 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
|
| 21 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
|
| 23 |
+
# 🎯 FIXED: Save final model to checkpoints/sft_t5 to protect existing models
|
| 24 |
+
OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_t5")
|
| 25 |
+
|
| 26 |
+
TRAIN_SPLIT = "train[:7000]"
|
| 27 |
+
EPOCHS = 8
|
| 28 |
+
LR = 3e-4
|
| 29 |
+
PER_DEVICE_BATCH = 4
|
| 30 |
+
GRAD_ACCUM = 2
|
| 31 |
+
|
| 32 |
+
MAX_INPUT = 512
|
| 33 |
+
MAX_OUTPUT = 128
|
| 34 |
+
|
| 35 |
+
# =====================================================
|
| 36 |
+
# DEVICE
|
| 37 |
+
# =====================================================
|
| 38 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 40 |
+
print("Using device:", device)
|
| 41 |
+
|
| 42 |
+
# =====================================================
|
| 43 |
+
# TOKENIZER
|
| 44 |
+
# =====================================================
|
| 45 |
+
print("Loading tokenizer/model:", BASE_MODEL)
|
| 46 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 47 |
+
|
| 48 |
+
if tokenizer.pad_token_id is None:
|
| 49 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 50 |
+
|
| 51 |
+
# =====================================================
|
| 52 |
+
# PREPROCESS FUNCTION (CRITICAL FIXED VERSION)
|
| 53 |
+
# =====================================================
|
| 54 |
+
def preprocess_function(example):
|
| 55 |
+
|
| 56 |
+
question = example["question"]
|
| 57 |
+
db_id = example["db_id"]
|
| 58 |
+
gold_sql = clean_gold_sql(example["query"])
|
| 59 |
+
|
| 60 |
+
# ---- Build Prompt ----
|
| 61 |
+
schema_text = get_schema_text(db_id)
|
| 62 |
+
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
|
| 63 |
+
|
| 64 |
+
model_inputs = tokenizer(
|
| 65 |
+
prompt,
|
| 66 |
+
max_length=MAX_INPUT,
|
| 67 |
+
truncation=True,
|
| 68 |
+
padding="max_length",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# ---- Target SQL ----
|
| 72 |
+
labels = tokenizer(
|
| 73 |
+
gold_sql,
|
| 74 |
+
max_length=MAX_OUTPUT,
|
| 75 |
+
truncation=True,
|
| 76 |
+
padding="max_length",
|
| 77 |
+
)["input_ids"]
|
| 78 |
+
|
| 79 |
+
# IMPORTANT: ignore padding in loss
|
| 80 |
+
labels = [
|
| 81 |
+
(tok if tok != tokenizer.pad_token_id else -100)
|
| 82 |
+
for tok in labels
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
model_inputs["labels"] = labels
|
| 86 |
+
return model_inputs
|
| 87 |
+
|
| 88 |
+
# =====================================================
|
| 89 |
+
# DATASET
|
| 90 |
+
# =====================================================
|
| 91 |
+
print("Loading Spider subset:", TRAIN_SPLIT)
|
| 92 |
+
dataset = load_dataset("spider", split=TRAIN_SPLIT)
|
| 93 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 94 |
+
|
| 95 |
+
train_ds = dataset["train"]
|
| 96 |
+
eval_ds = dataset["test"]
|
| 97 |
+
|
| 98 |
+
print("Tokenizing dataset (single process, stable)...")
|
| 99 |
+
|
| 100 |
+
train_tok = train_ds.map(
|
| 101 |
+
preprocess_function,
|
| 102 |
+
batched=False,
|
| 103 |
+
num_proc=1, # 🔥 VERY IMPORTANT FIX
|
| 104 |
+
remove_columns=train_ds.column_names,
|
| 105 |
+
load_from_cache_file=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
eval_tok = eval_ds.map(
|
| 109 |
+
preprocess_function,
|
| 110 |
+
batched=False,
|
| 111 |
+
num_proc=1,
|
| 112 |
+
remove_columns=eval_ds.column_names,
|
| 113 |
+
load_from_cache_file=False,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
print("Train dataset size:", len(train_tok))
|
| 117 |
+
print("Eval dataset size:", len(eval_tok))
|
| 118 |
+
|
| 119 |
+
# =====================================================
|
| 120 |
+
# MODEL + LoRA
|
| 121 |
+
# =====================================================
|
| 122 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
|
| 123 |
+
|
| 124 |
+
base_model.config.use_cache = False
|
| 125 |
+
base_model.gradient_checkpointing_enable()
|
| 126 |
+
|
| 127 |
+
lora_config = LoraConfig(
|
| 128 |
+
r=8,
|
| 129 |
+
lora_alpha=16,
|
| 130 |
+
lora_dropout=0.1,
|
| 131 |
+
bias="none",
|
| 132 |
+
task_type="SEQ_2_SEQ_LM",
|
| 133 |
+
target_modules=["q", "v"], # correct for T5
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
model = get_peft_model(base_model, lora_config)
|
| 137 |
+
model.to(device)
|
| 138 |
+
|
| 139 |
+
# =====================================================
|
| 140 |
+
# TRAINER
|
| 141 |
+
# =====================================================
|
| 142 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 143 |
+
tokenizer=tokenizer,
|
| 144 |
+
model=model,
|
| 145 |
+
padding=True,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
args = Seq2SeqTrainingArguments(
|
| 149 |
+
# 🎯 FIXED: Changed path to prevent mixing logs with your old CodeT5 logs
|
| 150 |
+
output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_t5_runs"),
|
| 151 |
+
num_train_epochs=EPOCHS,
|
| 152 |
+
learning_rate=LR,
|
| 153 |
+
per_device_train_batch_size=PER_DEVICE_BATCH,
|
| 154 |
+
per_device_eval_batch_size=PER_DEVICE_BATCH,
|
| 155 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 156 |
+
dataloader_num_workers=0,
|
| 157 |
+
dataloader_pin_memory=False,
|
| 158 |
+
evaluation_strategy="epoch",
|
| 159 |
+
|
| 160 |
+
# 🎯 FIXED: "no" completely stops intermediate saving! Only the final model will be saved.
|
| 161 |
+
save_strategy="no",
|
| 162 |
+
|
| 163 |
+
logging_steps=50,
|
| 164 |
+
report_to=[],
|
| 165 |
+
fp16=False,
|
| 166 |
+
bf16=False,
|
| 167 |
+
predict_with_generate=True,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
trainer = Seq2SeqTrainer(
|
| 171 |
+
model=model,
|
| 172 |
+
args=args,
|
| 173 |
+
train_dataset=train_tok,
|
| 174 |
+
eval_dataset=eval_tok,
|
| 175 |
+
tokenizer=tokenizer,
|
| 176 |
+
data_collator=data_collator,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# =====================================================
|
| 180 |
+
# TRAIN
|
| 181 |
+
# =====================================================
|
| 182 |
+
trainer.train()
|
| 183 |
+
|
| 184 |
+
# =====================================================
|
| 185 |
+
# SAVE
|
| 186 |
+
# =====================================================
|
| 187 |
+
print("Saving LoRA adapter to:", OUT_DIR)
|
| 188 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 189 |
+
model.save_pretrained(OUT_DIR)
|
| 190 |
+
tokenizer.save_pretrained(OUT_DIR)
|
| 191 |
+
|
| 192 |
+
print("DONE ✔ SFT warmup finished")
|
src/train_sft_bart.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from peft import LoraConfig, get_peft_model
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForSeq2SeqLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
DataCollatorForSeq2Seq,
|
| 11 |
+
Seq2SeqTrainer,
|
| 12 |
+
Seq2SeqTrainingArguments,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from prompting import clean_gold_sql, get_schema_text, build_prompt
|
| 16 |
+
|
| 17 |
+
# =====================================================
|
| 18 |
+
# SETTINGS
|
| 19 |
+
# =====================================================
|
| 20 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base")
|
| 21 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
|
| 23 |
+
OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_best_bart_2")
|
| 24 |
+
|
| 25 |
+
TRAIN_SPLIT = "train[:7000]"
|
| 26 |
+
|
| 27 |
+
EPOCHS = 12
|
| 28 |
+
LR = 3e-4
|
| 29 |
+
PER_DEVICE_BATCH = 16
|
| 30 |
+
GRAD_ACCUM = 4
|
| 31 |
+
|
| 32 |
+
MAX_INPUT = 512
|
| 33 |
+
MAX_OUTPUT = 128
|
| 34 |
+
|
| 35 |
+
# =====================================================
|
| 36 |
+
# DEVICE
|
| 37 |
+
# =====================================================
|
| 38 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 40 |
+
print("Using device:", device)
|
| 41 |
+
|
| 42 |
+
# =====================================================
|
| 43 |
+
# TOKENIZER
|
| 44 |
+
# =====================================================
|
| 45 |
+
print("Loading tokenizer/model:", BASE_MODEL)
|
| 46 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 47 |
+
|
| 48 |
+
if tokenizer.pad_token_id is None:
|
| 49 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 50 |
+
|
| 51 |
+
# =====================================================
|
| 52 |
+
# PREPROCESS FUNCTION
|
| 53 |
+
# =====================================================
|
| 54 |
+
def preprocess_function(example):
|
| 55 |
+
|
| 56 |
+
question = example["question"]
|
| 57 |
+
db_id = example["db_id"]
|
| 58 |
+
gold_sql = clean_gold_sql(example["query"])
|
| 59 |
+
|
| 60 |
+
# ---- Build Prompt ----
|
| 61 |
+
schema_text = get_schema_text(db_id)
|
| 62 |
+
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
|
| 63 |
+
|
| 64 |
+
model_inputs = tokenizer(
|
| 65 |
+
prompt,
|
| 66 |
+
max_length=MAX_INPUT,
|
| 67 |
+
truncation=True,
|
| 68 |
+
padding="max_length",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# ---- Target SQL ----
|
| 72 |
+
labels = tokenizer(
|
| 73 |
+
gold_sql,
|
| 74 |
+
max_length=MAX_OUTPUT,
|
| 75 |
+
truncation=True,
|
| 76 |
+
padding="max_length",
|
| 77 |
+
)["input_ids"]
|
| 78 |
+
|
| 79 |
+
# IMPORTANT: ignore padding in loss
|
| 80 |
+
labels = [
|
| 81 |
+
(tok if tok != tokenizer.pad_token_id else -400)
|
| 82 |
+
for tok in labels
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
model_inputs["labels"] = labels
|
| 86 |
+
return model_inputs
|
| 87 |
+
|
| 88 |
+
# =====================================================
|
| 89 |
+
# DATASET
|
| 90 |
+
# =====================================================
|
| 91 |
+
print("Loading Spider subset:", TRAIN_SPLIT)
|
| 92 |
+
dataset = load_dataset("spider", split=TRAIN_SPLIT)
|
| 93 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 94 |
+
|
| 95 |
+
train_ds = dataset["train"]
|
| 96 |
+
eval_ds = dataset["test"]
|
| 97 |
+
|
| 98 |
+
print("Tokenizing dataset (single process, stable)...")
|
| 99 |
+
|
| 100 |
+
train_tok = train_ds.map(
|
| 101 |
+
preprocess_function,
|
| 102 |
+
batched=False,
|
| 103 |
+
num_proc=1,
|
| 104 |
+
remove_columns=train_ds.column_names,
|
| 105 |
+
load_from_cache_file=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
eval_tok = eval_ds.map(
|
| 109 |
+
preprocess_function,
|
| 110 |
+
batched=False,
|
| 111 |
+
num_proc=1,
|
| 112 |
+
remove_columns=eval_ds.column_names,
|
| 113 |
+
load_from_cache_file=False,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
print("Train dataset size:", len(train_tok))
|
| 117 |
+
print("Eval dataset size:", len(eval_tok))
|
| 118 |
+
|
| 119 |
+
# =====================================================
|
| 120 |
+
# MODEL + LoRA
|
| 121 |
+
# =====================================================
|
| 122 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
|
| 123 |
+
|
| 124 |
+
base_model.config.use_cache = False
|
| 125 |
+
|
| 126 |
+
# 🚀 UPGRADE 1: Expanded LoRA brainpower
|
| 127 |
+
lora_config = LoraConfig(
|
| 128 |
+
r=16, # Increased rank for more learning capacity
|
| 129 |
+
lora_alpha=32, # Alpha is typically 2x the rank
|
| 130 |
+
lora_dropout=0.1,
|
| 131 |
+
bias="none",
|
| 132 |
+
task_type="SEQ_2_SEQ_LM",
|
| 133 |
+
# Target all attention and dense layers in BART
|
| 134 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
model = get_peft_model(base_model, lora_config)
|
| 138 |
+
model.to(device)
|
| 139 |
+
|
| 140 |
+
# =====================================================
|
| 141 |
+
# TRAINER
|
| 142 |
+
# =====================================================
|
| 143 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 144 |
+
tokenizer=tokenizer,
|
| 145 |
+
model=model,
|
| 146 |
+
padding=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
args = Seq2SeqTrainingArguments(
|
| 150 |
+
output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_bart_runs"),
|
| 151 |
+
num_train_epochs=EPOCHS,
|
| 152 |
+
learning_rate=LR,
|
| 153 |
+
per_device_train_batch_size=PER_DEVICE_BATCH,
|
| 154 |
+
per_device_eval_batch_size=PER_DEVICE_BATCH,
|
| 155 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 156 |
+
dataloader_num_workers=0,
|
| 157 |
+
dataloader_pin_memory=False,
|
| 158 |
+
|
| 159 |
+
# 🚀 UPGRADE 2 & 3: Better optimization & generalization
|
| 160 |
+
warmup_ratio=0.05, # Slowly ramp up learning rate
|
| 161 |
+
weight_decay=0.01, # Penalize over-reliance on single tokens
|
| 162 |
+
label_smoothing_factor=0.1, # Prevent overconfidence in SQL token matching
|
| 163 |
+
|
| 164 |
+
evaluation_strategy="epoch",
|
| 165 |
+
save_strategy="epoch",
|
| 166 |
+
|
| 167 |
+
save_total_limit=1,
|
| 168 |
+
load_best_model_at_end=True,
|
| 169 |
+
metric_for_best_model="eval_loss",
|
| 170 |
+
greater_is_better=False,
|
| 171 |
+
|
| 172 |
+
logging_steps=50,
|
| 173 |
+
report_to=[],
|
| 174 |
+
fp16=False,
|
| 175 |
+
bf16=False,
|
| 176 |
+
predict_with_generate=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
trainer = Seq2SeqTrainer(
|
| 180 |
+
model=model,
|
| 181 |
+
args=args,
|
| 182 |
+
train_dataset=train_tok,
|
| 183 |
+
eval_dataset=eval_tok,
|
| 184 |
+
tokenizer=tokenizer,
|
| 185 |
+
data_collator=data_collator,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# =====================================================
|
| 189 |
+
# TRAIN
|
| 190 |
+
# =====================================================
|
| 191 |
+
trainer.train()
|
| 192 |
+
|
| 193 |
+
# =====================================================
|
| 194 |
+
# SAVE BEST MODEL
|
| 195 |
+
# =====================================================
|
| 196 |
+
print("Saving best BART LoRA adapter to:", OUT_DIR)
|
| 197 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
trainer.model.save_pretrained(OUT_DIR)
|
| 200 |
+
tokenizer.save_pretrained(OUT_DIR)
|
| 201 |
+
|
| 202 |
+
print("DONE ✔ SFT BART finished")
|
src/train_sft_codet5.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from peft import LoraConfig, get_peft_model
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForSeq2SeqLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
DataCollatorForSeq2Seq,
|
| 11 |
+
Seq2SeqTrainer,
|
| 12 |
+
Seq2SeqTrainingArguments,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from prompting import clean_gold_sql, get_schema_text, build_prompt
|
| 16 |
+
|
| 17 |
+
# =====================================================
|
| 18 |
+
# SETTINGS
|
| 19 |
+
# =====================================================
|
| 20 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 21 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter_codet5")
|
| 23 |
+
|
| 24 |
+
TRAIN_SPLIT = "train[:7000]"
|
| 25 |
+
EPOCHS = 10
|
| 26 |
+
LR = 2e-4
|
| 27 |
+
PER_DEVICE_BATCH = 2 # codet5 bigger -> reduce
|
| 28 |
+
GRAD_ACCUM = 4
|
| 29 |
+
|
| 30 |
+
MAX_INPUT = 512
|
| 31 |
+
MAX_OUTPUT = 160
|
| 32 |
+
|
| 33 |
+
# =====================================================
|
| 34 |
+
# DEVICE
|
| 35 |
+
# =====================================================
|
| 36 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 37 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 38 |
+
print("Using device:", device)
|
| 39 |
+
|
| 40 |
+
# =====================================================
|
| 41 |
+
# TOKENIZER
|
| 42 |
+
# =====================================================
|
| 43 |
+
print("Loading tokenizer/model:", BASE_MODEL)
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 45 |
+
|
| 46 |
+
if tokenizer.pad_token is None:
|
| 47 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 48 |
+
|
| 49 |
+
# =====================================================
|
| 50 |
+
# PREPROCESS FUNCTION
|
| 51 |
+
# =====================================================
|
| 52 |
+
def preprocess_function(example):
|
| 53 |
+
|
| 54 |
+
question = example["question"]
|
| 55 |
+
db_id = example["db_id"]
|
| 56 |
+
gold_sql = clean_gold_sql(example["query"])
|
| 57 |
+
|
| 58 |
+
schema_text = get_schema_text(db_id)
|
| 59 |
+
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
|
| 60 |
+
|
| 61 |
+
model_inputs = tokenizer(
|
| 62 |
+
prompt,
|
| 63 |
+
max_length=MAX_INPUT,
|
| 64 |
+
truncation=True,
|
| 65 |
+
padding="max_length",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
labels = tokenizer(
|
| 69 |
+
gold_sql,
|
| 70 |
+
max_length=MAX_OUTPUT,
|
| 71 |
+
truncation=True,
|
| 72 |
+
padding="max_length",
|
| 73 |
+
)["input_ids"]
|
| 74 |
+
|
| 75 |
+
labels = [(tok if tok != tokenizer.pad_token_id else -100) for tok in labels]
|
| 76 |
+
|
| 77 |
+
model_inputs["labels"] = labels
|
| 78 |
+
return model_inputs
|
| 79 |
+
|
| 80 |
+
# =====================================================
|
| 81 |
+
# DATASET
|
| 82 |
+
# =====================================================
|
| 83 |
+
print("Loading Spider subset:", TRAIN_SPLIT)
|
| 84 |
+
dataset = load_dataset("spider", split=TRAIN_SPLIT)
|
| 85 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 86 |
+
|
| 87 |
+
train_ds = dataset["train"]
|
| 88 |
+
eval_ds = dataset["test"]
|
| 89 |
+
|
| 90 |
+
print("Tokenizing dataset...")
|
| 91 |
+
|
| 92 |
+
train_tok = train_ds.map(
|
| 93 |
+
preprocess_function,
|
| 94 |
+
batched=False,
|
| 95 |
+
num_proc=1,
|
| 96 |
+
remove_columns=train_ds.column_names,
|
| 97 |
+
load_from_cache_file=False,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
eval_tok = eval_ds.map(
|
| 101 |
+
preprocess_function,
|
| 102 |
+
batched=False,
|
| 103 |
+
num_proc=1,
|
| 104 |
+
remove_columns=eval_ds.column_names,
|
| 105 |
+
load_from_cache_file=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
print("Train dataset size:", len(train_tok))
|
| 109 |
+
print("Eval dataset size:", len(eval_tok))
|
| 110 |
+
|
| 111 |
+
# =====================================================
|
| 112 |
+
# MODEL + LoRA (CODET5 FIXED)
|
| 113 |
+
# =====================================================
|
| 114 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
|
| 115 |
+
|
| 116 |
+
base_model.config.use_cache = False
|
| 117 |
+
base_model.gradient_checkpointing_enable()
|
| 118 |
+
|
| 119 |
+
# 🔥 DIFFERENT FROM T5
|
| 120 |
+
lora_config = LoraConfig(
|
| 121 |
+
r=16,
|
| 122 |
+
lora_alpha=32,
|
| 123 |
+
lora_dropout=0.05,
|
| 124 |
+
bias="none",
|
| 125 |
+
task_type="SEQ_2_SEQ_LM",
|
| 126 |
+
target_modules=["q", "v"], # IMPORTANT FOR CODET5
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
model = get_peft_model(base_model, lora_config)
|
| 130 |
+
model.to(device)
|
| 131 |
+
|
| 132 |
+
model.print_trainable_parameters()
|
| 133 |
+
|
| 134 |
+
# =====================================================
|
| 135 |
+
# TRAINER
|
| 136 |
+
# =====================================================
|
| 137 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 138 |
+
tokenizer=tokenizer,
|
| 139 |
+
model=model,
|
| 140 |
+
padding=True,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
args = Seq2SeqTrainingArguments(
|
| 144 |
+
output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_runs_codet5"),
|
| 145 |
+
num_train_epochs=EPOCHS,
|
| 146 |
+
learning_rate=LR,
|
| 147 |
+
per_device_train_batch_size=PER_DEVICE_BATCH,
|
| 148 |
+
per_device_eval_batch_size=PER_DEVICE_BATCH,
|
| 149 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 150 |
+
dataloader_num_workers=0,
|
| 151 |
+
dataloader_pin_memory=False,
|
| 152 |
+
evaluation_strategy="epoch",
|
| 153 |
+
save_strategy="epoch",
|
| 154 |
+
save_total_limit=1,
|
| 155 |
+
logging_steps=50,
|
| 156 |
+
report_to=[],
|
| 157 |
+
fp16=False,
|
| 158 |
+
bf16=False,
|
| 159 |
+
predict_with_generate=True,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
trainer = Seq2SeqTrainer(
|
| 163 |
+
model=model,
|
| 164 |
+
args=args,
|
| 165 |
+
train_dataset=train_tok,
|
| 166 |
+
eval_dataset=eval_tok,
|
| 167 |
+
tokenizer=tokenizer,
|
| 168 |
+
data_collator=data_collator,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# =====================================================
|
| 172 |
+
# TRAIN
|
| 173 |
+
# =====================================================
|
| 174 |
+
trainer.train()
|
| 175 |
+
|
| 176 |
+
# =====================================================
|
| 177 |
+
# SAVE
|
| 178 |
+
# =====================================================
|
| 179 |
+
# =====================================================
|
| 180 |
+
# SAVE (SAFE PEFT SAVE)
|
| 181 |
+
# =====================================================
|
| 182 |
+
print("Saving LoRA adapter to:", OUT_DIR)
|
| 183 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
# unwrap trainer model (important!)
|
| 186 |
+
peft_model = trainer.model
|
| 187 |
+
|
| 188 |
+
# ensure on cpu before saving (mac mps bug fix)
|
| 189 |
+
peft_model = peft_model.to("cpu")
|
| 190 |
+
|
| 191 |
+
# save adapter only
|
| 192 |
+
peft_model.save_pretrained(OUT_DIR)
|
| 193 |
+
tokenizer.save_pretrained(OUT_DIR)
|
| 194 |
+
|
| 195 |
+
print("DONE ✔ CodeT5 SFT finished")
|