tjhalanigrid commited on
Commit
dc59b01
·
1 Parent(s): 16941e7

Add src folder

Browse files
src/__pycache__/schema_encoder.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
src/__pycache__/sql_validator.cpython-310.pyc ADDED
Binary file (3.25 kB). View file
 
src/__pycache__/text2sql_engine.cpython-310.pyc ADDED
Binary file (8.38 kB). View file
 
src/ask.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TERMINAL CHAT WITH DATABASE
3
+ Run:
4
+ python src/ask.py chinook_1
5
+ """
6
+
7
+ import sys
8
+ from text2sql_engine import get_engine
9
+
10
+
11
+ # -------------------------------
12
+ # Pretty table printer
13
+ # -------------------------------
14
+ def print_table(cols, rows, limit=20):
15
+ if not rows or not cols:
16
+ print("No results\n")
17
+ return
18
+
19
+ cols = [str(c) for c in cols]
20
+
21
+ widths = [max(len(c), 12) for c in cols]
22
+
23
+ for r in rows[:limit]:
24
+ for i, val in enumerate(r):
25
+ widths[i] = max(widths[i], len(str(val)))
26
+
27
+ header = " | ".join(cols[i].ljust(widths[i]) for i in range(len(cols)))
28
+ print("\n" + header)
29
+ print("-" * len(header))
30
+
31
+ for r in rows[:limit]:
32
+ print(" | ".join(str(r[i]).ljust(widths[i]) for i in range(len(cols))))
33
+
34
+ if len(rows) > limit:
35
+ print(f"\n... showing first {limit} rows of {len(rows)}")
36
+
37
+ print()
38
+
39
+
40
+ # -------------------------------
41
+ # Main loop
42
+ # -------------------------------
43
+ def main():
44
+ if len(sys.argv) < 2:
45
+ print("Usage: python src/ask.py <db_id>")
46
+ return
47
+
48
+ db_id = sys.argv[1].strip()
49
+
50
+ print("Loading model... (first time takes 20-40s)")
51
+ engine = get_engine()
52
+
53
+ print(f"\nConnected to database: {db_id}")
54
+ print("Type 'exit' to quit\n")
55
+
56
+ while True:
57
+ try:
58
+ q = input("Ask> ").strip()
59
+
60
+ if not q:
61
+ continue
62
+
63
+ if q.lower() in ["exit", "quit"]:
64
+ break
65
+
66
+ result = engine.ask(q, db_id)
67
+
68
+ if result is None:
69
+ print("Model returned no output\n")
70
+ continue
71
+
72
+ print("\nGenerated SQL:")
73
+ print(result.get("sql", "<no sql>"))
74
+
75
+ if result.get("error"):
76
+ print("\nSQL Error:")
77
+ print(result["error"])
78
+ else:
79
+ print_table(
80
+ result.get("columns", []),
81
+ result.get("rows", []),
82
+ )
83
+
84
+ except KeyboardInterrupt:
85
+ break
86
+ except Exception as e:
87
+ print("\nRuntime error:", e, "\n")
88
+
89
+ print("\nBye!")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
src/component_analysis.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import torch
4
+ import re
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
13
+
14
+ # -------------------------------
15
+ # Extract SQL components
16
+ # -------------------------------
17
+ def extract_components(sql):
18
+ sql = sql.lower()
19
+ return {
20
+ "select": "select" in sql,
21
+ "where": "where" in sql,
22
+ "group": "group by" in sql,
23
+ "order": "order by" in sql,
24
+ "and_or": (" and " in sql) or (" or " in sql),
25
+ "join": "join" in sql
26
+ }
27
+
28
+ # -------------------------------
29
+ # Fallback Difficulty Estimator
30
+ # -------------------------------
31
+ def estimate_difficulty(sql):
32
+ """Fallback if 'difficulty' is missing from the JSON."""
33
+ sql = sql.lower()
34
+ joins = sql.count("join")
35
+ conditions = sql.count("and") + sql.count("or")
36
+
37
+ if "intersect" in sql or "except" in sql or "union" in sql or joins > 2:
38
+ return "extra"
39
+ elif joins == 2 or ("group by" in sql and conditions > 0):
40
+ return "hard"
41
+ elif joins == 1 or "group by" in sql or "order by" in sql:
42
+ return "medium"
43
+ else:
44
+ return "easy"
45
+
46
+ # -------------------------------
47
+ # Load schema
48
+ # -------------------------------
49
+ def load_schema(db_path):
50
+ conn = sqlite3.connect(db_path)
51
+ conn.text_factory = lambda b: b.decode(errors='ignore')
52
+ cursor = conn.cursor()
53
+
54
+ tables = cursor.execute(
55
+ "SELECT name FROM sqlite_master WHERE type='table';"
56
+ ).fetchall()
57
+
58
+ schema = ""
59
+ for (table,) in tables:
60
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
61
+ col_names = [c[1] for c in cols]
62
+ schema += f"{table}({', '.join(col_names)})\n"
63
+
64
+ conn.close()
65
+ return schema
66
+
67
+ # -------------------------------
68
+ # Prompt
69
+ # -------------------------------
70
+ def build_prompt(question, schema):
71
+ return f"""Database Schema:
72
+ {schema}
73
+
74
+ Translate English to SQL:
75
+ {question}
76
+ SQL:
77
+ """
78
+
79
+ # -------------------------------
80
+ # Main
81
+ # -------------------------------
82
+ def main():
83
+ adapter = "checkpoints/rl_step_1800"
84
+ base_model = "Salesforce/codet5-base"
85
+
86
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
87
+
88
+ print("Loading tokenizer and models...")
89
+ tokenizer = AutoTokenizer.from_pretrained(adapter)
90
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
91
+ model = PeftModel.from_pretrained(base, adapter).to(device)
92
+ model = model.merge_and_unload()
93
+ model.eval()
94
+
95
+ dev_json = PROJECT_ROOT / "data" / "dev.json"
96
+
97
+ with open(dev_json) as f:
98
+ dev = json.load(f)[:1000] # Adjust number to test more/less
99
+
100
+ components_list = ["select", "where", "group", "order", "and_or", "join"]
101
+ difficulties_list = ["easy", "medium", "hard", "extra"]
102
+
103
+ # Nested dictionary for components
104
+ stats = {
105
+ comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list}
106
+ for comp in components_list
107
+ }
108
+
109
+ # 🚀 NEW: Trackers for OVERALL accuracy by difficulty
110
+ overall_correct = {diff: 0 for diff in difficulties_list}
111
+ overall_total = {diff: 0 for diff in difficulties_list}
112
+
113
+ print(f"\nRunning grouped evaluation on {len(dev)} examples...\n")
114
+
115
+ for i, ex in enumerate(dev, 1):
116
+ question = ex["question"]
117
+ gold_sql = ex["query"]
118
+ db_id = ex["db_id"]
119
+
120
+ # Determine difficulty
121
+ difficulty = ex.get("difficulty", estimate_difficulty(gold_sql))
122
+ if difficulty not in difficulties_list:
123
+ difficulty = "medium"
124
+
125
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
126
+ schema = load_schema(db_path)
127
+ prompt = build_prompt(question, schema)
128
+
129
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
130
+
131
+ with torch.no_grad():
132
+ outputs = model.generate(
133
+ **inputs,
134
+ max_new_tokens=1000,
135
+ num_beams=4,
136
+ do_sample=False
137
+ )
138
+
139
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+ if "SQL:" in pred_sql:
141
+ pred_sql = pred_sql.split("SQL:")[-1]
142
+
143
+ # --- 1. Update Overall Accuracy Trackers ---
144
+ overall_total[difficulty] += 1
145
+ # Simple string match for quick overall accuracy
146
+ if pred_sql.strip().lower() == gold_sql.strip().lower():
147
+ overall_correct[difficulty] += 1
148
+
149
+ # --- 2. Update Component Stats ---
150
+ pred_comp = extract_components(pred_sql)
151
+ gold_comp = extract_components(gold_sql)
152
+
153
+ for comp in components_list:
154
+ if gold_comp[comp]: # If the gold SQL required this component
155
+ stats[comp][difficulty]["total"] += 1
156
+ if pred_comp[comp]: # If the model successfully generated it
157
+ stats[comp][difficulty]["correct"] += 1
158
+
159
+ if i % 20 == 0:
160
+ print(f"Processed {i}/{len(dev)}")
161
+
162
+ # -------------------------------
163
+ # Plotting (Grouped Bar Chart)
164
+ # -------------------------------
165
+ x = np.arange(len(components_list))
166
+ width = 0.2
167
+
168
+ def get_acc(diff):
169
+ return [
170
+ (stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0
171
+ for comp in components_list
172
+ ]
173
+
174
+ acc_easy = get_acc("easy")
175
+ acc_medium = get_acc("medium")
176
+ acc_hard = get_acc("hard")
177
+ acc_extra = get_acc("extra")
178
+
179
+ fig, ax = plt.subplots(figsize=(14, 7))
180
+
181
+ bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71')
182
+ bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f')
183
+ bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22')
184
+ bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c')
185
+
186
+ ax.set_ylabel('Accuracy (%)', fontsize=12)
187
+ ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold')
188
+ ax.set_xticks(x)
189
+ ax.set_xticklabels([c.upper() for c in components_list], fontsize=11)
190
+ ax.legend(title="Query Difficulty")
191
+ ax.set_ylim(0, 115)
192
+
193
+ def autolabel(rects):
194
+ for rect in rects:
195
+ height = rect.get_height()
196
+ if height > 0:
197
+ ax.annotate(f'{int(height)}%',
198
+ xy=(rect.get_x() + rect.get_width() / 2, height),
199
+ xytext=(0, 3),
200
+ textcoords="offset points",
201
+ ha='center', va='bottom', fontsize=8, rotation=90)
202
+
203
+ autolabel(bars1)
204
+ autolabel(bars2)
205
+ autolabel(bars3)
206
+ autolabel(bars4)
207
+
208
+ ax.yaxis.grid(True, linestyle='--', alpha=0.7)
209
+ plt.tight_layout()
210
+ plt.savefig("component_by_difficulty_plot.png", dpi=300)
211
+
212
+ # -------------------------------
213
+ # 🚀 Terminal Printout
214
+ # -------------------------------
215
+ print("\n✅ Saved merged plot -> component_by_difficulty_plot.png")
216
+
217
+ print("\n========================================")
218
+ print("🏆 OVERALL AVERAGE ACCURACY BY DIFFICULTY")
219
+ print("========================================")
220
+ for diff in difficulties_list:
221
+ if overall_total[diff] > 0:
222
+ avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2)
223
+ print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)")
224
+ else:
225
+ print(f"{diff.capitalize():<8}: N/A (0 queries)")
226
+ print("========================================\n")
227
+
228
+ if __name__ == "__main__":
229
+ main()
src/convert_to_hf_dataset.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ import pandas as pd
3
+
4
+ df = pd.read_csv("../data/processed/train.csv")
5
+ ds = Dataset.from_pandas(df)
6
+ ds.save_to_disk("../data/processed/train")
7
+ print("DONE")
8
+
src/eval_baseline_codet5.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ from pathlib import Path
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ # ---------------- PROMPT (same style as training) ----------------
8
+ def build_prompt(question, schema):
9
+ return f"""translate English to SQL:
10
+
11
+ Schema:
12
+ {schema}
13
+
14
+ Question:
15
+ {question}
16
+
17
+ SQL:"""
18
+
19
+ # ---------------- LOAD SCHEMA ----------------
20
+ def load_schema(db_path):
21
+ conn = sqlite3.connect(db_path)
22
+ cursor = conn.cursor()
23
+
24
+ tables = cursor.execute(
25
+ "SELECT name FROM sqlite_master WHERE type='table';"
26
+ ).fetchall()
27
+
28
+ schema = ""
29
+ for (table,) in tables:
30
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
31
+ col_names = [c[1] for c in cols]
32
+ schema += f"{table}({', '.join(col_names)})\n"
33
+
34
+ conn.close()
35
+ return schema
36
+
37
+ # ---------------- EXECUTION MATCH ----------------
38
+ def execution_match(pred_sql, gold_sql, db_path):
39
+ try:
40
+ conn = sqlite3.connect(db_path)
41
+ cur = conn.cursor()
42
+
43
+ cur.execute(pred_sql)
44
+ pred = cur.fetchall()
45
+
46
+ cur.execute(gold_sql)
47
+ gold = cur.fetchall()
48
+
49
+ conn.close()
50
+ return pred == gold
51
+
52
+ except Exception:
53
+ return False
54
+
55
+ # ---------------- MAIN ----------------
56
+ def main():
57
+ project_root = Path(__file__).resolve().parents[1]
58
+
59
+ dev_json = project_root / "data" / "dev.json"
60
+ db_root = project_root / "data" / "database"
61
+
62
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
63
+
64
+ print("Loading BASE CodeT5...")
65
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
66
+ model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
67
+ model.eval()
68
+
69
+ with open(dev_json) as f:
70
+ dev = json.load(f)[:100]
71
+
72
+ correct = 0
73
+
74
+ print(f"\nEvaluating {len(dev)} samples...\n")
75
+
76
+ for i, ex in enumerate(dev, 1):
77
+ question = ex["question"]
78
+ db_id = ex["db_id"]
79
+ gold_sql = ex["query"]
80
+
81
+ db_path = db_root / db_id / f"{db_id}.sqlite"
82
+ schema = load_schema(db_path)
83
+
84
+ prompt = build_prompt(question, schema)
85
+
86
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=80,
92
+ num_beams=4,
93
+ do_sample=False
94
+ )
95
+
96
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+
98
+ if "SQL:" in pred_sql:
99
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
100
+
101
+ if execution_match(pred_sql, gold_sql, db_path):
102
+ correct += 1
103
+
104
+ if i % 10 == 0:
105
+ print(f"{i}/100 | Accuracy: {correct/i:.3f}")
106
+
107
+ print("\n=============================")
108
+ print(f"BASE MODEL ACCURACY: {correct}% / 100 = {correct}%")
109
+ print("=============================")
110
+
111
+ if __name__ == "__main__":
112
+ main()
src/eval_both_metrics.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import torch
4
+ import re
5
+ import time
6
+ import argparse
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
13
+
14
+ # -------------------------------
15
+ # 1. NORMALIZATION FOR EXACT MATCH
16
+ # -------------------------------
17
+ def normalize_sql(sql):
18
+ """Cleans SQL to make Exact Match grading fair (ignores spacing/cases)."""
19
+ sql = sql.replace('"', "'") # Standardize quotes
20
+ sql = re.sub(r"\s+", " ", sql) # Remove extra spaces/newlines
21
+ sql = sql.strip().lower() # Lowercase everything
22
+ sql = sql.rstrip(";") # Remove trailing semicolons
23
+ return sql
24
+
25
+ # -------------------------------
26
+ # 2. EXECUTION ACCURACY CHECK
27
+ # -------------------------------
28
+ def check_execution(pred_sql, gold_sql, db_path):
29
+ """Runs both queries and checks if the output rows/columns match."""
30
+ try:
31
+ conn = sqlite3.connect(db_path)
32
+ # Handle bad characters in Spider DBs
33
+ conn.text_factory = lambda b: b.decode(errors='ignore')
34
+
35
+ # 5-second timeout
36
+ start_time = time.monotonic()
37
+ def timeout_handler():
38
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
39
+ conn.set_progress_handler(timeout_handler, 10000)
40
+
41
+ cursor = conn.cursor()
42
+
43
+ # Get Predicted Result
44
+ cursor.execute(pred_sql)
45
+ pred_res = cursor.fetchall()
46
+
47
+ # Get Gold Result
48
+ cursor.execute(gold_sql)
49
+ gold_res = cursor.fetchall()
50
+
51
+ conn.close()
52
+ return pred_res == gold_res
53
+ except Exception:
54
+ return False
55
+
56
+ # -------------------------------
57
+ # 3. LOAD SCHEMA
58
+ # -------------------------------
59
+ def load_schema(db_path):
60
+ conn = sqlite3.connect(db_path)
61
+ conn.text_factory = lambda b: b.decode(errors='ignore')
62
+ cursor = conn.cursor()
63
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
64
+ schema = ""
65
+ for (table,) in tables:
66
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
67
+ col_names = [c[1] for c in cols]
68
+ schema += f"{table}({', '.join(col_names)})\n"
69
+ conn.close()
70
+ return schema
71
+
72
+ # -------------------------------
73
+ # 4. MAIN PIPELINE
74
+ # -------------------------------
75
+ def main():
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
78
+ parser.add_argument("--num_samples", type=int, default=1034, help="How many samples to evaluate")
79
+ args = parser.parse_args()
80
+
81
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
82
+ base_model = "Salesforce/codet5-base"
83
+
84
+ print(f"\n🚀 Loading Model from: {args.adapter}")
85
+ tokenizer = AutoTokenizer.from_pretrained(args.adapter)
86
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
87
+ model = PeftModel.from_pretrained(base, args.adapter).to(device)
88
+ model = model.merge_and_unload()
89
+ model.eval()
90
+
91
+ dev_json = PROJECT_ROOT / "data" / "dev.json"
92
+ with open(dev_json) as f:
93
+ dev = json.load(f)[:args.num_samples]
94
+
95
+ em_correct = 0
96
+ ex_correct = 0
97
+ total = len(dev)
98
+
99
+ print(f"\n📊 Evaluating {total} queries for BOTH Exact Match and Execution Accuracy...\n")
100
+
101
+ for i, ex in enumerate(dev, 1):
102
+ question = ex["question"]
103
+ gold_sql = ex["query"]
104
+ db_id = ex["db_id"]
105
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
106
+
107
+ # Generate SQL
108
+ schema = load_schema(db_path)
109
+ prompt = f"Database Schema:\n{schema}\nTranslate English to SQL:\n{question}\nSQL:\n"
110
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
111
+
112
+ with torch.no_grad():
113
+ outputs = model.generate(**inputs, max_new_tokens=100, num_beams=4, do_sample=False)
114
+
115
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
+ if "SQL:" in pred_sql:
117
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
118
+
119
+ # --- METRIC 1: EXACT MATCH ---
120
+ is_em = (normalize_sql(pred_sql) == normalize_sql(gold_sql))
121
+ if is_em:
122
+ em_correct += 1
123
+
124
+ # --- METRIC 2: EXECUTION ACCURACY ---
125
+ is_ex = check_execution(pred_sql, gold_sql, db_path)
126
+ if is_ex:
127
+ ex_correct += 1
128
+
129
+ if i % 50 == 0 or i == total:
130
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
131
+
132
+ # Final Results
133
+ final_em = (em_correct / total) * 100
134
+ final_ex = (ex_correct / total) * 100
135
+
136
+ print("\n==========================================")
137
+ print(f"🎯 FINAL RESULTS FOR: {args.adapter}")
138
+ print("==========================================")
139
+ print(f"Exact Match (EM) Accuracy : {final_em:.2f}%")
140
+ print(f"Execution (EX) Accuracy : {final_ex:.2f}%")
141
+ print("==========================================\n")
142
+
143
+ if __name__ == "__main__":
144
+ main()
src/eval_rl_fixed.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ # import sqlite3
3
+ # import argparse
4
+ # from pathlib import Path
5
+ # import torch
6
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ # from peft import PeftModel
8
+
9
+ # # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
10
+ # def build_prompt(question, schema):
11
+ # return f"""
12
+ # Database Schema:
13
+ # {schema}
14
+
15
+ # Translate English to SQL:
16
+ # {question}
17
+ # SQL:
18
+ # """
19
+
20
+ # # ---------------- LOAD SCHEMA ----------------
21
+ # def load_schema(db_path):
22
+ # conn = sqlite3.connect(db_path)
23
+ # cursor = conn.cursor()
24
+
25
+ # tables = cursor.execute(
26
+ # "SELECT name FROM sqlite_master WHERE type='table';"
27
+ # ).fetchall()
28
+
29
+ # schema = ""
30
+ # for (table,) in tables:
31
+ # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
32
+ # col_names = [c[1] for c in cols]
33
+ # schema += f"{table}({', '.join(col_names)})\n"
34
+
35
+ # conn.close()
36
+ # return schema
37
+
38
+
39
+ # # ---------------- EXECUTION CHECK ----------------
40
+ # def execution_match(pred_sql, gold_sql, db_path):
41
+ # try:
42
+ # conn = sqlite3.connect(db_path)
43
+ # cur = conn.cursor()
44
+
45
+ # cur.execute(pred_sql)
46
+ # pred = cur.fetchall()
47
+
48
+ # cur.execute(gold_sql)
49
+ # gold = cur.fetchall()
50
+
51
+ # conn.close()
52
+ # return pred == gold
53
+
54
+ # except Exception:
55
+ # return False
56
+
57
+
58
+ # # ---------------- MAIN ----------------
59
+ # def main():
60
+ # parser = argparse.ArgumentParser()
61
+ # parser.add_argument("--adapter", type=str, required=True)
62
+ # parser.add_argument("--num_samples", type=int, default=1034)
63
+ # args = parser.parse_args()
64
+
65
+ # project_root = Path(__file__).resolve().parents[1]
66
+
67
+ # dev_json = project_root / "data" / "dev.json"
68
+ # db_root = project_root / "data" / "database"
69
+
70
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
71
+
72
+ # # load model
73
+ # base_model = "Salesforce/codet5-base"
74
+ # tokenizer = AutoTokenizer.from_pretrained(args.adapter)
75
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
76
+ # model = PeftModel.from_pretrained(base, args.adapter).to(device)
77
+ # model = model.merge_and_unload()
78
+
79
+ # with open(dev_json) as f:
80
+ # dev = json.load(f)[: args.num_samples]
81
+
82
+ # correct = 0
83
+
84
+ # print(f"Evaluating {len(dev)} examples...\n")
85
+
86
+ # for i, ex in enumerate(dev, 1):
87
+ # question = ex["question"]
88
+ # db_id = ex["db_id"]
89
+ # gold_sql = ex["query"]
90
+
91
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
92
+ # schema = load_schema(db_path)
93
+
94
+ # prompt = build_prompt(question, schema)
95
+
96
+ # inputs = tokenizer(prompt, return_tensors="pt").to(device)
97
+
98
+ # with torch.no_grad():
99
+ # outputs = model.generate(
100
+ # **inputs,
101
+ # max_new_tokens=80,
102
+ # do_sample=False,
103
+ # num_beams=4,
104
+ # )
105
+
106
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
107
+
108
+ # if "SQL:" in pred_sql:
109
+ # pred_sql = pred_sql.split("SQL:")[-1].strip()
110
+
111
+ # match = execution_match(pred_sql, gold_sql, db_path)
112
+
113
+ # if match:
114
+ # correct += 1
115
+
116
+ # if i % 10 == 0:
117
+ # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
118
+
119
+ # print("\n=============================")
120
+ # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
121
+ # print("=============================")
122
+
123
+
124
+ # if __name__ == "__main__":
125
+ # main()
126
+
127
+
128
+ # import json
129
+ # import sqlite3
130
+ # import argparse
131
+ # import time
132
+ # from pathlib import Path
133
+ # import torch
134
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
135
+ # from peft import PeftModel
136
+
137
+ # # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
138
+ # def build_prompt(question, schema):
139
+ # return f"""
140
+ # Database Schema:
141
+ # {schema}
142
+
143
+ # Translate English to SQL:
144
+ # {question}
145
+ # SQL:
146
+ # """
147
+
148
+ # # ---------------- LOAD SCHEMA ----------------
149
+ # def load_schema(db_path):
150
+ # conn = sqlite3.connect(db_path)
151
+ # cursor = conn.cursor()
152
+
153
+ # tables = cursor.execute(
154
+ # "SELECT name FROM sqlite_master WHERE type='table';"
155
+ # ).fetchall()
156
+
157
+ # schema = ""
158
+ # for (table,) in tables:
159
+ # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
160
+ # col_names = [c[1] for c in cols]
161
+ # schema += f"{table}({', '.join(col_names)})\n"
162
+
163
+ # conn.close()
164
+ # return schema
165
+
166
+
167
+ # # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
168
+ # def execution_match(pred_sql, gold_sql, db_path):
169
+ # try:
170
+ # conn = sqlite3.connect(db_path)
171
+
172
+ # # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
173
+ # start_time = time.monotonic()
174
+ # def timeout_handler():
175
+ # return 1 if (time.monotonic() - start_time) > 5.0 else 0
176
+ # conn.set_progress_handler(timeout_handler, 10000)
177
+
178
+ # cur = conn.cursor()
179
+
180
+ # cur.execute(pred_sql)
181
+ # pred = cur.fetchall()
182
+
183
+ # cur.execute(gold_sql)
184
+ # gold = cur.fetchall()
185
+
186
+ # conn.close()
187
+ # return pred == gold
188
+
189
+ # except Exception:
190
+ # return False
191
+
192
+
193
+ # # ---------------- MAIN ----------------
194
+ # def main():
195
+ # parser = argparse.ArgumentParser()
196
+ # parser.add_argument("--adapter", type=str, required=True)
197
+ # parser.add_argument("--num_samples", type=int, default=1034)
198
+ # args = parser.parse_args()
199
+
200
+ # project_root = Path(__file__).resolve().parents[1]
201
+
202
+ # dev_json = project_root / "data" / "dev.json"
203
+ # db_root = project_root / "data" / "database"
204
+
205
+ # # 🎯 Added CUDA support for Nvidia GPUs
206
+ # device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
207
+
208
+ # # load model
209
+ # base_model = "Salesforce/codet5-base"
210
+ # print(f"Loading Base: {base_model}")
211
+ # print(f"Loading Adapter: {args.adapter}")
212
+
213
+ # tokenizer = AutoTokenizer.from_pretrained(args.adapter)
214
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
215
+ # model = PeftModel.from_pretrained(base, args.adapter).to(device)
216
+ # model = model.merge_and_unload()
217
+
218
+ # with open(dev_json) as f:
219
+ # dev = json.load(f)[: args.num_samples]
220
+
221
+ # correct = 0
222
+
223
+ # print(f"Evaluating {len(dev)} examples...\n")
224
+
225
+ # for i, ex in enumerate(dev, 1):
226
+ # question = ex["question"]
227
+ # db_id = ex["db_id"]
228
+ # gold_sql = ex["query"]
229
+
230
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
231
+ # schema = load_schema(db_path)
232
+
233
+ # prompt = build_prompt(question, schema)
234
+
235
+ # inputs = tokenizer(prompt, return_tensors="pt").to(device)
236
+
237
+ # with torch.no_grad():
238
+ # outputs = model.generate(
239
+ # **inputs,
240
+ # max_new_tokens=80,
241
+ # do_sample=False,
242
+ # num_beams=4,
243
+ # )
244
+
245
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
246
+
247
+ # if "SQL:" in pred_sql:
248
+ # pred_sql = pred_sql.split("SQL:")[-1].strip()
249
+
250
+ # match = execution_match(pred_sql, gold_sql, db_path)
251
+
252
+ # if match:
253
+ # correct += 1
254
+
255
+ # if i % 10 == 0:
256
+ # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
257
+
258
+ # print("\n=============================")
259
+ # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
260
+ # print("=============================")
261
+
262
+
263
+ # if __name__ == "__main__":
264
+ # main()
265
+
266
+
267
+ import json
268
+ import subprocess
269
+ import sys
270
+ import argparse
271
+ import random
272
+ import sqlite3
273
+ import time
274
+ import re
275
+ from pathlib import Path
276
+
277
+ import torch
278
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
279
+ from peft import PeftModel
280
+
281
+ # Assuming you have a prompting.py that has encode_prompt
282
+ from prompting import encode_prompt
283
+
284
+ # -------------------------------
285
+ # LIVE CHECK HELPERS
286
+ # -------------------------------
287
+ def normalize_sql(sql):
288
+ """Basic normalization for the live progress bar."""
289
+ sql = sql.replace('"', "'")
290
+ sql = re.sub(r"\s+", " ", sql)
291
+ return sql.strip().lower().rstrip(";")
292
+
293
+ def check_execution(pred_sql, gold_sql, db_path):
294
+ """Basic execution check for the live progress bar."""
295
+ try:
296
+ conn = sqlite3.connect(db_path)
297
+ conn.text_factory = lambda b: b.decode(errors='ignore')
298
+
299
+ # 2-second timeout so the live tracker doesn't freeze forever
300
+ start_time = time.monotonic()
301
+ def timeout_handler():
302
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
303
+ conn.set_progress_handler(timeout_handler, 10000)
304
+
305
+ cursor = conn.cursor()
306
+ cursor.execute(pred_sql)
307
+ pred_res = cursor.fetchall()
308
+
309
+ cursor.execute(gold_sql)
310
+ gold_res = cursor.fetchall()
311
+ conn.close()
312
+
313
+ # Simple sorted check for the live tracker
314
+ return sorted(pred_res) == sorted(gold_res)
315
+ except Exception:
316
+ return False
317
+
318
+ # -------------------------------
319
+ # SPIDER PARSER
320
+ # -------------------------------
321
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
322
+ for line in stdout.splitlines():
323
+ if metric_type == "exec" and line.strip().startswith("execution"):
324
+ try: return float(line.split()[-1])
325
+ except: pass
326
+ elif metric_type == "match" and line.strip().startswith("exact"):
327
+ try: return float(line.split()[-1])
328
+ except: pass
329
+ return None
330
+
331
+ # -------------------------------
332
+ # MAIN
333
+ # -------------------------------
334
+ def main():
335
+ parser = argparse.ArgumentParser()
336
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
337
+ parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate")
338
+ parser.add_argument("--shuffle_dev", action="store_true")
339
+ parser.add_argument("--shuffle_seed", type=int, default=42)
340
+ args = parser.parse_args()
341
+
342
+ project_root = Path(__file__).resolve().parents[1]
343
+ adapter_dir = project_root / args.adapter
344
+
345
+ db_root = project_root / "data" / "database"
346
+ table_json = project_root / "data" / "tables.json"
347
+ dev_json = project_root / "data" / "dev.json"
348
+
349
+ pred_path = project_root / "temp_predictions.txt"
350
+ temp_gold_path = project_root / "temp_gold.sql"
351
+
352
+ if not adapter_dir.exists():
353
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
354
+
355
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
356
+ print(f"Using device: {device}")
357
+
358
+ BASE_MODEL = "Salesforce/codet5-base"
359
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
360
+ if tokenizer.pad_token is None:
361
+ tokenizer.pad_token = tokenizer.eos_token
362
+
363
+ print(f"Loading Model: {args.adapter}...")
364
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
365
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
366
+ model = model.merge_and_unload()
367
+ model.eval()
368
+
369
+ with dev_json.open() as f:
370
+ dev = json.load(f)
371
+
372
+ if args.shuffle_dev:
373
+ rng = random.Random(args.shuffle_seed)
374
+ rng.shuffle(dev)
375
+
376
+ dev = dev[: args.num_samples]
377
+ total = len(dev)
378
+
379
+ gen_kwargs = dict(
380
+ max_new_tokens=160,
381
+ num_beams=4,
382
+ do_sample=False,
383
+ early_stopping=True,
384
+ pad_token_id=tokenizer.pad_token_id,
385
+ eos_token_id=tokenizer.eos_token_id,
386
+ )
387
+
388
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
389
+
390
+ em_correct = 0
391
+ ex_correct = 0
392
+
393
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
394
+ for i, ex in enumerate(dev, start=1):
395
+ db_id = ex["db_id"]
396
+ question = ex["question"]
397
+ gold_query = ex["query"]
398
+ db_path = db_root / db_id / f"{db_id}.sqlite"
399
+
400
+ # Generate
401
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
402
+ input_ids = input_ids.unsqueeze(0).to(device)
403
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
404
+
405
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
406
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
407
+
408
+ # Write to files for official spider eval later
409
+ out_pred.write(f"{pred_sql}\n")
410
+ out_gold.write(f"{gold_query}\t{db_id}\n")
411
+
412
+ # --- LIVE TRACKING CHECKS ---
413
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
414
+ em_correct += 1
415
+ if check_execution(pred_sql, gold_query, db_path):
416
+ ex_correct += 1
417
+
418
+ # Print progress every 50 loops
419
+ if i % 10 == 0 or i == total:
420
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
421
+
422
+ print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
423
+
424
+ eval_script = project_root / "spider_eval" / "evaluation.py"
425
+
426
+ # 1. RUN EXACT MATCH EVAL
427
+ cmd_match = [
428
+ sys.executable, str(eval_script),
429
+ "--gold", str(temp_gold_path),
430
+ "--pred", str(pred_path),
431
+ "--etype", "match",
432
+ "--db", str(db_root),
433
+ "--table", str(table_json),
434
+ ]
435
+ proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
436
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
437
+
438
+ # 2. RUN EXECUTION EVAL
439
+ cmd_exec = [
440
+ sys.executable, str(eval_script),
441
+ "--gold", str(temp_gold_path),
442
+ "--pred", str(pred_path),
443
+ "--etype", "exec",
444
+ "--db", str(db_root),
445
+ "--table", str(table_json),
446
+ ]
447
+ proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
448
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
449
+
450
+ print("==========================================")
451
+ print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
452
+ print("==========================================")
453
+
454
+ if exact_acc is not None:
455
+ print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
456
+ else:
457
+ print("Exact Set Match Accuracy : Could not parse output")
458
+
459
+ if exec_acc is not None:
460
+ print(f"Execution Accuracy : {exec_acc*100:.2f}%")
461
+ else:
462
+ print("Execution Accuracy : Could not parse output")
463
+ print("==========================================\n")
464
+
465
+ if __name__ == "__main__":
466
+ main()
src/eval_rl_t5.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # import os
3
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
+ # import json
5
+
6
+ # import subprocess
7
+
8
+ # import argparse
9
+ # from pathlib import Path
10
+
11
+ # import torch
12
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ # from peft import PeftModel
14
+
15
+ # # IMPORTANT: must match training prompt format
16
+ # from prompting import build_prompt
17
+ # from schema_utils import get_schema as get_db_schema
18
+
19
+
20
+ # def _parse_exec_accuracy(stdout: str):
21
+ # for line in stdout.splitlines():
22
+ # if line.strip().startswith("execution"):
23
+ # parts = line.split()
24
+ # try:
25
+ # return float(parts[-1])
26
+ # except Exception:
27
+ # return None
28
+ # return None
29
+
30
+
31
+ # def main():
32
+ # parser = argparse.ArgumentParser()
33
+ # parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
34
+ # parser.add_argument("--num_samples", type=int, default=200)
35
+ # args = parser.parse_args()
36
+
37
+ # project_root = Path(__file__).resolve().parents[1]
38
+ # adapter_dir = project_root / args.adapter
39
+
40
+ # if not adapter_dir.exists():
41
+ # raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
42
+
43
+ # db_root = project_root / "data" / "database"
44
+ # table_json = project_root / "data" / "tables.json"
45
+ # dev_json = project_root / "data" / "dev.json"
46
+ # gold_sql = project_root / "data" / "dev_gold.sql"
47
+ # pred_path = project_root / "predictions_rl.txt"
48
+
49
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
50
+
51
+ # # ---- LOAD MODEL (CodeT5 + LoRA) ----
52
+ # base_model = "Salesforce/codet5-base"
53
+
54
+ # tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
55
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
56
+ # model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
57
+
58
+ # # merge LoRA for faster inference
59
+ # model = model.merge_and_unload()
60
+ # model.eval()
61
+ # model.config.use_cache = True
62
+
63
+ # if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
64
+ # tokenizer.pad_token = tokenizer.eos_token
65
+
66
+ # # ---- LOAD DATA ----
67
+ # with dev_json.open() as f:
68
+ # dev = json.load(f)
69
+
70
+ # dev = dev[: args.num_samples]
71
+
72
+ # gen_kwargs = dict(
73
+ # max_new_tokens=120,
74
+ # do_sample=False,
75
+ # num_beams=1,
76
+ # pad_token_id=tokenizer.pad_token_id,
77
+ # eos_token_id=tokenizer.eos_token_id,
78
+ # )
79
+
80
+ # print(f"Generating {len(dev)} predictions...")
81
+
82
+ # with pred_path.open("w") as out_f, torch.no_grad():
83
+ # for i, ex in enumerate(dev, start=1):
84
+ # db_id = ex["db_id"]
85
+ # question = ex["question"]
86
+
87
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
88
+ # schema = get_db_schema(str(db_path))
89
+ # prompt = build_prompt(question, schema, use_schema=True)
90
+
91
+ # inputs = tokenizer(
92
+ # prompt,
93
+ # return_tensors="pt",
94
+ # truncation=True,
95
+ # max_length=512
96
+ # ).to(device)
97
+
98
+ # out = model.generate(**inputs, **gen_kwargs)
99
+ # pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
100
+
101
+ # out_f.write(f"{pred_sql}\t{db_id}\n")
102
+
103
+ # if i % 20 == 0 or i == len(dev):
104
+ # print(f"{i}/{len(dev)} done")
105
+
106
+ # # ---- SPIDER OFFICIAL EVAL ----
107
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
108
+
109
+ # cmd = [
110
+ # sys.executable,
111
+ # str(eval_script),
112
+ # "--gold",
113
+ # str(gold_sql),
114
+ # "--pred",
115
+ # str(pred_path),
116
+ # "--etype",
117
+ # "exec",
118
+ # "--db",
119
+ # str(db_root),
120
+ # "--table",
121
+ # str(table_json),
122
+ # ]
123
+
124
+ # print("\nRunning Spider execution evaluation...\n")
125
+ # proc = subprocess.run(cmd, capture_output=True, text=True)
126
+
127
+ # if proc.returncode != 0:
128
+ # print(proc.stdout)
129
+ # print(proc.stderr)
130
+ # sys.exit(proc.returncode)
131
+
132
+ # print(proc.stdout)
133
+
134
+ # acc = _parse_exec_accuracy(proc.stdout)
135
+ # if acc is not None:
136
+ # print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%")
137
+ # else:
138
+ # print("Could not parse execution accuracy")
139
+
140
+
141
+ # if __name__ == "__main__":
142
+ # main()
143
+
144
+
145
+ import json
146
+ import sqlite3
147
+ import argparse
148
+ import time
149
+ from pathlib import Path
150
+ import torch
151
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
152
+ from peft import PeftModel
153
+
154
+ # ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ----------------
155
+ def build_prompt(question, schema):
156
+ return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
157
+
158
+ # ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ----------------
159
+ def load_schema(db_path):
160
+ conn = sqlite3.connect(db_path)
161
+ cursor = conn.cursor()
162
+
163
+ tables = cursor.execute(
164
+ "SELECT name FROM sqlite_master WHERE type='table';"
165
+ ).fetchall()
166
+
167
+ schema = ""
168
+ for (table,) in tables:
169
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
170
+ col_names = [c[1] for c in cols]
171
+ # Space-separated, not newline-separated, just like the RLHF script
172
+ schema += f"{table}({', '.join(col_names)}) "
173
+
174
+ conn.close()
175
+ return schema.strip()
176
+
177
+
178
+ # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
179
+ def execution_match(pred_sql, gold_sql, db_path):
180
+ try:
181
+ conn = sqlite3.connect(db_path)
182
+
183
+ # --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG ---
184
+ start_time = time.monotonic()
185
+ def timeout_handler():
186
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
187
+ conn.set_progress_handler(timeout_handler, 10000)
188
+
189
+ cur = conn.cursor()
190
+
191
+ cur.execute(pred_sql)
192
+ pred = cur.fetchall()
193
+
194
+ cur.execute(gold_sql)
195
+ gold = cur.fetchall()
196
+
197
+ conn.close()
198
+ return pred == gold
199
+
200
+ except Exception:
201
+ return False
202
+
203
+
204
+ # ---------------- MAIN ----------------
205
+ def main():
206
+ parser = argparse.ArgumentParser()
207
+ # 🎯 Set the default directly to your best RLHF model!
208
+ parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best")
209
+ parser.add_argument("--num_samples", type=int, default=1000)
210
+ args = parser.parse_args()
211
+
212
+ project_root = Path(__file__).resolve().parents[1]
213
+
214
+ # Resolve adapter path safely
215
+ adapter_path = project_root / args.adapter
216
+
217
+ dev_json = project_root / "data" / "dev.json"
218
+ db_root = project_root / "data" / "database"
219
+
220
+ # 🎯 Added CUDA support
221
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
222
+
223
+ # load model
224
+ base_model = "t5-small"
225
+ print(f"Loading Base: {base_model}")
226
+ print(f"Loading Adapter: {adapter_path}")
227
+
228
+ tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
229
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
230
+ model = PeftModel.from_pretrained(base, str(adapter_path)).to(device)
231
+ model = model.merge_and_unload()
232
+
233
+ with open(dev_json) as f:
234
+ dev = json.load(f)[: args.num_samples]
235
+
236
+ correct = 0
237
+
238
+ print(f"Evaluating {len(dev)} examples...\n")
239
+
240
+ for i, ex in enumerate(dev, 1):
241
+ question = ex["question"]
242
+ db_id = ex["db_id"]
243
+ gold_sql = ex["query"]
244
+
245
+ db_path = db_root / db_id / f"{db_id}.sqlite"
246
+ schema = load_schema(db_path)
247
+
248
+ prompt = build_prompt(question, schema)
249
+
250
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
251
+
252
+ with torch.no_grad():
253
+ outputs = model.generate(
254
+ **inputs,
255
+ max_new_tokens=80,
256
+ do_sample=False,
257
+ num_beams=4,
258
+ )
259
+
260
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
261
+
262
+ if "SQL:" in pred_sql:
263
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
264
+
265
+ match = execution_match(pred_sql, gold_sql, db_path)
266
+
267
+ if match:
268
+ correct += 1
269
+
270
+ if i % 10 == 0:
271
+ print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
272
+
273
+ print("\n=============================")
274
+ print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
275
+ print("=============================")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
src/eval_single_model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ import sys
4
+ import argparse
5
+ import random
6
+ import sqlite3
7
+ import time
8
+ import re
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
+ from peft import PeftModel
16
+
17
+ # Assuming you have a prompting.py that has encode_prompt
18
+ from prompting import encode_prompt
19
+
20
+ # -------------------------------
21
+ # LIVE CHECK HELPERS
22
+ # -------------------------------
23
+ def normalize_sql(sql):
24
+ sql = sql.replace('"', "'")
25
+ sql = re.sub(r"\s+", " ", sql)
26
+ return sql.strip().lower().rstrip(";")
27
+
28
+ def check_execution(pred_sql, gold_sql, db_path):
29
+ try:
30
+ conn = sqlite3.connect(db_path)
31
+ conn.text_factory = lambda b: b.decode(errors='ignore')
32
+
33
+ start_time = time.monotonic()
34
+ def timeout_handler():
35
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
36
+ conn.set_progress_handler(timeout_handler, 10000)
37
+
38
+ cursor = conn.cursor()
39
+ cursor.execute(pred_sql)
40
+ pred_res = cursor.fetchall()
41
+
42
+ cursor.execute(gold_sql)
43
+ gold_res = cursor.fetchall()
44
+ conn.close()
45
+
46
+ return sorted(pred_res) == sorted(gold_res)
47
+ except Exception:
48
+ return False
49
+
50
+ # -------------------------------
51
+ # SPIDER PARSER
52
+ # -------------------------------
53
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
54
+ for line in stdout.splitlines():
55
+ if metric_type == "exec" and line.strip().startswith("execution"):
56
+ try: return float(line.split()[-1])
57
+ except: pass
58
+ elif metric_type == "match" and line.strip().startswith("exact"):
59
+ try: return float(line.split()[-1])
60
+ except: pass
61
+ return None
62
+
63
+ # -------------------------------
64
+ # MAIN
65
+ # -------------------------------
66
+ def main():
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your checkpoint")
69
+ parser.add_argument("--base_model", type=str, required=True, help="E.g., facebook/bart-base, t5-small")
70
+ parser.add_argument("--model_name", type=str, required=True, help="Name for the plot label (e.g., 'BART RLHF')")
71
+ parser.add_argument("--num_samples", type=int, default=700)
72
+ args = parser.parse_args()
73
+
74
+ project_root = Path(__file__).resolve().parents[1]
75
+ adapter_dir = project_root / args.adapter
76
+
77
+ db_root = project_root / "data" / "database"
78
+ table_json = project_root / "data" / "tables.json"
79
+ dev_json = project_root / "data" / "dev.json"
80
+
81
+ pred_path = project_root / "temp_predictions.txt"
82
+ temp_gold_path = project_root / "temp_gold.sql"
83
+
84
+ # NEW: Plot directory setup
85
+ plot_dir = project_root / "comparison_plots"
86
+ plot_dir.mkdir(parents=True, exist_ok=True)
87
+ results_json_path = plot_dir / "all_metrics.json"
88
+
89
+ if not adapter_dir.exists():
90
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
91
+
92
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
93
+ print(f"Loading Base Model: {args.base_model} on {device}...")
94
+
95
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ base = AutoModelForSeq2SeqLM.from_pretrained(args.base_model).to(device)
100
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
101
+ model = model.merge_and_unload()
102
+ model.eval()
103
+
104
+ with dev_json.open() as f:
105
+ dev = json.load(f)[: args.num_samples]
106
+ total = len(dev)
107
+
108
+ gen_kwargs = dict(
109
+ max_new_tokens=160,
110
+ num_beams=4,
111
+ do_sample=False,
112
+ early_stopping=True,
113
+ pad_token_id=tokenizer.pad_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ )
116
+
117
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
118
+
119
+ em_correct = 0
120
+ ex_correct = 0
121
+
122
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
123
+ for i, ex in enumerate(dev, start=1):
124
+ db_id = ex["db_id"]
125
+ question = ex["question"]
126
+ gold_query = ex["query"]
127
+ db_path = db_root / db_id / f"{db_id}.sqlite"
128
+
129
+ # Generate
130
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
131
+ input_ids = input_ids.unsqueeze(0).to(device)
132
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
133
+
134
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
135
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
136
+
137
+ out_pred.write(f"{pred_sql}\n")
138
+ out_gold.write(f"{gold_query}\t{db_id}\n")
139
+
140
+ # --- PRINT FIRST 3 EXAMPLES ---
141
+ if i <= 3:
142
+ print(f"--- 🔍 Example {i} ---")
143
+ print(f"Q : {question}")
144
+ print(f"Gold: {gold_query}")
145
+ print(f"Pred: {pred_sql}")
146
+ print("-" * 25)
147
+
148
+ # --- LIVE TRACKING CHECKS ---
149
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
150
+ em_correct += 1
151
+ if check_execution(pred_sql, gold_query, db_path):
152
+ ex_correct += 1
153
+
154
+ if i % 50 == 0 or i == total:
155
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
156
+
157
+ print("\nRunning Official Spider Evaluations...")
158
+ eval_script = project_root / "spider_eval" / "evaluation.py"
159
+
160
+ proc_match = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "match", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
161
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
162
+
163
+ proc_exec = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "exec", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
164
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
165
+
166
+ print("\n==========================================")
167
+ print(f"🎯 RESULTS FOR: {args.model_name}")
168
+ print("==========================================")
169
+ exact_val = exact_acc * 100 if exact_acc else 0
170
+ exec_val = exec_acc * 100 if exec_acc else 0
171
+ print(f"Exact Match : {exact_val:.2f}%")
172
+ print(f"Execution : {exec_val:.2f}%")
173
+ print("==========================================\n")
174
+
175
+ # -------------------------------
176
+ # SAVE JSON & GENERATE PLOT
177
+ # -------------------------------
178
+ if results_json_path.exists():
179
+ with open(results_json_path, 'r') as f:
180
+ all_results = json.load(f)
181
+ else:
182
+ all_results = {}
183
+
184
+ all_results[args.model_name] = {"EM": exact_val, "EX": exec_val}
185
+
186
+ with open(results_json_path, 'w') as f:
187
+ json.dump(all_results, f, indent=4)
188
+
189
+ labels = list(all_results.keys())
190
+ em_vals = [all_results[k]["EM"] for k in labels]
191
+ ex_vals = [all_results[k]["EX"] for k in labels]
192
+
193
+ x = np.arange(len(labels))
194
+ width = 0.35
195
+
196
+ plt.figure(figsize=(max(8, len(labels) * 1.5), 6))
197
+ plt.bar(x - width/2, em_vals, width, label='Exact Match', color='#3498db')
198
+ plt.bar(x + width/2, ex_vals, width, label='Execution', color='#2ecc71')
199
+
200
+ plt.ylabel('Accuracy (%)', fontweight='bold')
201
+ plt.title('Model Comparison: Exact Match vs Execution Accuracy', fontweight='bold', fontsize=14)
202
+ plt.xticks(x, labels, rotation=45, ha="right")
203
+ plt.legend()
204
+ plt.ylim(0, max(max(em_vals, default=0), max(ex_vals, default=0)) + 15)
205
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
206
+
207
+ # Attach labels to bars
208
+ for i in range(len(labels)):
209
+ plt.text(x[i] - width/2, em_vals[i] + 1, f"{em_vals[i]:.1f}%", ha='center', fontsize=9)
210
+ plt.text(x[i] + width/2, ex_vals[i] + 1, f"{ex_vals[i]:.1f}%", ha='center', fontsize=9)
211
+
212
+ plt.tight_layout()
213
+ plot_path = plot_dir / "accuracy_comparison.png"
214
+ plt.savefig(plot_path, dpi=300)
215
+ print(f"📈 Updated comparison plot saved to: {plot_path}")
216
+
217
+ if __name__ == "__main__":
218
+ main()
src/evaluate_model_codet5.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+
3
+ # import json
4
+ # import subprocess
5
+ # import sys
6
+ # import argparse
7
+ # import sqlite3
8
+ # import random
9
+ # from pathlib import Path
10
+
11
+ # import torch
12
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ # from peft import PeftModel
14
+
15
+ # from prompting import encode_prompt
16
+
17
+
18
+ # def _parse_exec_accuracy(stdout: str) -> float | None:
19
+ # for line in stdout.splitlines():
20
+ # if line.strip().startswith("execution"):
21
+ # try:
22
+ # return float(line.split()[-1])
23
+ # except:
24
+ # return None
25
+ # return None
26
+
27
+
28
+ # def main():
29
+
30
+ # # ---------------- ARGUMENTS ----------------
31
+ # parser = argparse.ArgumentParser()
32
+ # parser.add_argument("--adapter", type=str, default="checkpoints/sft_adapter_codet5")
33
+ # parser.add_argument("--num_samples", type=int, default=1000)
34
+ # parser.add_argument("--shuffle_dev", action="store_true")
35
+ # parser.add_argument("--shuffle_seed", type=int, default=42)
36
+ # parser.add_argument("--accuracy_log", type=str, default="")
37
+ # args = parser.parse_args()
38
+
39
+ # project_root = Path(__file__).resolve().parents[1]
40
+ # adapter_dir = project_root / args.adapter
41
+
42
+ # db_root = project_root / "data" / "database"
43
+ # table_json = project_root / "data" / "tables.json"
44
+ # dev_json = project_root / "data" / "dev.json"
45
+ # gold_sql = project_root / "data" / "dev_gold.sql"
46
+ # pred_path = project_root / "predictions.txt"
47
+
48
+ # if not adapter_dir.exists():
49
+ # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
50
+
51
+ # # ---------------- DEVICE ----------------
52
+ # device = "mps" if torch.backends.mps.is_available() else (
53
+ # "cuda" if torch.cuda.is_available() else "cpu"
54
+ # )
55
+ # print("Using device:", device)
56
+
57
+ # # ---------------- LOAD MODEL ----------------
58
+ # BASE_MODEL = "Salesforce/codet5-base"
59
+
60
+ # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
+
62
+ # if tokenizer.pad_token is None:
63
+ # tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
66
+ # model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
67
+
68
+ # model = model.merge_and_unload()
69
+ # model.eval()
70
+
71
+ # # ---------------- LOAD DATA ----------------
72
+ # with dev_json.open() as f:
73
+ # dev = json.load(f)
74
+
75
+ # if args.shuffle_dev:
76
+ # rng = random.Random(args.shuffle_seed)
77
+ # rng.shuffle(dev)
78
+
79
+ # dev = dev[: args.num_samples]
80
+
81
+ # # ---------------- GENERATION CONFIG ----------------
82
+ # gen_kwargs = dict(
83
+ # max_new_tokens=160,
84
+ # num_beams=4,
85
+ # do_sample=False,
86
+ # early_stopping=True,
87
+ # pad_token_id=tokenizer.pad_token_id,
88
+ # eos_token_id=tokenizer.eos_token_id,
89
+ # )
90
+
91
+ # print("Generating predictions...\n")
92
+
93
+ # correct = 0
94
+ # total = len(dev)
95
+ # accuracy_log_fh = None
96
+
97
+ # if args.accuracy_log:
98
+ # accuracy_log_path = (project_root / args.accuracy_log).resolve()
99
+ # accuracy_log_path.parent.mkdir(parents=True, exist_ok=True)
100
+ # accuracy_log_fh = accuracy_log_path.open("w")
101
+ # print(f"Writing running accuracy log to: {accuracy_log_path}")
102
+
103
+ # with pred_path.open("w") as out_f, torch.no_grad():
104
+
105
+ # for i, ex in enumerate(dev, start=1):
106
+
107
+ # db_id = ex["db_id"]
108
+ # question = ex["question"]
109
+ # gold_query = ex["query"]
110
+
111
+ # input_ids = encode_prompt(
112
+ # tokenizer,
113
+ # question,
114
+ # db_id,
115
+ # device=device,
116
+ # max_input_tokens=512,
117
+ # )
118
+
119
+ # input_ids = input_ids.unsqueeze(0).to(device)
120
+ # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
121
+
122
+ # outputs = model.generate(
123
+ # input_ids=input_ids,
124
+ # attention_mask=attention_mask,
125
+ # **gen_kwargs
126
+ # )
127
+
128
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
129
+ # out_f.write(f"{pred_sql}\t{db_id}\n")
130
+
131
+ # # ---------------- LIVE EXECUTION CHECK ----------------
132
+ # try:
133
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
134
+
135
+ # conn = sqlite3.connect(db_path)
136
+ # cursor = conn.cursor()
137
+
138
+ # cursor.execute(pred_sql)
139
+ # pred_rows = cursor.fetchall()
140
+
141
+ # cursor.execute(gold_query)
142
+ # gold_rows = cursor.fetchall()
143
+
144
+ # conn.close()
145
+
146
+ # if sorted(pred_rows) == sorted(gold_rows):
147
+ # correct += 1
148
+
149
+ # except Exception:
150
+ # pass # execution failed
151
+
152
+ # # 🔥 PRINT EVERY 10
153
+ # if i % 10 == 0 or i == total:
154
+ # current_acc = correct / i
155
+ # line = f"{i}/{total} | Acc: {current_acc:.3f}"
156
+ # print(line)
157
+ # if accuracy_log_fh is not None:
158
+ # accuracy_log_fh.write(line + "\n")
159
+
160
+ # if accuracy_log_fh is not None:
161
+ # accuracy_log_fh.close()
162
+
163
+ # print("\nGeneration finished.\n")
164
+
165
+ # # ---------------- OFFICIAL SPIDER EVAL ----------------
166
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
167
+
168
+ # cmd = [
169
+ # sys.executable,
170
+ # str(eval_script),
171
+ # "--gold", str(gold_sql),
172
+ # "--pred", str(pred_path),
173
+ # "--etype", "exec",
174
+ # "--db", str(db_root),
175
+ # "--table", str(table_json),
176
+ # ]
177
+
178
+ # print("Running Spider evaluation...")
179
+ # proc = subprocess.run(cmd, capture_output=True, text=True)
180
+
181
+ # print(proc.stdout)
182
+
183
+ # exec_acc = _parse_exec_accuracy(proc.stdout)
184
+ # if exec_acc is not None:
185
+ # print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
186
+ # else:
187
+ # print("Could not parse accuracy.")
188
+
189
+
190
+ # if __name__ == "__main__":
191
+ # main()
192
+
193
+ import json
194
+ import subprocess
195
+ import sys
196
+ import argparse
197
+ import random
198
+ import sqlite3
199
+ import time
200
+ import re
201
+ from pathlib import Path
202
+
203
+ import torch
204
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
205
+ from peft import PeftModel
206
+
207
+ # Assuming you have a prompting.py that has encode_prompt
208
+ from prompting import encode_prompt
209
+
210
+ # -------------------------------
211
+ # LIVE CHECK HELPERS
212
+ # -------------------------------
213
+ def normalize_sql(sql):
214
+ """Basic normalization for the live progress bar."""
215
+ sql = sql.replace('"', "'")
216
+ sql = re.sub(r"\s+", " ", sql)
217
+ return sql.strip().lower().rstrip(";")
218
+
219
+ def check_execution(pred_sql, gold_sql, db_path):
220
+ """Basic execution check for the live progress bar."""
221
+ try:
222
+ conn = sqlite3.connect(db_path)
223
+ conn.text_factory = lambda b: b.decode(errors='ignore')
224
+
225
+ # 2-second timeout so the live tracker doesn't freeze forever
226
+ start_time = time.monotonic()
227
+ def timeout_handler():
228
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
229
+ conn.set_progress_handler(timeout_handler, 10000)
230
+
231
+ cursor = conn.cursor()
232
+ cursor.execute(pred_sql)
233
+ pred_res = cursor.fetchall()
234
+
235
+ cursor.execute(gold_sql)
236
+ gold_res = cursor.fetchall()
237
+ conn.close()
238
+
239
+ # Simple sorted check for the live tracker
240
+ return sorted(pred_res) == sorted(gold_res)
241
+ except Exception:
242
+ return False
243
+
244
+ # -------------------------------
245
+ # SPIDER PARSER
246
+ # -------------------------------
247
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
248
+ for line in stdout.splitlines():
249
+ if metric_type == "exec" and line.strip().startswith("execution"):
250
+ try: return float(line.split()[-1])
251
+ except: pass
252
+ elif metric_type == "match" and line.strip().startswith("exact"):
253
+ try: return float(line.split()[-1])
254
+ except: pass
255
+ return None
256
+
257
+ # -------------------------------
258
+ # MAIN
259
+ # -------------------------------
260
+ def main():
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
263
+ parser.add_argument("--num_samples", type=int, default=1034, help="Number of samples to evaluate")
264
+ parser.add_argument("--shuffle_dev", action="store_true")
265
+ parser.add_argument("--shuffle_seed", type=int, default=42)
266
+ args = parser.parse_args()
267
+
268
+ project_root = Path(__file__).resolve().parents[1]
269
+ adapter_dir = project_root / args.adapter
270
+
271
+ db_root = project_root / "data" / "database"
272
+ table_json = project_root / "data" / "tables.json"
273
+ dev_json = project_root / "data" / "dev.json"
274
+
275
+ pred_path = project_root / "temp_predictions.txt"
276
+ temp_gold_path = project_root / "temp_gold.sql"
277
+
278
+ if not adapter_dir.exists():
279
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
280
+
281
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
282
+ print(f"Using device: {device}")
283
+
284
+ BASE_MODEL = "Salesforce/codet5-base"
285
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
286
+ if tokenizer.pad_token is None:
287
+ tokenizer.pad_token = tokenizer.eos_token
288
+
289
+ print(f"Loading Model: {args.adapter}...")
290
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
291
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
292
+ model = model.merge_and_unload()
293
+ model.eval()
294
+
295
+ with dev_json.open() as f:
296
+ dev = json.load(f)
297
+
298
+ if args.shuffle_dev:
299
+ rng = random.Random(args.shuffle_seed)
300
+ rng.shuffle(dev)
301
+
302
+ dev = dev[: args.num_samples]
303
+ total = len(dev)
304
+
305
+ gen_kwargs = dict(
306
+ max_new_tokens=160,
307
+ num_beams=4,
308
+ do_sample=False,
309
+ early_stopping=True,
310
+ pad_token_id=tokenizer.pad_token_id,
311
+ eos_token_id=tokenizer.eos_token_id,
312
+ )
313
+
314
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
315
+
316
+ em_correct = 0
317
+ ex_correct = 0
318
+
319
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
320
+ for i, ex in enumerate(dev, start=1):
321
+ db_id = ex["db_id"]
322
+ question = ex["question"]
323
+ gold_query = ex["query"]
324
+ db_path = db_root / db_id / f"{db_id}.sqlite"
325
+
326
+ # Generate
327
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
328
+ input_ids = input_ids.unsqueeze(0).to(device)
329
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
330
+
331
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
332
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
333
+
334
+ # Write to files for official spider eval later
335
+ out_pred.write(f"{pred_sql}\n")
336
+ out_gold.write(f"{gold_query}\t{db_id}\n")
337
+
338
+ # --- LIVE TRACKING CHECKS ---
339
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
340
+ em_correct += 1
341
+ if check_execution(pred_sql, gold_query, db_path):
342
+ ex_correct += 1
343
+
344
+ # Print progress every 50 loops
345
+ if i % 50 == 0 or i == total:
346
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
347
+
348
+ print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
349
+
350
+ eval_script = project_root / "spider_eval" / "evaluation.py"
351
+
352
+ # 1. RUN EXACT MATCH EVAL
353
+ cmd_match = [
354
+ sys.executable, str(eval_script),
355
+ "--gold", str(temp_gold_path),
356
+ "--pred", str(pred_path),
357
+ "--etype", "match",
358
+ "--db", str(db_root),
359
+ "--table", str(table_json),
360
+ ]
361
+ proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
362
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
363
+
364
+ # 2. RUN EXECUTION EVAL
365
+ cmd_exec = [
366
+ sys.executable, str(eval_script),
367
+ "--gold", str(temp_gold_path),
368
+ "--pred", str(pred_path),
369
+ "--etype", "exec",
370
+ "--db", str(db_root),
371
+ "--table", str(table_json),
372
+ ]
373
+ proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
374
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
375
+
376
+ print("==========================================")
377
+ print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
378
+ print("==========================================")
379
+
380
+ if exact_acc is not None:
381
+ print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
382
+ else:
383
+ print("Exact Set Match Accuracy : Could not parse output")
384
+
385
+ if exec_acc is not None:
386
+ print(f"Execution Accuracy : {exec_acc*100:.2f}%")
387
+ else:
388
+ print("Execution Accuracy : Could not parse output")
389
+ print("==========================================\n")
390
+
391
+ if __name__ == "__main__":
392
+ main()
src/evaluate_model_t5_small_sft.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import subprocess
5
+ import sys
6
+ import argparse
7
+ import re
8
+ import sqlite3
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from peft import PeftModel
14
+ from prompting import encode_prompt
15
+
16
+
17
+ # ---------------- PARSE ACC ----------------
18
+ def _parse_exec_accuracy(stdout: str) -> float | None:
19
+ for line in stdout.splitlines():
20
+ if line.strip().startswith("execution"):
21
+ try:
22
+ return float(line.split()[-1])
23
+ except:
24
+ return None
25
+ return None
26
+
27
+
28
+ # ---------------- CLEAN SQL ----------------
29
+ def clean_prediction(pred_sql: str) -> str:
30
+ pred_sql = pred_sql.strip()
31
+
32
+ if "SQL:" in pred_sql:
33
+ pred_sql = pred_sql.split("SQL:")[-1]
34
+
35
+ pred_sql = pred_sql.replace('"', "'")
36
+ pred_sql = re.sub(r"\s+", " ", pred_sql).strip()
37
+
38
+ if not pred_sql.endswith(";"):
39
+ pred_sql += ";"
40
+
41
+ return pred_sql
42
+
43
+
44
+ def main():
45
+
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--adapter", type=str, default="checkpoints/sft_t5")
48
+ parser.add_argument("--num_samples", type=int, default=1000)
49
+ args = parser.parse_args()
50
+
51
+ project_root = Path(__file__).resolve().parents[1]
52
+ adapter_dir = project_root / args.adapter
53
+
54
+ db_root = project_root / "data/database"
55
+ table_json = project_root / "data/tables.json"
56
+ dev_json = project_root / "data/dev.json"
57
+ gold_sql = project_root / "data/dev_gold.sql"
58
+ pred_path = project_root / "pred.sql"
59
+
60
+ if not adapter_dir.exists():
61
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
62
+
63
+ # ---------------- DEVICE ----------------
64
+ device = "mps" if torch.backends.mps.is_available() else (
65
+ "cuda" if torch.cuda.is_available() else "cpu"
66
+ )
67
+ print("Using device:", device)
68
+
69
+ # ---------------- LOAD MODEL ----------------
70
+ BASE_MODEL = "t5-small"
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
73
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
74
+
75
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
76
+ model = model.merge_and_unload()
77
+ model.eval()
78
+
79
+ if tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token = tokenizer.eos_token
81
+
82
+ # ---------------- LOAD DATA ----------------
83
+ with dev_json.open() as f:
84
+ dev = json.load(f)[: args.num_samples]
85
+
86
+ print("Generating predictions...\n")
87
+
88
+ correct = 0
89
+ total = len(dev)
90
+
91
+ # ---------------- GENERATE + LIVE EXEC ----------------
92
+ with pred_path.open("w") as out_f, torch.no_grad():
93
+
94
+ for i, ex in enumerate(dev, start=1):
95
+
96
+ db_id = ex["db_id"]
97
+ question = ex["question"]
98
+ gold_query = ex["query"]
99
+
100
+ prompt_ids = encode_prompt(
101
+ tokenizer,
102
+ question,
103
+ db_id,
104
+ device=device,
105
+ max_input_tokens=512,
106
+ )
107
+
108
+ input_ids = prompt_ids.unsqueeze(0).to(device)
109
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
110
+
111
+ outputs = model.generate(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask,
114
+ max_new_tokens=160,
115
+ num_beams=4,
116
+ do_sample=False,
117
+ early_stopping=True,
118
+ )
119
+
120
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+ pred_sql = clean_prediction(pred_sql)
122
+
123
+ out_f.write(pred_sql + "\n")
124
+
125
+ # -------- LIVE EXECUTION CHECK --------
126
+ try:
127
+ db_path = db_root / db_id / f"{db_id}.sqlite"
128
+
129
+ conn = sqlite3.connect(db_path)
130
+ cursor = conn.cursor()
131
+
132
+ cursor.execute(pred_sql)
133
+ pred_rows = cursor.fetchall()
134
+
135
+ cursor.execute(gold_query)
136
+ gold_rows = cursor.fetchall()
137
+
138
+ conn.close()
139
+
140
+ if sorted(pred_rows) == sorted(gold_rows):
141
+ correct += 1
142
+
143
+ except Exception:
144
+ pass # execution failed
145
+
146
+ # 🔥 PRINT EVERY 10
147
+ if i % 10 == 0 or i == total:
148
+ current_acc = correct / i
149
+ print(f"{i}/{total} | Acc: {current_acc:.3f}")
150
+
151
+ print("\nGeneration finished.\n")
152
+
153
+ # ---------------- SPIDER EVAL ----------------
154
+ eval_script = project_root / "spider_eval/evaluation.py"
155
+
156
+ cmd = [
157
+ sys.executable,
158
+ str(eval_script),
159
+ "--gold", str(gold_sql),
160
+ "--pred", str(pred_path),
161
+ "--etype", "exec",
162
+ "--db", str(db_root),
163
+ "--table", str(table_json),
164
+ ]
165
+
166
+ print("Running Spider evaluation...")
167
+ proc = subprocess.run(cmd, capture_output=True, text=True)
168
+
169
+ print(proc.stdout)
170
+
171
+ exec_acc = _parse_exec_accuracy(proc.stdout)
172
+ if exec_acc is not None:
173
+ print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
174
+ else:
175
+ print("Could not parse accuracy.")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()
src/evaluate_rl_bart.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import sqlite3
4
+ import argparse
5
+ import time
6
+ from pathlib import Path
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
12
+ def build_prompt(question, schema):
13
+ return f"""
14
+ Database Schema:
15
+ {schema}
16
+
17
+ Translate English to SQL:
18
+ {question}
19
+ SQL:
20
+ """
21
+
22
+ # ---------------- LOAD SCHEMA ----------------
23
+ def load_schema(db_path):
24
+ conn = sqlite3.connect(db_path)
25
+ cursor = conn.cursor()
26
+
27
+ tables = cursor.execute(
28
+ "SELECT name FROM sqlite_master WHERE type='table';"
29
+ ).fetchall()
30
+
31
+ schema = ""
32
+ for (table,) in tables:
33
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
34
+ col_names = [c[1] for c in cols]
35
+ schema += f"{table}({', '.join(col_names)})\n"
36
+
37
+ conn.close()
38
+ return schema
39
+
40
+
41
+ # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
42
+ def execution_match(pred_sql, gold_sql, db_path):
43
+ try:
44
+ conn = sqlite3.connect(db_path)
45
+
46
+ # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
47
+ start_time = time.monotonic()
48
+ def timeout_handler():
49
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
50
+ conn.set_progress_handler(timeout_handler, 10000)
51
+
52
+ cur = conn.cursor()
53
+
54
+ cur.execute(pred_sql)
55
+ pred = cur.fetchall()
56
+
57
+ cur.execute(gold_sql)
58
+ gold = cur.fetchall()
59
+
60
+ conn.close()
61
+ return pred == gold
62
+
63
+ except Exception:
64
+ return False
65
+
66
+
67
+ # ---------------- MAIN ----------------
68
+ def main():
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--adapter", type=str, required=True)
71
+ parser.add_argument("--num_samples", type=int, default=1034)
72
+ args = parser.parse_args()
73
+
74
+ project_root = Path(__file__).resolve().parents[1]
75
+
76
+ dev_json = project_root / "data" / "dev.json"
77
+ db_root = project_root / "data" / "database"
78
+
79
+ # 🎯 Added CUDA support for Nvidia GPUs
80
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ # load model
83
+ base_model = "facebook/bart-base"
84
+ print(f"Loading Base: {base_model}")
85
+ print(f"Loading Adapter: {args.adapter}")
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(args.adapter)
88
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
89
+ model = PeftModel.from_pretrained(base, args.adapter).to(device)
90
+ model = model.merge_and_unload()
91
+
92
+ with open(dev_json) as f:
93
+ dev = json.load(f)[: args.num_samples]
94
+
95
+ correct = 0
96
+
97
+ print(f"Evaluating {len(dev)} examples...\n")
98
+
99
+ for i, ex in enumerate(dev, 1):
100
+ question = ex["question"]
101
+ db_id = ex["db_id"]
102
+ gold_sql = ex["query"]
103
+
104
+ db_path = db_root / db_id / f"{db_id}.sqlite"
105
+ schema = load_schema(db_path)
106
+
107
+ prompt = build_prompt(question, schema)
108
+
109
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model.generate(
113
+ **inputs,
114
+ max_new_tokens=80,
115
+ do_sample=False,
116
+ num_beams=4,
117
+ )
118
+
119
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
120
+
121
+ if "SQL:" in pred_sql:
122
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
123
+
124
+ match = execution_match(pred_sql, gold_sql, db_path)
125
+
126
+ if match:
127
+ correct += 1
128
+
129
+ if i % 10 == 0:
130
+ print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
131
+
132
+ print("\n=============================")
133
+ print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
134
+ print("=============================")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
src/evaluate_sft_bart.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import subprocess
5
+ import sys
6
+ import argparse
7
+ import re
8
+ import sqlite3
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from peft import PeftModel
14
+ from prompting import encode_prompt
15
+
16
+
17
+ # ---------------- SQL CLEAN ----------------
18
+ def extract_sql(text: str) -> str:
19
+ text = text.strip()
20
+
21
+ if "SQL:" in text:
22
+ text = text.split("SQL:")[-1]
23
+
24
+ match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL)
25
+ if match:
26
+ text = match.group(1)
27
+
28
+ text = text.replace('"', "'")
29
+ text = re.sub(r"\s+", " ", text).strip()
30
+
31
+ if not text.endswith(";"):
32
+ text += ";"
33
+
34
+ return text
35
+
36
+
37
+ # ---------------- ROBUST ACC PARSER ----------------
38
+ def parse_exec_accuracy(stdout: str):
39
+ for line in stdout.splitlines():
40
+ if "execution" in line.lower():
41
+ numbers = re.findall(r"\d+\.\d+", line)
42
+ if numbers:
43
+ return float(numbers[-1])
44
+ return None
45
+
46
+
47
+ def main():
48
+
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2")
51
+ parser.add_argument("--num_samples", type=int, default=1000)
52
+ args = parser.parse_args()
53
+
54
+ project_root = Path(__file__).resolve().parents[1]
55
+ adapter_dir = project_root / args.adapter
56
+
57
+ if not adapter_dir.exists():
58
+ raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
59
+
60
+ db_root = project_root / "data/database"
61
+ table_json = project_root / "data/tables.json"
62
+ dev_json = project_root / "data/dev.json"
63
+ gold_sql_file = project_root / "data/dev_gold.sql"
64
+ pred_sql_file = project_root / "pred.sql"
65
+
66
+ device = "mps" if torch.backends.mps.is_available() else (
67
+ "cuda" if torch.cuda.is_available() else "cpu"
68
+ )
69
+ print("Using device:", device)
70
+
71
+ # -------- LOAD MODEL --------
72
+ print("Loading tokenizer...")
73
+ tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
74
+
75
+ BASE_MODEL = "facebook/bart-base"
76
+ print(f"Loading base model {BASE_MODEL}...")
77
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
78
+
79
+ print("Loading LoRA adapter...")
80
+ model = PeftModel.from_pretrained(base_model, adapter_dir).to(device)
81
+ model = model.merge_and_unload()
82
+ model.eval()
83
+
84
+ if tokenizer.pad_token_id is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ # -------- LOAD DATA --------
88
+ with open(dev_json) as f:
89
+ dev = json.load(f)[: args.num_samples]
90
+
91
+ print("Generating SQL predictions...\n")
92
+
93
+ correct = 0
94
+ total = len(dev)
95
+
96
+ with open(pred_sql_file, "w") as f, torch.no_grad():
97
+
98
+ for i, ex in enumerate(dev, 1):
99
+
100
+ question = ex["question"]
101
+ db_id = ex["db_id"]
102
+ gold_query = ex["query"]
103
+
104
+ prompt_ids = encode_prompt(
105
+ tokenizer,
106
+ question,
107
+ db_id,
108
+ device=device,
109
+ max_input_tokens=512,
110
+ )
111
+
112
+ input_ids = prompt_ids.unsqueeze(0).to(device)
113
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
114
+
115
+ outputs = model.generate(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ max_new_tokens=160,
119
+ num_beams=4,
120
+ do_sample=False,
121
+ )
122
+
123
+ pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
+ pred_sql = extract_sql(pred)
125
+
126
+ f.write(f"{pred_sql}\t{db_id}\n")
127
+
128
+ # -------- LIVE EXECUTION CHECK --------
129
+ try:
130
+ db_path = db_root / db_id / f"{db_id}.sqlite"
131
+
132
+ conn = sqlite3.connect(db_path)
133
+ cursor = conn.cursor()
134
+
135
+ cursor.execute(pred_sql)
136
+ pred_rows = cursor.fetchall()
137
+
138
+ cursor.execute(gold_query)
139
+ gold_rows = cursor.fetchall()
140
+
141
+ conn.close()
142
+
143
+ # order insensitive comparison
144
+ if sorted(pred_rows) == sorted(gold_rows):
145
+ correct += 1
146
+
147
+ except Exception:
148
+ pass # execution failed
149
+
150
+ if i % 10 == 0 or i == total:
151
+ current_acc = correct / i
152
+ print(f"{i}/{total} | Acc: {current_acc:.3f}")
153
+
154
+ print("\nGeneration finished.\n")
155
+
156
+ # -------- RUN OFFICIAL SPIDER EVAL --------
157
+ eval_script = project_root / "spider_eval/evaluation.py"
158
+ if (project_root / "spider_eval/evaluation_bart.py").exists():
159
+ eval_script = project_root / "spider_eval/evaluation_bart.py"
160
+
161
+ cmd = [
162
+ sys.executable,
163
+ str(eval_script),
164
+ "--gold", str(gold_sql_file),
165
+ "--pred", str(pred_sql_file),
166
+ "--etype", "exec",
167
+ "--db", str(db_root),
168
+ "--table", str(table_json),
169
+ ]
170
+
171
+ print(f"\nRunning Spider evaluation using {eval_script.name}...")
172
+ proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore")
173
+
174
+ if proc.returncode != 0:
175
+ print("\nSpider evaluation crashed.")
176
+ print(proc.stderr)
177
+ return
178
+
179
+ print("\n--- Spider Eval Output ---")
180
+ print("\n".join(proc.stdout.splitlines()[-20:]))
181
+
182
+ acc = parse_exec_accuracy(proc.stdout)
183
+ if acc is not None:
184
+ print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%")
185
+ else:
186
+ print("\nCould not parse official accuracy.")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
src/execution_reward.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import sqlite3
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional, Sequence, Set, Tuple, Union
9
+
10
+ try:
11
+ import sqlparse
12
+ from sqlparse.sql import Function, Identifier, IdentifierList, Statement, Token, Where
13
+ from sqlparse.tokens import DML, Keyword, Name, Number, Punctuation, String, Whitespace
14
+ except Exception: # pragma: no cover
15
+ sqlparse = None # type: ignore[assignment]
16
+ Statement = object # type: ignore[misc,assignment]
17
+ Token = object # type: ignore[misc,assignment]
18
+
19
+
20
+ def _normalize_sql(sql: str) -> str:
21
+ if not isinstance(sql, str):
22
+ return ""
23
+ s = sql.strip()
24
+ if s.startswith("```"):
25
+ # Strip markdown fences if present.
26
+ s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
27
+ s = re.sub(r"\n?```$", "", s).strip()
28
+ if s.lower().startswith("sql:"):
29
+ s = s[4:].strip()
30
+ # Keep only the first statement to avoid accidental multi-statement execution.
31
+ if ";" in s:
32
+ s = s.split(";", 1)[0].strip()
33
+ return s
34
+
35
+
36
+ def _connect_readonly(db_path: str) -> sqlite3.Connection:
37
+ # Read-only prevents any accidental mutation during reward computation.
38
+ # Note: requires SQLite URI support (built-in).
39
+ uri = f"file:{os.path.abspath(db_path)}?mode=ro"
40
+ conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
41
+ conn.execute("PRAGMA query_only = ON;")
42
+ conn.execute("PRAGMA foreign_keys = ON;")
43
+ return conn
44
+
45
+
46
+ def _with_timeout(conn: sqlite3.Connection, timeout_s: float = 1.0) -> None:
47
+ start = time.monotonic()
48
+
49
+ def _handler() -> int:
50
+ return 1 if (time.monotonic() - start) > timeout_s else 0
51
+
52
+ # Call handler every N VM opcodes.
53
+ conn.set_progress_handler(_handler, 10_000)
54
+
55
+
56
+ def _list_tables(conn: sqlite3.Connection) -> List[str]:
57
+ try:
58
+ cur = conn.execute(
59
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
60
+ )
61
+ return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
62
+ except sqlite3.Error:
63
+ return []
64
+
65
+
66
+ def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
67
+ s = sql.lower()
68
+ for t in table_names:
69
+ tl = t.lower()
70
+ if not tl:
71
+ continue
72
+ if re.search(rf"\b{re.escape(tl)}\b", s):
73
+ return True
74
+ return False
75
+
76
+
77
+ def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
78
+ try:
79
+ _with_timeout(conn, timeout_s=1.0)
80
+ conn.execute(f"EXPLAIN QUERY PLAN {sql}")
81
+ return True
82
+ except sqlite3.Error:
83
+ return False
84
+
85
+
86
+ def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
87
+ try:
88
+ _with_timeout(conn, timeout_s=1.0)
89
+ cur = conn.execute(sql)
90
+ rows = cur.fetchmany(max_rows)
91
+ # Normalize to plain tuples for deterministic comparison.
92
+ norm_rows = [tuple(r) for r in rows]
93
+ return True, norm_rows, None
94
+ except sqlite3.Error as e:
95
+ return False, [], str(e)
96
+
97
+
98
+ _SQL_KEYWORDS_TO_IGNORE = {
99
+ "select",
100
+ "from",
101
+ "where",
102
+ "join",
103
+ "inner",
104
+ "left",
105
+ "right",
106
+ "full",
107
+ "outer",
108
+ "on",
109
+ "group",
110
+ "by",
111
+ "order",
112
+ "limit",
113
+ "having",
114
+ "distinct",
115
+ "union",
116
+ "intersect",
117
+ "except",
118
+ "as",
119
+ "and",
120
+ "or",
121
+ "not",
122
+ "in",
123
+ "is",
124
+ "null",
125
+ "like",
126
+ "between",
127
+ "case",
128
+ "when",
129
+ "then",
130
+ "else",
131
+ "end",
132
+ "asc",
133
+ "desc",
134
+ }
135
+
136
+ _SQL_FUNCTIONS_TO_IGNORE = {
137
+ "count",
138
+ "avg",
139
+ "min",
140
+ "max",
141
+ "sum",
142
+ "lower",
143
+ "upper",
144
+ "substr",
145
+ "coalesce",
146
+ "round",
147
+ "date",
148
+ "datetime",
149
+ "strftime",
150
+ }
151
+
152
+
153
+ def extract_tables(sql: str) -> Set[str]:
154
+ """
155
+ Best-effort table extraction from SQL using sqlparse.
156
+ Returns lowercase table names (unqualified).
157
+ """
158
+ sql = _normalize_sql(sql)
159
+ if not sql:
160
+ return set()
161
+ if sqlparse is None:
162
+ # Fallback: naive regex for FROM/JOIN.
163
+ found = set()
164
+ for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I):
165
+ found.add(m.group(2).lower())
166
+ return found
167
+
168
+ try:
169
+ statements = sqlparse.parse(sql)
170
+ except Exception:
171
+ return set()
172
+
173
+ tables: Set[str] = set()
174
+
175
+ def _add_identifier_as_table(ident: Identifier) -> None:
176
+ # Prefer real name over alias; strip any schema prefix.
177
+ name = ident.get_real_name() or ident.get_name()
178
+ if not name:
179
+ return
180
+ tables.add(name.lower())
181
+
182
+ for st in statements:
183
+ if not isinstance(st, Statement):
184
+ continue
185
+ seen_from = False
186
+ for tok in st.flatten():
187
+ if tok.ttype in Whitespace:
188
+ continue
189
+ if tok.ttype is Keyword and tok.value.upper() in {"FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"}:
190
+ seen_from = True
191
+ continue
192
+ if not seen_from:
193
+ continue
194
+
195
+ if isinstance(tok, Identifier):
196
+ _add_identifier_as_table(tok)
197
+ seen_from = False
198
+ elif tok.ttype is Name:
199
+ tables.add(tok.value.lower())
200
+ seen_from = False
201
+ elif tok.ttype is Keyword and tok.value.upper() in {"WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"}:
202
+ seen_from = False
203
+
204
+ return tables
205
+
206
+
207
+ def extract_columns(sql: str) -> Set[str]:
208
+ """
209
+ Best-effort column extraction from SQL using sqlparse.
210
+ Returns lowercase column names (unqualified).
211
+ """
212
+ sql = _normalize_sql(sql)
213
+ if not sql:
214
+ return set()
215
+ if sqlparse is None:
216
+ # Fallback: naive dotted identifiers and bare names after SELECT/WHERE/etc.
217
+ cols = set()
218
+ for m in re.finditer(r"\b([a-zA-Z_][\w$]*)\b", sql):
219
+ w = m.group(1).lower()
220
+ if w in _SQL_KEYWORDS_TO_IGNORE or w in _SQL_FUNCTIONS_TO_IGNORE:
221
+ continue
222
+ cols.add(w)
223
+ return cols
224
+
225
+ try:
226
+ statements = sqlparse.parse(sql)
227
+ except Exception:
228
+ return set()
229
+
230
+ cols: Set[str] = set()
231
+
232
+ def _maybe_add_col(name: Optional[str]) -> None:
233
+ if not name:
234
+ return
235
+ n = name.strip().strip('"').strip("'").lower()
236
+ if not n or n == "*":
237
+ return
238
+ if n in _SQL_KEYWORDS_TO_IGNORE or n in _SQL_FUNCTIONS_TO_IGNORE:
239
+ return
240
+ cols.add(n)
241
+
242
+ def _handle_identifier(ident: Identifier) -> None:
243
+ # If qualified (t.col), keep only col for overlap/hallucination checks.
244
+ _maybe_add_col(ident.get_real_name() or ident.get_name())
245
+
246
+ for st in statements:
247
+ if not isinstance(st, Statement):
248
+ continue
249
+ for tok in st.flatten():
250
+ # Skip whitespace/punctuation/string literals/numbers.
251
+ if getattr(tok, "ttype", None) in (Whitespace, Punctuation, String, Number):
252
+ continue
253
+
254
+ if isinstance(tok, Function):
255
+ fname = tok.get_name()
256
+ if fname:
257
+ # Don't treat function name as a column.
258
+ pass
259
+ continue
260
+
261
+ if isinstance(tok, IdentifierList):
262
+ for ident in tok.get_identifiers():
263
+ if isinstance(ident, Identifier):
264
+ _handle_identifier(ident)
265
+ continue
266
+
267
+ if isinstance(tok, Identifier):
268
+ _handle_identifier(tok)
269
+ continue
270
+
271
+ if getattr(tok, "ttype", None) is Name:
272
+ _maybe_add_col(tok.value)
273
+
274
+ return cols
275
+
276
+
277
+ def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
278
+ """
279
+ Return (tables, columns) sets from SQLite schema; all lowercased.
280
+ Columns are returned as a global set (unqualified).
281
+ """
282
+ tables = set()
283
+ columns = set()
284
+ for t in _list_tables(conn):
285
+ tl = t.lower()
286
+ if not tl:
287
+ continue
288
+ tables.add(tl)
289
+ try:
290
+ cur = conn.execute(f'PRAGMA table_info("{t}")')
291
+ for row in cur.fetchall():
292
+ if row and isinstance(row[1], str):
293
+ columns.add(row[1].lower())
294
+ except sqlite3.Error:
295
+ continue
296
+ return tables, columns
297
+
298
+
299
+ def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
300
+ # Deterministic comparison: compare exact row tuples in order.
301
+ return a == b
302
+
303
+
304
+ @dataclass
305
+ class RewardDebugStats:
306
+ total: int = 0
307
+ parsed_ok: int = 0
308
+ table_match: int = 0
309
+ column_match: int = 0
310
+ executed_ok: int = 0
311
+ exact_match: int = 0
312
+
313
+
314
+ _DEBUG = RewardDebugStats()
315
+
316
+
317
+ def reset_debug_metrics() -> None:
318
+ global _DEBUG
319
+ _DEBUG = RewardDebugStats()
320
+
321
+
322
+ def get_debug_metrics() -> dict:
323
+ denom = max(_DEBUG.total, 1)
324
+ return {
325
+ "valid_sql_rate": _DEBUG.parsed_ok / denom,
326
+ "table_match_rate": _DEBUG.table_match / denom,
327
+ "column_match_rate": _DEBUG.column_match / denom,
328
+ "execution_accuracy": _DEBUG.exact_match / denom,
329
+ }
330
+
331
+ EXECUTION_ERROR = "EXECUTION_ERROR"
332
+
333
+
334
+ def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
335
+ """
336
+ Execute SQL safely.
337
+
338
+ If sqlite raises ANY exception, return EXECUTION_ERROR (NOT empty list).
339
+ """
340
+ try:
341
+ _with_timeout(conn, timeout_s=1.0)
342
+ cur = conn.execute(sql)
343
+ rows = cur.fetchmany(max_rows)
344
+ return [tuple(r) for r in rows]
345
+ except Exception:
346
+ return EXECUTION_ERROR
347
+
348
+
349
+ def _sqlparse_valid_select(sql: str) -> bool:
350
+ """
351
+ Parse validation using sqlparse:
352
+ - parse() non-empty
353
+ - contains a SELECT statement
354
+ """
355
+ if sqlparse is None:
356
+ return False
357
+ try:
358
+ stmts = sqlparse.parse(sql)
359
+ if not stmts:
360
+ return False
361
+ for st in stmts:
362
+ try:
363
+ if hasattr(st, "get_type") and st.get_type() == "SELECT":
364
+ return True
365
+ except Exception:
366
+ continue
367
+ return False
368
+ except Exception:
369
+ return False
370
+
371
+ def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
372
+ try:
373
+ sql = _normalize_sql(pred_sql)
374
+ gold = _normalize_sql(gold_sql)
375
+
376
+ if not sql or "SELECT" not in sql.upper():
377
+ return -1.0
378
+
379
+ if not _sqlparse_valid_select(sql):
380
+ return -1.0
381
+
382
+ reward = -0.2 # valid SQL baseline
383
+
384
+ pred_tables = extract_tables(sql)
385
+ gold_tables = extract_tables(gold)
386
+
387
+ if pred_tables == gold_tables and len(gold_tables) > 0:
388
+ reward += 0.3
389
+
390
+ pred_cols = extract_columns(sql)
391
+ gold_cols = extract_columns(gold)
392
+
393
+ if gold_cols:
394
+ overlap = len(pred_cols & gold_cols) / len(gold_cols)
395
+ reward += 0.3 * overlap
396
+
397
+ with _connect_readonly(db_path) as conn:
398
+ pred_res = execute_sql(conn, sql)
399
+ if pred_res != EXECUTION_ERROR:
400
+ reward += 0.2
401
+
402
+ gold_res = execute_sql(conn, gold)
403
+ if pred_res != EXECUTION_ERROR and _safe_results_equal(pred_res, gold_res):
404
+ return 1.0
405
+
406
+ return max(-1.0, min(1.0, reward))
407
+
408
+ except Exception:
409
+ return -1.0
src/generate_sql.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+ from transformers import AutoModelForSeq2SeqLM
7
+ from peft import PeftModel
8
+
9
+ from prompting import encode_prompt
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(description="Generate SQL from a question + db_id using the RLHF model.")
14
+ parser.add_argument("--question", type=str, required=True)
15
+ parser.add_argument("--db_id", type=str, required=True)
16
+ parser.add_argument("--model_dir", type=str, default=None, help="Defaults to outputs/rlhf_text2sql")
17
+ parser.add_argument("--use_schema", action="store_true", help="Include schema in the prompt (must match training).")
18
+ parser.add_argument("--max_schema_chars", type=int, default=1500)
19
+ parser.add_argument("--max_new_tokens", type=int, default=80)
20
+ args = parser.parse_args()
21
+
22
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
23
+
24
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
25
+ adapter_dir = args.model_dir or os.path.join(project_root, "outputs", "rlhf_text2sql")
26
+ base_model = os.environ.get("BASE_MODEL", "t5-small")
27
+ fallback_base_model = os.path.join(project_root, "models", "t5_spider_sft")
28
+ if not os.path.isdir(base_model) and os.path.isdir(fallback_base_model):
29
+ base_model = fallback_base_model
30
+
31
+ local_only = not os.path.isdir(base_model)
32
+ tokenizer_source = adapter_dir if os.path.isdir(adapter_dir) else base_model
33
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, local_files_only=not os.path.isdir(tokenizer_source))
34
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model, local_files_only=local_only).to(device)
35
+ model = PeftModel.from_pretrained(base, adapter_dir).to(device)
36
+ # Merge adapters for faster/stabler generation.
37
+ model = model.merge_and_unload()
38
+ model.config.use_cache = False
39
+
40
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
41
+ tokenizer.pad_token = tokenizer.eos_token
42
+
43
+ input_ids = encode_prompt(
44
+ tokenizer,
45
+ args.question,
46
+ args.db_id,
47
+ device=device,
48
+ max_input_tokens=512,
49
+ )
50
+
51
+ gen_kwargs = dict(
52
+ max_new_tokens=args.max_new_tokens,
53
+ do_sample=False,
54
+ num_beams=1,
55
+ early_stopping=True,
56
+ pad_token_id=tokenizer.pad_token_id,
57
+ eos_token_id=tokenizer.eos_token_id,
58
+ )
59
+
60
+ with torch.no_grad():
61
+ out = model.generate(input_ids=input_ids.unsqueeze(0), **gen_kwargs)
62
+
63
+ sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
64
+ print(sql)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
src/human_eval_runner.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ from pathlib import Path
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from peft import PeftModel
7
+
8
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
9
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
10
+
11
+ # Added CUDA fallback for consistency
12
+ DEVICE = "mps" if torch.backends.mps.is_available() else (
13
+ "cuda" if torch.cuda.is_available() else "cpu"
14
+ )
15
+
16
+ # ================= LOAD MODEL =================
17
+ def load_model(adapter_path):
18
+ base_name = "Salesforce/codet5-base"
19
+
20
+ # 🐛 FIXED: Convert relative path to absolute path to prevent Hugging Face 404 errors
21
+ abs_path = (PROJECT_ROOT / adapter_path).resolve()
22
+ if not abs_path.exists():
23
+ raise FileNotFoundError(f"Adapter not found at: {abs_path}")
24
+
25
+ print(f"\nLoading model from: {abs_path}")
26
+
27
+ # 🐛 FIXED: Added fallback in case tokenizer isn't saved in the adapter folder
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(str(abs_path), local_files_only=True)
30
+ except Exception:
31
+ print("Adapter tokenizer missing — using base tokenizer")
32
+ tokenizer = AutoTokenizer.from_pretrained(base_name)
33
+
34
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_name).to(DEVICE)
35
+ model = PeftModel.from_pretrained(base, str(abs_path)).to(DEVICE)
36
+ model.eval()
37
+
38
+ return tokenizer, model
39
+
40
+
41
+ # ================= SCHEMA =================
42
+ def load_schema(db_id):
43
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
44
+ conn = sqlite3.connect(db_path)
45
+ cursor = conn.cursor()
46
+
47
+ tables = cursor.execute(
48
+ "SELECT name FROM sqlite_master WHERE type='table';"
49
+ ).fetchall()
50
+
51
+ schema = ""
52
+ for (table,) in tables:
53
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
54
+ col_names = [c[1] for c in cols]
55
+ schema += f"{table}({', '.join(col_names)})\n"
56
+
57
+ conn.close()
58
+ return schema
59
+
60
+
61
+ # ================= GENERATE =================
62
+ def generate_sql(tokenizer, model, question, db_id):
63
+ schema = load_schema(db_id)
64
+
65
+ prompt = f"""
66
+ Database Schema:
67
+ {schema}
68
+
69
+ Translate English to SQL:
70
+ {question}
71
+ SQL:
72
+ """
73
+
74
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
75
+
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ **inputs,
79
+ max_new_tokens=120,
80
+ num_beams=4,
81
+ do_sample=False
82
+ )
83
+
84
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ if "SQL:" in sql:
87
+ sql = sql.split("SQL:")[-1]
88
+
89
+ return sql.strip()
90
+
91
+
92
+ # ================= EXECUTE =================
93
+ def try_execute(sql, db_id):
94
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
95
+ try:
96
+ conn = sqlite3.connect(db_path)
97
+ cur = conn.cursor()
98
+ cur.execute(sql)
99
+ cur.fetchall()
100
+ conn.close()
101
+ return True
102
+ except:
103
+ return False
104
+
105
+
106
+ # ================= MAIN =================
107
+ def main():
108
+ # paths (change if needed)
109
+ SFT_MODEL = "checkpoints/sft_adapter_codet5" # Ensure this matches your actual SFT folder name!
110
+ RLHF_MODEL = "checkpoints/best_rlhf_model"
111
+
112
+ tokenizer_sft, model_sft = load_model(SFT_MODEL)
113
+ tokenizer_rl, model_rl = load_model(RLHF_MODEL)
114
+
115
+ human_eval_path = PROJECT_ROOT / "data/human_eval.json"
116
+ with open(human_eval_path) as f:
117
+ questions = json.load(f)
118
+
119
+ sft_success = 0
120
+ rl_success = 0
121
+
122
+ print("\nRunning Human Evaluation...\n")
123
+
124
+ for i, q in enumerate(questions, 1):
125
+ db = q["db_id"]
126
+ question = q["question"]
127
+
128
+ sql_sft = generate_sql(tokenizer_sft, model_sft, question, db)
129
+ sql_rl = generate_sql(tokenizer_rl, model_rl, question, db)
130
+
131
+ ok_sft = try_execute(sql_sft, db)
132
+ ok_rl = try_execute(sql_rl, db)
133
+
134
+ if ok_sft:
135
+ sft_success += 1
136
+ if ok_rl:
137
+ rl_success += 1
138
+
139
+ print(f"\nQ{i}: {question}")
140
+ print(f"SFT : {'OK' if ok_sft else 'FAIL'}")
141
+ print(f"RLHF: {'OK' if ok_rl else 'FAIL'}")
142
+
143
+ print("\n=============================")
144
+ print("HUMAN EVALUATION RESULT")
145
+ print("=============================")
146
+ print(f"SFT Success: {sft_success}/{len(questions)} = {sft_success/len(questions)*100:.2f}%")
147
+ print(f"RLHF Success: {rl_success}/{len(questions)} = {rl_success/len(questions)*100:.2f}%")
148
+ print("=============================\n")
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()
src/load_lora_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ from peft import LoraConfig, get_peft_model, TaskType
4
+
5
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
6
+
7
+ MODEL_PATH = "../outputs/model" # your supervised trained model
8
+
9
+ print("Loading base model...")
10
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
11
+
12
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
13
+
14
+ # ---------------- LoRA CONFIG ----------------
15
+ lora_config = LoraConfig(
16
+ r=8, # rank (small brain attachment)
17
+ lora_alpha=16,
18
+ target_modules=["q", "v"], # attention matrices only
19
+ lora_dropout=0.05,
20
+ bias="none",
21
+ task_type=TaskType.SEQ_2_SEQ_LM
22
+ )
23
+
24
+ print("Attaching LoRA adapters...")
25
+ model = get_peft_model(model, lora_config)
26
+
27
+ model.print_trainable_parameters()
28
+
29
+ print("READY ✔ LoRA model loaded")
30
+
src/make_rl_dataset.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datasets import load_dataset
3
+
4
+ print("Loading Spider dataset...")
5
+ dataset = load_dataset("spider", split="train")
6
+
7
+ data = []
8
+
9
+ for ex in dataset:
10
+ data.append({
11
+ "question": ex["question"],
12
+ "query": ex["query"],
13
+ "db_id": ex["db_id"] # ⭐ CRITICAL FIELD
14
+ })
15
+
16
+ print("Saving JSON...")
17
+ with open("data/train_spider.json", "w") as f:
18
+ json.dump(data, f, indent=2)
19
+
20
+ print("Done! File saved at data/train_spider.json")
src/manual_check.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from peft import PeftModel
4
+
5
+ BASE_MODEL = "Salesforce/codet5-base"
6
+ ADAPTER = "checkpoints/sft_adapter" # change if needed
7
+
8
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
9
+
10
+ print("Loading model...")
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
14
+ model = PeftModel.from_pretrained(model, ADAPTER)
15
+
16
+ model = model.to(device)
17
+ model.eval()
18
+
19
+ # 5 random Spider style questions
20
+ questions = [
21
+ "List all employee names",
22
+ "Find the number of students in each department",
23
+ "Show the average salary of employees",
24
+ "Which flights depart from LA?",
25
+ "Find customers who bought more than 5 items"
26
+ ]
27
+
28
+ for q in questions:
29
+ prompt = f"Translate to SQL: {q}"
30
+
31
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
32
+
33
+ with torch.no_grad():
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=128,
37
+ temperature=0.0, # deterministic
38
+ )
39
+
40
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+
42
+ print("\nQUESTION:", q)
43
+ print("SQL:", sql)
44
+ print("-"*60)
src/predict.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sqlite3
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ # --------------------------------------------------
6
+ # PATH
7
+ # --------------------------------------------------
8
+ MODEL_PATH = "outputs/model"
9
+
10
+ print("Loading tokenizer...")
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
+
13
+ print("Loading fine-tuned model...")
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
15
+ model.eval()
16
+
17
+ # --------------------------------------------------
18
+ # CONNECT DATABASE
19
+ # --------------------------------------------------
20
+ print("Connecting to database...")
21
+ # conn = sqlite3.connect("../data/database/department_management/department_management.sqlite")
22
+ conn = sqlite3.connect("data/database/department_management/department_management.sqlite")
23
+ cursor = conn.cursor()
24
+ print("Database connected ✔")
25
+
26
+ # --------------------------------------------------
27
+ # BUILD PROMPT
28
+ # --------------------------------------------------
29
+ def build_prompt(question):
30
+ schema = """
31
+ Table department columns = Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees.
32
+ Table head columns = head_ID, name, born_state, age.
33
+ Table management columns = department_ID, head_ID, temporary_acting.
34
+ """
35
+ return f"translate English to SQL: {schema} question: {question}"
36
+
37
+ # --------------------------------------------------
38
+ # GENERATE SQL
39
+ # --------------------------------------------------
40
+ def generate_sql(question):
41
+
42
+ prompt = build_prompt(question)
43
+
44
+ encoding = tokenizer(
45
+ prompt,
46
+ return_tensors="pt",
47
+ truncation=True,
48
+ padding=True,
49
+ max_length=256
50
+ )
51
+
52
+ with torch.no_grad():
53
+ outputs = model.generate(
54
+ input_ids=encoding["input_ids"],
55
+ attention_mask=encoding["attention_mask"],
56
+ max_length=256,
57
+ num_beams=5,
58
+ early_stopping=True
59
+ )
60
+
61
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+ return sql.strip()
63
+
64
+ # --------------------------------------------------
65
+ # EVALUATE SQL (REWARD FUNCTION)
66
+ # --------------------------------------------------
67
+ def evaluate_sql(sql):
68
+ try:
69
+ cursor.execute(sql)
70
+ rows = cursor.fetchall()
71
+
72
+ # executed but no useful result
73
+ if len(rows) == 0:
74
+ return -0.2, rows
75
+
76
+ # good query
77
+ else:
78
+ return 1.0, rows
79
+
80
+ except Exception as e:
81
+ # invalid SQL
82
+ return -1.0, str(e)
83
+
84
+ # --------------------------------------------------
85
+ # INTERACTIVE LOOP
86
+ # --------------------------------------------------
87
+ while True:
88
+ q = input("\nAsk question (type exit to quit): ")
89
+
90
+ if q.lower() == "exit":
91
+ break
92
+
93
+ sql = generate_sql(q)
94
+
95
+ print("\nPredicted SQL:")
96
+ print(sql)
97
+
98
+ # ---------------- RUN SQL + REWARD ----------------
99
+ reward, output = evaluate_sql(sql)
100
+
101
+ print("\nReward:", reward)
102
+
103
+ if reward == -1.0:
104
+ print("SQL Error:", output)
105
+
106
+ elif reward == -0.2:
107
+ print("No results found")
108
+
109
+ else:
110
+ print("\nAnswer:")
111
+ for r in output:
112
+ print(r)
src/prepare_dataset.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sqlite3
4
+ from datasets import Dataset
5
+ from transformers import T5Tokenizer
6
+
7
+ # =========================================================
8
+ # PROJECT ROOT (VERY IMPORTANT — fixes path issues)
9
+ # =========================================================
10
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ TRAIN_JSON = os.path.join(BASE_DIR, "data", "train_spider.json")
13
+ DEV_JSON = os.path.join(BASE_DIR, "data", "dev.json")
14
+ DB_FOLDER = os.path.join(BASE_DIR, "data", "database")
15
+
16
+ SAVE_TRAIN = os.path.join(BASE_DIR, "data", "tokenized", "train")
17
+ SAVE_DEV = os.path.join(BASE_DIR, "data", "tokenized", "validation")
18
+
19
+ os.makedirs(os.path.dirname(SAVE_TRAIN), exist_ok=True)
20
+
21
+ print("Project root:", BASE_DIR)
22
+ print("Train file:", TRAIN_JSON)
23
+ print("Database folder:", DB_FOLDER)
24
+
25
+ # =========================================================
26
+ # TOKENIZER
27
+ # =========================================================
28
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
29
+
30
+ # =========================================================
31
+ # READ DATABASE SCHEMA
32
+ # =========================================================
33
+ def get_schema(db_path):
34
+ conn = sqlite3.connect(db_path)
35
+ cursor = conn.cursor()
36
+
37
+ tables = cursor.execute(
38
+ "SELECT name FROM sqlite_master WHERE type='table';"
39
+ ).fetchall()
40
+
41
+ schema_text = []
42
+
43
+ for table in tables:
44
+ table = table[0]
45
+
46
+ columns = cursor.execute(f"PRAGMA table_info({table});").fetchall()
47
+ col_names = [c[1] for c in columns]
48
+
49
+ schema_text.append(f"{table}({', '.join(col_names)})")
50
+
51
+ conn.close()
52
+ return "\n".join(schema_text)
53
+
54
+
55
+ # =========================================================
56
+ # BUILD TRAINING EXAMPLES
57
+ # =========================================================
58
+ def build_examples(spider_json):
59
+
60
+ print(f"\nBuilding dataset from: {spider_json}")
61
+
62
+ data = json.load(open(spider_json))
63
+
64
+ inputs = []
65
+ outputs = []
66
+
67
+ for ex in data:
68
+
69
+ question = ex["question"]
70
+ sql = ex["query"]
71
+ db_id = ex["db_id"]
72
+
73
+ db_path = os.path.join(DB_FOLDER, db_id, f"{db_id}.sqlite")
74
+
75
+ # skip if db missing (safety)
76
+ if not os.path.exists(db_path):
77
+ continue
78
+
79
+ schema = get_schema(db_path)
80
+
81
+ # ⭐ SCHEMA-AWARE PROMPT (VERY IMPORTANT)
82
+ input_text = f"""Database Schema:
83
+ {schema}
84
+
85
+ Translate English to SQL:
86
+ {question}
87
+ SQL:
88
+ """
89
+
90
+ inputs.append(input_text)
91
+ outputs.append(sql)
92
+
93
+ return Dataset.from_dict({"input": inputs, "output": outputs})
94
+
95
+
96
+ # =========================================================
97
+ # TOKENIZE
98
+ # =========================================================
99
+ def tokenize(example):
100
+
101
+ model_input = tokenizer(
102
+ example["input"],
103
+ max_length=512,
104
+ padding="max_length",
105
+ truncation=True
106
+ )
107
+
108
+ label = tokenizer(
109
+ example["output"],
110
+ max_length=256,
111
+ padding="max_length",
112
+ truncation=True
113
+ )
114
+
115
+ model_input["labels"] = label["input_ids"]
116
+ return model_input
117
+
118
+
119
+ # =========================================================
120
+ # RUN PIPELINE
121
+ # =========================================================
122
+ print("\nBuilding TRAIN dataset...")
123
+ train_dataset = build_examples(TRAIN_JSON)
124
+
125
+ print("Tokenizing TRAIN dataset...")
126
+ tokenized_train = train_dataset.map(tokenize, batched=False)
127
+
128
+ print("Saving TRAIN dataset...")
129
+ tokenized_train.save_to_disk(SAVE_TRAIN)
130
+
131
+
132
+ print("\nBuilding VALIDATION dataset...")
133
+ val_dataset = build_examples(DEV_JSON)
134
+
135
+ print("Tokenizing VALIDATION dataset...")
136
+ tokenized_val = val_dataset.map(tokenize, batched=False)
137
+
138
+ print("Saving VALIDATION dataset...")
139
+ tokenized_val.save_to_disk(SAVE_DEV)
140
+
141
+ print("\nDONE ✔ Dataset prepared successfully!")
142
+ print("Train saved at:", SAVE_TRAIN)
143
+ print("Validation saved at:", SAVE_DEV)
src/prompting.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import sqlite3
6
+ from contextlib import closing
7
+ from typing import Dict, Optional
8
+
9
+ import torch
10
+
11
+ # Keep for compatibility with existing imports. Schema linking is disabled for
12
+ # SFT/RL alignment in this project version (full schema, deterministic order).
13
+ USE_SCHEMA_LINKING = False
14
+
15
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
+ DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
17
+
18
+ SCHEMA_CACHE: Dict[str, str] = {}
19
+
20
+
21
+ def get_schema_text(db_id: str) -> str:
22
+ """
23
+ Deterministic schema string:
24
+ table(col1, col2, ...)
25
+ Tables ordered alphabetically. Columns kept in PRAGMA order.
26
+ """
27
+ if db_id in SCHEMA_CACHE:
28
+ return SCHEMA_CACHE[db_id]
29
+
30
+ db_path = os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
31
+ schema_lines = []
32
+ try:
33
+ with closing(sqlite3.connect(db_path)) as conn:
34
+ cur = conn.cursor()
35
+ tables = cur.execute(
36
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
37
+ ).fetchall()
38
+ table_names = sorted([t[0] for t in tables if t and isinstance(t[0], str)])
39
+ for tname in table_names:
40
+ cols = cur.execute(f'PRAGMA table_info("{tname}")').fetchall()
41
+ col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
42
+ schema_lines.append(f"{tname}({', '.join(col_names)})")
43
+ except Exception:
44
+ schema_lines = []
45
+
46
+ schema_text = "\n".join(schema_lines).strip()
47
+ SCHEMA_CACHE[db_id] = schema_text
48
+ return schema_text
49
+
50
+
51
+ def clean_gold_sql(sql: str) -> str:
52
+ """
53
+ Lowercase SQL + strip common Spider aliases safely.
54
+ If alias removal is ambiguous (same table used multiple times), keep SQL as-is.
55
+ """
56
+ if not isinstance(sql, str):
57
+ return ""
58
+ s = sql.strip().rstrip(";").strip()
59
+ if not s:
60
+ return ""
61
+
62
+ # Attempt to resolve T1/T2 aliases to table names for simple cases.
63
+ # Build alias -> table map from FROM/JOIN clauses.
64
+ alias_map: Dict[str, str] = {}
65
+ table_counts: Dict[str, int] = {}
66
+
67
+ for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)\s+(?:as\s+)?(t\d+)\b", s, flags=re.I):
68
+ table = m.group(2)
69
+ alias = m.group(3)
70
+ table_counts[table.lower()] = table_counts.get(table.lower(), 0) + 1
71
+ alias_map[alias.lower()] = table
72
+
73
+ # If any table appears multiple times, alias removal can be ambiguous → skip.
74
+ if any(c > 1 for c in table_counts.values()):
75
+ return s.lower()
76
+
77
+ # Replace alias-qualified refs alias.col -> table.col
78
+ out = s
79
+ for alias, table in alias_map.items():
80
+ out = re.sub(rf"\b{re.escape(alias)}\.", f"{table}.", out, flags=re.I)
81
+
82
+ # Remove alias declarations: "table AS t1" or "table t1"
83
+ for alias, table in alias_map.items():
84
+ out = re.sub(rf"\b{re.escape(table)}\s+as\s+{re.escape(alias)}\b", table, out, flags=re.I)
85
+ out = re.sub(rf"\b{re.escape(table)}\s+{re.escape(alias)}\b", table, out, flags=re.I)
86
+
87
+ return out.lower().strip()
88
+
89
+
90
+ def build_prompt(
91
+ question: str,
92
+ db_id: str,
93
+ *,
94
+ schema_text: str,
95
+ training_sql: Optional[str] = None,
96
+ ) -> str:
97
+ """
98
+ Required prompt format:
99
+
100
+ You are a SQLite expert.
101
+
102
+ Database: <db_id>
103
+
104
+ Schema:
105
+ <table>(col1, col2, ...)
106
+ ...
107
+
108
+ Question:
109
+ <question>
110
+
111
+ SQL:
112
+ <gold sql> (training only)
113
+ """
114
+ base = (
115
+ "You are a SQLite expert.\n\n"
116
+ f"Database: {db_id}\n\n"
117
+ "Schema:\n"
118
+ f"{schema_text}\n\n"
119
+ "Question:\n"
120
+ f"{question}\n\n"
121
+ "SQL:"
122
+ )
123
+ if training_sql is None:
124
+ return base
125
+ return base + "\n" + training_sql
126
+
127
+
128
+ def encode_prompt(
129
+ tokenizer,
130
+ question: str,
131
+ db_id: str,
132
+ *,
133
+ device: str,
134
+ max_input_tokens: int = 512,
135
+ training_sql: Optional[str] = None,
136
+ ) -> torch.Tensor:
137
+ """
138
+ Inference mode: stops at "SQL:"
139
+ Training mode: can include SQL target (optional; we still recommend decoder labels).
140
+ Truncation happens only on schema portion by character trimming (deterministic).
141
+ """
142
+ schema_text = get_schema_text(db_id)
143
+ prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=training_sql)
144
+ enc = tokenizer(
145
+ prompt,
146
+ truncation=True,
147
+ max_length=max_input_tokens,
148
+ padding=False,
149
+ return_tensors="pt",
150
+ )
151
+ return enc.input_ids[0].to(device)
src/run_sql.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sqlite3
5
+ from contextlib import closing
6
+
7
+
8
+ def execute_sql(db_path: str, sql: str):
9
+ try:
10
+ with closing(sqlite3.connect(db_path)) as conn:
11
+ cursor = conn.cursor()
12
+ cursor.execute(sql)
13
+ rows = cursor.fetchall()
14
+ return {"ok": True, "rows": rows, "error": None}
15
+ except Exception as e:
16
+ return {"ok": False, "rows": [], "error": str(e)}
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser(description="Safely execute SQL against a Spider SQLite DB.")
21
+ group = parser.add_mutually_exclusive_group(required=True)
22
+ group.add_argument("--db_path", type=str, help="Path to SQLite database file")
23
+ group.add_argument("--db_id", type=str, help="Spider database id (uses data/database/<db_id>/<db_id>.sqlite)")
24
+ parser.add_argument("--sql", type=str, required=True, help="SQL to execute")
25
+ args = parser.parse_args()
26
+
27
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
28
+ if args.db_path:
29
+ db_path = args.db_path
30
+ else:
31
+ db_path = os.path.join(project_root, "data", "database", args.db_id, f"{args.db_id}.sqlite")
32
+
33
+ result = execute_sql(db_path, args.sql)
34
+ print(json.dumps(result, ensure_ascii=False, default=str))
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()
39
+
src/schema_encoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+
4
+ class SchemaEncoder:
5
+
6
+ def __init__(self, db_root):
7
+ self.db_root = db_root
8
+
9
+ def get_tables_and_columns(self, db_id):
10
+ db_path = self.db_root / db_id / f"{db_id}.sqlite"
11
+ conn = sqlite3.connect(db_path)
12
+ cursor = conn.cursor()
13
+
14
+ tables = cursor.execute(
15
+ "SELECT name FROM sqlite_master WHERE type='table';"
16
+ ).fetchall()
17
+
18
+ schema = {}
19
+
20
+ for (table,) in tables:
21
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
22
+ col_names = [c[1] for c in cols]
23
+ schema[table] = col_names
24
+
25
+ conn.close()
26
+ return schema
27
+
28
+ # -----------------------------------
29
+ # Strategy 1: Structured (current)
30
+ # -----------------------------------
31
+ def structured_schema(self, db_id):
32
+ schema = self.get_tables_and_columns(db_id)
33
+
34
+ lines = []
35
+ for table, cols in schema.items():
36
+ lines.append(f"{table}({', '.join(cols)})")
37
+
38
+ return "\n".join(lines)
39
+
40
+ # -----------------------------------
41
+ # Strategy 2: Natural Language
42
+ # -----------------------------------
43
+ def natural_language_schema(self, db_id):
44
+ schema = self.get_tables_and_columns(db_id)
45
+
46
+ lines = []
47
+ for table, cols in schema.items():
48
+ col_text = ", ".join(cols)
49
+ lines.append(f"The table '{table}' contains the columns: {col_text}.")
50
+
51
+ return "\n".join(lines)
src/schema_linker.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple schema linking for Spider-style Text-to-SQL.
3
+
4
+ Goal:
5
+ - Given (question, db_id), select a small set of relevant tables/columns
6
+ to include in the prompt (RAG-style schema retrieval).
7
+
8
+ Design constraints:
9
+ - Pure Python (no heavy external deps).
10
+ - Robust to missing/odd schemas: never crash.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import re
18
+ import sqlite3
19
+ from contextlib import closing
20
+ from dataclasses import dataclass
21
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
22
+
23
+
24
+ _ALNUM_RE = re.compile(r"[A-Za-z0-9]+")
25
+ _CAMEL_RE = re.compile(r"([a-z])([A-Z])")
26
+
27
+
28
+ def _normalize_identifier(text: str) -> str:
29
+ """
30
+ Normalize a schema identifier:
31
+ - split underscores
32
+ - split camelCase / PascalCase boundaries
33
+ - lowercase
34
+ """
35
+ text = str(text or "")
36
+ text = text.replace("_", " ")
37
+ text = _CAMEL_RE.sub(r"\1 \2", text)
38
+ return text.lower()
39
+
40
+
41
+ def _tokenize(text: str) -> List[str]:
42
+ text = _normalize_identifier(text)
43
+ return _ALNUM_RE.findall(text)
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class TableSchema:
48
+ table_name: str
49
+ columns: Tuple[str, ...]
50
+
51
+
52
+ class SchemaLinker:
53
+ """
54
+ Loads Spider `tables.json` and (optionally) SQLite schemas from disk.
55
+ Provides a lightweight table scoring function based on token overlap.
56
+ """
57
+
58
+ def __init__(self, tables_json_path: str, db_root: Optional[str] = None):
59
+ self.tables_json_path = tables_json_path
60
+ self.db_root = db_root
61
+ self._tables_by_db: Dict[str, List[TableSchema]] = {}
62
+ self._sqlite_schema_cache: Dict[str, Dict[str, List[str]]] = {}
63
+ self._load_tables_json()
64
+
65
+ def _load_tables_json(self) -> None:
66
+ with open(self.tables_json_path) as f:
67
+ entries = json.load(f)
68
+
69
+ tables_by_db: Dict[str, List[TableSchema]] = {}
70
+ for entry in entries:
71
+ db_id = entry["db_id"]
72
+ table_names: List[str] = entry.get("table_names_original") or entry.get("table_names") or []
73
+ col_names: List[Sequence] = entry.get("column_names_original") or entry.get("column_names") or []
74
+
75
+ columns_by_table_idx: Dict[int, List[str]] = {i: [] for i in range(len(table_names))}
76
+ for col in col_names:
77
+ # Spider format: [table_idx, col_name]
78
+ if not col or len(col) < 2:
79
+ continue
80
+ table_idx, col_name = col[0], col[1]
81
+ if table_idx is None or table_idx < 0:
82
+ continue # skip "*"
83
+ if table_idx not in columns_by_table_idx:
84
+ continue
85
+ columns_by_table_idx[table_idx].append(str(col_name))
86
+
87
+ tables: List[TableSchema] = []
88
+ for i, tname in enumerate(table_names):
89
+ cols = tuple(columns_by_table_idx.get(i, []))
90
+ tables.append(TableSchema(table_name=str(tname), columns=cols))
91
+
92
+ tables_by_db[db_id] = tables
93
+
94
+ self._tables_by_db = tables_by_db
95
+
96
+ def _db_path(self, db_id: str) -> Optional[str]:
97
+ if not self.db_root:
98
+ return None
99
+ path = os.path.join(self.db_root, db_id, f"{db_id}.sqlite")
100
+ return path if os.path.exists(path) else None
101
+
102
+ def _load_sqlite_schema(self, db_id: str) -> Dict[str, List[str]]:
103
+ """
104
+ Load actual SQLite schema (table -> columns). Cached per db_id.
105
+ """
106
+ if db_id in self._sqlite_schema_cache:
107
+ return self._sqlite_schema_cache[db_id]
108
+
109
+ schema: Dict[str, List[str]] = {}
110
+ db_path = self._db_path(db_id)
111
+ if not db_path:
112
+ self._sqlite_schema_cache[db_id] = schema
113
+ return schema
114
+
115
+ try:
116
+ with closing(sqlite3.connect(db_path)) as conn:
117
+ cursor = conn.cursor()
118
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
119
+ for (table_name,) in tables:
120
+ columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
121
+ schema[str(table_name)] = [str(col[1]) for col in columns]
122
+ except Exception:
123
+ schema = {}
124
+
125
+ self._sqlite_schema_cache[db_id] = schema
126
+ return schema
127
+
128
+ def get_schema(self, db_id: str) -> List[TableSchema]:
129
+ """
130
+ Returns a list of table schemas for this db.
131
+ Prefers `tables.json` (Spider canonical), but can fallback to SQLite if needed.
132
+ """
133
+ tables = self._tables_by_db.get(db_id, [])
134
+ if tables:
135
+ return tables
136
+
137
+ sqlite_schema = self._load_sqlite_schema(db_id)
138
+ return [TableSchema(table_name=t, columns=tuple(cols)) for t, cols in sqlite_schema.items()]
139
+
140
+ def score_tables(self, question: str, db_id: str) -> List[Tuple[float, TableSchema]]:
141
+ """
142
+ Score each table using token overlap:
143
+ - table token overlap (higher weight)
144
+ - column token overlap (lower weight)
145
+ """
146
+ q_tokens = set(_tokenize(question))
147
+ tables = self.get_schema(db_id)
148
+
149
+ scored: List[Tuple[float, TableSchema]] = []
150
+ for t in tables:
151
+ table_tokens = set(_tokenize(t.table_name))
152
+ col_tokens: set[str] = set()
153
+ for c in t.columns:
154
+ col_tokens.update(_tokenize(c))
155
+
156
+ table_overlap = len(q_tokens & table_tokens)
157
+ col_overlap = len(q_tokens & col_tokens)
158
+
159
+ # Simple weighted overlap (tuned to bias table matches).
160
+ score = 3.0 * table_overlap + 1.0 * col_overlap
161
+
162
+ # Small boost for substring mentions (helps e.g. "album" vs "albums").
163
+ q_text = _normalize_identifier(question)
164
+ if t.table_name and _normalize_identifier(t.table_name) in q_text:
165
+ score += 0.5
166
+
167
+ scored.append((score, t))
168
+
169
+ scored.sort(key=lambda x: (x[0], x[1].table_name), reverse=True)
170
+ return scored
171
+
172
+ def select_top_tables(self, question: str, db_id: str, top_k: int = 4) -> List[TableSchema]:
173
+ scored = self.score_tables(question, db_id)
174
+ if not scored:
175
+ return []
176
+ top_k = max(1, int(top_k))
177
+ selected = [t for _, t in scored[:top_k]]
178
+
179
+ # If everything scores 0, still return a stable selection.
180
+ if scored[0][0] <= 0:
181
+ tables = self.get_schema(db_id)
182
+ return tables[:top_k]
183
+
184
+ return selected
185
+
186
+ def columns_for_selected_tables(self, db_id: str, selected_tables: Iterable[TableSchema]) -> Dict[str, List[str]]:
187
+ """
188
+ Returns only columns belonging to selected tables.
189
+ Prefer SQLite columns (actual DB) if available; fallback to tables.json.
190
+ """
191
+ sqlite_schema = self._load_sqlite_schema(db_id)
192
+ out: Dict[str, List[str]] = {}
193
+ for t in selected_tables:
194
+ if t.table_name in sqlite_schema and sqlite_schema[t.table_name]:
195
+ out[t.table_name] = sqlite_schema[t.table_name]
196
+ else:
197
+ out[t.table_name] = list(t.columns)
198
+ return out
199
+
200
+ def format_relevant_schema(self, question: str, db_id: str, top_k: int = 4) -> Tuple[List[str], Dict[str, List[str]]]:
201
+ """
202
+ Returns:
203
+ - lines: ["table(col1, col2)", ...]
204
+ - selected: {table: [cols...], ...}
205
+ """
206
+ selected_tables = self.select_top_tables(question, db_id, top_k=top_k)
207
+ selected = self.columns_for_selected_tables(db_id, selected_tables)
208
+
209
+ lines: List[str] = []
210
+ for table_name, cols in selected.items():
211
+ cols_str = ", ".join(cols)
212
+ lines.append(f"{table_name}({cols_str})")
213
+
214
+ return lines, selected
215
+
src/sql_validator.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import re
3
+ from pathlib import Path
4
+
5
+ class SQLValidator:
6
+
7
+ def __init__(self, db_root):
8
+ self.db_root = Path(db_root)
9
+
10
+ # ---------------------------
11
+ # Load schema
12
+ # ---------------------------
13
+ def load_schema(self, db_id):
14
+ db_path = self.db_root / db_id / f"{db_id}.sqlite"
15
+
16
+ conn = sqlite3.connect(db_path)
17
+ cursor = conn.cursor()
18
+
19
+ tables = cursor.execute(
20
+ "SELECT name FROM sqlite_master WHERE type='table';"
21
+ ).fetchall()
22
+
23
+ schema = {}
24
+
25
+ for (table,) in tables:
26
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
27
+ schema[table.lower()] = [c[1].lower() for c in cols]
28
+
29
+ conn.close()
30
+ return schema
31
+
32
+
33
+ # ---------------------------
34
+ # Basic syntax check
35
+ # ---------------------------
36
+ def basic_structure_valid(self, sql):
37
+ s = sql.lower()
38
+
39
+ if "select" not in s or "from" not in s:
40
+ return False, "Missing SELECT or FROM"
41
+
42
+ if len(s.split()) < 4:
43
+ return False, "Too short to be SQL"
44
+
45
+ return True, None
46
+
47
+
48
+ # ---------------------------
49
+ # Extract identifiers
50
+ # ---------------------------
51
+ def extract_identifiers(self, sql):
52
+ tokens = re.findall(r"[A-Za-z_]+", sql.lower())
53
+ return set(tokens)
54
+
55
+
56
+ # ---------------------------
57
+ # Table validation
58
+ # ---------------------------
59
+ def validate_tables(self, sql, schema):
60
+ words = self.extract_identifiers(sql)
61
+ tables = set(schema.keys())
62
+
63
+ used_tables = [w for w in words if w in tables]
64
+
65
+ if not used_tables:
66
+ return False, "No valid table used"
67
+
68
+ return True, None
69
+
70
+
71
+ # ---------------------------
72
+ # Column validation
73
+ # ---------------------------
74
+ def validate_columns(self, sql, schema):
75
+ words = self.extract_identifiers(sql)
76
+
77
+ valid_columns = set()
78
+ for cols in schema.values():
79
+ valid_columns.update(cols)
80
+
81
+ # ignore SQL keywords
82
+ keywords = {
83
+ "select","from","where","join","on","group","by",
84
+ "order","limit","count","sum","avg","min","max",
85
+ "and","or","in","like","distinct","asc","desc"
86
+ }
87
+
88
+ invalid = []
89
+ for w in words:
90
+ if w not in valid_columns and w not in schema and w not in keywords:
91
+ if not w.isdigit():
92
+ invalid.append(w)
93
+
94
+ # allow small hallucinations but block many
95
+ if len(invalid) > 3:
96
+ return False, f"Too many unknown identifiers: {invalid[:5]}"
97
+
98
+ return True, None
99
+
100
+
101
+ # ---------------------------
102
+ # Dangerous query protection
103
+ # ---------------------------
104
+ def block_dangerous(self, sql):
105
+ bad = ["drop", "delete", "update", "insert", "alter"]
106
+
107
+ s = sql.lower()
108
+ for b in bad:
109
+ if b in s:
110
+ return False, f"Dangerous keyword detected: {b}"
111
+
112
+ return True, None
113
+
114
+
115
+ # ---------------------------
116
+ # Main validation
117
+ # ---------------------------
118
+ def validate(self, sql, db_id):
119
+
120
+ schema = self.load_schema(db_id)
121
+
122
+ checks = [
123
+ self.block_dangerous(sql),
124
+ self.basic_structure_valid(sql),
125
+ self.validate_tables(sql, schema),
126
+ self.validate_columns(sql, schema),
127
+ ]
128
+
129
+ for ok, msg in checks:
130
+ if not ok:
131
+ return False, msg
132
+
133
+ return True, None
src/text2sql_engine.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import sqlite3
4
+ import torch
5
+ import re
6
+ import time
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+ from src.sql_validator import SQLValidator
11
+ from src.schema_encoder import SchemaEncoder
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
15
+
16
+ # ==========================================
17
+ # UNIVERSAL STRING NORMALIZERS
18
+ # ==========================================
19
+ def normalize_question(q: str):
20
+ q = q.lower().strip()
21
+ q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
22
+ q = re.sub(r"\s+", " ", q)
23
+ return q
24
+
25
+ def semantic_fix(question, sql):
26
+ """Universal structural fixes that apply to ALL queries and ALL databases."""
27
+ q = question.lower().strip()
28
+ s = sql.lower()
29
+
30
+ # UNIVERSAL LIMIT CATCHER: Enforce LIMIT if a number is in the question
31
+ # FIXED: Removed the '?'. Now it ONLY catches numbers explicitly preceded by "show", "top", "limit", etc.
32
+ # This stops it from accidentally catching years like "2000".
33
+ num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q)
34
+ if num_match and "limit" not in s and "count(" not in s:
35
+ limit_val = num_match.group(1)
36
+ sql = sql.rstrip(";")
37
+ sql = f"{sql.strip()} LIMIT {limit_val}"
38
+
39
+ return sql
40
+
41
+
42
+ class Text2SQLEngine:
43
+ def __init__(self,
44
+ adapter_path="checkpoints/best_rlhf_model",
45
+ base_model_name="Salesforce/codet5-base",
46
+ use_lora=True):
47
+
48
+ self.device = "mps" if torch.backends.mps.is_available() else (
49
+ "cuda" if torch.cuda.is_available() else "cpu"
50
+ )
51
+
52
+ self.validator = SQLValidator(DB_ROOT)
53
+ self.schema_encoder = SchemaEncoder(DB_ROOT)
54
+ self.schema_mode = "structured"
55
+
56
+ # Security Keywords
57
+ self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
58
+
59
+ print("Loading base model...")
60
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
61
+
62
+ if not use_lora:
63
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
64
+ self.model = base.to(self.device)
65
+ self.model.eval()
66
+ print("✅ Base model ready\n")
67
+ return
68
+
69
+ adapter_path = (PROJECT_ROOT / adapter_path).resolve()
70
+
71
+ print("Loading tokenizer and LoRA adapter...")
72
+ try:
73
+ self.tokenizer = AutoTokenizer.from_pretrained(str(adapter_path), local_files_only=True)
74
+ except Exception:
75
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
76
+
77
+ self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
78
+ self.model.eval()
79
+ print("✅ RLHF model ready\n")
80
+
81
+ # ==========================================
82
+ # ---------------- PROMPT BUILDERS ---------
83
+ # ==========================================
84
+ def build_prompt(self, question, schema):
85
+ return f"""You are an expert SQL generator.
86
+ Database schema:
87
+ {schema}
88
+ Generate a valid SQLite query for the question.
89
+ Question:
90
+ {question}
91
+ SQL:
92
+ """
93
+
94
+ def build_repair_prompt(self, question, schema, bad_sql, error_msg):
95
+ # UNIVERSAL UPGRADE: Extract the hallucinated column and explicitly warn the model
96
+ hallucinated_warning = ""
97
+ col_match = re.search(r"no such column:\s*([^\s]+)", error_msg, re.IGNORECASE)
98
+ if col_match:
99
+ bad_col = col_match.group(1)
100
+ hallucinated_warning = f"\n🚨 CRITICAL ERROR: You hallucinated the column '{bad_col}'. IT DOES NOT EXIST. Look at the schema and find the actual column name (it might be spelled differently or be a synonym like 'details', 'desc', or have a typo)."
101
+
102
+ return f"""You are an expert SQL generator.
103
+ Database schema:
104
+ {schema}
105
+
106
+ You generated this incorrect SQL for the question "{question}":
107
+ {bad_sql}
108
+
109
+ Execution failed with this SQLite error:
110
+ {error_msg}{hallucinated_warning}
111
+
112
+ UNIVERSAL RULES TO FIX THIS:
113
+ 1. NEVER invent or guess column names. Use ONLY the exact table and column names listed in the schema above.
114
+ 2. Watch out for typos in the database schema! If you need 'assessment', look for 'asessment'. If you need 'name', look for 'details'.
115
+ 3. If the error is "no such column", you either hallucinated the name, or you forgot an INNER JOIN. Check the schema and fix it.
116
+ 4. If the query requires a COUNT() but also selects names, ensure you added a GROUP BY.
117
+
118
+ Write the corrected SQLite SQL query.
119
+ SQL:
120
+ """
121
+
122
+ def get_schema(self, db_id):
123
+ return self.schema_encoder.structured_schema(db_id)
124
+
125
+ # ==========================================
126
+ # ---------------- SQL POSTPROCESS ---------
127
+ # ==========================================
128
+ def extract_sql(self, text: str):
129
+ text = text.strip()
130
+ if "SQL:" in text:
131
+ text = text.split("SQL:")[-1]
132
+ match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
133
+ if match:
134
+ text = match.group(0)
135
+ return text.split(";")[0].strip()
136
+
137
+ def clean_sql(self, sql: str):
138
+ sql = sql.replace('"', "'")
139
+ sql = re.sub(r"\s+", " ", sql)
140
+ return sql.strip()
141
+
142
+ def repair_logic(self, question, sql):
143
+ """Universal logical repairs (like missing NOT NULL for negation)"""
144
+ q = question.lower()
145
+ s = sql.lower()
146
+
147
+ # Universal Negation Auto-Joiner
148
+ if any(word in q for word in ["never", "no ", "without"]):
149
+ m = re.search(r"from\s+(\w+).*join\s+(\w+)", s)
150
+ if m:
151
+ left, right = m.group(1), m.group(2)
152
+ key = re.search(r"on\s+(\w+\.\w+)\s*=\s*(\w+\.\w+)", s)
153
+ if key:
154
+ sql = f"SELECT {left}.* FROM {left} LEFT JOIN {right} ON {key.group(1)} = {key.group(2)} WHERE {key.group(2)} IS NULL"
155
+
156
+ # Universal LIKE wildcard injection
157
+ if any(w in q for w in ["contain", "with", "include"]):
158
+ sql = re.sub(r"=\s*'([^']+)'", r"LIKE '%\1%'", sql, flags=re.IGNORECASE)
159
+
160
+ return sql
161
+
162
+ # ==========================================
163
+ # ---------------- GENERATE ----------------
164
+ # ==========================================
165
+ def generate_sql(self, prompt, is_repair=False):
166
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
167
+
168
+ # FIXED: Dynamic Generation Parameters (No more terminal warnings)
169
+ gen_kwargs = {
170
+ "max_new_tokens": 128,
171
+ }
172
+
173
+ if is_repair:
174
+ # If the model failed, it needs to think differently.
175
+ # We turn off rigid beam search and introduce sampling so it doesn't repeat the exact same broken SQL.
176
+ gen_kwargs["do_sample"] = True
177
+ gen_kwargs["temperature"] = 0.5
178
+ gen_kwargs["top_p"] = 0.9
179
+ else:
180
+ # First attempt is strictly deterministic for maximum benchmark accuracy
181
+ gen_kwargs["num_beams"] = 5
182
+ gen_kwargs["do_sample"] = False
183
+ gen_kwargs["early_stopping"] = True # <--- Moved here so it doesn't clash with sampling!
184
+
185
+ with torch.no_grad():
186
+ outputs = self.model.generate(**inputs, **gen_kwargs)
187
+
188
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
189
+ return self.clean_sql(self.extract_sql(decoded))
190
+
191
+ # ==========================================
192
+ # ---------------- EXECUTE -----------------
193
+ # ==========================================
194
+ def execute_sql(self, question, sql, db_id):
195
+
196
+ # 🛡️ DEFENSE LAYER 2: Block Execution of Malicious SQL
197
+ if re.search(self.dml_keywords, sql, re.IGNORECASE):
198
+ return sql, [], [], "❌ Security Alert: Malicious DML/DDL SQL syntax blocked."
199
+
200
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
201
+
202
+ sql = self.repair_logic(question, sql)
203
+ sql = self.clean_sql(sql)
204
+ sql = semantic_fix(question, sql)
205
+
206
+ is_valid, reason = self.validator.validate(sql, db_id)
207
+ if not is_valid:
208
+ return sql, [], [], f"Blocked unsafe SQL: {reason}"
209
+
210
+ try:
211
+ conn = sqlite3.connect(db_path)
212
+ start_time = time.monotonic()
213
+ def timeout_handler():
214
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
215
+ conn.set_progress_handler(timeout_handler, 10000)
216
+
217
+ cursor = conn.cursor()
218
+ cursor.execute(sql)
219
+ rows = cursor.fetchall()
220
+ columns = [d[0] for d in cursor.description] if cursor.description else []
221
+ conn.close()
222
+
223
+ return sql, columns, rows, None
224
+
225
+ except Exception as e:
226
+ return sql, [], [], str(e)
227
+
228
+ # ==========================================
229
+ # ---------------- PIPELINE ----------------
230
+ # ==========================================
231
+ def ask(self, question, db_id):
232
+ question = normalize_question(question)
233
+
234
+ # 🛡️ DEFENSE LAYER 1: Block Malicious Natural Language Intent Early
235
+ if re.search(self.dml_keywords, question, re.IGNORECASE):
236
+ return {
237
+ "question": question,
238
+ "sql": "-- BLOCKED",
239
+ "columns": [],
240
+ "rows": [],
241
+ "error": "❌ Security Alert: Malicious intent (DELETE/DROP/UPDATE) detected in the prompt."
242
+ }
243
+
244
+ # 1. First Pass Generation
245
+ schema = self.get_schema(db_id)
246
+ prompt = self.build_prompt(question, schema)
247
+
248
+ # is_repair=False -> Uses strict Beam Search
249
+ raw_sql = self.generate_sql(prompt, is_repair=False)
250
+
251
+ # 2. First Execution Attempt
252
+ final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
253
+
254
+ # 🤖 3. UNIVERSAL AGENTIC SELF-CORRECTION LOOP
255
+ if error and "Security Alert" not in error:
256
+ print(f"\n Caught SQLite Error: {error}")
257
+ print(f" Triggering Stochastic LLM Self-Correction...")
258
+
259
+ # Feed the explicit error instructions back to the LLM
260
+ repair_prompt = self.build_repair_prompt(question, schema, final_sql, error)
261
+
262
+ # is_repair=True -> Uses Temperature Sampling to break out of hallucination loops
263
+ repaired_sql = self.generate_sql(repair_prompt, is_repair=True)
264
+
265
+ # Try executing the repaired SQL
266
+ final_sql, cols, rows, error = self.execute_sql(question, repaired_sql, db_id)
267
+
268
+ if not error:
269
+ print("✅ Universal Agent successfully self-corrected the query!")
270
+ else:
271
+ print("❌ Model failed self-correction.")
272
+
273
+ return {
274
+ "question": question,
275
+ "sql": final_sql,
276
+ "columns": cols,
277
+ "rows": rows,
278
+ "error": error
279
+ }
280
+
281
+ _engine = None
282
+ def get_engine():
283
+ global _engine
284
+ if _engine is None:
285
+ _engine = Text2SQLEngine()
286
+ return _engine
src/tokenize_dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ from transformers import T5Tokenizer
3
+ import pandas as pd
4
+
5
+ print("Loading processed dataset...")
6
+ train = pd.read_csv("../data/processed/train.csv")
7
+ val = pd.read_csv("../data/processed/validation.csv")
8
+
9
+ # remove hidden pandas index column if exists
10
+ train = train.drop(columns=[c for c in train.columns if "index" in c.lower()], errors="ignore")
11
+ val = val.drop(columns=[c for c in val.columns if "index" in c.lower()], errors="ignore")
12
+
13
+ print("Loading tokenizer (t5-small)...")
14
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
15
+
16
+ SQL_PREFIX = "translate English to SQL: "
17
+
18
+ # ------------------------------------------------------
19
+ # TOKENIZATION FUNCTION
20
+ # ------------------------------------------------------
21
+ def tokenize(example):
22
+
23
+ # input = schema + question
24
+ input_text = SQL_PREFIX + example["input"]
25
+
26
+ # target = real SQL
27
+ target_sql = example["sql"]
28
+
29
+ model_inputs = tokenizer(
30
+ input_text,
31
+ text_target=target_sql,
32
+ max_length=256,
33
+ padding="max_length",
34
+ truncation=True
35
+ )
36
+
37
+ return model_inputs
38
+
39
+
40
+ # ------------------------------------------------------
41
+ # DATASET CONVERSION
42
+ # ------------------------------------------------------
43
+ print("Preparing dataset...")
44
+ train_ds = Dataset.from_pandas(train)
45
+ val_ds = Dataset.from_pandas(val)
46
+
47
+ print("Tokenizing train...")
48
+ train_ds = train_ds.map(tokenize, remove_columns=train_ds.column_names)
49
+
50
+ print("Tokenizing validation...")
51
+ val_ds = val_ds.map(tokenize, remove_columns=val_ds.column_names)
52
+
53
+ # save tokenized dataset
54
+ train_ds.save_to_disk("../data/tokenized/train")
55
+ val_ds.save_to_disk("../data/tokenized/validation")
56
+
57
+ print("DONE ✔ Tokenized dataset saved correctly")
src/train_model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_from_disk
3
+ from transformers import (
4
+ T5ForConditionalGeneration,
5
+ T5Tokenizer,
6
+ DataCollatorForSeq2Seq,
7
+ Seq2SeqTrainer,
8
+ Seq2SeqTrainingArguments
9
+ )
10
+
11
+ # ======================================================
12
+ # DEVICE (Mac M1/M2/M3 Safe)
13
+ # ======================================================
14
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
15
+ print("Using device:", device)
16
+
17
+ # ======================================================
18
+ # LOAD TOKENIZED DATASET (FIXED PATHS)
19
+ # ======================================================
20
+ print("Loading tokenized dataset...")
21
+
22
+ train_dataset = load_from_disk("data/tokenized/train")
23
+ val_dataset = load_from_disk("data/tokenized/validation")
24
+
25
+ print("Train size:", len(train_dataset))
26
+ print("Validation size:", len(val_dataset))
27
+
28
+ # ======================================================
29
+ # LOAD MODEL
30
+ # ======================================================
31
+ print("Loading model (t5-small)...")
32
+
33
+ model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
34
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
35
+
36
+ # Prevent Mac memory crash
37
+ model.config.use_cache = False
38
+
39
+ # Important T5 settings (prevents generation bugs)
40
+ model.config.decoder_start_token_id = tokenizer.pad_token_id
41
+ model.config.eos_token_id = tokenizer.eos_token_id
42
+ model.config.pad_token_id = tokenizer.pad_token_id
43
+
44
+ # ======================================================
45
+ # DATA COLLATOR
46
+ # ======================================================
47
+ data_collator = DataCollatorForSeq2Seq(
48
+ tokenizer=tokenizer,
49
+ model=model
50
+ )
51
+
52
+ # ======================================================
53
+ # TRAINING ARGUMENTS (Mac Safe)
54
+ # ======================================================
55
+ print("Setting training config...")
56
+
57
+ training_args = Seq2SeqTrainingArguments(
58
+ output_dir="outputs/model",
59
+
60
+ evaluation_strategy="epoch",
61
+ save_strategy="epoch",
62
+
63
+ learning_rate=3e-4,
64
+ num_train_epochs=5,
65
+
66
+ per_device_train_batch_size=1,
67
+ per_device_eval_batch_size=1,
68
+ gradient_accumulation_steps=8,
69
+
70
+ logging_steps=50,
71
+
72
+ fp16=False,
73
+ bf16=False,
74
+ dataloader_pin_memory=False,
75
+
76
+ predict_with_generate=True,
77
+ report_to="none"
78
+ )
79
+
80
+ # ======================================================
81
+ # TRAINER
82
+ # ======================================================
83
+ trainer = Seq2SeqTrainer(
84
+ model=model,
85
+ args=training_args,
86
+ train_dataset=train_dataset,
87
+ eval_dataset=val_dataset,
88
+ tokenizer=tokenizer,
89
+ data_collator=data_collator,
90
+ )
91
+
92
+ # ======================================================
93
+ # TRAIN
94
+ # ======================================================
95
+ print("Training started 🚀")
96
+ trainer.train()
97
+
98
+ # ======================================================
99
+ # SAVE MODEL
100
+ # ======================================================
101
+ print("Saving model...")
102
+ trainer.save_model("outputs/model")
103
+ tokenizer.save_pretrained("outputs/model")
104
+
105
+ print("\nDONE ✔ Base model trained successfully")
src/train_rl.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION)
3
+ # =========================================================
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from datasets import load_dataset
8
+ from transformers import AutoTokenizer
9
+ from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
10
+ from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
11
+ from peft import PeftModel
12
+ import os, sys, sqlite3, re, random
13
+
14
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
15
+ from execution_reward import execution_reward, extract_tables, extract_columns
16
+ try:
17
+ import sqlparse # gate PPO updates on parsable SQL only
18
+ except Exception: # pragma: no cover
19
+ sqlparse = None
20
+
21
+
22
+ # ======================================================
23
+ # DEVICE
24
+ # ======================================================
25
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
26
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
27
+ print("Using device:", device)
28
+
29
+
30
+ # ======================================================
31
+ # TRAINING SETTINGS
32
+ # ======================================================
33
+ NUM_EPOCHS = 5
34
+ LOG_EVERY = 20
35
+ USE_SCHEMA = True
36
+ SCHEMA_WARMUP_EPOCHS = 0
37
+ MAX_SCHEMA_CHARS = 1500
38
+ MAX_OUTPUT_TOKENS = 80
39
+ ROLLOUTS_PER_EPOCH = 2048
40
+
41
+
42
+ # ======================================================
43
+ # PATHS
44
+ # ======================================================
45
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
46
+
47
+ # 🎯 FIXED: Save ONLY the best model to this exact path
48
+ RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "rlhf_t5_best")
49
+ output_dir = RL_MODEL_PATH
50
+
51
+ DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
52
+
53
+ # 🎯 Updated to point to our newly trained t5-small SFT model
54
+ ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_t5")
55
+
56
+ FALLBACK_ADAPTER_PATH = os.path.join(PROJECT_ROOT, "models/t5_spider_sft_lora")
57
+ FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "outputs/sft_text2sql")
58
+ # 🎯 ENSURING t5-small is used
59
+ BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
60
+
61
+
62
+ # ======================================================
63
+ # LOAD MODEL (LoRA)
64
+ # ======================================================
65
+ print("Loading base:", BASE_MODEL)
66
+ if not os.path.isdir(ADAPTER_PATH):
67
+ if os.path.isdir(FALLBACK_ADAPTER_PATH):
68
+ ADAPTER_PATH = FALLBACK_ADAPTER_PATH
69
+ elif os.path.isdir(FALLBACK_ADAPTER_PATH_2):
70
+ ADAPTER_PATH = FALLBACK_ADAPTER_PATH_2
71
+ print("Loading adapters:", ADAPTER_PATH)
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
74
+
75
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
76
+ model.pretrained_model = PeftModel.from_pretrained(model.pretrained_model, ADAPTER_PATH)
77
+
78
+ ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
79
+ ref_model.pretrained_model = PeftModel.from_pretrained(ref_model.pretrained_model, ADAPTER_PATH)
80
+
81
+ ref_model.eval()
82
+ for p in ref_model.parameters():
83
+ p.requires_grad_(False)
84
+
85
+ # Freeze base transformer weights; train LoRA adapters + value head.
86
+ for name, p in model.named_parameters():
87
+ # Train value head
88
+ if name.startswith("v_head"):
89
+ p.requires_grad = True
90
+ # Train LoRA adapters (policy learning!)
91
+ elif "lora_" in name:
92
+ p.requires_grad = True
93
+ # Freeze base model
94
+ else:
95
+ p.requires_grad = False
96
+
97
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
98
+ total = sum(p.numel() for p in model.parameters())
99
+ print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
100
+
101
+ model.config.use_cache = False
102
+ ref_model.config.use_cache = False
103
+
104
+ if tokenizer.pad_token_id is None:
105
+ tokenizer.pad_token = tokenizer.eos_token
106
+
107
+
108
+ # ======================================================
109
+ # DATASET
110
+ # ======================================================
111
+ print("Loading Spider subset...")
112
+ random.seed(0)
113
+
114
+ # Train on a small, stable curriculum of DBs first.
115
+ TRAIN_DBS = [
116
+ "flight_1",
117
+ "student_assessment",
118
+ "store_1",
119
+ "bike_1",
120
+ "book_2",
121
+ "chinook_1",
122
+ ]
123
+
124
+ dataset = load_dataset("spider", split="train")
125
+ _TRAIN_DBS_SET = set(TRAIN_DBS)
126
+ dataset = dataset.filter(lambda x: x["db_id"] in _TRAIN_DBS_SET)
127
+ dataset = dataset.select(range(min(800, len(dataset))))
128
+
129
+ print("Using RLHF DBs:", TRAIN_DBS)
130
+ print("Filtered size:", len(dataset))
131
+
132
+ total_steps = ROLLOUTS_PER_EPOCH
133
+
134
+ # ======================================================
135
+ # DB UTILITIES
136
+ # ======================================================
137
+ def get_db_path(db_id):
138
+ return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
139
+
140
+
141
+ def get_db_schema(db_path):
142
+ schema_text = ""
143
+ try:
144
+ conn = sqlite3.connect(db_path)
145
+ cursor = conn.cursor()
146
+
147
+ tables = cursor.execute(
148
+ "SELECT name FROM sqlite_master WHERE type='table';"
149
+ ).fetchall()
150
+
151
+ for table in tables:
152
+ table_name = table[0]
153
+ columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
154
+ col_names = [col[1] for col in columns]
155
+ schema_text += f"{table_name}({', '.join(col_names)}) "
156
+
157
+ conn.close()
158
+ except:
159
+ pass
160
+
161
+ return schema_text
162
+
163
+
164
+ # ======================================================
165
+ # PROMPT
166
+ # ======================================================
167
+ PREFIX = "translate English to SQL:"
168
+
169
+
170
+ def trim_schema(schema: str, max_chars: int = 1200) -> str:
171
+ if schema is None:
172
+ return ""
173
+ schema = str(schema)
174
+ if len(schema) <= max_chars:
175
+ return schema
176
+ return schema[:max_chars]
177
+ def build_prompt(question: str, schema: str, use_schema: bool) -> str:
178
+ if not use_schema:
179
+ return f"{PREFIX}\n\nQuestion:\n{question}\n\nSQL:"
180
+
181
+ schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
182
+ return f"{PREFIX}\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
183
+
184
+ def encode_prompt(question, schema, use_schema):
185
+ # Never truncate the question; only truncate schema tokens if needed.
186
+ if not use_schema:
187
+ prompt = build_prompt(question, schema, use_schema=False)
188
+ return tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0].to(device)
189
+
190
+ schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
191
+ prefix_schema = f"{PREFIX}\n\nSchema:\n"
192
+ mid = "\n\nQuestion:\n"
193
+ suffix = f"{question}\n\nSQL:"
194
+
195
+ prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
196
+ schema_ids = tokenizer.encode(schema, add_special_tokens=False)
197
+ mid_ids = tokenizer.encode(mid, add_special_tokens=False)
198
+ suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
199
+
200
+ max_len = getattr(tokenizer, "model_max_length", 512)
201
+ eos_id = tokenizer.eos_token_id
202
+ max_without_eos = max_len - (1 if eos_id is not None else 0)
203
+
204
+ # Ensure the question+SQL suffix always fits; truncate schema first.
205
+ fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
206
+ if fixed_len > max_without_eos:
207
+ # Extremely rare; clip the suffix (question) only if unavoidable.
208
+ keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
209
+ suffix_ids = suffix_ids[:keep]
210
+ fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
211
+
212
+ remaining_for_schema = max_without_eos - fixed_len
213
+ if remaining_for_schema < 0:
214
+ remaining_for_schema = 0
215
+ schema_ids = schema_ids[:remaining_for_schema]
216
+
217
+ ids = prefix_ids + schema_ids + mid_ids + suffix_ids
218
+ ids = ids[:max_without_eos]
219
+ if eos_id is not None:
220
+ ids = ids + [eos_id]
221
+
222
+ return torch.tensor(ids, dtype=torch.long).to(device)
223
+
224
+
225
+ # ======================================================
226
+ # SQL CONSTRAINED DECODING
227
+ # ======================================================
228
+ SQL_KEYWORDS = [
229
+ "select", "from", "where", "join", "inner", "left", "right",
230
+ "full", "outer", "on", "group", "by", "order", "having",
231
+ "limit", "distinct", "as", "and", "or", "not", "in", "is",
232
+ "null", "like", "between", "asc", "desc", "union",
233
+ "intersect", "except",
234
+ ]
235
+
236
+ SQL_OPERATORS = ["*", ",", ".", "(", ")", "=", "<", ">", "!", "+", "-", "/", "%", "_"]
237
+
238
+
239
+ def _piece_token_str(tok: str) -> str:
240
+ # T5 SentencePiece: "▁" marks a leading space; strip it for char checks.
241
+ return tok.lstrip("▁")
242
+
243
+
244
+ def _precompute_always_allowed_token_ids():
245
+ vocab_size = len(tokenizer)
246
+ allowed = set()
247
+
248
+ # Always allow special tokens.
249
+ for tid in [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.unk_token_id]:
250
+ if tid is not None and tid >= 0:
251
+ allowed.add(int(tid))
252
+
253
+ # Allow whitespace/newlines in case they exist as pieces.
254
+ for s in [" ", "\n", "\t"]:
255
+ allowed.update(tokenizer.encode(s, add_special_tokens=False))
256
+
257
+ # Allow operators/punctuation/numeric pieces broadly.
258
+ op_chars = set("".join(SQL_OPERATORS))
259
+ for tid in range(vocab_size):
260
+ tok = tokenizer.convert_ids_to_tokens(tid)
261
+ if not isinstance(tok, str) or not tok:
262
+ continue
263
+ piece = _piece_token_str(tok)
264
+ if not piece:
265
+ continue
266
+ if all((ch in op_chars) for ch in piece):
267
+ allowed.add(tid)
268
+ continue
269
+ if piece.isdigit():
270
+ allowed.add(tid)
271
+ continue
272
+ # Common numeric fragments like "1", "00", etc.
273
+ if all(ch.isdigit() for ch in piece):
274
+ allowed.add(tid)
275
+
276
+ # Allow keyword pieces.
277
+ for kw in SQL_KEYWORDS:
278
+ for variant in {kw, kw.upper(), kw.capitalize()}:
279
+ allowed.update(tokenizer.encode(" " + variant, add_special_tokens=False))
280
+ allowed.update(tokenizer.encode(variant, add_special_tokens=False))
281
+
282
+ return allowed
283
+
284
+
285
+ ALWAYS_ALLOWED_TOKEN_IDS = _precompute_always_allowed_token_ids()
286
+
287
+
288
+ def _schema_allowed_token_ids(table_names, column_names):
289
+ allowed = set(ALWAYS_ALLOWED_TOKEN_IDS)
290
+
291
+ def _add_identifier(name: str):
292
+ if not name:
293
+ return
294
+ # Add whole identifier and common splits.
295
+ variants = {name, name.lower(), name.upper()}
296
+ parts = re.split(r"[_\s]+", name)
297
+ variants.update({p for p in parts if p})
298
+ for v in variants:
299
+ allowed.update(tokenizer.encode(" " + v, add_special_tokens=False))
300
+ allowed.update(tokenizer.encode(v, add_special_tokens=False))
301
+
302
+ for t in table_names:
303
+ _add_identifier(t)
304
+ for c in column_names:
305
+ _add_identifier(c)
306
+
307
+ return allowed
308
+
309
+
310
+ class SQLVocabularyLogitsProcessor(LogitsProcessor):
311
+ def __init__(self, allowed_token_ids):
312
+ self.allowed_token_ids = {int(i) for i in allowed_token_ids if int(i) >= 0}
313
+ self._bias = None
314
+ self._bias_vocab_size = None
315
+
316
+ def _get_bias(self, scores: torch.Tensor) -> torch.Tensor:
317
+ vocab_size = int(scores.shape[-1])
318
+ if (
319
+ self._bias is None
320
+ or self._bias.device != scores.device
321
+ or self._bias.dtype != scores.dtype
322
+ or self._bias_vocab_size != vocab_size
323
+ ):
324
+ bias = torch.full((vocab_size,), float("-inf"), device=scores.device, dtype=scores.dtype)
325
+ for tid in self.allowed_token_ids:
326
+ if tid < vocab_size:
327
+ bias[tid] = 0.0
328
+ self._bias = bias
329
+ self._bias_vocab_size = vocab_size
330
+ return self._bias
331
+
332
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
333
+ return scores + self._get_bias(scores)
334
+
335
+
336
+ _DB_VOCAB_CACHE = {}
337
+
338
+
339
+ def get_db_tables_columns(db_path: str):
340
+ if db_path in _DB_VOCAB_CACHE:
341
+ return _DB_VOCAB_CACHE[db_path]
342
+ tables, cols = [], []
343
+ try:
344
+ conn = sqlite3.connect(db_path)
345
+ cur = conn.cursor()
346
+ for (tname,) in cur.execute(
347
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
348
+ ).fetchall():
349
+ if not tname:
350
+ continue
351
+ tables.append(tname)
352
+ try:
353
+ for row in cur.execute(f'PRAGMA table_info("{tname}")').fetchall():
354
+ if row and isinstance(row[1], str):
355
+ cols.append(row[1])
356
+ except Exception:
357
+ continue
358
+ conn.close()
359
+ except Exception:
360
+ pass
361
+ _DB_VOCAB_CACHE[db_path] = (tables, cols)
362
+ return tables, cols
363
+
364
+
365
+ # ======================================================
366
+ # PPO CONFIG (stable learning)
367
+ # ======================================================
368
+ ppo_config = PPOConfig(
369
+ learning_rate=2e-5, # was 1e-6 → model could not move
370
+ batch_size=8, # better gradient estimate
371
+ mini_batch_size=2,
372
+ gradient_accumulation_steps=2, # stable updates on small data
373
+ ppo_epochs=1,
374
+
375
+ # --- KL control (MOST IMPORTANT FIX) ---
376
+ init_kl_coef=0.05, # reduce punishment
377
+ target_kl=0.15, # relax constraint to avoid skipped updates
378
+ adap_kl_ctrl=True,
379
+
380
+ # --- stability ---
381
+ cliprange=0.25,
382
+ cliprange_value=0.25,
383
+ whiten_rewards=True,
384
+ kl_penalty="kl",
385
+ max_grad_norm=1.0,
386
+ )
387
+ trainer = PPOTrainer(
388
+ config=ppo_config,
389
+ model=model,
390
+ ref_model=ref_model,
391
+ tokenizer=tokenizer,
392
+ )
393
+ optimizer = trainer.optimizer
394
+
395
+ # Provide `.device` attribute for the supervised anchor helper.
396
+ try:
397
+ model.device = torch.device(device)
398
+ except Exception:
399
+ pass
400
+
401
+
402
+ # ======================================================
403
+ # GENERATION (schema-constrained decoding)
404
+ # ======================================================
405
+ generation_kwargs = dict(
406
+ max_new_tokens=64, # 128 causes garbage SQL loops
407
+
408
+ do_sample=True,
409
+ temperature=0.9, # encourage exploration
410
+ top_p=0.95,
411
+ top_k=100,
412
+
413
+ repetition_penalty=1.1, # prevents SELECT SELECT SELECT
414
+ no_repeat_ngram_size=3,
415
+
416
+ num_beams=1,
417
+ pad_token_id=tokenizer.pad_token_id,
418
+ eos_token_id=tokenizer.eos_token_id,
419
+ )
420
+ # ======================================================
421
+ # TRAIN LOOP
422
+ # ======================================================
423
+ print("Starting RL training 🚀")
424
+
425
+ query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
426
+ query_text_buffer = []
427
+
428
+ best_reward = -999999
429
+ best_epoch = -1
430
+
431
+ def _is_parsable_sql(sql: str) -> bool:
432
+ s = (sql or "").strip()
433
+ if not s:
434
+ return False
435
+ up = s.upper()
436
+ if "SELECT" not in up or "FROM" not in up:
437
+ return False
438
+ if sqlparse is None:
439
+ return True
440
+ try:
441
+ return bool(sqlparse.parse(s))
442
+ except Exception:
443
+ return False
444
+
445
+
446
+ def _pad_2d(seqs, pad_id: int):
447
+ max_len = max(int(s.numel()) for s in seqs)
448
+ out = torch.full((len(seqs), max_len), int(pad_id), dtype=torch.long, device=device)
449
+ attn = torch.zeros((len(seqs), max_len), dtype=torch.long, device=device)
450
+ for i, s in enumerate(seqs):
451
+ n = int(s.numel())
452
+ out[i, :n] = s.to(device)
453
+ attn[i, :n] = 1
454
+ return out, attn
455
+
456
+
457
+ def _shift_right(labels: torch.Tensor, start_id: int) -> torch.Tensor:
458
+ dec = labels.clone()
459
+ dec[:, 1:] = labels[:, :-1]
460
+ dec[:, 0] = int(start_id)
461
+ return dec
462
+
463
+
464
+ def safe_get_kl(stats):
465
+ if not isinstance(stats, dict):
466
+ return None
467
+ for k in stats.keys():
468
+ if "kl" in str(k).lower():
469
+ v = stats[k]
470
+ try:
471
+ return float(v.item() if hasattr(v, "item") else v)
472
+ except Exception:
473
+ return None
474
+ return None
475
+
476
+ def supervised_anchor_step(model, tokenizer, queries, gold_sqls, weight=0.05):
477
+ model.train()
478
+ total_loss = 0.0
479
+
480
+ for q, gold in zip(queries, gold_sqls):
481
+
482
+ enc = tokenizer(q, return_tensors="pt", truncation=True).to(model.device)
483
+ dec = tokenizer(text_target=gold, return_tensors="pt", truncation=True)
484
+
485
+ labels = dec.input_ids.to(model.device)
486
+
487
+ # teacher forcing shift
488
+ decoder_input_ids = labels[:, :-1].contiguous()
489
+ target_ids = labels[:, 1:].contiguous()
490
+
491
+ outputs = model(
492
+ input_ids=enc.input_ids,
493
+ attention_mask=enc.attention_mask,
494
+ decoder_input_ids=decoder_input_ids,
495
+ )
496
+
497
+ logits = outputs[0]
498
+
499
+ vocab_size = logits.size(-1)
500
+ loss = F.cross_entropy(
501
+ logits.view(-1, vocab_size),
502
+ target_ids.view(-1),
503
+ ignore_index=tokenizer.pad_token_id,
504
+ )
505
+
506
+ (loss * weight).backward()
507
+ total_loss += loss.item()
508
+
509
+ return total_loss
510
+
511
+
512
+ @torch.no_grad()
513
+ def _estimate_policy_entropy(query_tensors, response_tensors) -> torch.Tensor:
514
+ """
515
+ Returns per-sample average token entropy of the policy on the sampled response tokens.
516
+ Used as a small bonus to reduce repetition collapse.
517
+ """
518
+ pad_id = int(tokenizer.pad_token_id)
519
+ enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
520
+ dec_ids, dec_attn = _pad_2d(response_tensors, pad_id)
521
+
522
+ start_id = int(getattr(model.pretrained_model.config, "decoder_start_token_id", pad_id))
523
+ dec_inp = _shift_right(dec_ids, start_id)
524
+
525
+ out = model.pretrained_model(
526
+ input_ids=enc_ids,
527
+ attention_mask=enc_attn,
528
+ decoder_input_ids=dec_inp,
529
+ use_cache=False,
530
+ )
531
+ logp = torch.log_softmax(out.logits, dim=-1)
532
+ p = torch.exp(logp)
533
+ ent = -(p * logp).sum(dim=-1) # [B, T]
534
+ # average only over non-pad positions of the sampled response
535
+ denom = dec_attn.sum(dim=-1).clamp_min(1)
536
+ return (ent * dec_attn).sum(dim=-1) / denom # [B]
537
+
538
+
539
+ def _repeat_penalty(response_tensor: torch.Tensor) -> float:
540
+ """
541
+ Penalize repetition to avoid 'SELECT SELECT SELECT' collapse.
542
+ Simple heuristic: consecutive duplicate token ratio + low-unique-token ratio.
543
+ """
544
+ ids = response_tensor.detach().tolist()
545
+ n = len(ids)
546
+ if n <= 1:
547
+ return 0.0
548
+ consec_dup = 0
549
+ for i in range(1, n):
550
+ if ids[i] == ids[i - 1]:
551
+ consec_dup += 1
552
+ unique_ratio = len(set(ids)) / n
553
+ consec_ratio = consec_dup / (n - 1)
554
+ # Higher penalty when low unique + high consecutive duplicates
555
+ return float(0.5 * consec_ratio + 0.5 * (1.0 - unique_ratio))
556
+
557
+
558
+ def _supervised_anchor_step(query_tensors, gold_sql_texts, weight: float = 0.05) -> None:
559
+ """
560
+ Small teacher-forcing step on gold SQL to anchor grammar during PPO.
561
+ Runs only if PPOTrainer exposes (accelerator, optimizer).
562
+ """
563
+ if not gold_sql_texts:
564
+ return
565
+ accelerator = getattr(trainer, "accelerator", None)
566
+ optimizer = getattr(trainer, "optimizer", None)
567
+ if accelerator is None or optimizer is None:
568
+ return
569
+
570
+ pad_id = int(tokenizer.pad_token_id)
571
+ enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
572
+
573
+ # Tokenize gold SQL targets (decoder side)
574
+ gold_ids = []
575
+ for s in gold_sql_texts:
576
+ g = (s or "").strip()
577
+ if not g:
578
+ g = "SELECT 1"
579
+ ids = tokenizer.encode(g, add_special_tokens=False)[:256]
580
+ if tokenizer.eos_token_id is not None:
581
+ ids = ids + [int(tokenizer.eos_token_id)]
582
+ gold_ids.append(torch.tensor(ids, dtype=torch.long))
583
+
584
+ dec_ids, dec_attn = _pad_2d(gold_ids, pad_id)
585
+ labels = dec_ids.clone()
586
+ labels[dec_attn == 0] = -100
587
+
588
+ # PEFT model forward supports labels -> returns loss
589
+ out = model.pretrained_model(
590
+ input_ids=enc_ids,
591
+ attention_mask=enc_attn,
592
+ labels=labels,
593
+ use_cache=False,
594
+ )
595
+ loss = out.loss * float(weight)
596
+
597
+ optimizer.zero_grad(set_to_none=True) if hasattr(optimizer, "zero_grad") else None
598
+ accelerator.backward(loss)
599
+ optimizer.step()
600
+
601
+
602
+ def _curriculum_allows(gold_sql: str, epoch_num: int) -> bool:
603
+ gold_up = (gold_sql or "").upper()
604
+ has_join = "JOIN" in gold_up
605
+ has_set_op = any(op in gold_up for op in ["UNION", "INTERSECT", "EXCEPT"])
606
+ tables = extract_tables(gold_sql)
607
+ single_table = len(tables) <= 1 and (not has_join)
608
+
609
+ # Epoch 1: only single-table, no joins/set-ops.
610
+ if epoch_num == 1:
611
+ return single_table and (not has_set_op)
612
+ # Epoch 2: allow joins, but still avoid set-ops.
613
+ if epoch_num == 2:
614
+ return (single_table or has_join) and (not has_set_op)
615
+ # Epoch 3+: full dataset.
616
+ return True
617
+
618
+
619
+ for epoch in range(1, NUM_EPOCHS + 1):
620
+
621
+ use_schema_this_epoch = USE_SCHEMA and (epoch > SCHEMA_WARMUP_EPOCHS)
622
+
623
+ epoch_reward_sum = 0
624
+ negative_rewards = 0
625
+ partial_rewards = 0
626
+ correct_rewards = 0
627
+
628
+ total_considered = 0
629
+ valid_sql_count = 0
630
+ exec_correct_count = 0
631
+ table_overlap_sum = 0.0
632
+ column_overlap_sum = 0.0
633
+ kl_values = []
634
+
635
+ for step in range(1, total_steps + 1):
636
+
637
+ example = dataset[random.randrange(len(dataset))]
638
+
639
+ question = example["question"]
640
+ gold_sql = example["query"]
641
+ db_id = example["db_id"]
642
+ db_path = get_db_path(db_id)
643
+
644
+ # NOTE: sampling-with-replacement provides more rollouts per epoch.
645
+
646
+ schema = get_db_schema(db_path)
647
+ question_text = build_prompt(question, schema, use_schema_this_epoch)
648
+ query_tensor = encode_prompt(question, schema, use_schema_this_epoch)
649
+
650
+ # ----- generate -----
651
+ table_names, column_names = get_db_tables_columns(db_path)
652
+ allowed_ids = _schema_allowed_token_ids(table_names, column_names)
653
+ logits_processor = LogitsProcessorList([SQLVocabularyLogitsProcessor(allowed_ids)])
654
+
655
+ response = trainer.generate([query_tensor], logits_processor=logits_processor, **generation_kwargs)[0]
656
+ response_tensor = response.squeeze(0)[:MAX_OUTPUT_TOKENS]
657
+
658
+ pred_sql = tokenizer.decode(response_tensor.cpu(), skip_special_tokens=True)
659
+
660
+ total_considered += 1
661
+
662
+ # PPO must optimize ONLY when SQL parses successfully.
663
+ if not _is_parsable_sql(pred_sql):
664
+ negative_rewards += 1
665
+ continue
666
+
667
+ # Reject generations shorter than 6 tokens.
668
+ if int(response_tensor.numel()) < 6:
669
+ negative_rewards += 1
670
+ continue
671
+
672
+ # ----- reward -----
673
+ reward_value = execution_reward(pred_sql, db_path, gold_sql)
674
+
675
+ # SQL validity gate: if invalid/unparsable -> reward_value is None -> skip PPO entirely.
676
+ if reward_value is None:
677
+ if step % 100 == 0:
678
+ ratio = valid_sql_count / max(total_considered, 1)
679
+ print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
680
+ if ratio < 0.15:
681
+ print("MODEL COLLAPSING")
682
+ continue
683
+
684
+ # Clip rewards to [-1, 1]
685
+ reward_value = float(max(-1.0, min(1.0, reward_value)))
686
+ # Penalize repetition in decoded output (token-level heuristic).
687
+ reward_value = float(max(-1.0, min(1.0, reward_value - 0.2 * _repeat_penalty(response_tensor))))
688
+ # Keep rewards on CPU for normalization; move to device only for trainer.step().
689
+ reward_tensor = torch.tensor(reward_value, dtype=torch.float32)
690
+
691
+ epoch_reward_sum += reward_value
692
+
693
+ # ----- metrics -----
694
+ # "Valid sample" means reward is not None (parsable SQL).
695
+ valid_sql_count += 1
696
+
697
+ pred_tables = extract_tables(pred_sql)
698
+ gold_tables = extract_tables(gold_sql)
699
+ pred_cols = extract_columns(pred_sql)
700
+ gold_cols = extract_columns(gold_sql)
701
+
702
+ if len(gold_tables) > 0:
703
+ table_overlap_sum += len(pred_tables & gold_tables) / max(len(gold_tables), 1)
704
+ if len(gold_cols) > 0:
705
+ column_overlap_sum += len(pred_cols & gold_cols) / max(len(gold_cols), 1)
706
+
707
+ # execution_reward returns 1.0 for correct execution result.
708
+ if reward_value >= 1.0:
709
+ exec_correct_count += 1
710
+
711
+ if reward_value <= -1.0:
712
+ negative_rewards += 1
713
+ elif reward_value >= 1.0:
714
+ correct_rewards += 1
715
+ else:
716
+ partial_rewards += 1
717
+
718
+ # Train only on informative samples:
719
+ # - invalid SQL already skipped (reward is None)
720
+ # - very small magnitude signal skipped
721
+ if abs(reward_value) < 0.1:
722
+ continue
723
+
724
+ query_buffer.append(query_tensor)
725
+ response_buffer.append(response_tensor)
726
+ reward_buffer.append(reward_tensor)
727
+ gold_buffer.append(gold_sql)
728
+ query_text_buffer.append(question_text)
729
+
730
+ # ----- PPO update -----
731
+ if len(query_buffer) == ppo_config.batch_size:
732
+ # move rewards to device
733
+ reward_buffer = [r.to(device) for r in reward_buffer]
734
+
735
+ # run PPO step
736
+ stats = trainer.step(query_buffer, response_buffer, reward_buffer)
737
+
738
+ # log KL safely (no control logic)
739
+ kl = safe_get_kl(stats)
740
+ if kl is not None:
741
+ kl_values.append(kl)
742
+
743
+ # --- supervised anchor to prevent grammar collapse ---
744
+ supervised_anchor_step(model, tokenizer, query_text_buffer, gold_buffer, weight=0.05)
745
+ optimizer.step()
746
+ optimizer.zero_grad()
747
+
748
+ # reset buffers
749
+ query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
750
+ query_text_buffer = []
751
+
752
+ # ----- learning ratio logging -----
753
+ if step % 100 == 0:
754
+ ratio = valid_sql_count / max(total_considered, 1)
755
+ print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
756
+ if ratio < 0.15:
757
+ print("MODEL COLLAPSING")
758
+ # Increase KL coefficient dynamically when valid_sql_rate drops.
759
+ try:
760
+ if hasattr(trainer, "kl_ctl") and hasattr(trainer.kl_ctl, "value"):
761
+ trainer.kl_ctl.value *= 1.5
762
+ print(f"Increasing KL coef -> {trainer.kl_ctl.value:.4f}")
763
+ except Exception:
764
+ pass
765
+
766
+ # ----- logging -----
767
+ if step % LOG_EVERY == 0:
768
+ avg_reward = epoch_reward_sum / step
769
+ print("\n---------------------------")
770
+ print(f"Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{total_steps} | Avg Reward {avg_reward:.3f}")
771
+ print("DB:", db_id)
772
+ print("Q:", question)
773
+ print("SQL:", pred_sql)
774
+ print("Reward:", reward_value)
775
+
776
+ # epoch stats
777
+ print(f"\nEpoch {epoch} stats:")
778
+ print("negative:", negative_rewards)
779
+ print("partial:", partial_rewards)
780
+ print("correct:", correct_rewards)
781
+
782
+ denom = max(total_considered, 1)
783
+ print("\nEpoch metrics:")
784
+ print(f"execution_accuracy: {exec_correct_count/denom:.3f}")
785
+ print(f"valid_sql_rate: {valid_sql_count/denom:.3f}")
786
+ print(f"table_match_rate: {table_overlap_sum/denom:.3f}")
787
+ print(f"column_match_rate: {column_overlap_sum/denom:.3f}")
788
+ print(f"avg_reward: {epoch_reward_sum/max(denom,1):.3f}")
789
+ if kl_values:
790
+ avg_kl = sum(kl_values) / max(len(kl_values), 1)
791
+ print(f"avg_kl: {avg_kl:.3f}")
792
+ if avg_kl < -8:
793
+ print("\nKL collapse guard triggered (avg_kl < -8). Stopping early.")
794
+ break
795
+
796
+ # 🎯 FIXED: Removed the code that saved intermediate checkpoints at the end of each epoch
797
+
798
+ # Only save if this epoch is the best one so far
799
+ epoch_avg_reward = epoch_reward_sum / max(denom, 1)
800
+ if epoch_avg_reward > best_reward:
801
+ best_reward = epoch_avg_reward
802
+ best_epoch = epoch
803
+
804
+ print(f"\nNew best model at epoch {epoch} with reward {best_reward:.4f}")
805
+
806
+ # 🎯 FIXED: Save directly to checkpoints/rlhf_t5_best, overwriting if needed
807
+ os.makedirs(output_dir, exist_ok=True)
808
+
809
+ trainer.model.save_pretrained(output_dir)
810
+ tokenizer.save_pretrained(output_dir)
811
+
812
+
813
+ print(f"\nTraining finished.")
814
+ print(f"Best epoch: {best_epoch}")
815
+ print(f"Best reward: {best_reward:.4f}")
816
+ print(f"Best model saved at: {output_dir}")
src/train_rl_bart.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # RLHF TRAINING FOR TEXT2SQL (OPTIMIZED PPO VERSION - BART)
3
+ # =========================================================
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer
8
+ from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
9
+ from peft import PeftModel
10
+ import os, sys, sqlite3, re, random
11
+
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+ from execution_reward import execution_reward, extract_tables, extract_columns
14
+
15
+ try:
16
+ import sqlparse # gate PPO updates on parsable SQL only
17
+ except Exception: # pragma: no cover
18
+ sqlparse = None
19
+
20
+ # ======================================================
21
+ # DEVICE
22
+ # ======================================================
23
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
+ device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
25
+ print("Using device:", device)
26
+
27
+ # ======================================================
28
+ # TRAINING SETTINGS (🚀 OPTIMIZED FOR SPEED)
29
+ # ======================================================
30
+ NUM_EPOCHS = 10 # Increased to compensate for faster epochs
31
+ LOG_EVERY = 5 # Print logs much more frequently
32
+ MAX_SCHEMA_CHARS = 1500
33
+ MAX_OUTPUT_TOKENS = 48 # 🚀 Down from 64. 95% of Spider SQL is <40 tokens.
34
+ ROLLOUTS_PER_EPOCH = 256 # 🚀 Down from 1024. Epochs will finish 4x faster!
35
+
36
+ # ======================================================
37
+ # PATHS
38
+ # ======================================================
39
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
40
+ DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
41
+
42
+ # 🎯 Strict Input: Load strictly from your SFT BART checkpoint
43
+ ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_best_bart_2")
44
+
45
+ # 🎯 Strict Output: Save strictly to rl_best_bart
46
+ OUTPUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints/rl_best_bart")
47
+
48
+ BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base")
49
+
50
+ if not os.path.exists(ADAPTER_PATH):
51
+ raise RuntimeError(f"❌ No valid LoRA adapter found at: {ADAPTER_PATH}")
52
+
53
+ print("Loading base:", BASE_MODEL)
54
+ print("Loading adapter:", ADAPTER_PATH)
55
+
56
+ # ======================================================
57
+ # TOKENIZER
58
+ # ======================================================
59
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
60
+ if tokenizer.pad_token is None:
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+
63
+ # ======================================================
64
+ # LOAD PPO MODEL
65
+ # ======================================================
66
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
67
+ BASE_MODEL,
68
+ torch_dtype=torch.float32
69
+ ).to(device)
70
+
71
+ model.pretrained_model = PeftModel.from_pretrained(
72
+ model.pretrained_model,
73
+ ADAPTER_PATH,
74
+ is_trainable=True
75
+ )
76
+
77
+ # ======================================================
78
+ # LOAD REFERENCE MODEL (FROZEN)
79
+ # ======================================================
80
+ ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
81
+ BASE_MODEL,
82
+ torch_dtype=torch.float32
83
+ ).to(device)
84
+
85
+ ref_model.pretrained_model = PeftModel.from_pretrained(
86
+ ref_model.pretrained_model,
87
+ ADAPTER_PATH,
88
+ is_trainable=False
89
+ )
90
+
91
+ ref_model.eval()
92
+ for p in ref_model.parameters():
93
+ p.requires_grad = False
94
+
95
+ # ======================================================
96
+ # TRAINABLE PARAMS — ONLY LoRA + VALUE HEAD
97
+ # ======================================================
98
+ for name, p in model.named_parameters():
99
+ if "lora_" in name or "v_head" in name:
100
+ p.requires_grad = True
101
+ else:
102
+ p.requires_grad = False
103
+
104
+ model.train()
105
+
106
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
107
+ total = sum(p.numel() for p in model.parameters())
108
+ print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
109
+
110
+ model.config.use_cache = False
111
+ ref_model.config.use_cache = False
112
+
113
+ # ======================================================
114
+ # DATASET
115
+ # ======================================================
116
+ print("Loading Spider subset...")
117
+ random.seed(0)
118
+
119
+ TRAIN_DBS = [
120
+ # already trained
121
+ "flight_1","student_assessment","store_1","bike_1","book_2","chinook_1",
122
+ "academic","aircraft","car_1","cinema","club_1","csu_1",
123
+
124
+ # medium difficulty (NEW)
125
+ "college_1","college_2","company_1","company_employee",
126
+ "customer_complaints","department_store","employee_hire_evaluation",
127
+ "museum_visit","products_for_hire","restaurant_1",
128
+ "school_finance","shop_membership","small_bank_1",
129
+ "soccer_1","student_1","tvshow","voter_1","world_1"
130
+ ]
131
+ dataset = load_dataset("spider", split="train")
132
+ dataset = dataset.filter(lambda x: x["db_id"] in TRAIN_DBS)
133
+
134
+ def valid_example(x):
135
+ return 5 <= len(x["question"].split()) <= 40
136
+
137
+ dataset = dataset.filter(valid_example)
138
+ print("Filtered dataset size:", len(dataset))
139
+
140
+ def sample_example():
141
+ return dataset[random.randrange(len(dataset))]
142
+
143
+ # ======================================================
144
+ # DB UTILITIES
145
+ # ======================================================
146
+ def get_db_path(db_id):
147
+ return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
148
+
149
+ _SCHEMA_CACHE = {}
150
+
151
+ def get_db_schema_cached(db_path):
152
+ if db_path in _SCHEMA_CACHE:
153
+ return _SCHEMA_CACHE[db_path]
154
+
155
+ schema_text = ""
156
+ try:
157
+ conn = sqlite3.connect(db_path)
158
+ cursor = conn.cursor()
159
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
160
+
161
+ for table in tables:
162
+ table_name = table[0]
163
+ columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
164
+ col_names = [col[1] for col in columns]
165
+ schema_text += f"{table_name}({', '.join(col_names)})\n"
166
+ conn.close()
167
+ except:
168
+ pass
169
+
170
+ _SCHEMA_CACHE[db_path] = schema_text.strip()
171
+ return _SCHEMA_CACHE[db_path]
172
+
173
+ # ======================================================
174
+ # PROMPT
175
+ # ======================================================
176
+ def trim_schema(schema: str, max_chars: int = 1200) -> str:
177
+ if schema is None:
178
+ return ""
179
+ schema = str(schema)
180
+ if len(schema) <= max_chars:
181
+ return schema
182
+ return schema[:max_chars]
183
+
184
+ def build_prompt(question: str, schema: str) -> str:
185
+ schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
186
+ return f"Database Schema:\n{schema}\n\nTranslate English to SQL:\n{question}\nSQL:\n"
187
+
188
+ # ======================================================
189
+ # PPO CONFIG (STABLE POLICY LEARNING)
190
+ # ======================================================
191
+ ppo_config = PPOConfig(
192
+ learning_rate=3e-6, # slower = prevents policy jump (very important)
193
+ batch_size=8,
194
+ mini_batch_size=4, # good size, keep this
195
+ gradient_accumulation_steps=2,
196
+
197
+ ppo_epochs=2, # smoother policy update (was 1 → unstable)
198
+
199
+ # ---- KL CONTROL (main fix for negative KL) ----
200
+ init_kl_coef=0.1,
201
+ target_kl=0.08, # 0.02 was too strict → caused oscillation
202
+ adap_kl_ctrl=True,
203
+
204
+ # ---- CLIPPING ----
205
+ cliprange=0.15,
206
+ cliprange_value=0.15,
207
+
208
+ # ---- REWARD STABILITY ----
209
+ whiten_rewards=True, # VERY IMPORTANT for binary execution reward
210
+ kl_penalty="kl",
211
+
212
+ # ---- GRADIENT SAFETY ----
213
+ max_grad_norm=0.3,
214
+ )
215
+ trainer = PPOTrainer(
216
+ config=ppo_config,
217
+ model=model,
218
+ ref_model=ref_model,
219
+ tokenizer=tokenizer,
220
+ )
221
+
222
+ try:
223
+ model.device = torch.device(device)
224
+ except Exception:
225
+ pass
226
+
227
+ # ======================================================
228
+ # GENERATION CONFIG
229
+ # ======================================================
230
+ generation_kwargs = dict(
231
+ max_new_tokens=MAX_OUTPUT_TOKENS,
232
+ do_sample=True,
233
+ temperature=0.7,
234
+ top_p=0.9,
235
+ pad_token_id=tokenizer.pad_token_id,
236
+ eos_token_id=tokenizer.eos_token_id,
237
+ )
238
+ # ======================================================
239
+ # TRAIN LOOP (BATCHED & OPTIMIZED)
240
+ # ======================================================
241
+ print("Starting RL training 🚀 (BART PPO Optimized)")
242
+
243
+ best_reward = -1e9
244
+ global_ppo_step = 0
245
+ model.train()
246
+
247
+ for epoch in range(1, NUM_EPOCHS + 1):
248
+ epoch_reward_sum = 0
249
+ valid_sql_count = 0
250
+ total_seen = 0
251
+
252
+ for step in range(0, ROLLOUTS_PER_EPOCH, ppo_config.batch_size):
253
+
254
+ batch_prompts = []
255
+ batch_meta = []
256
+
257
+ for _ in range(ppo_config.batch_size):
258
+ example = sample_example()
259
+ question = example["question"]
260
+ gold_sql = example["query"]
261
+ db_id = example["db_id"]
262
+ db_path = get_db_path(db_id)
263
+
264
+ schema = get_db_schema_cached(db_path)
265
+ prompt = build_prompt(question, schema)
266
+
267
+ batch_prompts.append(prompt)
268
+ batch_meta.append((question, gold_sql, db_path, db_id))
269
+
270
+ encoded_inputs = tokenizer(
271
+ batch_prompts,
272
+ return_tensors="pt",
273
+ padding=True,
274
+ truncation=True,
275
+ max_length=512,
276
+ pad_to_multiple_of=8
277
+ ).to(device)
278
+
279
+ query_tensors = [encoded_inputs.input_ids[i] for i in range(ppo_config.batch_size)]
280
+
281
+ # 🎯 BYPASS: Native model.generate to prevent TRL's truncation crash
282
+ with torch.no_grad():
283
+ response_tensors_raw = model.generate(
284
+ input_ids=encoded_inputs.input_ids,
285
+ attention_mask=encoded_inputs.attention_mask,
286
+ **generation_kwargs
287
+ )
288
+
289
+ batch_rewards = []
290
+ batch_responses_text = []
291
+ response_tensors = []
292
+
293
+ for i in range(ppo_config.batch_size):
294
+ resp = response_tensors_raw[i]
295
+
296
+ # 🎯 Strip padding safely so TRL's mask calculation never crashes
297
+ non_pad_mask = resp != tokenizer.pad_token_id
298
+ if non_pad_mask.sum() == 0:
299
+ resp = torch.tensor([tokenizer.eos_token_id], device=device)
300
+ non_pad_mask = resp != tokenizer.pad_token_id
301
+
302
+ valid_len = non_pad_mask.nonzero()[-1].item() + 1
303
+ clean_resp = resp[:valid_len]
304
+ response_tensors.append(clean_resp)
305
+
306
+ response = tokenizer.decode(clean_resp, skip_special_tokens=True)
307
+ batch_responses_text.append(response)
308
+
309
+ question, gold_sql, db_path, db_id = batch_meta[i]
310
+ total_seen += 1
311
+
312
+ if "select" not in response.lower():
313
+ batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
314
+ continue
315
+
316
+ reward = execution_reward(response, db_path, gold_sql)
317
+ if reward is None:
318
+ batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
319
+ continue
320
+
321
+ reward = float(reward)
322
+
323
+ pred_tables = extract_tables(response)
324
+ gold_tables = extract_tables(gold_sql)
325
+ if len(gold_tables) > 0:
326
+ reward += 0.25 * (len(pred_tables & gold_tables) / len(gold_tables))
327
+
328
+ pred_cols = extract_columns(response)
329
+ gold_cols = extract_columns(gold_sql)
330
+ if len(gold_cols) > 0:
331
+ reward += 0.15 * (len(pred_cols & gold_cols) / len(gold_cols))
332
+
333
+ reward = max(-1.0, min(1.0, reward))
334
+ batch_rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))
335
+
336
+ epoch_reward_sum += reward
337
+ valid_sql_count += 1
338
+
339
+ # ---------- PPO UPDATE ----------
340
+ try:
341
+ trainer.step(query_tensors, response_tensors, batch_rewards)
342
+ global_ppo_step += 1
343
+ except Exception as e:
344
+ print("⚠️ PPO skipped:", e)
345
+ continue
346
+
347
+ # ---------- LOG ----------
348
+ if step % (LOG_EVERY * ppo_config.batch_size) == 0 and valid_sql_count > 0:
349
+ print("\n---------------------------")
350
+ print(f"Epoch {epoch}/{NUM_EPOCHS} Step {step}/{ROLLOUTS_PER_EPOCH} | Global Update {global_ppo_step}")
351
+ print("Avg Reward:", round(epoch_reward_sum/valid_sql_count,3))
352
+ print("Valid SQL:", valid_sql_count,"/",total_seen)
353
+
354
+ sample_idx = random.randint(0, ppo_config.batch_size - 1)
355
+ print("DB:", batch_meta[sample_idx][3])
356
+ print("Q:", batch_meta[sample_idx][0])
357
+ print("SQL:", batch_responses_text[sample_idx])
358
+ print("Reward:", round(batch_rewards[sample_idx].item(), 3))
359
+
360
+ # ---------- SAVE ONLY THE BEST MODEL ----------
361
+ avg_reward = epoch_reward_sum / max(valid_sql_count, 1)
362
+
363
+ if avg_reward > best_reward:
364
+ best_reward = avg_reward
365
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
366
+
367
+ model.save_pretrained(OUTPUT_DIR)
368
+ tokenizer.save_pretrained(OUTPUT_DIR)
369
+
370
+ print(f"\n✅ Saved BEST RLHF model for Epoch {epoch} (reward {best_reward:.3f}) at {OUTPUT_DIR}")
src/train_rl_codet5.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION)
3
+ # =========================================================
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer
8
+ from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
9
+ from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
10
+ from peft import PeftModel
11
+ import os, sys, sqlite3, re, random
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+ from execution_reward import execution_reward, extract_tables, extract_columns
15
+ try:
16
+ import sqlparse # gate PPO updates on parsable SQL only
17
+ except Exception: # pragma: no cover
18
+ sqlparse = None
19
+
20
+ # ======================================================
21
+ # DEVICE
22
+ # ======================================================
23
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
+ device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
25
+ print("Using device:", device)
26
+
27
+ # ======================================================
28
+ # TRAINING SETTINGS
29
+ # ======================================================
30
+ NUM_EPOCHS = 15
31
+ LOG_EVERY = 20
32
+ USE_SCHEMA = True
33
+ SCHEMA_WARMUP_EPOCHS = 2
34
+ MAX_SCHEMA_CHARS = 1500
35
+ MAX_OUTPUT_TOKENS = 64 # 🚀 Speed up: Reduced max tokens
36
+ ROLLOUTS_PER_EPOCH = 1024
37
+
38
+ # ======================================================
39
+ # PATHS
40
+ # ======================================================
41
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
+ RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "outputs/rlhf_text2sql")
43
+ output_dir = RL_MODEL_PATH
44
+ DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
45
+
46
+ # Explicit resume checkpoint
47
+ RESUME_CHECKPOINT = os.path.join(PROJECT_ROOT, "checkpoints/milestone_before_more_dbs")
48
+
49
+ ADAPTER_PATH = os.path.abspath(os.path.join(PROJECT_ROOT, "checkpoints/sft_adapter_codet5"))
50
+ FALLBACK_ADAPTER_PATH = ADAPTER_PATH
51
+ FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "checkpoints")
52
+
53
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Salesforce/codet5-base")
54
+
55
+ # ======================================================
56
+ # LOAD MODEL (LoRA)
57
+ # ======================================================
58
+ def find_valid_adapter(path_candidates):
59
+ # 🚀 SAFETY & RESUME: Check for existing milestone first
60
+ if os.path.exists(os.path.join(RESUME_CHECKPOINT, "adapter_config.json")):
61
+ print(f"\n✅ Resuming RL training from checkpoint: {RESUME_CHECKPOINT}\n")
62
+ return RESUME_CHECKPOINT
63
+
64
+ for p in path_candidates:
65
+ if p and os.path.exists(os.path.join(p, "adapter_config.json")):
66
+ return os.path.abspath(p)
67
+ return None
68
+
69
+ print("Loading base:", BASE_MODEL)
70
+
71
+ ADAPTER_PATH = find_valid_adapter([
72
+ ADAPTER_PATH,
73
+ FALLBACK_ADAPTER_PATH,
74
+ FALLBACK_ADAPTER_PATH_2,
75
+ ])
76
+
77
+ if ADAPTER_PATH is None:
78
+ raise RuntimeError("❌ No valid LoRA adapter found!")
79
+
80
+ print("Loading adapter:", ADAPTER_PATH)
81
+
82
+ # ======================================================
83
+ # TOKENIZER
84
+ # ======================================================
85
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
86
+ if tokenizer.pad_token is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+
89
+ # ======================================================
90
+ # LOAD PPO MODEL
91
+ # ======================================================
92
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
93
+ BASE_MODEL,
94
+ torch_dtype=torch.float32
95
+ ).to(device)
96
+
97
+ # 🚀 RESUME: Load adapter dynamically and ensure it's trainable
98
+ model.pretrained_model = PeftModel.from_pretrained(
99
+ model.pretrained_model,
100
+ ADAPTER_PATH,
101
+ is_trainable=True
102
+ )
103
+
104
+ # ======================================================
105
+ # LOAD REFERENCE MODEL (FROZEN)
106
+ # ======================================================
107
+ ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
108
+ BASE_MODEL,
109
+ torch_dtype=torch.float32
110
+ ).to(device)
111
+
112
+ ref_model.pretrained_model = PeftModel.from_pretrained(
113
+ ref_model.pretrained_model,
114
+ ADAPTER_PATH,
115
+ is_trainable=False
116
+ )
117
+
118
+ ref_model.eval()
119
+ for p in ref_model.parameters():
120
+ p.requires_grad = False
121
+
122
+ # ======================================================
123
+ # TRAINABLE PARAMS — ONLY LoRA + VALUE HEAD
124
+ # ======================================================
125
+ for name, p in model.named_parameters():
126
+ if "lora_" in name or "v_head" in name:
127
+ p.requires_grad = True
128
+ else:
129
+ p.requires_grad = False
130
+
131
+ model.train()
132
+
133
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
134
+ total = sum(p.numel() for p in model.parameters())
135
+ print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
136
+
137
+ model.config.use_cache = False
138
+ ref_model.config.use_cache = False
139
+
140
+ # ======================================================
141
+ # DATASET
142
+ # ======================================================
143
+ print("Loading Spider subset...")
144
+ random.seed(0)
145
+
146
+ TRAIN_DBS = [
147
+ # already trained
148
+ "flight_1","student_assessment","store_1","bike_1","book_2","chinook_1",
149
+ "academic","aircraft","car_1","cinema","club_1","csu_1",
150
+
151
+ # medium difficulty (NEW)
152
+ "college_1","college_2","company_1","company_employee",
153
+ "customer_complaints","department_store","employee_hire_evaluation",
154
+ "museum_visit","products_for_hire","restaurant_1",
155
+ "school_finance","shop_membership","small_bank_1",
156
+ "soccer_1","student_1","tvshow","voter_1","world_1"
157
+ ]
158
+
159
+ dataset = load_dataset("spider", split="train")
160
+ dataset = dataset.filter(lambda x: x["db_id"] in TRAIN_DBS)
161
+
162
+ def valid_example(x):
163
+ return 5 <= len(x["question"].split()) <= 40
164
+
165
+ dataset = dataset.filter(valid_example)
166
+ print("Filtered dataset size:", len(dataset))
167
+
168
+ def sample_example():
169
+ return dataset[random.randrange(len(dataset))]
170
+
171
+ # ======================================================
172
+ # DB UTILITIES
173
+ # ======================================================
174
+ def get_db_path(db_id):
175
+ return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
176
+
177
+ # 🚀 SPEED OPTIMIZATION: Cache schema so we don't spam disk IO
178
+ _SCHEMA_CACHE = {}
179
+
180
+ def get_db_schema_cached(db_path):
181
+ if db_path in _SCHEMA_CACHE:
182
+ return _SCHEMA_CACHE[db_path]
183
+
184
+ schema_text = ""
185
+ try:
186
+ conn = sqlite3.connect(db_path)
187
+ cursor = conn.cursor()
188
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
189
+
190
+ for table in tables:
191
+ table_name = table[0]
192
+ columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
193
+ col_names = [col[1] for col in columns]
194
+ schema_text += f"{table_name}({', '.join(col_names)}) "
195
+ conn.close()
196
+ except:
197
+ pass
198
+
199
+ _SCHEMA_CACHE[db_path] = schema_text
200
+ return schema_text
201
+
202
+ # ======================================================
203
+ # PROMPT
204
+ # ======================================================
205
+ def trim_schema(schema: str, max_chars: int = 1200) -> str:
206
+ if schema is None:
207
+ return ""
208
+ schema = str(schema)
209
+ if len(schema) <= max_chars:
210
+ return schema
211
+ return schema[:max_chars]
212
+
213
+ def build_prompt(question: str, schema: str, use_schema: bool) -> str:
214
+ if not use_schema:
215
+ return f"### Question:\n{question}\n### SQL:"
216
+ schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
217
+ return f"### Database Schema:\n{schema}\n### Question:\n{question}\n### SQL:"
218
+
219
+ # ======================================================
220
+ # PPO CONFIG (STABLE POLICY LEARNING)
221
+ # ======================================================
222
+ ppo_config = PPOConfig(
223
+ learning_rate=5e-6,
224
+ batch_size=8,
225
+ mini_batch_size=2,
226
+
227
+ gradient_accumulation_steps=2,
228
+ ppo_epochs=1,
229
+ init_kl_coef=0.2,
230
+ target_kl=0.02,
231
+ adap_kl_ctrl=True,
232
+ cliprange=0.1,
233
+ cliprange_value=0.1,
234
+ whiten_rewards=False,
235
+ kl_penalty="kl",
236
+ max_grad_norm=0.5,
237
+ )
238
+
239
+ trainer = PPOTrainer(
240
+ config=ppo_config,
241
+ model=model,
242
+ ref_model=ref_model,
243
+ tokenizer=tokenizer,
244
+ )
245
+
246
+ try:
247
+ model.device = torch.device(device)
248
+ except Exception:
249
+ pass
250
+
251
+ # ======================================================
252
+ # GENERATION CONFIG
253
+ # ======================================================
254
+ # 🚀 SPEED OPTIMIZATION: generation limits and randomness bypass
255
+ generation_kwargs = dict(
256
+ max_new_tokens=MAX_OUTPUT_TOKENS,
257
+ do_sample=True, # TRL Requires do_sample=True
258
+ temperature=1.0, # Disabled randomness logic
259
+ top_p=1.0, # Disabled randomness logic
260
+ top_k=0, # Disabled randomness logic
261
+ pad_token_id=tokenizer.pad_token_id,
262
+ eos_token_id=tokenizer.eos_token_id,
263
+ )
264
+
265
+ # ======================================================
266
+ # TRAIN LOOP (BATCHED & OPTIMIZED)
267
+ # ======================================================
268
+ print("Starting RL training 🚀 (CodeT5 PPO Stable)")
269
+
270
+ best_reward = -1e9
271
+ global_ppo_step = 0
272
+ model.train()
273
+
274
+ for epoch in range(1, NUM_EPOCHS + 1):
275
+ epoch_reward_sum = 0
276
+ valid_sql_count = 0
277
+ total_seen = 0
278
+
279
+ # Process in exact chunks matching batch_size to avoid buffer remnants
280
+ for step in range(0, ROLLOUTS_PER_EPOCH, ppo_config.batch_size):
281
+
282
+ batch_prompts = []
283
+ batch_meta = [] # Store tuple of (question, gold_sql, db_path, db_id)
284
+
285
+ # 🚀 BATCH PREPARATION
286
+ for _ in range(ppo_config.batch_size):
287
+ example = sample_example()
288
+ question = example["question"]
289
+ gold_sql = example["query"]
290
+ db_id = example["db_id"]
291
+ db_path = get_db_path(db_id)
292
+
293
+ schema = get_db_schema_cached(db_path)
294
+ prompt = build_prompt(question, schema, use_schema=True)
295
+
296
+ batch_prompts.append(prompt)
297
+ batch_meta.append((question, gold_sql, db_path, db_id))
298
+
299
+ # 🚀 SPEED OPTIMIZATION: Padded Batch Tokenization (Multiple of 8)
300
+ encoded_inputs = tokenizer(
301
+ batch_prompts,
302
+ return_tensors="pt",
303
+ padding=True,
304
+ truncation=True,
305
+ max_length=512,
306
+ pad_to_multiple_of=8
307
+ ).to(device)
308
+
309
+ # TRL expects lists of 1D tensors
310
+ query_tensors = [encoded_inputs.input_ids[i] for i in range(ppo_config.batch_size)]
311
+
312
+ # 🚀 SPEED OPTIMIZATION: Disable gradients for generation pass
313
+ with torch.no_grad():
314
+ response_tensors = trainer.generate(
315
+ query_tensors,
316
+ **generation_kwargs
317
+ )
318
+
319
+ batch_rewards = []
320
+ batch_responses_text = []
321
+
322
+ # 🚀 BATCH SQL REWARD EXECUTION (Strictly CPU strings)
323
+ for i in range(ppo_config.batch_size):
324
+ response = tokenizer.decode(response_tensors[i], skip_special_tokens=True)
325
+ batch_responses_text.append(response)
326
+ question, gold_sql, db_path, db_id = batch_meta[i]
327
+
328
+ total_seen += 1
329
+
330
+ # ---------- BASIC SQL FILTER ----------
331
+ if "select" not in response.lower():
332
+ batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
333
+ continue
334
+
335
+ # ---------- EXECUTION REWARD ----------
336
+ reward = execution_reward(response, db_path, gold_sql)
337
+ if reward is None:
338
+ batch_rewards.append(torch.tensor(-1.0, dtype=torch.float32).to(device))
339
+ continue
340
+
341
+ reward = float(reward)
342
+
343
+ # ---------- TABLE BONUS ----------
344
+ pred_tables = extract_tables(response)
345
+ gold_tables = extract_tables(gold_sql)
346
+ if len(gold_tables) > 0:
347
+ reward += 0.25 * (len(pred_tables & gold_tables) / len(gold_tables))
348
+
349
+ # ---------- COLUMN BONUS ----------
350
+ pred_cols = extract_columns(response)
351
+ gold_cols = extract_columns(gold_sql)
352
+ if len(gold_cols) > 0:
353
+ reward += 0.15 * (len(pred_cols & gold_cols) / len(gold_cols))
354
+
355
+ # ---------- CLAMP ----------
356
+ reward = max(-1.0, min(1.0, reward))
357
+ batch_rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))
358
+
359
+ epoch_reward_sum += reward
360
+ valid_sql_count += 1
361
+
362
+ # ---------- PPO UPDATE ----------
363
+ try:
364
+ trainer.step(
365
+ query_tensors,
366
+ response_tensors,
367
+ batch_rewards
368
+ )
369
+ global_ppo_step += 1
370
+ except Exception as e:
371
+ print("⚠️ PPO skipped:", e)
372
+ continue
373
+
374
+ # 🚀 AUTO CHECKPOINT SAVING: Every 200 PPO Updates
375
+ if global_ppo_step > 0 and global_ppo_step % 200 == 0:
376
+ step_save_path = os.path.join(PROJECT_ROOT, f"checkpoints/rl_step_{global_ppo_step}")
377
+ os.makedirs(step_save_path, exist_ok=True)
378
+
379
+ # Saves ONLY the adapter, keeping disk usage tiny!
380
+ model.save_pretrained(step_save_path)
381
+ tokenizer.save_pretrained(step_save_path)
382
+ print(f"\n💾 [AUTO-SAVE] Checkpoint saved at PPO step {global_ppo_step} -> {step_save_path}")
383
+
384
+ # ---------- LOG ----------
385
+ if step % (LOG_EVERY * ppo_config.batch_size) == 0 and valid_sql_count > 0:
386
+ print("\n---------------------------")
387
+ print(f"Epoch {epoch}/{NUM_EPOCHS} Step {step}/{ROLLOUTS_PER_EPOCH} | Global Update {global_ppo_step}")
388
+ print("Avg Reward:", round(epoch_reward_sum/valid_sql_count,3))
389
+ print("Valid SQL:", valid_sql_count,"/",total_seen)
390
+
391
+ # Print sample from latest batch
392
+ sample_idx = random.randint(0, ppo_config.batch_size - 1)
393
+ print("DB:", batch_meta[sample_idx][3])
394
+ print("Q:", batch_meta[sample_idx][0])
395
+ print("SQL:", batch_responses_text[sample_idx])
396
+ print("Reward:", round(batch_rewards[sample_idx].item(), 3))
397
+
398
+ # ---------- SAVE BEST MODEL (INSIDE EPOCH) ----------
399
+ avg_reward = epoch_reward_sum / max(valid_sql_count, 1)
400
+
401
+ if avg_reward > best_reward:
402
+ best_reward = avg_reward
403
+ save_path = os.path.join(PROJECT_ROOT, "checkpoints/best_rlhf_model")
404
+ os.makedirs(save_path, exist_ok=True)
405
+
406
+ model.save_pretrained(save_path)
407
+ tokenizer.save_pretrained(save_path)
408
+
409
+ print(f"\n✅ Saved BEST RLHF model for Epoch {epoch} (reward {best_reward:.3f})")
src/train_rl_lora.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ======================================
2
+ # RLHF Text2SQL — FINAL WORKING VERSION
3
+ # T5-small + LoRA + PPO + Execution Reward
4
+ # Single-sample stable training (Mac MPS safe)
5
+ # ======================================
6
+
7
+ from execution_reward import execution_reward
8
+ import os, gc, json, random, torch
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ from trl import PPOTrainer, PPOConfig
11
+ from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead
12
+ from peft import LoraConfig, get_peft_model
13
+
14
+ # ---------------- SETTINGS ----------------
15
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
19
+ print("Using device:", device)
20
+
21
+ os.makedirs("rlhf_text2sql_lora", exist_ok=True)
22
+
23
+ # ---------------- MODEL ----------------
24
+ model_name = "google/flan-t5-small"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
+
30
+ # LoRA
31
+ lora_config = LoraConfig(
32
+ r=8,
33
+ lora_alpha=16,
34
+ target_modules=["q","v"],
35
+ lora_dropout=0.05,
36
+ bias="none",
37
+ task_type="SEQ_2_SEQ_LM",
38
+ )
39
+
40
+ base_model = get_peft_model(base_model, lora_config)
41
+
42
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device)
43
+ ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device)
44
+
45
+ model.config.use_cache = False
46
+ ref_model.config.use_cache = False
47
+
48
+ # ---------------- DATA ----------------
49
+ with open("data/train_spider.json") as f:
50
+ dataset = json.load(f)
51
+
52
+ def build_prompt(example):
53
+ return f"Translate to SQL: {example['question']}"
54
+
55
+ # ---------------- PPO ----------------
56
+ ppo_config = PPOConfig(
57
+ batch_size=1,
58
+ mini_batch_size=1,
59
+ learning_rate=2e-6,
60
+ target_kl=0.05,
61
+ adap_kl_ctrl=True,
62
+ init_kl_coef=0.2,
63
+ )
64
+
65
+ ppo_trainer = PPOTrainer(
66
+ config=ppo_config,
67
+ model=model,
68
+ ref_model=ref_model,
69
+ tokenizer=tokenizer,
70
+ )
71
+
72
+ # ---------------- GENERATION ----------------
73
+ def generate_sql(query_tensors):
74
+
75
+ # deterministic decoding = prevents NaN explosion
76
+ with torch.no_grad():
77
+ response_tensors = ppo_trainer.generate(
78
+ query_tensors,
79
+ max_new_tokens=64,
80
+
81
+ # 🔴 CRITICAL: disable sampling
82
+ do_sample=False,
83
+
84
+ # stable decoding
85
+ num_beams=1,
86
+ early_stopping=True,
87
+
88
+ # prevents invalid tokens
89
+ pad_token_id=tokenizer.eos_token_id,
90
+ )
91
+
92
+ # extra safety (important on MPS)
93
+ cleaned = []
94
+ for t in response_tensors:
95
+ t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0)
96
+ cleaned.append(t)
97
+
98
+ return cleaned
99
+
100
+ # ---------------- TRAIN ----------------
101
+ MAX_STEPS = 1200
102
+
103
+ for step in range(MAX_STEPS):
104
+
105
+ # pick random Spider example
106
+ example = random.choice(dataset)
107
+
108
+ question = example["question"]
109
+ gold_sql = example["query"]
110
+ db_id = example["db_id"]
111
+ db_path = f"data/database/{db_id}/{db_id}.sqlite"
112
+
113
+ # tokenize
114
+ enc = tokenizer(build_prompt(example), return_tensors="pt")
115
+ query_tensor = enc.input_ids.to(device)
116
+ query_tensors = [query_tensor[0]]
117
+
118
+ # generate SQL
119
+ response_tensors = generate_sql(query_tensors)
120
+ pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True)
121
+
122
+ # -------- EXECUTION REWARD --------
123
+ reward = execution_reward(pred_sql, gold_sql, db_path)
124
+ reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device)
125
+
126
+ # PPO update
127
+ stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor])
128
+
129
+ # stabilize
130
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
131
+
132
+ # cleanup
133
+ del query_tensor, response_tensors, reward_tensor
134
+ gc.collect()
135
+ if device == "mps":
136
+ torch.mps.empty_cache()
137
+
138
+ # log
139
+ if step % 20 == 0:
140
+ print(f"\nStep {step}/{MAX_STEPS}")
141
+ print("DB:", db_id)
142
+ print("Q:", question)
143
+ print("Pred:", pred_sql)
144
+ print("Gold:", gold_sql)
145
+ print("Reward:", reward)
146
+
147
+ # ---------------- SAVE ----------------
148
+ model.save_pretrained("rlhf_text2sql_lora")
149
+ tokenizer.save_pretrained("rlhf_text2sql_lora")
150
+
151
+ print("\nTraining complete — model saved!")
src/train_sft.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import torch
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig, get_peft_model
7
+ from transformers import (
8
+ AutoModelForSeq2SeqLM,
9
+ AutoTokenizer,
10
+ DataCollatorForSeq2Seq,
11
+ Seq2SeqTrainer,
12
+ Seq2SeqTrainingArguments,
13
+ )
14
+
15
+ from prompting import clean_gold_sql, get_schema_text, build_prompt
16
+
17
+ # =====================================================
18
+ # SETTINGS
19
+ # =====================================================
20
+ BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
21
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+
23
+ # 🎯 FIXED: Save final model to checkpoints/sft_t5 to protect existing models
24
+ OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_t5")
25
+
26
+ TRAIN_SPLIT = "train[:7000]"
27
+ EPOCHS = 8
28
+ LR = 3e-4
29
+ PER_DEVICE_BATCH = 4
30
+ GRAD_ACCUM = 2
31
+
32
+ MAX_INPUT = 512
33
+ MAX_OUTPUT = 128
34
+
35
+ # =====================================================
36
+ # DEVICE
37
+ # =====================================================
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
40
+ print("Using device:", device)
41
+
42
+ # =====================================================
43
+ # TOKENIZER
44
+ # =====================================================
45
+ print("Loading tokenizer/model:", BASE_MODEL)
46
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
47
+
48
+ if tokenizer.pad_token_id is None:
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+
51
+ # =====================================================
52
+ # PREPROCESS FUNCTION (CRITICAL FIXED VERSION)
53
+ # =====================================================
54
+ def preprocess_function(example):
55
+
56
+ question = example["question"]
57
+ db_id = example["db_id"]
58
+ gold_sql = clean_gold_sql(example["query"])
59
+
60
+ # ---- Build Prompt ----
61
+ schema_text = get_schema_text(db_id)
62
+ prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
63
+
64
+ model_inputs = tokenizer(
65
+ prompt,
66
+ max_length=MAX_INPUT,
67
+ truncation=True,
68
+ padding="max_length",
69
+ )
70
+
71
+ # ---- Target SQL ----
72
+ labels = tokenizer(
73
+ gold_sql,
74
+ max_length=MAX_OUTPUT,
75
+ truncation=True,
76
+ padding="max_length",
77
+ )["input_ids"]
78
+
79
+ # IMPORTANT: ignore padding in loss
80
+ labels = [
81
+ (tok if tok != tokenizer.pad_token_id else -100)
82
+ for tok in labels
83
+ ]
84
+
85
+ model_inputs["labels"] = labels
86
+ return model_inputs
87
+
88
+ # =====================================================
89
+ # DATASET
90
+ # =====================================================
91
+ print("Loading Spider subset:", TRAIN_SPLIT)
92
+ dataset = load_dataset("spider", split=TRAIN_SPLIT)
93
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
94
+
95
+ train_ds = dataset["train"]
96
+ eval_ds = dataset["test"]
97
+
98
+ print("Tokenizing dataset (single process, stable)...")
99
+
100
+ train_tok = train_ds.map(
101
+ preprocess_function,
102
+ batched=False,
103
+ num_proc=1, # 🔥 VERY IMPORTANT FIX
104
+ remove_columns=train_ds.column_names,
105
+ load_from_cache_file=False,
106
+ )
107
+
108
+ eval_tok = eval_ds.map(
109
+ preprocess_function,
110
+ batched=False,
111
+ num_proc=1,
112
+ remove_columns=eval_ds.column_names,
113
+ load_from_cache_file=False,
114
+ )
115
+
116
+ print("Train dataset size:", len(train_tok))
117
+ print("Eval dataset size:", len(eval_tok))
118
+
119
+ # =====================================================
120
+ # MODEL + LoRA
121
+ # =====================================================
122
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
123
+
124
+ base_model.config.use_cache = False
125
+ base_model.gradient_checkpointing_enable()
126
+
127
+ lora_config = LoraConfig(
128
+ r=8,
129
+ lora_alpha=16,
130
+ lora_dropout=0.1,
131
+ bias="none",
132
+ task_type="SEQ_2_SEQ_LM",
133
+ target_modules=["q", "v"], # correct for T5
134
+ )
135
+
136
+ model = get_peft_model(base_model, lora_config)
137
+ model.to(device)
138
+
139
+ # =====================================================
140
+ # TRAINER
141
+ # =====================================================
142
+ data_collator = DataCollatorForSeq2Seq(
143
+ tokenizer=tokenizer,
144
+ model=model,
145
+ padding=True,
146
+ )
147
+
148
+ args = Seq2SeqTrainingArguments(
149
+ # 🎯 FIXED: Changed path to prevent mixing logs with your old CodeT5 logs
150
+ output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_t5_runs"),
151
+ num_train_epochs=EPOCHS,
152
+ learning_rate=LR,
153
+ per_device_train_batch_size=PER_DEVICE_BATCH,
154
+ per_device_eval_batch_size=PER_DEVICE_BATCH,
155
+ gradient_accumulation_steps=GRAD_ACCUM,
156
+ dataloader_num_workers=0,
157
+ dataloader_pin_memory=False,
158
+ evaluation_strategy="epoch",
159
+
160
+ # 🎯 FIXED: "no" completely stops intermediate saving! Only the final model will be saved.
161
+ save_strategy="no",
162
+
163
+ logging_steps=50,
164
+ report_to=[],
165
+ fp16=False,
166
+ bf16=False,
167
+ predict_with_generate=True,
168
+ )
169
+
170
+ trainer = Seq2SeqTrainer(
171
+ model=model,
172
+ args=args,
173
+ train_dataset=train_tok,
174
+ eval_dataset=eval_tok,
175
+ tokenizer=tokenizer,
176
+ data_collator=data_collator,
177
+ )
178
+
179
+ # =====================================================
180
+ # TRAIN
181
+ # =====================================================
182
+ trainer.train()
183
+
184
+ # =====================================================
185
+ # SAVE
186
+ # =====================================================
187
+ print("Saving LoRA adapter to:", OUT_DIR)
188
+ os.makedirs(OUT_DIR, exist_ok=True)
189
+ model.save_pretrained(OUT_DIR)
190
+ tokenizer.save_pretrained(OUT_DIR)
191
+
192
+ print("DONE ✔ SFT warmup finished")
src/train_sft_bart.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import torch
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig, get_peft_model
7
+ from transformers import (
8
+ AutoModelForSeq2SeqLM,
9
+ AutoTokenizer,
10
+ DataCollatorForSeq2Seq,
11
+ Seq2SeqTrainer,
12
+ Seq2SeqTrainingArguments,
13
+ )
14
+
15
+ from prompting import clean_gold_sql, get_schema_text, build_prompt
16
+
17
+ # =====================================================
18
+ # SETTINGS
19
+ # =====================================================
20
+ BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base")
21
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+
23
+ OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_best_bart_2")
24
+
25
+ TRAIN_SPLIT = "train[:7000]"
26
+
27
+ EPOCHS = 12
28
+ LR = 3e-4
29
+ PER_DEVICE_BATCH = 16
30
+ GRAD_ACCUM = 4
31
+
32
+ MAX_INPUT = 512
33
+ MAX_OUTPUT = 128
34
+
35
+ # =====================================================
36
+ # DEVICE
37
+ # =====================================================
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
40
+ print("Using device:", device)
41
+
42
+ # =====================================================
43
+ # TOKENIZER
44
+ # =====================================================
45
+ print("Loading tokenizer/model:", BASE_MODEL)
46
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
47
+
48
+ if tokenizer.pad_token_id is None:
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+
51
+ # =====================================================
52
+ # PREPROCESS FUNCTION
53
+ # =====================================================
54
+ def preprocess_function(example):
55
+
56
+ question = example["question"]
57
+ db_id = example["db_id"]
58
+ gold_sql = clean_gold_sql(example["query"])
59
+
60
+ # ---- Build Prompt ----
61
+ schema_text = get_schema_text(db_id)
62
+ prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
63
+
64
+ model_inputs = tokenizer(
65
+ prompt,
66
+ max_length=MAX_INPUT,
67
+ truncation=True,
68
+ padding="max_length",
69
+ )
70
+
71
+ # ---- Target SQL ----
72
+ labels = tokenizer(
73
+ gold_sql,
74
+ max_length=MAX_OUTPUT,
75
+ truncation=True,
76
+ padding="max_length",
77
+ )["input_ids"]
78
+
79
+ # IMPORTANT: ignore padding in loss
80
+ labels = [
81
+ (tok if tok != tokenizer.pad_token_id else -400)
82
+ for tok in labels
83
+ ]
84
+
85
+ model_inputs["labels"] = labels
86
+ return model_inputs
87
+
88
+ # =====================================================
89
+ # DATASET
90
+ # =====================================================
91
+ print("Loading Spider subset:", TRAIN_SPLIT)
92
+ dataset = load_dataset("spider", split=TRAIN_SPLIT)
93
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
94
+
95
+ train_ds = dataset["train"]
96
+ eval_ds = dataset["test"]
97
+
98
+ print("Tokenizing dataset (single process, stable)...")
99
+
100
+ train_tok = train_ds.map(
101
+ preprocess_function,
102
+ batched=False,
103
+ num_proc=1,
104
+ remove_columns=train_ds.column_names,
105
+ load_from_cache_file=False,
106
+ )
107
+
108
+ eval_tok = eval_ds.map(
109
+ preprocess_function,
110
+ batched=False,
111
+ num_proc=1,
112
+ remove_columns=eval_ds.column_names,
113
+ load_from_cache_file=False,
114
+ )
115
+
116
+ print("Train dataset size:", len(train_tok))
117
+ print("Eval dataset size:", len(eval_tok))
118
+
119
+ # =====================================================
120
+ # MODEL + LoRA
121
+ # =====================================================
122
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
123
+
124
+ base_model.config.use_cache = False
125
+
126
+ # 🚀 UPGRADE 1: Expanded LoRA brainpower
127
+ lora_config = LoraConfig(
128
+ r=16, # Increased rank for more learning capacity
129
+ lora_alpha=32, # Alpha is typically 2x the rank
130
+ lora_dropout=0.1,
131
+ bias="none",
132
+ task_type="SEQ_2_SEQ_LM",
133
+ # Target all attention and dense layers in BART
134
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
135
+ )
136
+
137
+ model = get_peft_model(base_model, lora_config)
138
+ model.to(device)
139
+
140
+ # =====================================================
141
+ # TRAINER
142
+ # =====================================================
143
+ data_collator = DataCollatorForSeq2Seq(
144
+ tokenizer=tokenizer,
145
+ model=model,
146
+ padding=True,
147
+ )
148
+
149
+ args = Seq2SeqTrainingArguments(
150
+ output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_bart_runs"),
151
+ num_train_epochs=EPOCHS,
152
+ learning_rate=LR,
153
+ per_device_train_batch_size=PER_DEVICE_BATCH,
154
+ per_device_eval_batch_size=PER_DEVICE_BATCH,
155
+ gradient_accumulation_steps=GRAD_ACCUM,
156
+ dataloader_num_workers=0,
157
+ dataloader_pin_memory=False,
158
+
159
+ # 🚀 UPGRADE 2 & 3: Better optimization & generalization
160
+ warmup_ratio=0.05, # Slowly ramp up learning rate
161
+ weight_decay=0.01, # Penalize over-reliance on single tokens
162
+ label_smoothing_factor=0.1, # Prevent overconfidence in SQL token matching
163
+
164
+ evaluation_strategy="epoch",
165
+ save_strategy="epoch",
166
+
167
+ save_total_limit=1,
168
+ load_best_model_at_end=True,
169
+ metric_for_best_model="eval_loss",
170
+ greater_is_better=False,
171
+
172
+ logging_steps=50,
173
+ report_to=[],
174
+ fp16=False,
175
+ bf16=False,
176
+ predict_with_generate=True,
177
+ )
178
+
179
+ trainer = Seq2SeqTrainer(
180
+ model=model,
181
+ args=args,
182
+ train_dataset=train_tok,
183
+ eval_dataset=eval_tok,
184
+ tokenizer=tokenizer,
185
+ data_collator=data_collator,
186
+ )
187
+
188
+ # =====================================================
189
+ # TRAIN
190
+ # =====================================================
191
+ trainer.train()
192
+
193
+ # =====================================================
194
+ # SAVE BEST MODEL
195
+ # =====================================================
196
+ print("Saving best BART LoRA adapter to:", OUT_DIR)
197
+ os.makedirs(OUT_DIR, exist_ok=True)
198
+
199
+ trainer.model.save_pretrained(OUT_DIR)
200
+ tokenizer.save_pretrained(OUT_DIR)
201
+
202
+ print("DONE ✔ SFT BART finished")
src/train_sft_codet5.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import torch
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig, get_peft_model
7
+ from transformers import (
8
+ AutoModelForSeq2SeqLM,
9
+ AutoTokenizer,
10
+ DataCollatorForSeq2Seq,
11
+ Seq2SeqTrainer,
12
+ Seq2SeqTrainingArguments,
13
+ )
14
+
15
+ from prompting import clean_gold_sql, get_schema_text, build_prompt
16
+
17
+ # =====================================================
18
+ # SETTINGS
19
+ # =====================================================
20
+ BASE_MODEL = "Salesforce/codet5-base"
21
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+ OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter_codet5")
23
+
24
+ TRAIN_SPLIT = "train[:7000]"
25
+ EPOCHS = 10
26
+ LR = 2e-4
27
+ PER_DEVICE_BATCH = 2 # codet5 bigger -> reduce
28
+ GRAD_ACCUM = 4
29
+
30
+ MAX_INPUT = 512
31
+ MAX_OUTPUT = 160
32
+
33
+ # =====================================================
34
+ # DEVICE
35
+ # =====================================================
36
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
37
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
38
+ print("Using device:", device)
39
+
40
+ # =====================================================
41
+ # TOKENIZER
42
+ # =====================================================
43
+ print("Loading tokenizer/model:", BASE_MODEL)
44
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
+
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ # =====================================================
50
+ # PREPROCESS FUNCTION
51
+ # =====================================================
52
+ def preprocess_function(example):
53
+
54
+ question = example["question"]
55
+ db_id = example["db_id"]
56
+ gold_sql = clean_gold_sql(example["query"])
57
+
58
+ schema_text = get_schema_text(db_id)
59
+ prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
60
+
61
+ model_inputs = tokenizer(
62
+ prompt,
63
+ max_length=MAX_INPUT,
64
+ truncation=True,
65
+ padding="max_length",
66
+ )
67
+
68
+ labels = tokenizer(
69
+ gold_sql,
70
+ max_length=MAX_OUTPUT,
71
+ truncation=True,
72
+ padding="max_length",
73
+ )["input_ids"]
74
+
75
+ labels = [(tok if tok != tokenizer.pad_token_id else -100) for tok in labels]
76
+
77
+ model_inputs["labels"] = labels
78
+ return model_inputs
79
+
80
+ # =====================================================
81
+ # DATASET
82
+ # =====================================================
83
+ print("Loading Spider subset:", TRAIN_SPLIT)
84
+ dataset = load_dataset("spider", split=TRAIN_SPLIT)
85
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
86
+
87
+ train_ds = dataset["train"]
88
+ eval_ds = dataset["test"]
89
+
90
+ print("Tokenizing dataset...")
91
+
92
+ train_tok = train_ds.map(
93
+ preprocess_function,
94
+ batched=False,
95
+ num_proc=1,
96
+ remove_columns=train_ds.column_names,
97
+ load_from_cache_file=False,
98
+ )
99
+
100
+ eval_tok = eval_ds.map(
101
+ preprocess_function,
102
+ batched=False,
103
+ num_proc=1,
104
+ remove_columns=eval_ds.column_names,
105
+ load_from_cache_file=False,
106
+ )
107
+
108
+ print("Train dataset size:", len(train_tok))
109
+ print("Eval dataset size:", len(eval_tok))
110
+
111
+ # =====================================================
112
+ # MODEL + LoRA (CODET5 FIXED)
113
+ # =====================================================
114
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
115
+
116
+ base_model.config.use_cache = False
117
+ base_model.gradient_checkpointing_enable()
118
+
119
+ # 🔥 DIFFERENT FROM T5
120
+ lora_config = LoraConfig(
121
+ r=16,
122
+ lora_alpha=32,
123
+ lora_dropout=0.05,
124
+ bias="none",
125
+ task_type="SEQ_2_SEQ_LM",
126
+ target_modules=["q", "v"], # IMPORTANT FOR CODET5
127
+ )
128
+
129
+ model = get_peft_model(base_model, lora_config)
130
+ model.to(device)
131
+
132
+ model.print_trainable_parameters()
133
+
134
+ # =====================================================
135
+ # TRAINER
136
+ # =====================================================
137
+ data_collator = DataCollatorForSeq2Seq(
138
+ tokenizer=tokenizer,
139
+ model=model,
140
+ padding=True,
141
+ )
142
+
143
+ args = Seq2SeqTrainingArguments(
144
+ output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_runs_codet5"),
145
+ num_train_epochs=EPOCHS,
146
+ learning_rate=LR,
147
+ per_device_train_batch_size=PER_DEVICE_BATCH,
148
+ per_device_eval_batch_size=PER_DEVICE_BATCH,
149
+ gradient_accumulation_steps=GRAD_ACCUM,
150
+ dataloader_num_workers=0,
151
+ dataloader_pin_memory=False,
152
+ evaluation_strategy="epoch",
153
+ save_strategy="epoch",
154
+ save_total_limit=1,
155
+ logging_steps=50,
156
+ report_to=[],
157
+ fp16=False,
158
+ bf16=False,
159
+ predict_with_generate=True,
160
+ )
161
+
162
+ trainer = Seq2SeqTrainer(
163
+ model=model,
164
+ args=args,
165
+ train_dataset=train_tok,
166
+ eval_dataset=eval_tok,
167
+ tokenizer=tokenizer,
168
+ data_collator=data_collator,
169
+ )
170
+
171
+ # =====================================================
172
+ # TRAIN
173
+ # =====================================================
174
+ trainer.train()
175
+
176
+ # =====================================================
177
+ # SAVE
178
+ # =====================================================
179
+ # =====================================================
180
+ # SAVE (SAFE PEFT SAVE)
181
+ # =====================================================
182
+ print("Saving LoRA adapter to:", OUT_DIR)
183
+ os.makedirs(OUT_DIR, exist_ok=True)
184
+
185
+ # unwrap trainer model (important!)
186
+ peft_model = trainer.model
187
+
188
+ # ensure on cpu before saving (mac mps bug fix)
189
+ peft_model = peft_model.to("cpu")
190
+
191
+ # save adapter only
192
+ peft_model.save_pretrained(OUT_DIR)
193
+ tokenizer.save_pretrained(OUT_DIR)
194
+
195
+ print("DONE ✔ CodeT5 SFT finished")