tjhalanigrid commited on
Commit
862cf68
·
1 Parent(s): 14b487f
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: Text2sql Demo
3
- emoji: 📊
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- python_version: 3.10.13
12
- short_description: 'Text to SQL with RLHF'
13
- ---
 
 
1
  ---
2
  title: Text2sql Demo
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 6.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: 'to show the gradio interface '
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- GRADIO DEMO UI - LAZY LOADING EDITION
3
  NL → SQL → Result Table
4
  """
5
 
@@ -7,563 +7,273 @@ import gradio as gr
7
  import pandas as pd
8
  import re
9
  import time
10
- import os
11
- import torch
12
- import sys
13
- import json
14
- import subprocess
15
- import base64
16
- import io
17
- from pathlib import Path
18
- from typing import Iterator
19
-
20
- # ==========================================
21
- # RELATIVE PATH RESOLUTION (GLOBAL)
22
- # ==========================================
23
- try:
24
- PROJECT_ROOT = Path(__file__).resolve().parent
25
- except NameError:
26
- PROJECT_ROOT = Path(".").resolve()
27
-
28
- if (PROJECT_ROOT / "data" / "database").exists():
29
- DB_ROOT = PROJECT_ROOT / "data" / "database"
30
- else:
31
- DB_ROOT = PROJECT_ROOT / "final_databases"
32
-
33
- def get_db_path(db_id: str) -> str:
34
- path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
35
- path2 = DB_ROOT / f"{db_id}.sqlite"
36
- return str(path1) if path1.exists() else str(path2)
37
-
38
- # ==========================================
39
- # 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
40
- # ==========================================
41
- if not torch.cuda.is_available():
42
- class MockCUDAEvent:
43
- def __init__(self, enable_timing=False, blocking=False, interprocess=False):
44
- self.t = 0.0
45
- def record(self, stream=None):
46
- self.t = time.perf_counter()
47
- def elapsed_time(self, end_event):
48
- return (end_event.t - self.t) * 1000.0
49
-
50
- torch.cuda.Event = MockCUDAEvent
51
- if not hasattr(torch.cuda, 'synchronize'):
52
- torch.cuda.synchronize = lambda: None
53
-
54
- # ==========================================
55
- # IMPORTS & ENGINE SETUP
56
- # ==========================================
57
- from src.quantized_text2sql_engine import QuantizedText2SQLEngine
58
- from src.schema_encoder import SchemaEncoder
59
-
60
- DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
61
-
62
- _ENGINE_CACHE = {}
63
- _QUERY_LOG = []
64
- _PERF_LOG = []
65
- _SUCCESS_LOG = []
66
-
67
- _OP_STATS = {
68
- "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
69
- "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
70
- }
71
-
72
- def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
73
- key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
74
- if key not in _ENGINE_CACHE:
75
- try:
76
- _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
77
- except TypeError:
78
- _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
79
- return _ENGINE_CACHE[key]
80
-
81
- # 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
82
- quant_engine = None
83
- try:
84
- schema_encoder = SchemaEncoder(DB_ROOT)
85
- except Exception as e:
86
- print(f"Warning: SchemaEncoder failed to load: {e}")
87
- schema_encoder = None
88
 
 
 
 
 
 
89
  SAMPLES = [
90
- ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
91
- ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
92
- ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
93
- ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
94
- ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
95
- ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
 
 
 
 
 
 
96
  ]
 
97
  SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
98
 
 
 
 
99
  def explain_sql(sql):
100
- if not sql: return ""
101
  explanation = "This SQL query retrieves information from the database."
102
  sql_lower = sql.lower()
103
- if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
104
- if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
105
- if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
106
- if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
107
- if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
 
 
 
 
 
 
 
108
  return explanation
109
 
110
- def sql_ops(sql: str) -> list[str]:
111
- s = (sql or "").lower()
112
- ops = ["SELECT"]
113
- if " where " in f" {s} ": ops.append("WHERE")
114
- if " join " in f" {s} ": ops.append("JOIN")
115
- if " group by " in f" {s} ": ops.append("GROUP_BY")
116
- if " order by " in f" {s} ": ops.append("ORDER_BY")
117
- if " having " in f" {s} ": ops.append("HAVING")
118
- if " limit " in f" {s} ": ops.append("LIMIT")
119
- return ops
120
-
121
- def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
122
- s = (sql or "").lower()
123
- m = (error_msg or "").lower()
124
- if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
125
- if not s.strip().startswith(("select", "with")): return "syntax_error"
126
- if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
127
- if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
128
- if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
129
- if "no such table" in m: return "missing_table"
130
- if "no such column" in m: return "missing_column"
131
- if "ambiguous column name" in m: return "ambiguous_column"
132
- if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
133
- if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
134
- if "syntax error" in m: return "syntax_error"
135
- if "near" in m and "syntax error" in m: return "syntax_error"
136
- if "runtime" in m or "constraint failed" in m: return "runtime_error"
137
- return "other"
138
-
139
- def get_hint(error_type):
140
- hints = {
141
- "missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
142
- "wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
143
- "missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
144
- "ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
145
- }
146
- return hints.get(error_type, "Review query.")
147
-
148
- def is_relevant_to_schema(question, db_id):
149
- if schema_encoder is None: return True
150
- try: raw_schema = schema_encoder.structured_schema(db_id).lower()
151
- except: return True
152
- schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
153
- q_words = re.findall(r'[a-z0-9_]+', question.lower())
154
- stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
155
- meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
156
- if not meaningful_q_words: return True
157
- for word in meaningful_q_words:
158
- singular_word = word[:-1] if word.endswith('s') else word
159
- if word in schema_words or singular_word in schema_words: return True
160
- return False
161
 
 
 
 
162
  def run_query(method, sample_q, custom_q, db_id):
163
- global quant_engine
164
 
165
- # 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
166
- if quant_engine is None:
167
- print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
168
- try:
169
- quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
170
- if quant_engine is None:
171
- return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
172
- except Exception as e:
173
- return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
174
-
175
- def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
176
- _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
177
-
178
- def _perf_log(payload: dict) -> None:
179
- _PERF_LOG.append(payload)
180
- if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
181
-
182
- raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
183
-
184
- if not raw_question or str(raw_question).strip() == "":
185
- return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
186
- if not db_id or str(db_id).strip() == "":
187
- return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
188
-
189
- typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
190
- question = str(raw_question)
191
- for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
192
- q_lower = question.strip().lower()
193
 
194
- if len(q_lower.split()) < 2:
195
- _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
196
- return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
197
-
198
- if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
199
- _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
200
- return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
201
-
202
- if not is_relevant_to_schema(question, db_id):
203
- _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
204
- return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
205
 
206
  start_time = time.time()
207
- t0 = time.perf_counter()
208
- ui_warnings = ""
209
 
 
210
  try:
211
- try:
212
- result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
213
- except TypeError:
214
- result = quant_engine.ask(question, str(db_id))
215
  except Exception as e:
216
- _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
217
- return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
218
-
219
- final_sql = str(result.get("sql", ""))
220
- model_sql = final_sql
221
-
222
- num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
223
- if not num_match and q_lower.startswith(("show", "list", "get")):
224
- num_match = re.search(r'\b(\d+)\b', q_lower)
225
-
226
- if num_match and final_sql:
227
- limit_val = num_match.group(1)
228
- final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
229
- final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
230
- final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
231
- final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
232
-
233
- agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
234
- if not any(k in q_lower for k in agg_kws):
235
- final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
236
- final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
237
- final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
238
- final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
239
-
240
- if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
241
- final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
242
-
243
- if "limit" not in final_sql.lower():
244
- final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
245
-
246
- # Execution
247
- from src.sql_validator import validate_sql_schema
248
- db_path = get_db_path(str(db_id))
249
-
250
- try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
251
- except Exception: strict_valid = False
252
 
253
- error_msg = None
254
- rows, cols = [], []
255
- sqlite_success = False
 
256
 
257
- try:
258
- rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
259
- sqlite_success = True
260
- except Exception as e:
261
- error_msg = str(e)
262
- sqlite_success = False
263
-
264
- if not sqlite_success and model_sql and model_sql != final_sql:
265
- try:
266
- alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
267
- final_sql = model_sql
268
- rows, cols = alt_rows, alt_cols
269
- sqlite_success = True
270
- error_msg = None
271
- except Exception: pass
272
-
273
- valid = sqlite_success
274
-
275
- if error_msg or not valid:
276
- et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
277
- _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
278
-
279
- latency = round(time.time() - start_time, 3)
280
- t1 = time.perf_counter()
281
-
282
- engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
283
-
284
- perf = {
285
- "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
286
- "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
287
- "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
288
- }
289
- _perf_log(perf)
290
-
291
- window = _PERF_LOG[-50:]
292
- avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
293
- constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
294
-
295
- perf_block = (
296
- "\n\n---\nPerformance (task impact)\n"
297
- f"- Total latency (ms): {perf['latency_total_ms']}\n"
298
- f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
299
- f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
300
- f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
301
- f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
302
- )
303
 
304
- if error_msg or not valid:
305
- display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
306
- explanation = f"{ui_warnings}Error Details:\n\n"
307
- if error_msg: explanation += f"{error_msg}\n\n"
308
-
309
- error_type = classify_error(final_sql, str(error_msg or ""))
310
- explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
311
- explanation += perf_block
312
- ops = sql_ops(final_sql)
313
- for op in ops:
314
- if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
315
- return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
316
-
317
- safe_cols = cols if cols else ["Result"]
318
- explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
319
-
320
- ops = sql_ops(final_sql)
321
- for op in ops:
322
- if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
323
- _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
324
 
325
  limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
326
- if limit_match and len(rows) < int(limit_match.group(1)):
327
- explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
328
-
329
- return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
330
-
331
- def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
332
- project_root = str(PROJECT_ROOT)
333
- env = os.environ.copy()
334
- env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
335
- env.setdefault("MPLBACKEND", "Agg")
336
- env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
337
- try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
338
- except Exception: pass
339
-
340
- cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
341
- proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
342
- last_yield = time.perf_counter()
343
- lines: list[str] = []
344
- yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
345
-
346
- assert proc.stdout is not None
347
- for line in proc.stdout:
348
- lines.append(line)
349
- now = time.perf_counter()
350
- if now - last_yield >= 0.5:
351
- last_yield = now
352
- yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
353
-
354
- proc.wait()
355
- out = "".join(lines).strip()
356
-
357
- plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
358
- if os.path.exists(plot_path):
359
- try:
360
- b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
361
- yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
362
- return
363
- except Exception:
364
- yield out, f"<pre>{plot_path}</pre>"
365
- return
366
-
367
- yield out, "<i>No plot generated</i>"
368
-
369
- def task2_dashboard_structured():
370
- if not _QUERY_LOG:
371
- empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
372
- empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
373
- return empty_counts, empty_recent, gr.update(choices=[], value=None)
374
-
375
- counts = {}
376
- for r in _QUERY_LOG[-1000:]:
377
- k = r.get("error_type") or "other"
378
- counts[k] = counts.get(k, 0) + 1
379
- rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
380
- counts_df = pd.DataFrame(rows)
381
-
382
- recent = []
383
- for r in _QUERY_LOG[-100:]:
384
- ts = r.get("t")
385
- try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
386
- except Exception: ts_s = ""
387
- recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
388
- recent_df = pd.DataFrame(recent)
389
-
390
- choices = [str(x["error_type"]) for x in rows]
391
- default = choices[0] if choices else None
392
- return counts_df, recent_df, gr.update(choices=choices, value=default)
393
-
394
- def task2_error_examples(error_type: str) -> str:
395
- if not error_type: return ""
396
- hint = get_hint(error_type)
397
- matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
398
- if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
399
- out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
400
- for i, r in enumerate(matches, 1):
401
- out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
402
- return "\n".join(out).strip()
403
-
404
- def _plot_op_stats_html() -> str:
405
- try:
406
- import matplotlib.pyplot as plt
407
- labels = list(_OP_STATS.keys())
408
- oks = [int(_OP_STATS[k]["ok"]) for k in labels]
409
- fails = [int(_OP_STATS[k]["fail"]) for k in labels]
410
-
411
- fig, ax = plt.subplots(figsize=(9, 3.5))
412
- x = list(range(len(labels)))
413
- ax.bar(x, oks, label="ok", color="#16a34a")
414
- ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
415
- ax.set_xticks(x)
416
- ax.set_xticklabels(labels, rotation=30, ha="right")
417
- ax.set_title("Success/Failure by SQL operation")
418
- ax.legend()
419
- fig.tight_layout()
420
-
421
- buf = io.BytesIO()
422
- fig.savefig(buf, format="png", dpi=160)
423
- plt.close(fig)
424
- b64 = base64.b64encode(buf.getvalue()).decode("ascii")
425
- return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
426
- except Exception as e: return f"<pre>Plot error: {e}</pre>"
427
-
428
- def task2_ops_table():
429
- rows = []
430
- for op, d in _OP_STATS.items():
431
- ok = int(d.get("ok", 0))
432
- fail = int(d.get("fail", 0))
433
- total = ok + fail
434
- rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
435
- return pd.DataFrame(rows), _plot_op_stats_html()
436
 
437
  def toggle_input_method(method, current_sample):
438
  if method == "💡 Pick a Sample":
 
439
  db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
440
- return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
441
- return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  def load_sample(selected_question):
444
- if not selected_question: return gr.update()
445
- return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
 
 
 
446
 
447
  def clear_inputs():
448
- return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
 
 
 
 
 
 
 
449
 
450
  def update_schema(db_id):
451
- if not db_id or schema_encoder is None: return ""
 
452
  try:
453
- raw_schema = schema_encoder.structured_schema(db_id)
454
  html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
455
  for line in raw_schema.strip().split('\n'):
456
  line = line.strip()
457
  if not line: continue
458
  match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
459
- if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
460
- else: html_output += f"<div style='color: #475569;'>{line}</div>"
 
 
 
 
461
  html_output += "</div>"
462
  return html_output
463
- except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
 
 
464
 
465
  # =========================
466
  # UI LAYOUT
467
  # =========================
468
- with gr.Blocks(title="Text-to-SQL RLHF") as demo:
469
- gr.HTML("""
 
 
470
  <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
471
  <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
472
  <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
473
  </div>
474
- """)
 
475
 
476
- DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- with gr.Tabs():
479
- with gr.Tab("Inference"):
480
  with gr.Row():
481
- with gr.Column(scale=1):
482
- gr.Markdown("### 1. Configuration & Input")
483
- input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
484
- sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
485
- type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
486
- gr.Markdown("---")
487
- db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
488
- custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
489
-
490
- gr.Markdown("#### 📋 Database Structure")
491
- gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
492
- schema_display = gr.HTML(value=update_schema("chinook_1"))
493
-
494
- with gr.Row():
495
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
496
- run_btn = gr.Button(" Generate & Run SQL", variant="primary")
497
-
498
- with gr.Column(scale=2):
499
- gr.Markdown("### 2. Execution Results")
500
- final_sql = gr.Code(language="sql", label="Final Executed SQL")
501
- result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
502
- explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
503
-
504
- with gr.Tab("Diagnostics"):
505
- gr.Markdown("## Diagnostics & Telemetry")
506
-
507
- with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
508
- gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
509
- t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
510
- t1_workers = gr.Number(value=10, precision=0, label="Max workers")
511
- t1_run = gr.Button("Run Task 1 benchmark")
512
- t1_out = gr.Textbox(label="Output", lines=12)
513
- t1_plot = gr.HTML(label="Plot (if generated)")
514
- t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
515
-
516
- with gr.Accordion("Task 2: Error Dashboard", open=True):
517
- gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
518
- t2_refresh = gr.Button("Refresh dashboard")
519
- t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
520
- t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
521
- t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
522
- t2_examples = gr.Textbox(label="Examples + hint", lines=10)
523
-
524
- t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
525
- t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
526
-
527
- with gr.Accordion("Task 2: Clause Telemetry", open=False):
528
- gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
529
- t2_ops_refresh = gr.Button("Refresh SQL-op stats")
530
- t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
531
- t2_ops_plot = gr.HTML(label="Op plot")
532
- t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
533
-
534
- # EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
535
- input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
536
  sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
 
537
  db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
538
 
539
  run_btn.click(
540
- fn=run_query,
541
- inputs=[input_method, sample_dropdown, custom_question, db_id],
542
  outputs=[final_sql, result_table, explanation]
543
- ).then(
544
- fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
545
- ).then(
546
- fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
547
  )
548
 
549
- clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
 
 
 
 
 
550
 
551
  if __name__ == "__main__":
552
- server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
553
- base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
554
- max_retries = 10
555
-
556
- for port in range(base_port, base_port + max_retries):
557
- try:
558
- print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
559
- demo.launch(server_name=server_name, server_port=port)
560
- break # If successful, exit the loop
561
- except OSError as e:
562
- if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
563
- print(f"⚠️ Port {port} is in use, trying next port...")
564
- continue
565
- else:
566
- # If it's a different OSError, raise it normally
567
- raise e
568
- else:
569
- print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
 
1
  """
2
+ GRADIO DEMO UI
3
  NL → SQL → Result Table
4
  """
5
 
 
7
  import pandas as pd
8
  import re
9
  import time
10
+ from src.text2sql_engine import get_engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ engine = get_engine()
13
+
14
+ # =========================
15
+ # SAMPLE QUESTIONS DATA
16
+ # =========================
17
  SAMPLES = [
18
+ ("Show 10 distinct employee first names.", "chinook_1"),
19
+ ("Which artist has the most albums?", "chinook_1"),
20
+ ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"),
21
+ ("What are the names of all the cities?", "flight_1"),
22
+ ("Find the flight number and cost of the cheapest flight.", "flight_1"),
23
+ ("List the airlines that fly out of New York.", "flight_1"),
24
+ ("Which campus was opened between 1935 and 1939?", "csu_1"),
25
+ ("Count the number of students in each department.", "college_2"),
26
+ ("List the names of all clubs.", "club_1"),
27
+ ("How many members does each club have?", "club_1"),
28
+ ("Show the names of all cinemas.", "cinema"),
29
+ ("Which cinema has the most screens?", "cinema")
30
  ]
31
+
32
  SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
33
 
34
+ # =========================
35
+ # SQL EXPLAINER
36
+ # =========================
37
  def explain_sql(sql):
 
38
  explanation = "This SQL query retrieves information from the database."
39
  sql_lower = sql.lower()
40
+
41
+ if "join" in sql_lower:
42
+ explanation += "\n• It combines data from multiple tables using JOIN."
43
+ if "where" in sql_lower:
44
+ explanation += "\n• It filters rows using a WHERE condition."
45
+ if "group by" in sql_lower:
46
+ explanation += "\n• It groups results using GROUP BY."
47
+ if "order by" in sql_lower:
48
+ explanation += "\n• It sorts the results using ORDER BY."
49
+ if "limit" in sql_lower:
50
+ explanation += "\n• It limits the number of returned rows."
51
+
52
  return explanation
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # =========================
56
+ # CORE FUNCTIONS
57
+ # =========================
58
  def run_query(method, sample_q, custom_q, db_id):
 
59
 
60
+ # 1. Safely determine the question
61
+ question = sample_q if method == "💡 Pick a Sample" else custom_q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # 2. Validate inputs before hitting the engine
64
+ if not question or str(question).strip() == "":
65
+ return "", pd.DataFrame(), "⚠️ Please enter a question."
66
+
67
+ if not db_id or str(db_id).strip() == "":
68
+ return "", pd.DataFrame(), "⚠️ Please select a database."
 
 
 
 
 
69
 
70
  start_time = time.time()
 
 
71
 
72
+ # 3. GIANT SAFETY NET to prevent infinite loading spinners
73
  try:
74
+ result = engine.ask(str(question), str(db_id))
 
 
 
75
  except Exception as e:
76
+ return "", pd.DataFrame(), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ final_sql = result.get("sql", "")
79
+ error_msg = result.get("error", None)
80
+ rows = result.get("rows", [])
81
+ cols = result.get("columns", [])
82
 
83
+ end_time = time.time()
84
+ latency = round(end_time - start_time, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # 4. Handle SQL generation/execution errors
87
+ if error_msg:
88
+ return final_sql, pd.DataFrame(), f"❌ SQL Error:\n{error_msg}"
89
+
90
+ # 5. Handle Zero Rows gracefully
91
+ if not rows:
92
+ df = pd.DataFrame(columns=cols if cols else [])
93
+ explanation = f"✅ Query executed successfully\n\nRows returned: 0\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
94
+ return final_sql, df, explanation
95
+
96
+ # 6. Handle successful execution
97
+ df = pd.DataFrame(rows, columns=cols)
98
+ actual_rows = len(rows)
99
+
100
+ explanation = f"✅ Query executed successfully\n\nRows returned: {actual_rows}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
 
 
 
 
 
101
 
102
  limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
103
+ if limit_match:
104
+ requested_limit = int(limit_match.group(1))
105
+ if actual_rows < requested_limit:
106
+ explanation += f"\n\nℹ️ Query allowed up to {requested_limit} rows but only {actual_rows} matched."
107
+
108
+ return final_sql, df, explanation
109
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def toggle_input_method(method, current_sample):
112
  if method == "💡 Pick a Sample":
113
+ # Find the DB matching the current sample (fallback to 'chinook_1')
114
  db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
115
+ return (
116
+ gr.update(visible=True), # Show sample_dropdown
117
+ gr.update(visible=False), # Hide type_own_warning
118
+ gr.update(visible=False), # Hide custom_question
119
+ gr.update(value=db, interactive=False) # Lock and reset db_id
120
+ )
121
+ else:
122
+ return (
123
+ gr.update(visible=False), # Hide sample_dropdown
124
+ gr.update(visible=True), # Show type_own_warning
125
+ gr.update(visible=True), # Show custom_question
126
+ gr.update(interactive=True) # Unlock db_id
127
+ )
128
+
129
 
130
  def load_sample(selected_question):
131
+ if not selected_question:
132
+ return gr.update()
133
+ db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1")
134
+ return gr.update(value=db)
135
+
136
 
137
  def clear_inputs():
138
+ return (
139
+ gr.update(value="💡 Pick a Sample"),
140
+ gr.update(value=SAMPLE_QUESTIONS[0], visible=True), # sample_dropdown
141
+ gr.update(visible=False), # type_own_warning
142
+ gr.update(value="", visible=False), # custom_question
143
+ gr.update(value="chinook_1", interactive=False), # db_id
144
+ "", pd.DataFrame(), "" # Outputs (SQL, Table, Explanation)
145
+ )
146
 
147
  def update_schema(db_id):
148
+ if not db_id:
149
+ return ""
150
  try:
151
+ raw_schema = engine.get_schema(db_id)
152
  html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
153
  for line in raw_schema.strip().split('\n'):
154
  line = line.strip()
155
  if not line: continue
156
  match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
157
+ if match:
158
+ table_name = match.group(1).upper()
159
+ columns = match.group(2).lower()
160
+ html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{table_name}</strong> <span style='color: #64748b;'>( {columns} )</span></div>"
161
+ else:
162
+ html_output += f"<div style='color: #475569;'>{line}</div>"
163
  html_output += "</div>"
164
  return html_output
165
+ except Exception as e:
166
+ return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
167
+
168
 
169
  # =========================
170
  # UI LAYOUT
171
  # =========================
172
+ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL RLHF") as demo:
173
+
174
+ gr.HTML(
175
+ """
176
  <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
177
  <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
178
  <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
179
  </div>
180
+ """
181
+ )
182
 
183
+ DBS = sorted([
184
+ "flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1",
185
+ "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1",
186
+ "college_1", "college_2", "company_1", "company_employee",
187
+ "customer_complaints", "department_store", "employee_hire_evaluation",
188
+ "museum_visit", "products_for_hire", "restaurant_1",
189
+ "school_finance", "shop_membership", "small_bank_1",
190
+ "soccer_1", "student_1", "tvshow", "voter_1", "world_1"
191
+ ])
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ gr.Markdown("### 1. Configuration & Input")
196
+
197
+ input_method = gr.Radio(
198
+ choices=["💡 Pick a Sample", "✍️ Type my own"],
199
+ value="💡 Pick a Sample",
200
+ label="How do you want to ask?"
201
+ )
202
+
203
+ # --- SAMPLE SECTION ---
204
+ sample_dropdown = gr.Dropdown(
205
+ choices=SAMPLE_QUESTIONS,
206
+ value=SAMPLE_QUESTIONS[0],
207
+ label="Select a Sample Question",
208
+ info="The database will be selected automatically.",
209
+ visible=True
210
+ )
211
+
212
+ # --- CUSTOM TYPE WARNING ---
213
+ type_own_warning = gr.Markdown(
214
+ "**⚠️ Please select a Database first, then type your custom question below:**",
215
+ visible=False
216
+ )
217
+
218
+ gr.Markdown("---")
219
+
220
+ # --- DATABASE SELECTION (Moved Up) ---
221
+ db_id = gr.Dropdown(
222
+ choices=DBS,
223
+ value="chinook_1",
224
+ label="Select Database",
225
+ interactive=False
226
+ )
227
+
228
+ # --- CUSTOM QUESTION BOX ---
229
+ custom_question = gr.Textbox(
230
+ label="Ask your Custom Question",
231
+ placeholder="Type your own question here...",
232
+ lines=3,
233
+ visible=False
234
+ )
235
+
236
+ gr.Markdown("#### 📋 Database Structure")
237
+ gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
238
+ schema_display = gr.HTML(value=update_schema("chinook_1"))
239
 
 
 
240
  with gr.Row():
241
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
242
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
243
+
244
+ with gr.Column(scale=2):
245
+ gr.Markdown("### 2. Execution Results")
246
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
247
+ result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
248
+ explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
249
+
250
+ # =========================
251
+ # EVENT LISTENERS
252
+ # =========================
253
+
254
+ # Updated to handle the new Markdown warning toggle
255
+ input_method.change(
256
+ fn=toggle_input_method,
257
+ inputs=[input_method, sample_dropdown],
258
+ outputs=[sample_dropdown, type_own_warning, custom_question, db_id]
259
+ )
260
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
262
+
263
  db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
264
 
265
  run_btn.click(
266
+ fn=run_query,
267
+ inputs=[input_method, sample_dropdown, custom_question, db_id],
268
  outputs=[final_sql, result_table, explanation]
 
 
 
 
269
  )
270
 
271
+ clear_btn.click(
272
+ fn=clear_inputs,
273
+ inputs=[],
274
+ # Output list matches the updated clear_inputs() return values
275
+ outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation]
276
+ )
277
 
278
  if __name__ == "__main__":
279
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
best_rlhf_model/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ ---
4
+ ## Training procedure
5
+
6
+ ### Framework versions
7
+
8
+
9
+ - PEFT 0.4.0
best_rlhf_model/adapter_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": null,
3
+ "base_model_name_or_path": "Salesforce/codet5-base",
4
+ "bias": "none",
5
+ "fan_in_fan_out": false,
6
+ "inference_mode": true,
7
+ "init_lora_weights": true,
8
+ "layers_pattern": null,
9
+ "layers_to_transform": null,
10
+ "lora_alpha": 32,
11
+ "lora_dropout": 0.05,
12
+ "modules_to_save": null,
13
+ "peft_type": "LORA",
14
+ "r": 16,
15
+ "revision": null,
16
+ "target_modules": [
17
+ "q",
18
+ "v"
19
+ ],
20
+ "task_type": "SEQ_2_SEQ_LM"
21
+ }
{int8_dynamic/tokenizer → best_rlhf_model}/merges.txt RENAMED
File without changes
{int8_dynamic/tokenizer → best_rlhf_model}/special_tokens_map.json RENAMED
File without changes
{int8_dynamic/tokenizer → best_rlhf_model}/tokenizer_config.json RENAMED
@@ -954,6 +954,5 @@
954
  "pad_token": "<pad>",
955
  "sep_token": "</s>",
956
  "tokenizer_class": "RobertaTokenizer",
957
- "trim_offsets": true,
958
  "unk_token": "<unk>"
959
  }
 
954
  "pad_token": "<pad>",
955
  "sep_token": "</s>",
956
  "tokenizer_class": "RobertaTokenizer",
 
957
  "unk_token": "<unk>"
958
  }
best_rlhf_model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
int8_dynamic/meta.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "mode": "int8_dynamic",
3
- "base_model": "Salesforce/codet5-base",
4
- "adapter_path": "checkpoints/best_rlhf_model_2",
5
- "created_at_s": 1774418718.320342,
6
- "estimated_model_bytes": 98804736
7
- }
 
 
 
 
 
 
 
 
int8_dynamic/model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f398e044cd49fc84553b746d26ad79beb1dd565d90cf8f6f5e50d27f48d08228
3
- size 322871519
 
 
 
 
int8_dynamic/tokenizer/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
int8_dynamic/tokenizer/vocab.json DELETED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
- gradio==5.8.0
2
  pandas
3
  sqlparse
4
  transformers
5
- torch --index-url https://download.pytorch.org/whl/cpu
6
  peft
7
  trl
8
- sentencepiece
9
- matplotlib
10
- huggingface_hub
 
1
+ gradio
2
  pandas
3
  sqlparse
4
  transformers
5
+ torch
6
  peft
7
  trl
8
+ sentencepiece
 
 
scripts/benchmark_parallel_reward.py DELETED
@@ -1,202 +0,0 @@
1
- import os
2
- # Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
3
- os.environ.setdefault("MPLBACKEND", "Agg")
4
- os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
5
- import time
6
- import json
7
- import argparse
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- import sys
11
- from pathlib import Path
12
-
13
- # ==========================================
14
- # RELATIVE PATH RESOLUTION
15
- # ==========================================
16
- PROJECT_ROOT = Path(__file__).resolve().parent.parent
17
- sys.path.append(str(PROJECT_ROOT))
18
-
19
- # Dynamically resolve where the databases are kept
20
- if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
21
- DB_ROOT = PROJECT_ROOT / "data" / "database"
22
- else:
23
- DB_ROOT = PROJECT_ROOT / "final_databases"
24
-
25
- from src.execution_reward import (
26
- execution_reward_batch_sequential,
27
- execution_reward_batch_parallel,
28
- execution_reward_batch_parallel_by_db,
29
- execution_reward_timed,
30
- set_use_cache,
31
- set_use_schema_validation,
32
- clear_result_cache
33
- )
34
-
35
- def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
36
- """Generates heavy queries across multiple databases to properly test true concurrency."""
37
- print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
38
-
39
- # Smart search for real databases
40
- real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
41
-
42
- if real_dbs:
43
- print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
44
- else:
45
- print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
46
- sys.exit(1)
47
-
48
- rollouts = []
49
- for i in range(num_rollouts):
50
- db_path = real_dbs[i % len(real_dbs)]
51
-
52
- # Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
53
- heavy_sql = f"""
54
- WITH RECURSIVE cnt(x) AS (
55
- SELECT 1
56
- UNION ALL
57
- SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
58
- )
59
- SELECT sum(x) FROM cnt;
60
- """
61
- clean_sql = heavy_sql.replace("\n", " ").strip()
62
- rollouts.append((clean_sql, db_path, clean_sql))
63
- if num_rollouts >= 500 and (i + 1) % 250 == 0:
64
- print(f" generated {i + 1}/{num_rollouts}...", flush=True)
65
-
66
- return rollouts
67
-
68
- def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
69
- """Profiles CPU usage to identify time spent in parsing, planning, and execution."""
70
- print("\n" + "="*65)
71
- print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
72
- print("="*65)
73
-
74
- clear_result_cache()
75
- set_use_cache(False) # Disable cache to force real work
76
- set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
77
-
78
- total_parse = 0.0
79
- total_plan = 0.0
80
- total_exec = 0.0
81
-
82
- # Profile a small subset by default so the script prints quickly.
83
- sample_size = min(int(sample_size), len(rollouts))
84
- sample_rollouts = rollouts[:sample_size]
85
-
86
- for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
87
- _, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
88
- total_parse += timings['parse_s']
89
- total_plan += timings['plan_s']
90
- total_exec += timings['exec_s']
91
- if print_every and (i % int(print_every) == 0 or i == sample_size):
92
- print(f" profiled {i}/{sample_size}...", flush=True)
93
-
94
- total_time = total_parse + total_plan + total_exec
95
- if total_time == 0: total_time = 0.0001 # Prevent div by zero
96
-
97
- print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
98
- print("-" * 65)
99
- print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
100
- print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
101
- print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
102
- print("="*65 + "\n")
103
-
104
- def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
105
- set_use_cache(use_cache)
106
- set_use_schema_validation(False) # benchmark focuses on execution speed
107
-
108
- # Sequential
109
- clear_result_cache()
110
- start_time = time.perf_counter()
111
- execution_reward_batch_sequential(rollouts)
112
- sequential_s = time.perf_counter() - start_time
113
-
114
- # Parallel
115
- clear_result_cache()
116
- start_time = time.perf_counter()
117
- # 1 thread per DB (recommended)
118
- execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
119
- parallel_s = time.perf_counter() - start_time
120
-
121
- speedup = sequential_s / parallel_s if parallel_s > 0 else 0
122
-
123
- return {
124
- "sequential_s": sequential_s,
125
- "parallel_s": parallel_s,
126
- "speedup": speedup
127
- }
128
-
129
- def print_comparison_table(results):
130
- print("="*65)
131
- print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
132
- print("-" * 65)
133
- for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
134
- seq = results[key]['sequential_s']
135
- par = results[key]['parallel_s']
136
- spd = results[key]['speedup']
137
- print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
138
- print("="*65 + "\n")
139
-
140
- def plot_results(results, output_path: str):
141
- labels = ['With Cache', 'Without Cache']
142
- seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
143
- par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
144
-
145
- x = np.arange(len(labels))
146
- width = 0.35
147
-
148
- fig, ax = plt.subplots(figsize=(8, 6))
149
- ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
150
- ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
151
-
152
- ax.set_ylabel('Execution Time (seconds)')
153
- ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
154
- ax.set_xticks(x)
155
- ax.set_xticklabels(labels)
156
- ax.legend()
157
-
158
- for container in ax.containers:
159
- ax.bar_label(container, fmt='%.2f', padding=3)
160
-
161
- fig.tight_layout()
162
- plt.savefig(output_path, dpi=300)
163
- plt.close()
164
-
165
- def main():
166
- parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
167
- parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
168
- parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
169
- parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
170
- parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
171
- parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
172
- args = parser.parse_args()
173
-
174
- os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
175
-
176
- rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
177
-
178
- if not args.skip_profile:
179
- profile_bottlenecks(rollouts, sample_size=args.profile_n)
180
-
181
- print("Starting Main Scalability Benchmarks...")
182
-
183
- print("Running Experiment A: Cache ENABLED...")
184
- results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
185
-
186
- print("Running Experiment B: Cache DISABLED...")
187
- results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
188
-
189
- final_results = {
190
- "with_cache": results_with_cache,
191
- "without_cache": results_without_cache
192
- }
193
-
194
- json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
195
- with open(json_path, 'w') as f:
196
- json.dump(final_results, f, indent=4)
197
-
198
- print_comparison_table(final_results)
199
- plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
200
-
201
- if __name__ == "__main__":
202
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/benchmark_quantization.py DELETED
@@ -1,108 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import os
6
- import time
7
- from pathlib import Path
8
- from typing import Dict, List, Tuple
9
-
10
- import numpy as np
11
- import torch
12
-
13
- from src.execution_reward import execution_reward
14
- from src.prompting import encode_prompt
15
- from src.quantization_utils import load_fp32_model, load_quant_artifact
16
-
17
-
18
- def _load_dev_items(root: Path, n: int, seed: int = 42) -> List[dict]:
19
- data = json.loads((root / "data" / "dev.json").read_text())
20
- if n >= len(data):
21
- return data
22
- rng = np.random.default_rng(seed)
23
- idxs = rng.choice(len(data), size=n, replace=False)
24
- return [data[int(i)] for i in idxs]
25
-
26
-
27
- def _bench_variant(name: str, tok, model, items: List[dict], device: str) -> Dict[str, float]:
28
- latencies: List[float] = []
29
- ex = 0
30
-
31
- # Warmup (1 item)
32
- if items:
33
- it = items[0]
34
- _ = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
35
-
36
- for it in items:
37
- db_id = it["db_id"]
38
- q = it["question"]
39
- gold = it["query"]
40
- db_path = str(Path("data") / "database" / db_id / f"{db_id}.sqlite")
41
-
42
- input_ids = encode_prompt(tok, q, db_id, device=device, max_input_tokens=512).unsqueeze(0)
43
- t0 = time.perf_counter()
44
- out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8, repetition_penalty=1.2)
45
- dt = time.perf_counter() - t0
46
- latencies.append(dt)
47
-
48
- pred = tok.decode(out[0], skip_special_tokens=True).strip()
49
- r = execution_reward(pred, db_path, gold)
50
- if float(r) >= 1.0:
51
- ex += 1
52
-
53
- p50 = float(np.percentile(latencies, 50)) if latencies else 0.0
54
- p90 = float(np.percentile(latencies, 90)) if latencies else 0.0
55
- mean = float(np.mean(latencies)) if latencies else 0.0
56
- return {
57
- "n": float(len(items)),
58
- "ex": float(ex / max(len(items), 1)),
59
- "lat_mean_s": mean,
60
- "lat_p50_s": p50,
61
- "lat_p90_s": p90,
62
- }
63
-
64
-
65
- def main() -> None:
66
- p = argparse.ArgumentParser(description="Benchmark fp32 vs quantized artifacts (CPU-focused).")
67
- p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
68
- p.add_argument("--adapter", default="", help="Optional adapter for fp32 baseline.")
69
- p.add_argument("--artifact_int8", default="", help="Artifact dir exported by scripts/quantize_export.py")
70
- p.add_argument("--artifact_int8_decoder", default="", help="Artifact dir for decoder-only int8")
71
- p.add_argument("--num_samples", type=int, default=100)
72
- p.add_argument("--seed", type=int, default=42)
73
- p.add_argument("--out", default="results/task5_quant_bench.json")
74
- p.add_argument("--local_only", action="store_true")
75
- args = p.parse_args()
76
-
77
- device = "cpu"
78
- root = Path(".")
79
- items = _load_dev_items(root, args.num_samples, args.seed)
80
-
81
- report: Dict[str, Dict[str, float]] = {}
82
-
83
- tok, fp32 = load_fp32_model(
84
- args.base_model,
85
- adapter_path=args.adapter.strip() or None,
86
- device=device,
87
- local_only=args.local_only,
88
- )
89
- report["fp32"] = _bench_variant("fp32", tok, fp32, items, device)
90
-
91
- if args.artifact_int8:
92
- tok8, m8, _meta = load_quant_artifact(args.artifact_int8, device=device, local_only=True)
93
- report["int8_dynamic"] = _bench_variant("int8_dynamic", tok8, m8, items, device)
94
-
95
- if args.artifact_int8_decoder:
96
- tokd, md, _meta = load_quant_artifact(args.artifact_int8_decoder, device=device, local_only=True)
97
- report["int8_decoder_dynamic"] = _bench_variant("int8_decoder_dynamic", tokd, md, items, device)
98
-
99
- out_path = Path(args.out)
100
- out_path.parent.mkdir(parents=True, exist_ok=True)
101
- out_path.write_text(json.dumps(report, indent=2))
102
- print(json.dumps(report, indent=2))
103
-
104
-
105
- if __name__ == "__main__":
106
- torch.set_grad_enabled(False)
107
- main()
108
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/benchmark_rollout_generation.py DELETED
@@ -1,66 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import os
6
- import time
7
- from pathlib import Path
8
- from typing import List
9
-
10
- import numpy as np
11
- import torch
12
-
13
- from src.prompting import encode_prompt
14
- from src.quantization_utils import load_fp32_model, load_quant_artifact
15
-
16
-
17
- def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
18
- data = json.loads((root / "data" / "dev.json").read_text())
19
- if n >= len(data):
20
- return data
21
- rng = np.random.default_rng(seed)
22
- idxs = rng.choice(len(data), size=n, replace=False)
23
- return [data[int(i)] for i in idxs]
24
-
25
-
26
- def _bench_generate(tok, model, items: List[dict], device: str) -> float:
27
- t0 = time.perf_counter()
28
- for it in items:
29
- input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
30
- _ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
31
- return time.perf_counter() - t0
32
-
33
-
34
- def main() -> None:
35
- p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
36
- p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
37
- p.add_argument("--adapter", default="")
38
- p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
39
- p.add_argument("--num_rollouts", type=int, default=128)
40
- p.add_argument("--seed", type=int, default=42)
41
- p.add_argument("--local_only", action="store_true")
42
- args = p.parse_args()
43
-
44
- device = "cpu"
45
- root = Path(".")
46
- items = _load_items(root, args.num_rollouts, args.seed)
47
-
48
- tok, fp32 = load_fp32_model(
49
- args.base_model,
50
- adapter_path=args.adapter.strip() or None,
51
- device=device,
52
- local_only=args.local_only,
53
- )
54
- t_fp32 = _bench_generate(tok, fp32, items, device)
55
- print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")
56
-
57
- if args.artifact:
58
- tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
59
- t_q = _bench_generate(tokq, mq, items, device)
60
- mode = meta.get("mode", "quant")
61
- print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")
62
-
63
-
64
- if __name__ == "__main__":
65
- torch.set_grad_enabled(False)
66
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/error_dashboard.py DELETED
@@ -1,99 +0,0 @@
1
-
2
- import json
3
- from collections import Counter
4
-
5
- # ==============================
6
- # LOAD LOGS
7
- # ==============================
8
- with open("results/error_logs.json") as f:
9
- logs = json.load(f)
10
-
11
- total_errors = len(logs)
12
-
13
- # ==============================
14
- # ERROR DISTRIBUTION
15
- # ==============================
16
- error_counts = Counter([e["error_type"] for e in logs])
17
-
18
- print("\n" + "="*50)
19
- print("📊 TEXT-to-SQL ERROR DASHBOARD")
20
- print("="*50)
21
-
22
- print(f"\n🔢 Total Errors Logged: {total_errors}")
23
-
24
- print("\n📊 ERROR DISTRIBUTION:")
25
- print("-"*30)
26
- for k, v in error_counts.items():
27
- percent = (v / total_errors) * 100
28
- print(f"{k:<20} : {v:>4} ({percent:.1f}%)")
29
-
30
- # ==============================
31
- # TOP ERROR
32
- # ==============================
33
- top_error = error_counts.most_common(1)[0]
34
-
35
- print("\n🔥 MOST COMMON ERROR:")
36
- print("-"*30)
37
- print(f"{top_error[0]} ({top_error[1]} times)")
38
-
39
- # ==============================
40
- # SQL OPERATION ANALYSIS
41
- # ==============================
42
- join_count = 0
43
- where_count = 0
44
- group_count = 0
45
- order_count = 0
46
-
47
- for e in logs:
48
- sql = e["sql"].lower()
49
-
50
- if "join" in sql:
51
- join_count += 1
52
- if "where" in sql:
53
- where_count += 1
54
- if "group by" in sql:
55
- group_count += 1
56
- if "order by" in sql:
57
- order_count += 1
58
-
59
- print("\n🧠 SQL OPERATION ANALYSIS:")
60
- print("-"*30)
61
- print(f"JOIN used in : {join_count} queries")
62
- print(f"WHERE used in : {where_count} queries")
63
- print(f"GROUP BY used in : {group_count} queries")
64
- print(f"ORDER BY used in : {order_count} queries")
65
-
66
- # ==============================
67
- # SAMPLE ERRORS
68
- # ==============================
69
- print("\n🧪 SAMPLE ERROR CASES:")
70
- print("-"*50)
71
-
72
- for i, e in enumerate(logs[:3], 1):
73
- print(f"\nCase {i}:")
74
- print(f"Q : {e['question']}")
75
- print(f"SQL : {e['sql']}")
76
- print(f"Type: {e['error_type']}")
77
-
78
- # ==============================
79
- # FINAL INSIGHT
80
- # ==============================
81
- print("\n📌 FINAL INSIGHT:")
82
- print("-"*30)
83
-
84
- if top_error[0] == "wrong_column":
85
- print("⚠️ Model struggles with column selection (schema understanding issue).")
86
-
87
- elif top_error[0] == "wrong_table":
88
- print("⚠️ Model struggles with correct table mapping.")
89
-
90
- elif top_error[0] == "syntax_error":
91
- print("⚠️ Model generates invalid SQL syntax.")
92
-
93
- else:
94
- print("⚠️ Mixed errors — needs general improvement.")
95
-
96
- print("\n" + "="*50)
97
- print("✅ DASHBOARD COMPLETE")
98
- print("="*50)
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/evaluate.py DELETED
@@ -1,170 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import sqlite3
5
- from contextlib import closing
6
- from typing import Dict, List
7
-
8
- import torch
9
- from datasets import load_dataset
10
- from peft import PeftModel
11
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
- from trl import AutoModelForSeq2SeqLMWithValueHead
13
-
14
- import sys
15
-
16
- PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
- sys.path.append(PROJECT_ROOT)
18
- from src.execution_reward import execution_reward # noqa: E402
19
-
20
-
21
- BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
22
- DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
23
-
24
- # Prefer RL best model if present; otherwise fall back.
25
- RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
26
- if not os.path.isdir(RL_DIR):
27
- RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
28
-
29
- SPLIT = "train[:100]" # quick sanity check
30
- MAX_NEW_TOKENS = 128
31
-
32
- PREFIX = "translate English to SQL:"
33
- MAX_SCHEMA_CHARS = 1500
34
- MAX_INPUT_TOKENS = 512
35
-
36
-
37
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
38
- device = "mps" if torch.backends.mps.is_available() else "cpu"
39
- print("Using device:", device)
40
-
41
-
42
- def get_db_path(db_id: str) -> str:
43
- return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
44
-
45
-
46
- _SCHEMA_CACHE: Dict[str, str] = {}
47
-
48
-
49
- def get_db_schema_text(db_path: str) -> str:
50
- if db_path in _SCHEMA_CACHE:
51
- return _SCHEMA_CACHE[db_path]
52
- schema_text = ""
53
- try:
54
- with closing(sqlite3.connect(db_path)) as conn:
55
- cur = conn.cursor()
56
- tables = cur.execute(
57
- "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
58
- ).fetchall()
59
- for (tname,) in tables:
60
- cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
61
- col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
62
- schema_text += f"{tname}({', '.join(col_names)}) "
63
- except Exception:
64
- schema_text = ""
65
- if len(schema_text) > MAX_SCHEMA_CHARS:
66
- schema_text = schema_text[:MAX_SCHEMA_CHARS]
67
- _SCHEMA_CACHE[db_path] = schema_text
68
- return schema_text
69
-
70
-
71
- def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
72
- schema = (schema or "")[:MAX_SCHEMA_CHARS]
73
- prefix_schema = f"{PREFIX}\n\nSchema:\n"
74
- mid = "\n\nQuestion:\n"
75
- suffix = f"{question}\n\nSQL:"
76
-
77
- prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
78
- schema_ids = tokenizer.encode(schema, add_special_tokens=False)
79
- mid_ids = tokenizer.encode(mid, add_special_tokens=False)
80
- suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
81
-
82
- eos_id = tokenizer.eos_token_id
83
- max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
84
-
85
- fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
86
- if fixed_len > max_without_eos:
87
- keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
88
- suffix_ids = suffix_ids[:keep]
89
- fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
90
-
91
- remaining_for_schema = max_without_eos - fixed_len
92
- if remaining_for_schema < 0:
93
- remaining_for_schema = 0
94
- schema_ids = schema_ids[:remaining_for_schema]
95
-
96
- ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
97
- if eos_id is not None:
98
- ids = ids + [eos_id]
99
-
100
- return torch.tensor(ids, dtype=torch.long).to(device)
101
-
102
-
103
- def load_model_and_tokenizer():
104
- # Try loading the PPO-saved value-head model directly.
105
- try:
106
- tok = AutoTokenizer.from_pretrained(RL_DIR)
107
- mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
108
- return tok, mdl
109
- except Exception:
110
- pass
111
-
112
- # Fallback: treat RL_DIR as a LoRA adapter directory.
113
- tok = AutoTokenizer.from_pretrained(BASE_MODEL)
114
- if tok.pad_token_id is None:
115
- tok.pad_token = tok.eos_token
116
- base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
117
- try:
118
- base = PeftModel.from_pretrained(base, RL_DIR)
119
- except Exception:
120
- # Final fallback: use SFT adapter (if RL adapter not found)
121
- sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
122
- base = PeftModel.from_pretrained(base, sft_dir)
123
- return tok, base
124
-
125
-
126
- def main() -> None:
127
- tokenizer, model = load_model_and_tokenizer()
128
- model.eval()
129
-
130
- ds = load_dataset("spider", split=SPLIT)
131
-
132
- correct = 0
133
- valid = 0
134
-
135
- for i, ex in enumerate(ds, start=1):
136
- question = ex["question"]
137
- gold_sql = ex["query"]
138
- db_id = ex["db_id"]
139
- db_path = get_db_path(db_id)
140
- schema = get_db_schema_text(db_path)
141
-
142
- inp = encode_prompt(tokenizer, question, schema)
143
- with torch.no_grad():
144
- out = model.generate(
145
- input_ids=inp.unsqueeze(0),
146
- max_new_tokens=MAX_NEW_TOKENS,
147
- do_sample=False,
148
- num_beams=1,
149
- pad_token_id=tokenizer.pad_token_id,
150
- eos_token_id=tokenizer.eos_token_id,
151
- )
152
- pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
153
- r = execution_reward(pred_sql, db_path, gold_sql)
154
- if r > -1.0:
155
- valid += 1
156
- if r >= 1.0:
157
- correct += 1
158
-
159
- if i % 25 == 0:
160
- print(f"Evaluated {i}/{len(ds)}")
161
-
162
- n = len(ds)
163
- print("\nRESULTS")
164
- print(f"examples: {n}")
165
- print(f"execution_accuracy: {correct/n:.3f}")
166
- print(f"valid_sql_rate: {valid/n:.3f}")
167
-
168
-
169
- if __name__ == "__main__":
170
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/plot_task2.py DELETED
@@ -1,58 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import seaborn as sns
3
-
4
- # ==========================================
5
- # 1. EXTRACTED DATA FROM TERMINAL
6
- # ==========================================
7
- # Error Distribution Data
8
- error_types = ['wrong_column', 'wrong_table', 'ambiguous_column', 'other']
9
- error_counts = [61, 11, 4, 1]
10
-
11
- # SQL Operation Analysis Data
12
- sql_ops = ['WHERE', 'JOIN', 'ORDER BY', 'GROUP BY']
13
- op_counts = [55, 36, 20, 14]
14
-
15
- # ==========================================
16
- # 2. SET UP THE DASHBOARD LAYOUT
17
- # ==========================================
18
- # Use a clean, modern aesthetic
19
- sns.set_theme(style="whitegrid")
20
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
21
-
22
- # ==========================================
23
- # 3. PLOT 1: ERROR DISTRIBUTION (Horizontal Bar)
24
- # ==========================================
25
- sns.barplot(x=error_counts, y=error_types, ax=ax1, palette="flare")
26
- ax1.set_title('Primary Cause of Failure (Total: 77 Errors)', fontsize=14, pad=15, fontweight='bold')
27
- ax1.set_xlabel('Number of Queries')
28
- ax1.set_ylabel('')
29
-
30
- # Add actual numbers next to the bars
31
- for i, v in enumerate(error_counts):
32
- ax1.text(v + 1.5, i, f"{v}", color='#333333', va='center', fontweight='bold')
33
-
34
- # ==========================================
35
- # 4. PLOT 2: SQL OPERATIONS (Vertical Bar)
36
- # ==========================================
37
- sns.barplot(x=sql_ops, y=op_counts, ax=ax2, palette="crest")
38
- ax2.set_title('Clauses Present in Failed Queries', fontsize=14, pad=15, fontweight='bold')
39
- ax2.set_ylabel('Frequency')
40
- ax2.set_xlabel('')
41
-
42
- # Add actual numbers on top of the bars
43
- for i, v in enumerate(op_counts):
44
- ax2.text(i, v + 1, str(v), color='#333333', ha='center', fontweight='bold')
45
-
46
- # ==========================================
47
- # 5. RENDER AND SAVE
48
- # ==========================================
49
- plt.suptitle('Text-to-SQL Error Diagnostic Dashboard', fontsize=18, fontweight='heavy', y=1.05)
50
- sns.despine(left=True, bottom=True) # Removes clunky borders
51
- plt.tight_layout()
52
-
53
- # Save the plot as a high-res image for your report!
54
- plt.savefig('error_diagnostic_plot.png', dpi=300, bbox_inches='tight')
55
- print("✅ Plot successfully saved as 'error_diagnostic_plot.png'")
56
-
57
- # Display the plot
58
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/plot_task3.py DELETED
@@ -1,15 +0,0 @@
1
- import matplotlib.pyplot as plt
2
-
3
- labels = ["Without", "With"]
4
- constraint = [0, 88]
5
-
6
- plt.figure()
7
- plt.bar(labels, constraint)
8
-
9
- plt.title("Constraint Satisfaction (Task 3)")
10
- plt.ylabel("Percentage")
11
-
12
- plt.savefig("task3_constraint.png")
13
- plt.show()
14
-
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/plot_task3_plotly.py DELETED
@@ -1,103 +0,0 @@
1
- import plotly.graph_objects as go
2
- from plotly.subplots import make_subplots
3
-
4
- # ==========================================
5
- # 1. YOUR DATA
6
- # ==========================================
7
- models = ['FP32 (Base)', 'INT8 Dynamic', 'INT8 Decoder-Only']
8
-
9
- # Accuracy (multiplied by 100 for percentage)
10
- accuracy = [36.0, 36.0, 38.0]
11
-
12
- # Latency metrics
13
- lat_mean = [3.11, 1.65, 1.66]
14
- lat_p50 = [2.94, 1.54, 1.56]
15
- lat_p90 = [4.64, 2.44, 2.48]
16
-
17
- # ==========================================
18
- # 2. SET UP THE SIDE-BY-SIDE LAYOUT
19
- # ==========================================
20
- fig = make_subplots(
21
- rows=1, cols=2,
22
- subplot_titles=(
23
- "<b>Model Accuracy (Execution)</b>",
24
- "<b>Inference Latency Profile</b>"
25
- ),
26
- horizontal_spacing=0.1
27
- )
28
-
29
- # ==========================================
30
- # 3. LEFT CHART: ACCURACY
31
- # ==========================================
32
- fig.add_trace(go.Bar(
33
- x=models,
34
- y=accuracy,
35
- name="Execution Accuracy",
36
- marker_color=['#94a3b8', '#38bdf8', '#10b981'], # Gray, Blue, Green
37
- text=[f"{val:.1f}%" for val in accuracy],
38
- textposition='auto',
39
- textfont=dict(size=14, color='white', family="Arial Black"),
40
- showlegend=False
41
- ), row=1, col=1)
42
-
43
- # ==========================================
44
- # 4. RIGHT CHART: LATENCY PROFILE
45
- # ==========================================
46
- # P50 Latency
47
- fig.add_trace(go.Bar(
48
- x=models, y=lat_p50,
49
- name="Median (P50)",
50
- marker_color="#ece80a" # Light Blue
51
- ), row=1, col=2)
52
-
53
- # Mean Latency
54
- fig.add_trace(go.Bar(
55
- x=models, y=lat_mean,
56
- name="Mean Latency",
57
- marker_color="#3b4da9" # Standard Blue
58
- ), row=1, col=2)
59
-
60
- # P90 Latency
61
- fig.add_trace(go.Bar(
62
- x=models, y=lat_p90,
63
- name="90th Percentile (P90)",
64
- marker_color="#d974e2" # Dark Blue
65
- ), row=1, col=2)
66
-
67
- # ==========================================
68
- # 5. APPLY ULTRA-MODERN STYLING
69
- # ==========================================
70
- fig.update_layout(
71
- title=dict(
72
- text="<b>Task 5: FP32 vs. INT8 Quantization Performance</b>",
73
- font=dict(size=22, color='#1e293b'),
74
- x=0.5
75
- ),
76
- plot_bgcolor='white',
77
- paper_bgcolor='white',
78
- barmode='group',
79
- legend=dict(
80
- orientation="h",
81
- yanchor="bottom", y=1.05,
82
- xanchor="center", x=0.8,
83
- bgcolor='rgba(255,255,255,0.8)'
84
- ),
85
- font=dict(family="-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif"),
86
- margin=dict(t=120, b=60, l=60, r=40)
87
- )
88
-
89
- # Style Left Axes
90
- fig.update_yaxes(title_text="<b>Accuracy (%)</b>", range=[0, 45], gridcolor='#f1f5f9', row=1, col=1)
91
- fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=1)
92
-
93
- # Style Right Axes
94
- fig.update_yaxes(title_text="<b>Seconds per Query</b>", gridcolor='#f1f5f9', row=1, col=2)
95
- fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=2)
96
-
97
- # ==========================================
98
- # 6. RENDER AND SAVE
99
- # ==========================================
100
- html_file = "task5_quantization_dashboard.html"
101
- fig.write_html(html_file)
102
- print(f"✅ Interactive Plotly Dashboard saved to: {html_file}")
103
- fig.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/quantize_export.py DELETED
@@ -1,86 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import os
5
- from pathlib import Path
6
-
7
- import torch
8
-
9
- from src.quantization_utils import (
10
- load_bnb_quantized_model,
11
- load_fp32_model,
12
- quantize_dynamic_int8,
13
- quantize_dynamic_int8_decoder_only,
14
- save_quant_artifact,
15
- )
16
-
17
-
18
- def main() -> None:
19
- p = argparse.ArgumentParser(description="Export quantized Seq2Seq model artifacts for CPU inference.")
20
- p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
21
- p.add_argument("--adapter", default="", help="Optional LoRA adapter directory.")
22
- p.add_argument("--out_dir", required=True, help="Output directory for artifact.")
23
- p.add_argument(
24
- "--mode",
25
- required=True,
26
- choices=["fp32", "int8_dynamic", "int8_decoder_dynamic", "int8_bnb", "int4_bnb"],
27
- )
28
- p.add_argument("--device", default="cpu", help="cpu|cuda (bnb requires cuda)")
29
- p.add_argument("--local_only", action="store_true", help="Do not hit network; use HF cache only.")
30
- args = p.parse_args()
31
-
32
- adapter = args.adapter.strip() or None
33
- out_dir = Path(args.out_dir)
34
-
35
- if args.mode == "fp32":
36
- tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device=args.device, local_only=args.local_only)
37
- save_quant_artifact(out_dir, mode="fp32", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
38
- return
39
-
40
- if args.mode == "int8_dynamic":
41
- tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
42
- model = quantize_dynamic_int8(model)
43
- save_quant_artifact(out_dir, mode="int8_dynamic", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
44
- return
45
-
46
- if args.mode == "int8_decoder_dynamic":
47
- tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
48
- model = quantize_dynamic_int8_decoder_only(model)
49
- save_quant_artifact(
50
- out_dir,
51
- mode="int8_decoder_dynamic",
52
- base_model=args.base_model,
53
- adapter_path=adapter,
54
- tokenizer=tok,
55
- model=model,
56
- )
57
- return
58
-
59
- if args.mode == "int8_bnb":
60
- tok, model = load_bnb_quantized_model(
61
- args.base_model,
62
- adapter_path=adapter,
63
- device=args.device,
64
- local_only=args.local_only,
65
- load_in_8bit=True,
66
- )
67
- # Note: saving bnb quantized weights in a portable way is non-trivial; we still save state_dict for reference.
68
- save_quant_artifact(out_dir, mode="int8_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
69
- return
70
-
71
- if args.mode == "int4_bnb":
72
- tok, model = load_bnb_quantized_model(
73
- args.base_model,
74
- adapter_path=adapter,
75
- device=args.device,
76
- local_only=args.local_only,
77
- load_in_4bit=True,
78
- )
79
- save_quant_artifact(out_dir, mode="int4_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
80
- return
81
-
82
-
83
- if __name__ == "__main__":
84
- torch.set_grad_enabled(False)
85
- main()
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/quantized_infer_harness.py DELETED
@@ -1,46 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import time
6
- from pathlib import Path
7
-
8
- from src.quantized_text2sql_engine import QuantizedText2SQLEngine
9
-
10
-
11
- def main() -> None:
12
- p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
13
- p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
14
- p.add_argument("--num_samples", type=int, default=128)
15
- p.add_argument("--out", default="results/task5_quant_infer.json")
16
- args = p.parse_args()
17
-
18
- root = Path(".")
19
- dev = json.loads((root / "data" / "dev.json").read_text())
20
- dev = dev[: args.num_samples]
21
-
22
- engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
23
- pairs = [(x["question"], x["db_id"]) for x in dev]
24
-
25
- t0 = time.perf_counter()
26
- results = engine.ask_batch_execute(pairs)
27
- dt = time.perf_counter() - t0
28
-
29
- out = {
30
- "n": len(results),
31
- "seconds": dt,
32
- "qps": len(results) / max(dt, 1e-9),
33
- "artifact": args.artifact,
34
- "meta": engine.meta,
35
- "results": results[:10], # sample
36
- }
37
-
38
- out_path = Path(args.out)
39
- out_path.parent.mkdir(parents=True, exist_ok=True)
40
- out_path.write_text(json.dumps(out, indent=2))
41
- print(json.dumps(out, indent=2))
42
-
43
-
44
- if __name__ == "__main__":
45
- main()
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/constrained_decoding.py DELETED
@@ -1,1058 +0,0 @@
1
- # from __future__ import annotations
2
-
3
- # import re
4
- # import threading
5
- # from dataclasses import dataclass
6
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
7
-
8
- # import torch
9
- # from transformers.generation.logits_process import LogitsProcessor
10
-
11
- # from schema_constraints import ConstraintGraph, build_constraint_graph
12
-
13
-
14
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
15
- # s = re.sub(r"\s+", " ", prefix_text.lower())
16
- # last_from = s.rfind(" from ")
17
- # last_join = s.rfind(" join ")
18
- # last_select = s.rfind(" select ")
19
- # last_where = s.rfind(" where ")
20
- # last_on = s.rfind(" on ")
21
- # last_group = s.rfind(" group by ")
22
- # last_order = s.rfind(" order by ")
23
- # last_having = s.rfind(" having ")
24
-
25
- # last_table_kw = max(last_from, last_join)
26
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
27
-
28
- # if last_table_kw < 0 and last_col_kw < 0:
29
- # return None
30
- # if last_table_kw > last_col_kw:
31
- # return "table"
32
- # if last_col_kw > last_table_kw:
33
- # return "column"
34
- # return None
35
-
36
-
37
- # class _TrieNode:
38
- # __slots__ = ("children", "terminal")
39
-
40
- # def __init__(self) -> None:
41
- # self.children: Dict[int, _TrieNode] = {}
42
- # self.terminal: bool = False
43
-
44
- # def insert(self, token_ids: Sequence[int]) -> None:
45
- # node: _TrieNode = self
46
- # for tid in token_ids:
47
- # tid_i = int(tid)
48
- # nxt = node.children.get(tid_i)
49
- # if nxt is None:
50
- # nxt = _TrieNode()
51
- # node.children[tid_i] = nxt
52
- # node = nxt
53
- # node.terminal = True
54
-
55
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
56
- # node: _TrieNode = self
57
- # for tid in prefix:
58
- # node = node.children.get(int(tid)) # type: ignore[assignment]
59
- # if node is None:
60
- # return None
61
- # return node
62
-
63
-
64
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
65
- # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
66
- # return tokenizer.encode(" " + name, add_special_tokens=False)
67
-
68
-
69
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
70
- # trie = _TrieNode()
71
- # for n in names:
72
- # if not n:
73
- # continue
74
- # try:
75
- # ids = _encode_identifier(tokenizer, n)
76
- # except Exception:
77
- # continue
78
- # if ids:
79
- # trie.insert(ids)
80
- # return trie
81
-
82
-
83
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
84
- # # Allow common delimiters so the model can end an identifier.
85
- # toks = [",", ")", "(", "\n", ".", ";"]
86
- # ids: Set[int] = set()
87
- # for t in toks:
88
- # try:
89
- # for tid in tokenizer.encode(t, add_special_tokens=False):
90
- # ids.add(int(tid))
91
- # except Exception:
92
- # continue
93
- # return torch.tensor(sorted(ids), dtype=torch.long)
94
-
95
-
96
- # @dataclass
97
- # class _PerDbTokenSets:
98
- # fp: str
99
- # table_trie: _TrieNode
100
- # column_trie: _TrieNode
101
- # allow_always: torch.Tensor
102
-
103
-
104
- # _DB_TOKENSET_LOCK = threading.Lock()
105
- # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
106
-
107
-
108
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
109
- # with _DB_TOKENSET_LOCK:
110
- # cached = _DB_TOKENSETS.get(graph.db_path)
111
- # if cached is not None and cached.fp == graph.fingerprint:
112
- # return cached
113
-
114
- # out = _PerDbTokenSets(
115
- # fp=graph.fingerprint,
116
- # table_trie=_build_trie(tokenizer, graph.tables),
117
- # column_trie=_build_trie(tokenizer, graph.all_columns),
118
- # allow_always=_allow_always_token_ids(tokenizer),
119
- # )
120
- # with _DB_TOKENSET_LOCK:
121
- # _DB_TOKENSETS[graph.db_path] = out
122
- # return out
123
-
124
-
125
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
126
- # """
127
- # Schema-aware constrained decoding per item in the generation batch.
128
- # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
129
- # """
130
-
131
- # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
132
- # self.tokenizer = tokenizer
133
- # self.db_paths = list(db_paths)
134
- # self.max_prefix_tokens = int(max_prefix_tokens)
135
-
136
- # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
137
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
138
-
139
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
140
- # if input_ids.dim() != 2 or scores.dim() != 2:
141
- # return scores
142
-
143
- # batch = input_ids.size(0)
144
- # if batch != len(self._graphs):
145
- # return scores
146
-
147
- # for i in range(batch):
148
- # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
149
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
150
- # expected = _infer_expected_identifier(prefix_text)
151
- # if expected is None:
152
- # continue
153
-
154
- # if expected == "table":
155
- # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
156
- # partial = m.group(1) if m else None
157
- # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
158
- # continue
159
- # trie = self._token_sets[i].table_trie
160
- # else:
161
- # m = re.search(
162
- # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
163
- # prefix_text,
164
- # flags=re.I,
165
- # )
166
- # partial = m.group(1) if m else None
167
- # if partial is None and not re.search(
168
- # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
169
- # ):
170
- # continue
171
- # trie = self._token_sets[i].column_trie
172
-
173
- # if not partial:
174
- # prefix_token_ids: List[int] = []
175
- # else:
176
- # try:
177
- # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
178
- # except Exception:
179
- # continue
180
-
181
- # node = trie.walk(prefix_token_ids)
182
- # if node is None or node.terminal:
183
- # continue
184
-
185
- # allowed_next = sorted(node.children.keys())
186
- # if not allowed_next:
187
- # continue
188
-
189
- # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
190
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
191
- # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
192
-
193
- # kept_scores = scores[i, keep].clone()
194
- # scores[i, :] = -float("inf")
195
- # scores[i, keep] = kept_scores
196
-
197
- # return scores
198
-
199
-
200
- # # Backwards-compatible names used elsewhere in the repo.
201
- # class SchemaConstraintGraph:
202
- # def __init__(self, db_path: str):
203
- # self._graph = build_constraint_graph(db_path)
204
- # self.tables = sorted(self._graph.tables)
205
- # self.columns = sorted(self._graph.all_columns)
206
-
207
-
208
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
209
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
210
- # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
211
-
212
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
213
- # return self._proc(input_ids, scores)
214
-
215
-
216
-
217
-
218
- # from __future__ import annotations
219
-
220
- # import re
221
- # import threading
222
- # from dataclasses import dataclass
223
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
224
-
225
- # import torch
226
- # from transformers.generation.logits_process import LogitsProcessor
227
-
228
- # from schema_constraints import ConstraintGraph, build_constraint_graph
229
-
230
-
231
- # # =========================================================
232
- # # 🔍 IDENTIFIER TYPE DETECTION
233
- # # =========================================================
234
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
235
- # s = re.sub(r"\s+", " ", prefix_text.lower())
236
-
237
- # last_from = s.rfind(" from ")
238
- # last_join = s.rfind(" join ")
239
- # last_select = s.rfind(" select ")
240
- # last_where = s.rfind(" where ")
241
- # last_on = s.rfind(" on ")
242
- # last_group = s.rfind(" group by ")
243
- # last_order = s.rfind(" order by ")
244
- # last_having = s.rfind(" having ")
245
-
246
- # last_table_kw = max(last_from, last_join)
247
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
248
-
249
- # if last_table_kw < 0 and last_col_kw < 0:
250
- # return None
251
- # if last_table_kw > last_col_kw:
252
- # return "table"
253
- # if last_col_kw > last_table_kw:
254
- # return "column"
255
- # return None
256
-
257
-
258
- # # =========================================================
259
- # # 🌳 TRIE STRUCTURE
260
- # # =========================================================
261
- # class _TrieNode:
262
- # __slots__ = ("children", "terminal")
263
-
264
- # def __init__(self) -> None:
265
- # self.children: Dict[int, _TrieNode] = {}
266
- # self.terminal: bool = False
267
-
268
- # def insert(self, token_ids: Sequence[int]) -> None:
269
- # node = self
270
- # for tid in token_ids:
271
- # tid = int(tid)
272
- # if tid not in node.children:
273
- # node.children[tid] = _TrieNode()
274
- # node = node.children[tid]
275
- # node.terminal = True
276
-
277
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
278
- # node = self
279
- # for tid in prefix:
280
- # node = node.children.get(int(tid))
281
- # if node is None:
282
- # return None
283
- # return node
284
-
285
-
286
- # # =========================================================
287
- # # 🔤 TOKEN ENCODING
288
- # # =========================================================
289
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
290
- # return tokenizer.encode(" " + name, add_special_tokens=False)
291
-
292
-
293
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
294
- # trie = _TrieNode()
295
- # for name in names:
296
- # try:
297
- # ids = _encode_identifier(tokenizer, name)
298
- # if ids:
299
- # trie.insert(ids)
300
- # except Exception:
301
- # continue
302
- # return trie
303
-
304
-
305
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
306
- # tokens = [",", ")", "(", ".", ";", "\n"]
307
- # ids: Set[int] = set()
308
-
309
- # for t in tokens:
310
- # try:
311
- # ids.update(tokenizer.encode(t, add_special_tokens=False))
312
- # except:
313
- # pass
314
-
315
- # return torch.tensor(sorted(ids), dtype=torch.long)
316
-
317
-
318
- # # =========================================================
319
- # # 📦 PER-DB CACHE
320
- # # =========================================================
321
- # @dataclass
322
- # class _PerDbTokenSets:
323
- # fp: str
324
- # table_trie: _TrieNode
325
- # column_trie: _TrieNode
326
- # allow_always: torch.Tensor
327
-
328
-
329
- # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
330
- # _DB_LOCK = threading.Lock()
331
-
332
-
333
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
334
- # with _DB_LOCK:
335
- # cached = _DB_CACHE.get(graph.db_path)
336
- # if cached and cached.fp == graph.fingerprint:
337
- # return cached
338
-
339
- # obj = _PerDbTokenSets(
340
- # fp=graph.fingerprint,
341
- # table_trie=_build_trie(tokenizer, graph.tables),
342
- # column_trie=_build_trie(tokenizer, graph.all_columns),
343
- # allow_always=_allow_always_token_ids(tokenizer),
344
- # )
345
-
346
- # with _DB_LOCK:
347
- # _DB_CACHE[graph.db_path] = obj
348
-
349
- # return obj
350
-
351
-
352
- # # =========================================================
353
- # # 🚀 MAIN LOGITS PROCESSOR
354
- # # =========================================================
355
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
356
- # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
357
- # self.tokenizer = tokenizer
358
- # self.db_paths = list(db_paths)
359
- # self.max_prefix_tokens = max_prefix_tokens
360
-
361
- # self._graphs = [build_constraint_graph(p) for p in db_paths]
362
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
363
-
364
- # # 📊 Metrics (IMPORTANT FOR REPORT)
365
- # self.total_steps = 0
366
- # self.constrained_steps = 0
367
-
368
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
369
- # batch = input_ids.size(0)
370
-
371
- # for i in range(batch):
372
- # self.total_steps += 1
373
-
374
- # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
375
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
376
-
377
- # expected = _infer_expected_identifier(prefix_text)
378
- # if expected is None:
379
- # continue
380
-
381
- # self.constrained_steps += 1
382
-
383
- # # =========================
384
- # # SELECT TRIE
385
- # # =========================
386
- # if expected == "table":
387
- # trie = self._token_sets[i].table_trie
388
- # else:
389
- # trie = self._token_sets[i].column_trie
390
-
391
- # # =========================
392
- # # PARTIAL TOKEN MATCH
393
- # # =========================
394
- # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
395
- # partial = match.group(1) if match else ""
396
-
397
- # try:
398
- # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
399
- # except:
400
- # continue
401
-
402
- # node = trie.walk(prefix_ids)
403
- # if node is None or node.terminal:
404
- # continue
405
-
406
- # allowed_next = list(node.children.keys())
407
- # if not allowed_next:
408
- # continue
409
-
410
- # allowed_next = torch.tensor(allowed_next, device=scores.device)
411
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
412
-
413
- # keep = torch.cat([allowed_next, allow_always])
414
-
415
- # kept_scores = scores[i, keep].clone()
416
- # scores[i, :] = -float("inf")
417
- # scores[i, keep] = kept_scores
418
-
419
- # return scores
420
-
421
- # # =========================================================
422
- # # 📊 METRICS FOR REPORT
423
- # # =========================================================
424
- # def get_constraint_stats(self):
425
- # if self.total_steps == 0:
426
- # return 0
427
- # return self.constrained_steps / self.total_steps
428
-
429
-
430
- # # =========================================================
431
- # # 🔁 BACKWARD COMPATIBILITY
432
- # # =========================================================
433
- # class SchemaConstraintGraph:
434
- # def __init__(self, db_path: str):
435
- # self._graph = build_constraint_graph(db_path)
436
- # self.tables = sorted(self._graph.tables)
437
- # self.columns = sorted(self._graph.all_columns)
438
-
439
-
440
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
441
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
442
- # self.proc = BatchSchemaConstrainedLogitsProcessor(
443
- # tokenizer, [schema_graph._graph.db_path]
444
- # )
445
-
446
- # def __call__(self, input_ids, scores):
447
- # return self.proc(input_ids, scores)
448
-
449
-
450
-
451
-
452
-
453
-
454
- # from __future__ import annotations
455
-
456
- # import re
457
- # import threading
458
- # from dataclasses import dataclass
459
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
460
-
461
- # import torch
462
- # from transformers.generation.logits_process import LogitsProcessor
463
-
464
- # from schema_constraints import ConstraintGraph, build_constraint_graph
465
-
466
-
467
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
468
- # s = re.sub(r"\s+", " ", prefix_text.lower())
469
- # last_from = s.rfind(" from ")
470
- # last_join = s.rfind(" join ")
471
- # last_select = s.rfind(" select ")
472
- # last_where = s.rfind(" where ")
473
- # last_on = s.rfind(" on ")
474
- # last_group = s.rfind(" group by ")
475
- # last_order = s.rfind(" order by ")
476
- # last_having = s.rfind(" having ")
477
-
478
- # last_table_kw = max(last_from, last_join)
479
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
480
-
481
- # if last_table_kw < 0 and last_col_kw < 0:
482
- # return None
483
- # if last_table_kw > last_col_kw:
484
- # return "table"
485
- # if last_col_kw > last_table_kw:
486
- # return "column"
487
- # return None
488
-
489
-
490
- # class _TrieNode:
491
- # __slots__ = ("children", "terminal")
492
-
493
- # def __init__(self) -> None:
494
- # self.children: Dict[int, _TrieNode] = {}
495
- # self.terminal: bool = False
496
-
497
- # def insert(self, token_ids: Sequence[int]) -> None:
498
- # node: _TrieNode = self
499
- # for tid in token_ids:
500
- # tid_i = int(tid)
501
- # nxt = node.children.get(tid_i)
502
- # if nxt is None:
503
- # nxt = _TrieNode()
504
- # node.children[tid_i] = nxt
505
- # node = nxt
506
- # node.terminal = True
507
-
508
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
509
- # node: _TrieNode = self
510
- # for tid in prefix:
511
- # node = node.children.get(int(tid)) # type: ignore[assignment]
512
- # if node is None:
513
- # return None
514
- # return node
515
-
516
-
517
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
518
- # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
519
- # return tokenizer.encode(" " + name, add_special_tokens=False)
520
-
521
-
522
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
523
- # trie = _TrieNode()
524
- # for n in names:
525
- # if not n:
526
- # continue
527
- # try:
528
- # ids = _encode_identifier(tokenizer, n)
529
- # except Exception:
530
- # continue
531
- # if ids:
532
- # trie.insert(ids)
533
- # return trie
534
-
535
-
536
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
537
- # # Allow common delimiters so the model can end an identifier.
538
- # toks = [",", ")", "(", "\n", ".", ";"]
539
- # ids: Set[int] = set()
540
- # for t in toks:
541
- # try:
542
- # for tid in tokenizer.encode(t, add_special_tokens=False):
543
- # ids.add(int(tid))
544
- # except Exception:
545
- # continue
546
- # return torch.tensor(sorted(ids), dtype=torch.long)
547
-
548
-
549
- # @dataclass
550
- # class _PerDbTokenSets:
551
- # fp: str
552
- # table_trie: _TrieNode
553
- # column_trie: _TrieNode
554
- # allow_always: torch.Tensor
555
-
556
-
557
- # _DB_TOKENSET_LOCK = threading.Lock()
558
- # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
559
-
560
-
561
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
562
- # with _DB_TOKENSET_LOCK:
563
- # cached = _DB_TOKENSETS.get(graph.db_path)
564
- # if cached is not None and cached.fp == graph.fingerprint:
565
- # return cached
566
-
567
- # out = _PerDbTokenSets(
568
- # fp=graph.fingerprint,
569
- # table_trie=_build_trie(tokenizer, graph.tables),
570
- # column_trie=_build_trie(tokenizer, graph.all_columns),
571
- # allow_always=_allow_always_token_ids(tokenizer),
572
- # )
573
- # with _DB_TOKENSET_LOCK:
574
- # _DB_TOKENSETS[graph.db_path] = out
575
- # return out
576
-
577
-
578
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
579
- # """
580
- # Schema-aware constrained decoding per item in the generation batch.
581
- # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
582
- # """
583
-
584
- # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
585
- # self.tokenizer = tokenizer
586
- # self.db_paths = list(db_paths)
587
- # self.max_prefix_tokens = int(max_prefix_tokens)
588
-
589
- # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
590
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
591
-
592
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
593
- # if input_ids.dim() != 2 or scores.dim() != 2:
594
- # return scores
595
-
596
- # batch = input_ids.size(0)
597
- # if batch != len(self._graphs):
598
- # return scores
599
-
600
- # for i in range(batch):
601
- # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
602
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
603
- # expected = _infer_expected_identifier(prefix_text)
604
- # if expected is None:
605
- # continue
606
-
607
- # if expected == "table":
608
- # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
609
- # partial = m.group(1) if m else None
610
- # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
611
- # continue
612
- # trie = self._token_sets[i].table_trie
613
- # else:
614
- # m = re.search(
615
- # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
616
- # prefix_text,
617
- # flags=re.I,
618
- # )
619
- # partial = m.group(1) if m else None
620
- # if partial is None and not re.search(
621
- # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
622
- # ):
623
- # continue
624
- # trie = self._token_sets[i].column_trie
625
-
626
- # if not partial:
627
- # prefix_token_ids: List[int] = []
628
- # else:
629
- # try:
630
- # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
631
- # except Exception:
632
- # continue
633
-
634
- # node = trie.walk(prefix_token_ids)
635
- # if node is None or node.terminal:
636
- # continue
637
-
638
- # allowed_next = sorted(node.children.keys())
639
- # if not allowed_next:
640
- # continue
641
-
642
- # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
643
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
644
- # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
645
-
646
- # kept_scores = scores[i, keep].clone()
647
- # scores[i, :] = -float("inf")
648
- # scores[i, keep] = kept_scores
649
-
650
- # return scores
651
-
652
-
653
- # # Backwards-compatible names used elsewhere in the repo.
654
- # class SchemaConstraintGraph:
655
- # def __init__(self, db_path: str):
656
- # self._graph = build_constraint_graph(db_path)
657
- # self.tables = sorted(self._graph.tables)
658
- # self.columns = sorted(self._graph.all_columns)
659
-
660
-
661
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
662
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
663
- # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
664
-
665
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
666
- # return self._proc(input_ids, scores)
667
-
668
-
669
-
670
-
671
- # from __future__ import annotations
672
-
673
- # import re
674
- # import threading
675
- # from dataclasses import dataclass
676
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
677
-
678
- # import torch
679
- # from transformers.generation.logits_process import LogitsProcessor
680
-
681
- # from schema_constraints import ConstraintGraph, build_constraint_graph
682
-
683
-
684
- # # =========================================================
685
- # # 🔍 IDENTIFIER TYPE DETECTION
686
- # # =========================================================
687
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
688
- # s = re.sub(r"\s+", " ", prefix_text.lower())
689
-
690
- # last_from = s.rfind(" from ")
691
- # last_join = s.rfind(" join ")
692
- # last_select = s.rfind(" select ")
693
- # last_where = s.rfind(" where ")
694
- # last_on = s.rfind(" on ")
695
- # last_group = s.rfind(" group by ")
696
- # last_order = s.rfind(" order by ")
697
- # last_having = s.rfind(" having ")
698
-
699
- # last_table_kw = max(last_from, last_join)
700
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
701
-
702
- # if last_table_kw < 0 and last_col_kw < 0:
703
- # return None
704
- # if last_table_kw > last_col_kw:
705
- # return "table"
706
- # if last_col_kw > last_table_kw:
707
- # return "column"
708
- # return None
709
-
710
-
711
- # # =========================================================
712
- # # 🌳 TRIE STRUCTURE
713
- # # =========================================================
714
- # class _TrieNode:
715
- # __slots__ = ("children", "terminal")
716
-
717
- # def __init__(self) -> None:
718
- # self.children: Dict[int, _TrieNode] = {}
719
- # self.terminal: bool = False
720
-
721
- # def insert(self, token_ids: Sequence[int]) -> None:
722
- # node = self
723
- # for tid in token_ids:
724
- # tid = int(tid)
725
- # if tid not in node.children:
726
- # node.children[tid] = _TrieNode()
727
- # node = node.children[tid]
728
- # node.terminal = True
729
-
730
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
731
- # node = self
732
- # for tid in prefix:
733
- # node = node.children.get(int(tid))
734
- # if node is None:
735
- # return None
736
- # return node
737
-
738
-
739
- # # =========================================================
740
- # # 🔤 TOKEN ENCODING
741
- # # =========================================================
742
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
743
- # return tokenizer.encode(" " + name, add_special_tokens=False)
744
-
745
-
746
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
747
- # trie = _TrieNode()
748
- # for name in names:
749
- # try:
750
- # ids = _encode_identifier(tokenizer, name)
751
- # if ids:
752
- # trie.insert(ids)
753
- # except Exception:
754
- # continue
755
- # return trie
756
-
757
-
758
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
759
- # tokens = [",", ")", "(", ".", ";", "\n"]
760
- # ids: Set[int] = set()
761
-
762
- # for t in tokens:
763
- # try:
764
- # ids.update(tokenizer.encode(t, add_special_tokens=False))
765
- # except:
766
- # pass
767
-
768
- # return torch.tensor(sorted(ids), dtype=torch.long)
769
-
770
-
771
- # # =========================================================
772
- # # 📦 PER-DB CACHE
773
- # # =========================================================
774
- # @dataclass
775
- # class _PerDbTokenSets:
776
- # fp: str
777
- # table_trie: _TrieNode
778
- # column_trie: _TrieNode
779
- # allow_always: torch.Tensor
780
-
781
-
782
- # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
783
- # _DB_LOCK = threading.Lock()
784
-
785
-
786
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
787
- # with _DB_LOCK:
788
- # cached = _DB_CACHE.get(graph.db_path)
789
- # if cached and cached.fp == graph.fingerprint:
790
- # return cached
791
-
792
- # obj = _PerDbTokenSets(
793
- # fp=graph.fingerprint,
794
- # table_trie=_build_trie(tokenizer, graph.tables),
795
- # column_trie=_build_trie(tokenizer, graph.all_columns),
796
- # allow_always=_allow_always_token_ids(tokenizer),
797
- # )
798
-
799
- # with _DB_LOCK:
800
- # _DB_CACHE[graph.db_path] = obj
801
-
802
- # return obj
803
-
804
-
805
- # # =========================================================
806
- # # 🚀 MAIN LOGITS PROCESSOR
807
- # # =========================================================
808
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
809
- # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
810
- # self.tokenizer = tokenizer
811
- # self.db_paths = list(db_paths)
812
- # self.max_prefix_tokens = max_prefix_tokens
813
-
814
- # self._graphs = [build_constraint_graph(p) for p in db_paths]
815
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
816
-
817
- # # 📊 Metrics (IMPORTANT FOR REPORT)
818
- # self.total_steps = 0
819
- # self.constrained_steps = 0
820
-
821
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
822
- # batch = input_ids.size(0)
823
-
824
- # for i in range(batch):
825
- # self.total_steps += 1
826
-
827
- # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
828
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
829
-
830
- # expected = _infer_expected_identifier(prefix_text)
831
- # if expected is None:
832
- # continue
833
-
834
- # self.constrained_steps += 1
835
-
836
- # # =========================
837
- # # SELECT TRIE
838
- # # =========================
839
- # if expected == "table":
840
- # trie = self._token_sets[i].table_trie
841
- # else:
842
- # trie = self._token_sets[i].column_trie
843
-
844
- # # =========================
845
- # # PARTIAL TOKEN MATCH
846
- # # =========================
847
- # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
848
- # partial = match.group(1) if match else ""
849
-
850
- # try:
851
- # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
852
- # except:
853
- # continue
854
-
855
- # node = trie.walk(prefix_ids)
856
- # if node is None or node.terminal:
857
- # continue
858
-
859
- # allowed_next = list(node.children.keys())
860
- # if not allowed_next:
861
- # continue
862
-
863
- # allowed_next = torch.tensor(allowed_next, device=scores.device)
864
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
865
-
866
- # keep = torch.cat([allowed_next, allow_always])
867
-
868
- # kept_scores = scores[i, keep].clone()
869
- # scores[i, :] = -float("inf")
870
- # scores[i, keep] = kept_scores
871
-
872
- # return scores
873
-
874
- # # =========================================================
875
- # # 📊 METRICS FOR REPORT
876
- # # =========================================================
877
- # def get_constraint_stats(self):
878
- # if self.total_steps == 0:
879
- # return 0
880
- # return self.constrained_steps / self.total_steps
881
-
882
-
883
- # # =========================================================
884
- # # 🔁 BACKWARD COMPATIBILITY
885
- # # =========================================================
886
- # class SchemaConstraintGraph:
887
- # def __init__(self, db_path: str):
888
- # self._graph = build_constraint_graph(db_path)
889
- # self.tables = sorted(self._graph.tables)
890
- # self.columns = sorted(self._graph.all_columns)
891
-
892
-
893
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
894
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
895
- # self.proc = BatchSchemaConstrainedLogitsProcessor(
896
- # tokenizer, [schema_graph._graph.db_path]
897
- # )
898
-
899
- # def __call__(self, input_ids, scores):
900
- # return self.proc(input_ids, scores)
901
-
902
-
903
-
904
-
905
-
906
-
907
-
908
-
909
- # ********* after task 3
910
-
911
- import re
912
- import threading
913
- from functools import lru_cache
914
-
915
- import torch
916
- from transformers import LogitsProcessor
917
-
918
- from src.schema_utils import get_constraint_graph
919
-
920
-
921
- _TOKEN_CACHE_LOCK = threading.Lock()
922
- _TOKEN_ID_CACHE = {} # (id(tokenizer), db_path) -> (allowed_ids_tensor, always_allow_ids_tensor)
923
-
924
-
925
- def _encode_variants(tokenizer, text: str) -> list[int]:
926
- ids: list[int] = []
927
- for variant in (text, " " + text):
928
- try:
929
- ids.extend(tokenizer.encode(variant, add_special_tokens=False))
930
- except Exception:
931
- continue
932
- # de-dup while keeping order
933
- seen = set()
934
- out = []
935
- for i in ids:
936
- if int(i) not in seen:
937
- seen.add(int(i))
938
- out.append(int(i))
939
- return out
940
-
941
-
942
- def _always_allow_ids(tokenizer) -> list[int]:
943
- """
944
- Tokens we should never block, otherwise decoding can get stuck or generate garbage:
945
- - EOS/PAD
946
- - punctuation/operators needed for SQL formatting
947
- - digits/quotes
948
- """
949
- ids: list[int] = []
950
- for special in [getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "pad_token_id", None)]:
951
- if special is not None:
952
- ids.append(int(special))
953
-
954
- # Common SQL punctuation/operators
955
- pieces = [
956
- " ", "\n", "\t",
957
- ",", ".", "(", ")", ";",
958
- "=", "!=", "<>", "<", ">", "<=", ">=",
959
- "*", "+", "-", "/", "%",
960
- "'", '"',
961
- ]
962
- for p in pieces:
963
- ids.extend(_encode_variants(tokenizer, p))
964
-
965
- # digits
966
- for d in "0123456789":
967
- ids.extend(_encode_variants(tokenizer, d))
968
-
969
- seen = set()
970
- out = []
971
- for i in ids:
972
- if int(i) not in seen:
973
- seen.add(int(i))
974
- out.append(int(i))
975
- return out
976
-
977
-
978
- def _infer_expected_identifier_tail(tail_text: str):
979
- """
980
- Returns ("table"|"column", partial_or_empty) if the tail looks like it's currently
981
- emitting a table/column identifier. Otherwise returns None.
982
- """
983
- t = re.sub(r"\s+", " ", (tail_text or "")).lower()
984
-
985
- m = re.search(r"(?:from|join)\s+([a-z_][a-z0-9_]*)?$", t)
986
- if m:
987
- partial = m.group(1) or ""
988
- # ensure we are actually after keyword (not elsewhere)
989
- if re.search(r"(?:from|join)\s*$", t) or partial:
990
- return "table", partial
991
-
992
- m = re.search(
993
- r"(?:select|where|on|group by|order by|having)\s+([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)?$",
994
- t,
995
- )
996
- if m:
997
- partial = m.group(1) or ""
998
- if re.search(r"(?:select|where|on|group by|order by|having)\s*$", t) or partial:
999
- return "column", partial
1000
-
1001
- return None
1002
-
1003
-
1004
- class SchemaConstrainedLogitsProcessor(LogitsProcessor):
1005
- def __init__(self, tokenizer, db_path):
1006
- self.tokenizer = tokenizer
1007
-
1008
- graph = get_constraint_graph(db_path)
1009
-
1010
- key = (id(tokenizer), str(db_path))
1011
- with _TOKEN_CACHE_LOCK:
1012
- cached = _TOKEN_ID_CACHE.get(key)
1013
- if cached is None:
1014
- allowed_tokens = set(graph.get("tables", set())) | set(graph.get("columns", set()))
1015
-
1016
- sql_keywords = {
1017
- "select", "from", "where", "join", "on",
1018
- "group", "by", "order", "limit", "having",
1019
- "and", "or", "desc", "asc",
1020
- "count", "avg", "min", "max", "sum",
1021
- "distinct", "as", "in", "like", "between",
1022
- "is", "null",
1023
- }
1024
- allowed_tokens |= sql_keywords
1025
-
1026
- allowed_ids: list[int] = []
1027
- for tok in sorted(allowed_tokens):
1028
- allowed_ids.extend(_encode_variants(tokenizer, tok))
1029
- always_ids = _always_allow_ids(tokenizer)
1030
-
1031
- allowed_ids_t = torch.tensor(sorted(set(allowed_ids)), dtype=torch.long)
1032
- always_ids_t = torch.tensor(sorted(set(always_ids)), dtype=torch.long)
1033
- cached = (allowed_ids_t, always_ids_t)
1034
- with _TOKEN_CACHE_LOCK:
1035
- _TOKEN_ID_CACHE[key] = cached
1036
-
1037
- self._allowed_ids_t, self._always_ids_t = cached
1038
-
1039
- def __call__(self, input_ids, scores):
1040
- # Decode only a tail window for speed (beam search calls this a lot).
1041
- try:
1042
- tail_ids = input_ids[0][-128:]
1043
- except Exception:
1044
- tail_ids = input_ids[0]
1045
- tail = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
1046
-
1047
- inferred = _infer_expected_identifier_tail(tail)
1048
- if inferred is None:
1049
- return scores
1050
-
1051
- keep = torch.cat([self._allowed_ids_t.to(scores.device), self._always_ids_t.to(scores.device)])
1052
- if keep.numel() == 0:
1053
- return scores
1054
-
1055
- kept_scores = scores[:, keep].clone()
1056
- scores[:] = -float("inf")
1057
- scores[:, keep] = kept_scores
1058
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/constrained_decoding_sample.py DELETED
@@ -1,516 +0,0 @@
1
- # from __future__ import annotations
2
-
3
- # import re
4
- # import threading
5
- # from dataclasses import dataclass
6
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
7
-
8
- # import torch
9
- # from transformers.generation.logits_process import LogitsProcessor
10
-
11
- # from schema_constraints import ConstraintGraph, build_constraint_graph
12
-
13
-
14
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
15
- # s = re.sub(r"\s+", " ", prefix_text.lower())
16
- # last_from = s.rfind(" from ")
17
- # last_join = s.rfind(" join ")
18
- # last_select = s.rfind(" select ")
19
- # last_where = s.rfind(" where ")
20
- # last_on = s.rfind(" on ")
21
- # last_group = s.rfind(" group by ")
22
- # last_order = s.rfind(" order by ")
23
- # last_having = s.rfind(" having ")
24
-
25
- # last_table_kw = max(last_from, last_join)
26
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
27
-
28
- # if last_table_kw < 0 and last_col_kw < 0:
29
- # return None
30
- # if last_table_kw > last_col_kw:
31
- # return "table"
32
- # if last_col_kw > last_table_kw:
33
- # return "column"
34
- # return None
35
-
36
-
37
- # class _TrieNode:
38
- # __slots__ = ("children", "terminal")
39
-
40
- # def __init__(self) -> None:
41
- # self.children: Dict[int, _TrieNode] = {}
42
- # self.terminal: bool = False
43
-
44
- # def insert(self, token_ids: Sequence[int]) -> None:
45
- # node: _TrieNode = self
46
- # for tid in token_ids:
47
- # tid_i = int(tid)
48
- # nxt = node.children.get(tid_i)
49
- # if nxt is None:
50
- # nxt = _TrieNode()
51
- # node.children[tid_i] = nxt
52
- # node = nxt
53
- # node.terminal = True
54
-
55
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
56
- # node: _TrieNode = self
57
- # for tid in prefix:
58
- # node = node.children.get(int(tid)) # type: ignore[assignment]
59
- # if node is None:
60
- # return None
61
- # return node
62
-
63
-
64
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
65
- # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
66
- # return tokenizer.encode(" " + name, add_special_tokens=False)
67
-
68
-
69
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
70
- # trie = _TrieNode()
71
- # for n in names:
72
- # if not n:
73
- # continue
74
- # try:
75
- # ids = _encode_identifier(tokenizer, n)
76
- # except Exception:
77
- # continue
78
- # if ids:
79
- # trie.insert(ids)
80
- # return trie
81
-
82
-
83
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
84
- # # Allow common delimiters so the model can end an identifier.
85
- # toks = [",", ")", "(", "\n", ".", ";"]
86
- # ids: Set[int] = set()
87
- # for t in toks:
88
- # try:
89
- # for tid in tokenizer.encode(t, add_special_tokens=False):
90
- # ids.add(int(tid))
91
- # except Exception:
92
- # continue
93
- # return torch.tensor(sorted(ids), dtype=torch.long)
94
-
95
-
96
- # @dataclass
97
- # class _PerDbTokenSets:
98
- # fp: str
99
- # table_trie: _TrieNode
100
- # column_trie: _TrieNode
101
- # allow_always: torch.Tensor
102
-
103
-
104
- # _DB_TOKENSET_LOCK = threading.Lock()
105
- # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
106
-
107
-
108
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
109
- # with _DB_TOKENSET_LOCK:
110
- # cached = _DB_TOKENSETS.get(graph.db_path)
111
- # if cached is not None and cached.fp == graph.fingerprint:
112
- # return cached
113
-
114
- # out = _PerDbTokenSets(
115
- # fp=graph.fingerprint,
116
- # table_trie=_build_trie(tokenizer, graph.tables),
117
- # column_trie=_build_trie(tokenizer, graph.all_columns),
118
- # allow_always=_allow_always_token_ids(tokenizer),
119
- # )
120
- # with _DB_TOKENSET_LOCK:
121
- # _DB_TOKENSETS[graph.db_path] = out
122
- # return out
123
-
124
-
125
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
126
- # """
127
- # Schema-aware constrained decoding per item in the generation batch.
128
- # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
129
- # """
130
-
131
- # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
132
- # self.tokenizer = tokenizer
133
- # self.db_paths = list(db_paths)
134
- # self.max_prefix_tokens = int(max_prefix_tokens)
135
-
136
- # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
137
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
138
-
139
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
140
- # if input_ids.dim() != 2 or scores.dim() != 2:
141
- # return scores
142
-
143
- # batch = input_ids.size(0)
144
- # if batch != len(self._graphs):
145
- # return scores
146
-
147
- # for i in range(batch):
148
- # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
149
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
150
- # expected = _infer_expected_identifier(prefix_text)
151
- # if expected is None:
152
- # continue
153
-
154
- # if expected == "table":
155
- # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
156
- # partial = m.group(1) if m else None
157
- # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
158
- # continue
159
- # trie = self._token_sets[i].table_trie
160
- # else:
161
- # m = re.search(
162
- # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
163
- # prefix_text,
164
- # flags=re.I,
165
- # )
166
- # partial = m.group(1) if m else None
167
- # if partial is None and not re.search(
168
- # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
169
- # ):
170
- # continue
171
- # trie = self._token_sets[i].column_trie
172
-
173
- # if not partial:
174
- # prefix_token_ids: List[int] = []
175
- # else:
176
- # try:
177
- # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
178
- # except Exception:
179
- # continue
180
-
181
- # node = trie.walk(prefix_token_ids)
182
- # if node is None or node.terminal:
183
- # continue
184
-
185
- # allowed_next = sorted(node.children.keys())
186
- # if not allowed_next:
187
- # continue
188
-
189
- # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
190
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
191
- # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
192
-
193
- # kept_scores = scores[i, keep].clone()
194
- # scores[i, :] = -float("inf")
195
- # scores[i, keep] = kept_scores
196
-
197
- # return scores
198
-
199
-
200
- # # Backwards-compatible names used elsewhere in the repo.
201
- # class SchemaConstraintGraph:
202
- # def __init__(self, db_path: str):
203
- # self._graph = build_constraint_graph(db_path)
204
- # self.tables = sorted(self._graph.tables)
205
- # self.columns = sorted(self._graph.all_columns)
206
-
207
-
208
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
209
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
210
- # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
211
-
212
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
213
- # return self._proc(input_ids, scores)
214
-
215
-
216
-
217
-
218
- # from __future__ import annotations
219
-
220
- # import re
221
- # import threading
222
- # from dataclasses import dataclass
223
- # from typing import Dict, Iterable, List, Optional, Sequence, Set
224
-
225
- # import torch
226
- # from transformers.generation.logits_process import LogitsProcessor
227
-
228
- # from schema_constraints import ConstraintGraph, build_constraint_graph
229
-
230
-
231
- # # =========================================================
232
- # # 🔍 IDENTIFIER TYPE DETECTION
233
- # # =========================================================
234
- # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
235
- # s = re.sub(r"\s+", " ", prefix_text.lower())
236
-
237
- # last_from = s.rfind(" from ")
238
- # last_join = s.rfind(" join ")
239
- # last_select = s.rfind(" select ")
240
- # last_where = s.rfind(" where ")
241
- # last_on = s.rfind(" on ")
242
- # last_group = s.rfind(" group by ")
243
- # last_order = s.rfind(" order by ")
244
- # last_having = s.rfind(" having ")
245
-
246
- # last_table_kw = max(last_from, last_join)
247
- # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
248
-
249
- # if last_table_kw < 0 and last_col_kw < 0:
250
- # return None
251
- # if last_table_kw > last_col_kw:
252
- # return "table"
253
- # if last_col_kw > last_table_kw:
254
- # return "column"
255
- # return None
256
-
257
-
258
- # # =========================================================
259
- # # 🌳 TRIE STRUCTURE
260
- # # =========================================================
261
- # class _TrieNode:
262
- # __slots__ = ("children", "terminal")
263
-
264
- # def __init__(self) -> None:
265
- # self.children: Dict[int, _TrieNode] = {}
266
- # self.terminal: bool = False
267
-
268
- # def insert(self, token_ids: Sequence[int]) -> None:
269
- # node = self
270
- # for tid in token_ids:
271
- # tid = int(tid)
272
- # if tid not in node.children:
273
- # node.children[tid] = _TrieNode()
274
- # node = node.children[tid]
275
- # node.terminal = True
276
-
277
- # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
278
- # node = self
279
- # for tid in prefix:
280
- # node = node.children.get(int(tid))
281
- # if node is None:
282
- # return None
283
- # return node
284
-
285
-
286
- # # =========================================================
287
- # # 🔤 TOKEN ENCODING
288
- # # =========================================================
289
- # def _encode_identifier(tokenizer, name: str) -> List[int]:
290
- # return tokenizer.encode(" " + name, add_special_tokens=False)
291
-
292
-
293
- # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
294
- # trie = _TrieNode()
295
- # for name in names:
296
- # try:
297
- # ids = _encode_identifier(tokenizer, name)
298
- # if ids:
299
- # trie.insert(ids)
300
- # except Exception:
301
- # continue
302
- # return trie
303
-
304
-
305
- # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
306
- # tokens = [",", ")", "(", ".", ";", "\n"]
307
- # ids: Set[int] = set()
308
-
309
- # for t in tokens:
310
- # try:
311
- # ids.update(tokenizer.encode(t, add_special_tokens=False))
312
- # except:
313
- # pass
314
-
315
- # return torch.tensor(sorted(ids), dtype=torch.long)
316
-
317
-
318
- # # =========================================================
319
- # # 📦 PER-DB CACHE
320
- # # =========================================================
321
- # @dataclass
322
- # class _PerDbTokenSets:
323
- # fp: str
324
- # table_trie: _TrieNode
325
- # column_trie: _TrieNode
326
- # allow_always: torch.Tensor
327
-
328
-
329
- # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
330
- # _DB_LOCK = threading.Lock()
331
-
332
-
333
- # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
334
- # with _DB_LOCK:
335
- # cached = _DB_CACHE.get(graph.db_path)
336
- # if cached and cached.fp == graph.fingerprint:
337
- # return cached
338
-
339
- # obj = _PerDbTokenSets(
340
- # fp=graph.fingerprint,
341
- # table_trie=_build_trie(tokenizer, graph.tables),
342
- # column_trie=_build_trie(tokenizer, graph.all_columns),
343
- # allow_always=_allow_always_token_ids(tokenizer),
344
- # )
345
-
346
- # with _DB_LOCK:
347
- # _DB_CACHE[graph.db_path] = obj
348
-
349
- # return obj
350
-
351
-
352
- # # =========================================================
353
- # # 🚀 MAIN LOGITS PROCESSOR
354
- # # =========================================================
355
- # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
356
- # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
357
- # self.tokenizer = tokenizer
358
- # self.db_paths = list(db_paths)
359
- # self.max_prefix_tokens = max_prefix_tokens
360
-
361
- # self._graphs = [build_constraint_graph(p) for p in db_paths]
362
- # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
363
-
364
- # # 📊 Metrics (IMPORTANT FOR REPORT)
365
- # self.total_steps = 0
366
- # self.constrained_steps = 0
367
-
368
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
369
- # batch = input_ids.size(0)
370
-
371
- # for i in range(batch):
372
- # self.total_steps += 1
373
-
374
- # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
375
- # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
376
-
377
- # expected = _infer_expected_identifier(prefix_text)
378
- # if expected is None:
379
- # continue
380
-
381
- # self.constrained_steps += 1
382
-
383
- # # =========================
384
- # # SELECT TRIE
385
- # # =========================
386
- # if expected == "table":
387
- # trie = self._token_sets[i].table_trie
388
- # else:
389
- # trie = self._token_sets[i].column_trie
390
-
391
- # # =========================
392
- # # PARTIAL TOKEN MATCH
393
- # # =========================
394
- # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
395
- # partial = match.group(1) if match else ""
396
-
397
- # try:
398
- # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
399
- # except:
400
- # continue
401
-
402
- # node = trie.walk(prefix_ids)
403
- # if node is None or node.terminal:
404
- # continue
405
-
406
- # allowed_next = list(node.children.keys())
407
- # if not allowed_next:
408
- # continue
409
-
410
- # allowed_next = torch.tensor(allowed_next, device=scores.device)
411
- # allow_always = self._token_sets[i].allow_always.to(scores.device)
412
-
413
- # keep = torch.cat([allowed_next, allow_always])
414
-
415
- # kept_scores = scores[i, keep].clone()
416
- # scores[i, :] = -float("inf")
417
- # scores[i, keep] = kept_scores
418
-
419
- # return scores
420
-
421
- # # =========================================================
422
- # # 📊 METRICS FOR REPORT
423
- # # =========================================================
424
- # def get_constraint_stats(self):
425
- # if self.total_steps == 0:
426
- # return 0
427
- # return self.constrained_steps / self.total_steps
428
-
429
-
430
- # # =========================================================
431
- # # 🔁 BACKWARD COMPATIBILITY
432
- # # =========================================================
433
- # class SchemaConstraintGraph:
434
- # def __init__(self, db_path: str):
435
- # self._graph = build_constraint_graph(db_path)
436
- # self.tables = sorted(self._graph.tables)
437
- # self.columns = sorted(self._graph.all_columns)
438
-
439
-
440
- # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
441
- # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
442
- # self.proc = BatchSchemaConstrainedLogitsProcessor(
443
- # tokenizer, [schema_graph._graph.db_path]
444
- # )
445
-
446
- # def __call__(self, input_ids, scores):
447
- # return self.proc(input_ids, scores)
448
-
449
-
450
-
451
-
452
-
453
-
454
-
455
-
456
- # ********* after task 3
457
-
458
- import re
459
- import torch
460
- from transformers import LogitsProcessor
461
- from src.schema_utils import get_constraint_graph
462
-
463
-
464
- def _infer_expected_identifier(prefix_text: str):
465
- s = prefix_text.lower()
466
-
467
- if " from " in s or " join " in s:
468
- return "table"
469
- if any(k in s for k in ["select", "where", "on", "group by", "order by"]):
470
- return "column"
471
-
472
- return None
473
-
474
-
475
- class SchemaConstrainedLogitsProcessor(LogitsProcessor):
476
- def __init__(self, tokenizer, db_path):
477
- self.tokenizer = tokenizer
478
-
479
- graph = get_constraint_graph(db_path)
480
-
481
- self.allowed_tokens = set(graph["tables"]) | set(graph["columns"])
482
-
483
- self.sql_keywords = {
484
- "select", "from", "where", "join", "on",
485
- "group", "by", "order", "limit",
486
- "and", "or", "desc", "asc",
487
- "count", "avg", "min", "max", "sum", "*"
488
- }
489
-
490
- self.allowed_tokens |= self.sql_keywords
491
-
492
- self.allowed_token_ids = set()
493
- for token in self.allowed_tokens:
494
- ids = tokenizer.encode(token, add_special_tokens=False)
495
- for i in ids:
496
- self.allowed_token_ids.add(i)
497
-
498
- def __call__(self, input_ids, scores):
499
-
500
- prefix = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
501
-
502
- # 🔥 SOFT CONSTRAINT (FIX)
503
- if len(prefix.strip()) < 10:
504
- return scores
505
-
506
- expected = _infer_expected_identifier(prefix)
507
-
508
- if expected not in ["table", "column"]:
509
- return scores
510
-
511
- mask = torch.full_like(scores, float("-inf"))
512
-
513
- for token_id in self.allowed_token_ids:
514
- mask[:, token_id] = scores[:, token_id]
515
-
516
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/eval_rl_fixed.py CHANGED
@@ -1,756 +1,466 @@
1
  # import json
2
- # import subprocess
3
- # import sys
4
- # import argparse
5
- # import random
6
  # import sqlite3
7
- # import time
8
- # import re
9
- # import os
10
  # from pathlib import Path
11
-
12
  # import torch
13
  # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
14
  # from peft import PeftModel
15
 
16
- # from prompting import encode_prompt
 
 
 
 
17
 
18
- # # -------------------------------
19
- # # NORMALIZATION
20
- # # -------------------------------
21
- # def normalize_sql(sql):
22
- # sql = sql.replace('"', "'")
23
- # sql = re.sub(r"\s+", " ", sql)
24
- # return sql.strip().lower().rstrip(";")
25
-
26
-
27
- # # -------------------------------
28
- # # 🔥 SAFE RESULT NORMALIZATION (FIX)
29
- # # -------------------------------
30
- # def normalize_result(res):
31
- # try:
32
- # return sorted([str(r) for r in res])
33
- # except:
34
- # return []
35
 
 
 
 
 
36
 
37
- # # -------------------------------
38
- # # EXECUTION CHECK (FIXED)
39
- # # -------------------------------
40
- # def check_execution(pred_sql, gold_sql, db_path):
41
- # try:
42
- # conn = sqlite3.connect(db_path)
43
- # conn.text_factory = lambda b: b.decode(errors='ignore')
44
 
45
- # start_time = time.monotonic()
 
 
 
 
46
 
47
- # def timeout_handler():
48
- # return 1 if (time.monotonic() - start_time) > 2.0 else 0
49
 
50
- # conn.set_progress_handler(timeout_handler, 10000)
51
 
52
- # cursor = conn.cursor()
 
 
 
 
53
 
54
- # cursor.execute(pred_sql)
55
- # pred_res = cursor.fetchall()
56
 
57
- # cursor.execute(gold_sql)
58
- # gold_res = cursor.fetchall()
59
 
60
  # conn.close()
61
-
62
- # # 🔥 FIXED COMPARISON
63
- # return normalize_result(pred_res) == normalize_result(gold_res)
64
 
65
  # except Exception:
66
  # return False
67
 
68
 
69
- # # -------------------------------
70
- # # SPIDER PARSER
71
- # # -------------------------------
72
- # def _parse_spider_accuracy(stdout: str, metric_type: str):
73
- # for line in stdout.splitlines():
74
- # if metric_type == "exec" and line.strip().startswith("execution"):
75
- # try:
76
- # return float(line.split()[-1])
77
- # except:
78
- # pass
79
- # elif metric_type == "match" and line.strip().startswith("exact"):
80
- # try:
81
- # return float(line.split()[-1])
82
- # except:
83
- # pass
84
- # return None
85
-
86
-
87
- # # -------------------------------
88
- # # MAIN
89
- # # -------------------------------
90
  # def main():
91
  # parser = argparse.ArgumentParser()
92
  # parser.add_argument("--adapter", type=str, required=True)
93
- # parser.add_argument("--num_samples", type=int, default=700)
94
- # parser.add_argument("--shuffle_dev", action="store_true")
95
- # parser.add_argument("--shuffle_seed", type=int, default=42)
96
  # args = parser.parse_args()
97
 
98
  # project_root = Path(__file__).resolve().parents[1]
99
- # adapter_dir = project_root / args.adapter
100
 
101
- # db_root = project_root / "data" / "database"
102
- # table_json = project_root / "data" / "tables.json"
103
  # dev_json = project_root / "data" / "dev.json"
 
104
 
105
- # pred_path = project_root / "temp_predictions.txt"
106
- # temp_gold_path = project_root / "temp_gold.sql"
107
 
108
- # if not adapter_dir.exists():
109
- # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
 
 
 
 
110
 
111
- # device = "mps" if torch.backends.mps.is_available() else (
112
- # "cuda" if torch.cuda.is_available() else "cpu"
113
- # )
114
- # print(f"Using device: {device}")
115
 
116
- # BASE_MODEL = "Salesforce/codet5-base"
117
- # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
118
 
119
- # if tokenizer.pad_token is None:
120
- # tokenizer.pad_token = tokenizer.eos_token
121
 
122
- # print(f"\n📦 Loading Model: {args.adapter}")
 
 
 
123
 
124
- # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
 
125
 
126
- # adapter_for_peft = os.path.relpath(adapter_dir, project_root)
127
 
128
- # model = PeftModel.from_pretrained(
129
- # base,
130
- # adapter_for_peft,
131
- # local_files_only=True
132
- # ).to(device)
133
-
134
- # model = model.merge_and_unload()
135
- # model.eval()
136
-
137
- # # -------------------------------
138
- # # LOAD DATA
139
- # # -------------------------------
140
- # with dev_json.open() as f:
141
- # dev = json.load(f)
142
-
143
- # if args.shuffle_dev:
144
- # rng = random.Random(args.shuffle_seed)
145
- # rng.shuffle(dev)
146
-
147
- # dev = dev[: args.num_samples]
148
- # total = len(dev)
149
-
150
- # gen_kwargs = dict(
151
- # max_new_tokens=160,
152
- # num_beams=8,
153
- # length_penalty=0.8,
154
- # do_sample=False,
155
- # early_stopping=True,
156
- # pad_token_id=tokenizer.pad_token_id,
157
- # eos_token_id=tokenizer.eos_token_id,
158
- # )
159
-
160
- # print(f"\n🚀 Evaluating {total} samples...\n")
161
-
162
- # em_correct = 0
163
- # ex_correct = 0
164
-
165
- # with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
166
- # for i, ex in enumerate(dev, start=1):
167
-
168
- # db_id = ex["db_id"]
169
- # question = ex["question"]
170
- # gold_query = ex["query"]
171
- # db_path = db_root / db_id / f"{db_id}.sqlite"
172
-
173
- # # -------------------------------
174
- # # GENERATE SQL
175
- # # -------------------------------
176
- # input_ids = encode_prompt(
177
- # tokenizer,
178
- # question,
179
- # db_id,
180
- # device=device,
181
- # max_input_tokens=512
182
- # )
183
-
184
- # input_ids = input_ids.unsqueeze(0).to(device)
185
- # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
186
 
 
187
  # outputs = model.generate(
188
- # input_ids=input_ids,
189
- # attention_mask=attention_mask,
190
- # **gen_kwargs
 
191
  # )
192
 
193
- # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
194
-
195
- # # -------------------------------
196
- # # SAVE FOR SPIDER EVAL
197
- # # -------------------------------
198
- # out_pred.write(f"{pred_sql}\n")
199
- # out_gold.write(f"{gold_query}\t{db_id}\n")
200
-
201
- # # -------------------------------
202
- # # LIVE METRICS
203
- # # -------------------------------
204
- # if normalize_sql(pred_sql) == normalize_sql(gold_query):
205
- # em_correct += 1
206
-
207
- # if check_execution(pred_sql, gold_query, db_path):
208
- # ex_correct += 1
209
-
210
- # if i % 20 == 0 or i == total:
211
- # print(
212
- # f"Progress: {i}/{total} | "
213
- # f"EM: {(em_correct/i)*100:.2f}% | "
214
- # f"EX: {(ex_correct/i)*100:.2f}%"
215
- # )
216
 
217
- # print("\n🚀 Running Official Spider Evaluation...\n")
 
218
 
219
- # eval_script = project_root / "spider_eval" / "evaluation.py"
220
 
221
- # # EXACT MATCH
222
- # cmd_match = [
223
- # sys.executable, str(eval_script),
224
- # "--gold", str(temp_gold_path),
225
- # "--pred", str(pred_path),
226
- # "--etype", "match",
227
- # "--db", str(db_root),
228
- # "--table", str(table_json),
229
- # ]
230
 
231
- # proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
232
- # exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
233
-
234
- # # EXECUTION
235
- # cmd_exec = [
236
- # sys.executable, str(eval_script),
237
- # "--gold", str(temp_gold_path),
238
- # "--pred", str(pred_path),
239
- # "--etype", "exec",
240
- # "--db", str(db_root),
241
- # "--table", str(table_json),
242
- # ]
243
-
244
- # proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
245
- # exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
246
-
247
- # print("==========================================")
248
- # print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
249
- # print("==========================================")
250
-
251
- # print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
252
- # print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
253
 
254
- # print("==========================================\n")
 
 
255
 
256
 
257
  # if __name__ == "__main__":
258
  # main()
259
 
260
 
261
-
262
-
263
-
264
-
265
  # import json
266
  # import sqlite3
267
- # import re
268
- # import time
269
- # import sys
270
  # import argparse
 
271
  # from pathlib import Path
 
 
 
272
 
273
- # # ==========================================
274
- # # PATH SETUP
275
- # # ==========================================
276
- # PROJECT_ROOT = Path(__file__).resolve().parents[1]
277
- # if str(PROJECT_ROOT) not in sys.path:
278
- # sys.path.insert(0, str(PROJECT_ROOT))
279
-
280
- # from src.text2sql_engine import get_engine
281
- # from src.sql_validator import validate_sql_schema
282
-
283
- # # ==========================================
284
- # # CONFIG
285
- # # ==========================================
286
- # DATA_PATH = PROJECT_ROOT / "data" / "dev.json"
287
- # DB_ROOT = PROJECT_ROOT / "data" / "database"
288
-
289
- # # ==========================================
290
- # # NORMALIZATION
291
- # # ==========================================
292
- # def normalize_sql(sql):
293
- # if not isinstance(sql, str):
294
- # return ""
295
- # sql = sql.replace('"', "'")
296
- # sql = re.sub(r"\s+", " ", sql)
297
- # return sql.strip().lower().rstrip(";")
298
-
299
- # def normalize_result(res):
300
- # try:
301
- # return sorted([tuple(map(str, r)) for r in res])
302
- # except:
303
- # return []
304
-
305
- # # ==========================================
306
- # # EXECUTION
307
- # # ==========================================
308
- # def execute_sql(db_path, sql):
309
- # try:
310
- # conn = sqlite3.connect(db_path)
311
-
312
- # start = time.time()
313
- # def timeout():
314
- # return 1 if (time.time() - start) > 2 else 0
315
-
316
- # conn.set_progress_handler(timeout, 10000)
317
-
318
- # cur = conn.cursor()
319
- # cur.execute(sql)
320
- # res = cur.fetchall()
321
-
322
- # conn.close()
323
- # return res
324
-
325
- # except Exception:
326
- # return None
327
-
328
- # # ==========================================
329
- # # EVALUATION
330
- # # ==========================================
331
- # def evaluate(engine, data, is_constrained=False, debug=False):
332
-
333
- # attempted = 0
334
- # total = 0
335
- # exact_match = 0
336
- # execution_match = 0
337
- # constraint_ok = 0
338
-
339
- # skipped_missing_db = 0
340
- # skipped_exception = 0
341
- # skipped_no_sql = 0
342
-
343
- # total_time = 0
344
 
345
- # for i, item in enumerate(data, 1):
 
 
 
346
 
347
- # question = item.get("question", "")
348
- # gold_sql = item.get("query", "")
349
- # db_id = item.get("db_id", "")
 
350
 
351
- # db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
 
 
352
 
353
- # if not db_path.exists():
354
- # skipped_missing_db += 1
355
- # continue
 
 
356
 
357
- # try:
358
- # start = time.time()
359
- # result = engine.ask(question, db_id)
360
- # total_time += (time.time() - start)
361
- # except Exception:
362
- # skipped_exception += 1
363
- # continue
364
 
365
- # if not isinstance(result, dict):
366
- # continue
367
 
368
- # pred_sql = result.get("sql", "")
 
 
 
 
 
 
 
 
 
369
 
370
- # # DEBUG
371
- # if debug:
372
- # print(f"\nQ: {question}")
373
- # print(f"PRED: {pred_sql}")
374
- # print(f"GOLD: {gold_sql}")
375
 
376
- # if not pred_sql:
377
- # skipped_no_sql += 1
378
- # continue
379
 
380
- # attempted += 1
381
- # total += 1
382
 
383
- # # CONSTRAINT CHECK
384
- # if is_constrained:
385
- # try:
386
- # is_valid, _ = validate_sql_schema(pred_sql, str(db_path))
387
- # if is_valid:
388
- # constraint_ok += 1
389
- # except:
390
- # pass
391
 
392
- # # EXACT MATCH
393
- # if normalize_sql(pred_sql) == normalize_sql(gold_sql):
394
- # exact_match += 1
395
 
396
- # # EXECUTION MATCH
397
- # pred_res = execute_sql(str(db_path), pred_sql)
398
- # gold_res = execute_sql(str(db_path), gold_sql)
399
 
400
- # if pred_res is not None and gold_res is not None:
401
- # if normalize_result(pred_res) == normalize_result(gold_res):
402
- # execution_match += 1
 
 
 
403
 
404
- # # PROGRESS
405
- # if i % 10 == 0:
406
- # print(
407
- # f"[{i}/{len(data)}] "
408
- # f"EM: {exact_match/max(total,1):.3f} | "
409
- # f"EX: {execution_match/max(total,1):.3f} | "
410
- # f"Constraint: {(constraint_ok/max(total,1)) if is_constrained else 0:.3f}"
411
- # )
412
 
413
- # avg_latency = total_time / max(attempted, 1)
414
-
415
- # return {
416
- # "exact_match": exact_match / total if total > 0 else 0,
417
- # "execution_accuracy": execution_match / total if total > 0 else 0,
418
- # "constraint_rate": (constraint_ok / total if (is_constrained and total > 0) else 0),
419
- # "avg_latency": avg_latency,
420
- # "total": total,
421
- # "attempted": attempted,
422
- # "skipped_missing_db": skipped_missing_db,
423
- # "skipped_exception": skipped_exception,
424
- # "skipped_no_sql": skipped_no_sql,
425
- # }
426
-
427
- # # ==========================================
428
- # # MAIN
429
- # # ==========================================
430
- # if __name__ == "__main__":
431
 
432
- # ap = argparse.ArgumentParser()
433
- # ap.add_argument("--num-samples", type=int, default=100)
434
- # ap.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
435
- # ap.add_argument("--debug", action="store_true")
436
- # args = ap.parse_args()
 
 
 
 
 
 
 
437
 
438
- # print(f"\n📥 Loading dataset from {DATA_PATH}...")
 
439
 
440
- # with open(str(DATA_PATH)) as f:
441
- # data = json.load(f)[: args.num_samples]
442
 
443
- # # ==========================================
444
- # # 🔴 BASE MODEL
445
- # # ==========================================
446
- # print("\n🚀 Running BASE MODEL...\n")
447
 
448
- # engine_base = get_engine(
449
- # adapter_path="checkpoints/sft_adapter_codet5" , # 🔥 change this
450
- # use_lora=True,
451
- # use_constrained=False
452
- # )
453
 
454
- # res_base = evaluate(engine_base, data, is_constrained=False, debug=args.debug)
 
455
 
456
- # # ==========================================
457
- # # 🟡 RLHF (NO CONSTRAINT)
458
- # # ==========================================
459
- # print("\n🚀 Running RLHF (NO CONSTRAINT)...\n")
460
 
461
- # engine_rlhf = get_engine(
462
- # adapter_path="checkpoints/best_rlhf_model",
463
- # use_lora=True,
464
- # use_constrained=False
465
- # )
466
 
467
- # res_rlhf = evaluate(engine_rlhf, data, is_constrained=False, debug=args.debug)
 
 
 
 
 
 
468
 
469
- # # ==========================================
470
- # # 🟢 RLHF + CONSTRAINT
471
- # # ==========================================
472
- # print("\n🚀 Running RLHF + CONSTRAINED...\n")
473
 
474
- # engine_const = get_engine(
475
- # adapter_path="checkpoints/best_rlhf_model_2",
476
- # use_lora=True,
477
- # use_constrained=True
478
- # )
479
 
480
- # res_const = evaluate(engine_const, data, is_constrained=True, debug=args.debug)
481
 
482
- # # ==========================================
483
- # # FINAL RESULTS
484
- # # ==========================================
485
- # print("\n==========================================")
486
- # print("🎯 FINAL RESULTS (3-WAY COMPARISON)")
487
- # print("==========================================")
488
 
489
- # print(f"Base Model → EM: {res_base['exact_match']*100:.2f}% | "
490
- # f"EX: {res_base['execution_accuracy']*100:.2f}%")
491
 
492
- # print(f"RLHF → EM: {res_rlhf['exact_match']*100:.2f}% | "
493
- # f"EX: {res_rlhf['execution_accuracy']*100:.2f}%")
 
494
 
495
- # print(f"RLHF + Constrain → EM: {res_const['exact_match']*100:.2f}% | "
496
- # f"EX: {res_const['execution_accuracy']*100:.2f}% | "
497
- # f"Constraint: {res_const['constraint_rate']*100:.2f}%")
498
 
499
- # print("==========================================\n")
 
500
 
501
 
502
  import json
 
 
503
  import argparse
 
504
  import sqlite3
505
  import time
506
  import re
507
- import os
508
  from pathlib import Path
509
 
510
  import torch
511
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
512
  from peft import PeftModel
513
 
514
- # Import handling
515
- try:
516
- from prompting import encode_prompt
517
- from src.sql_validator import validate_sql_schema
518
- except ImportError:
519
- import sys
520
- sys.path.append(str(Path(__file__).resolve().parents[1]))
521
- from src.prompting import encode_prompt
522
- from src.sql_validator import validate_sql_schema
523
-
524
- # =========================================================
525
- # ERROR LOGGING
526
- # =========================================================
527
- ERROR_LOG_FILE = "results/error_logs.json"
528
-
529
- def classify_error(sql, error_msg=""):
530
- sql = sql.lower()
531
- error_msg = str(error_msg).lower()
532
-
533
- if "no such column" in error_msg:
534
- return "wrong_column"
535
- if "no such table" in error_msg:
536
- return "wrong_table"
537
- if "syntax error" in error_msg:
538
- return "syntax_error"
539
- if "ambiguous column" in error_msg:
540
- return "ambiguous_column"
541
- if "join" in sql and " on " not in sql:
542
- return "missing_join"
543
-
544
- return "other"
545
-
546
- def log_error(question, sql, error, error_type):
547
- os.makedirs(os.path.dirname(ERROR_LOG_FILE), exist_ok=True)
548
-
549
- entry = {
550
- "question": question,
551
- "sql": sql,
552
- "error": str(error),
553
- "error_type": error_type,
554
- "timestamp": time.time()
555
- }
556
-
557
- logs = []
558
- if os.path.exists(ERROR_LOG_FILE):
559
- try:
560
- with open(ERROR_LOG_FILE, "r") as f:
561
- content = f.read().strip()
562
- if content:
563
- logs = json.loads(content)
564
- except:
565
- logs = []
566
-
567
- logs.append(entry)
568
-
569
- with open(ERROR_LOG_FILE, "w") as f:
570
- json.dump(logs, f, indent=2)
571
-
572
- # =========================================================
573
- # 🔥 FINAL FIX_SQL (BALANCED VERSION)
574
- # =========================================================
575
- def fix_sql(sql):
576
- if not sql:
577
- return "SELECT 1"
578
-
579
- s = str(sql).strip()
580
-
581
- # Extract SQL only
582
- match = re.search(r"(?i)(select|with)[\s\S]*", s)
583
- if match:
584
- s = match.group(0)
585
-
586
- s = s.split(";")[0].strip()
587
-
588
- # NULL fixes
589
- s = re.sub(r'(?i)=\s*null', 'IS NULL', s)
590
- s = re.sub(r'(?i)!=\s*null', 'IS NOT NULL', s)
591
-
592
- # Fix commas
593
- s = re.sub(r',\s*,+', ',', s)
594
- s = re.sub(r'(?i),\s*from', ' FROM', s)
595
-
596
- # 🔥 LIGHT COLUMN SAFETY (main improvement)
597
- if "select" in s.lower():
598
- if len(re.findall(r'\w+\.\w+', s)) > 3:
599
- s = re.sub(r'(?i)select\s+.*?\s+from', 'SELECT * FROM', s)
600
-
601
- # 🔥 JOIN fix
602
- if "join" in s.lower() and " on " not in s.lower():
603
- s = re.sub(r'join\s+(\w+)', r'JOIN \1 ON 1=1', s, flags=re.I)
604
-
605
- # Ensure valid SQL
606
- if not s.lower().startswith(("select", "with")):
607
- return "SELECT 1"
608
-
609
- return s.strip()
610
-
611
- # =========================================================
612
- # NORMALIZATION
613
- # =========================================================
614
- def normalize_sql(sql):
615
- if not sql:
616
- return ""
617
- return re.sub(r"\s+", " ", str(sql)).strip().lower()
618
 
619
- def normalize_result(res):
620
- if not res:
621
- return []
622
- try:
623
- normalized = [tuple(sorted(str(x) for x in row)) for row in res]
624
- return sorted(normalized)
625
- except:
626
- return sorted([str(r) for r in res])
627
-
628
- # =========================================================
629
- # EXECUTION HELPERS
630
- # =========================================================
631
- def is_executable(sql, db_path):
632
- try:
633
- conn = sqlite3.connect(db_path)
634
- cur = conn.cursor()
635
- cur.execute(sql)
636
- conn.close()
637
- return True
638
- except:
639
- return False
640
 
641
- def check_execution(pred_sql, gold_sql, db_path, question):
 
642
  try:
643
  conn = sqlite3.connect(db_path)
644
  conn.text_factory = lambda b: b.decode(errors='ignore')
645
- cur = conn.cursor()
646
-
647
- cur.execute(gold_sql)
648
- gold_res = cur.fetchall()
649
-
650
- cur.execute(pred_sql)
651
- pred_res = cur.fetchall()
652
-
 
 
 
 
 
653
  conn.close()
654
-
655
- return normalize_result(pred_res) == normalize_result(gold_res)
656
-
657
- except Exception as e:
658
- error_type = classify_error(pred_sql, str(e))
659
- log_error(question, pred_sql, str(e), error_type)
660
  return False
661
 
662
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  # MAIN
664
- # =========================================================
665
  def main():
666
  parser = argparse.ArgumentParser()
667
- parser.add_argument("--adapter", type=str, required=True)
668
- parser.add_argument("--num_samples", type=int, default=700)
 
 
669
  args = parser.parse_args()
670
 
671
- project_root = Path(__file__).resolve().parent
672
- if project_root.name in ["scripts", "src"]:
673
- project_root = project_root.parent
674
 
675
  db_root = project_root / "data" / "database"
 
676
  dev_json = project_root / "data" / "dev.json"
 
 
 
677
 
678
- device = "mps" if torch.backends.mps.is_available() else "cpu"
 
679
 
680
- print(f"Loading model on {device}...")
 
681
 
682
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
683
- base_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
 
 
684
 
685
- model = PeftModel.from_pretrained(base_model, args.adapter).to(device)
 
 
686
  model = model.merge_and_unload()
687
  model.eval()
688
 
689
- with open(dev_json, "r") as f:
690
- dev_data = json.load(f)[:args.num_samples]
691
-
692
- em_correct = 0
693
- ex_correct = 0
694
- constraint_ok = 0
695
-
696
- print(f"\n🚀 Evaluating {len(dev_data)} samples...\n")
697
-
698
- for i, ex in enumerate(dev_data, 1):
699
- db_id = ex["db_id"]
700
- question = ex["question"]
701
- gold_query = ex["query"]
702
-
703
- db_path = db_root / db_id / f"{db_id}.sqlite"
704
-
705
- input_tensor = encode_prompt(tokenizer, question, db_id, device=device).unsqueeze(0)
706
-
707
- with torch.no_grad():
708
- outputs = model.generate(
709
- input_ids=input_tensor,
710
- max_new_tokens=128,
711
- num_beams=8,
712
- num_return_sequences=8
713
- )
714
 
715
- best_sql = ""
 
 
716
 
717
- # 🔥 EXECUTION-GUIDED SELECTION
718
- for out in outputs:
719
- raw_pred = tokenizer.decode(out, skip_special_tokens=True)
720
- candidate_sql = fix_sql(raw_pred)
721
 
722
- if is_executable(candidate_sql, str(db_path)):
723
- best_sql = candidate_sql
724
- break
 
 
 
 
 
725
 
726
- if not best_sql:
727
- best_sql = fix_sql(tokenizer.decode(outputs[0], skip_special_tokens=True))
728
 
729
- try:
730
- is_valid, _ = validate_sql_schema(best_sql, str(db_path))
731
- except:
732
- is_valid = False
733
-
734
- if is_valid:
735
- constraint_ok += 1
736
-
737
- if normalize_sql(best_sql) == normalize_sql(gold_query):
738
- em_correct += 1
739
-
740
- if check_execution(best_sql, gold_query, str(db_path), question):
741
- ex_correct += 1
742
-
743
- if i % 50 == 0:
744
- print(f"{i}/{len(dev_data)} done")
745
 
746
- print("\n========================================")
747
- print("🎯 FINAL EVALUATION RESULTS")
748
- print("========================================")
749
- print(f"Exact Match (EM): {(em_correct/len(dev_data))*100:.2f}%")
750
- print(f"Execution Acc (EX): {(ex_correct/len(dev_data))*100:.2f}%")
751
- print(f"Constraint Rate: {(constraint_ok/len(dev_data))*100:.2f}%")
752
- print("========================================")
753
- print(f"Errors logged to: {ERROR_LOG_FILE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
  if __name__ == "__main__":
756
- main()
 
1
  # import json
 
 
 
 
2
  # import sqlite3
3
+ # import argparse
 
 
4
  # from pathlib import Path
 
5
  # import torch
6
  # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  # from peft import PeftModel
8
 
9
+ # # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
10
+ # def build_prompt(question, schema):
11
+ # return f"""
12
+ # Database Schema:
13
+ # {schema}
14
 
15
+ # Translate English to SQL:
16
+ # {question}
17
+ # SQL:
18
+ # """
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # # ---------------- LOAD SCHEMA ----------------
21
+ # def load_schema(db_path):
22
+ # conn = sqlite3.connect(db_path)
23
+ # cursor = conn.cursor()
24
 
25
+ # tables = cursor.execute(
26
+ # "SELECT name FROM sqlite_master WHERE type='table';"
27
+ # ).fetchall()
 
 
 
 
28
 
29
+ # schema = ""
30
+ # for (table,) in tables:
31
+ # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
32
+ # col_names = [c[1] for c in cols]
33
+ # schema += f"{table}({', '.join(col_names)})\n"
34
 
35
+ # conn.close()
36
+ # return schema
37
 
 
38
 
39
+ # # ---------------- EXECUTION CHECK ----------------
40
+ # def execution_match(pred_sql, gold_sql, db_path):
41
+ # try:
42
+ # conn = sqlite3.connect(db_path)
43
+ # cur = conn.cursor()
44
 
45
+ # cur.execute(pred_sql)
46
+ # pred = cur.fetchall()
47
 
48
+ # cur.execute(gold_sql)
49
+ # gold = cur.fetchall()
50
 
51
  # conn.close()
52
+ # return pred == gold
 
 
53
 
54
  # except Exception:
55
  # return False
56
 
57
 
58
+ # # ---------------- MAIN ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # def main():
60
  # parser = argparse.ArgumentParser()
61
  # parser.add_argument("--adapter", type=str, required=True)
62
+ # parser.add_argument("--num_samples", type=int, default=1034)
 
 
63
  # args = parser.parse_args()
64
 
65
  # project_root = Path(__file__).resolve().parents[1]
 
66
 
 
 
67
  # dev_json = project_root / "data" / "dev.json"
68
+ # db_root = project_root / "data" / "database"
69
 
70
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
 
71
 
72
+ # # load model
73
+ # base_model = "Salesforce/codet5-base"
74
+ # tokenizer = AutoTokenizer.from_pretrained(args.adapter)
75
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
76
+ # model = PeftModel.from_pretrained(base, args.adapter).to(device)
77
+ # model = model.merge_and_unload()
78
 
79
+ # with open(dev_json) as f:
80
+ # dev = json.load(f)[: args.num_samples]
 
 
81
 
82
+ # correct = 0
 
83
 
84
+ # print(f"Evaluating {len(dev)} examples...\n")
 
85
 
86
+ # for i, ex in enumerate(dev, 1):
87
+ # question = ex["question"]
88
+ # db_id = ex["db_id"]
89
+ # gold_sql = ex["query"]
90
 
91
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
92
+ # schema = load_schema(db_path)
93
 
94
+ # prompt = build_prompt(question, schema)
95
 
96
+ # inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # with torch.no_grad():
99
  # outputs = model.generate(
100
+ # **inputs,
101
+ # max_new_tokens=80,
102
+ # do_sample=False,
103
+ # num_beams=4,
104
  # )
105
 
106
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # if "SQL:" in pred_sql:
109
+ # pred_sql = pred_sql.split("SQL:")[-1].strip()
110
 
111
+ # match = execution_match(pred_sql, gold_sql, db_path)
112
 
113
+ # if match:
114
+ # correct += 1
 
 
 
 
 
 
 
115
 
116
+ # if i % 10 == 0:
117
+ # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # print("\n=============================")
120
+ # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
121
+ # print("=============================")
122
 
123
 
124
  # if __name__ == "__main__":
125
  # main()
126
 
127
 
 
 
 
 
128
  # import json
129
  # import sqlite3
 
 
 
130
  # import argparse
131
+ # import time
132
  # from pathlib import Path
133
+ # import torch
134
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
135
+ # from peft import PeftModel
136
 
137
+ # # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
138
+ # def build_prompt(question, schema):
139
+ # return f"""
140
+ # Database Schema:
141
+ # {schema}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Translate English to SQL:
144
+ # {question}
145
+ # SQL:
146
+ # """
147
 
148
+ # # ---------------- LOAD SCHEMA ----------------
149
+ # def load_schema(db_path):
150
+ # conn = sqlite3.connect(db_path)
151
+ # cursor = conn.cursor()
152
 
153
+ # tables = cursor.execute(
154
+ # "SELECT name FROM sqlite_master WHERE type='table';"
155
+ # ).fetchall()
156
 
157
+ # schema = ""
158
+ # for (table,) in tables:
159
+ # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
160
+ # col_names = [c[1] for c in cols]
161
+ # schema += f"{table}({', '.join(col_names)})\n"
162
 
163
+ # conn.close()
164
+ # return schema
 
 
 
 
 
165
 
 
 
166
 
167
+ # # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
168
+ # def execution_match(pred_sql, gold_sql, db_path):
169
+ # try:
170
+ # conn = sqlite3.connect(db_path)
171
+
172
+ # # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
173
+ # start_time = time.monotonic()
174
+ # def timeout_handler():
175
+ # return 1 if (time.monotonic() - start_time) > 5.0 else 0
176
+ # conn.set_progress_handler(timeout_handler, 10000)
177
 
178
+ # cur = conn.cursor()
 
 
 
 
179
 
180
+ # cur.execute(pred_sql)
181
+ # pred = cur.fetchall()
 
182
 
183
+ # cur.execute(gold_sql)
184
+ # gold = cur.fetchall()
185
 
186
+ # conn.close()
187
+ # return pred == gold
 
 
 
 
 
 
188
 
189
+ # except Exception:
190
+ # return False
 
191
 
 
 
 
192
 
193
+ # # ---------------- MAIN ----------------
194
+ # def main():
195
+ # parser = argparse.ArgumentParser()
196
+ # parser.add_argument("--adapter", type=str, required=True)
197
+ # parser.add_argument("--num_samples", type=int, default=1034)
198
+ # args = parser.parse_args()
199
 
200
+ # project_root = Path(__file__).resolve().parents[1]
 
 
 
 
 
 
 
201
 
202
+ # dev_json = project_root / "data" / "dev.json"
203
+ # db_root = project_root / "data" / "database"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ # # 🎯 Added CUDA support for Nvidia GPUs
206
+ # device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
207
+
208
+ # # load model
209
+ # base_model = "Salesforce/codet5-base"
210
+ # print(f"Loading Base: {base_model}")
211
+ # print(f"Loading Adapter: {args.adapter}")
212
+
213
+ # tokenizer = AutoTokenizer.from_pretrained(args.adapter)
214
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
215
+ # model = PeftModel.from_pretrained(base, args.adapter).to(device)
216
+ # model = model.merge_and_unload()
217
 
218
+ # with open(dev_json) as f:
219
+ # dev = json.load(f)[: args.num_samples]
220
 
221
+ # correct = 0
 
222
 
223
+ # print(f"Evaluating {len(dev)} examples...\n")
 
 
 
224
 
225
+ # for i, ex in enumerate(dev, 1):
226
+ # question = ex["question"]
227
+ # db_id = ex["db_id"]
228
+ # gold_sql = ex["query"]
 
229
 
230
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
231
+ # schema = load_schema(db_path)
232
 
233
+ # prompt = build_prompt(question, schema)
 
 
 
234
 
235
+ # inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
236
 
237
+ # with torch.no_grad():
238
+ # outputs = model.generate(
239
+ # **inputs,
240
+ # max_new_tokens=80,
241
+ # do_sample=False,
242
+ # num_beams=4,
243
+ # )
244
 
245
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
246
 
247
+ # if "SQL:" in pred_sql:
248
+ # pred_sql = pred_sql.split("SQL:")[-1].strip()
 
 
 
249
 
250
+ # match = execution_match(pred_sql, gold_sql, db_path)
251
 
252
+ # if match:
253
+ # correct += 1
 
 
 
 
254
 
255
+ # if i % 10 == 0:
256
+ # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
257
 
258
+ # print("\n=============================")
259
+ # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
260
+ # print("=============================")
261
 
 
 
 
262
 
263
+ # if __name__ == "__main__":
264
+ # main()
265
 
266
 
267
  import json
268
+ import subprocess
269
+ import sys
270
  import argparse
271
+ import random
272
  import sqlite3
273
  import time
274
  import re
 
275
  from pathlib import Path
276
 
277
  import torch
278
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
279
  from peft import PeftModel
280
 
281
+ # Assuming you have a prompting.py that has encode_prompt
282
+ from prompting import encode_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # -------------------------------
285
+ # LIVE CHECK HELPERS
286
+ # -------------------------------
287
+ def normalize_sql(sql):
288
+ """Basic normalization for the live progress bar."""
289
+ sql = sql.replace('"', "'")
290
+ sql = re.sub(r"\s+", " ", sql)
291
+ return sql.strip().lower().rstrip(";")
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ def check_execution(pred_sql, gold_sql, db_path):
294
+ """Basic execution check for the live progress bar."""
295
  try:
296
  conn = sqlite3.connect(db_path)
297
  conn.text_factory = lambda b: b.decode(errors='ignore')
298
+
299
+ # 2-second timeout so the live tracker doesn't freeze forever
300
+ start_time = time.monotonic()
301
+ def timeout_handler():
302
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
303
+ conn.set_progress_handler(timeout_handler, 10000)
304
+
305
+ cursor = conn.cursor()
306
+ cursor.execute(pred_sql)
307
+ pred_res = cursor.fetchall()
308
+
309
+ cursor.execute(gold_sql)
310
+ gold_res = cursor.fetchall()
311
  conn.close()
312
+
313
+ # Simple sorted check for the live tracker
314
+ return sorted(pred_res) == sorted(gold_res)
315
+ except Exception:
 
 
316
  return False
317
 
318
+ # -------------------------------
319
+ # SPIDER PARSER
320
+ # -------------------------------
321
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
322
+ for line in stdout.splitlines():
323
+ if metric_type == "exec" and line.strip().startswith("execution"):
324
+ try: return float(line.split()[-1])
325
+ except: pass
326
+ elif metric_type == "match" and line.strip().startswith("exact"):
327
+ try: return float(line.split()[-1])
328
+ except: pass
329
+ return None
330
+
331
+ # -------------------------------
332
  # MAIN
333
+ # -------------------------------
334
  def main():
335
  parser = argparse.ArgumentParser()
336
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
337
+ parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate")
338
+ parser.add_argument("--shuffle_dev", action="store_true")
339
+ parser.add_argument("--shuffle_seed", type=int, default=42)
340
  args = parser.parse_args()
341
 
342
+ project_root = Path(__file__).resolve().parents[1]
343
+ adapter_dir = project_root / args.adapter
 
344
 
345
  db_root = project_root / "data" / "database"
346
+ table_json = project_root / "data" / "tables.json"
347
  dev_json = project_root / "data" / "dev.json"
348
+
349
+ pred_path = project_root / "temp_predictions.txt"
350
+ temp_gold_path = project_root / "temp_gold.sql"
351
 
352
+ if not adapter_dir.exists():
353
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
354
 
355
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
356
+ print(f"Using device: {device}")
357
 
358
+ BASE_MODEL = "Salesforce/codet5-base"
359
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
360
+ if tokenizer.pad_token is None:
361
+ tokenizer.pad_token = tokenizer.eos_token
362
 
363
+ print(f"Loading Model: {args.adapter}...")
364
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
365
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
366
  model = model.merge_and_unload()
367
  model.eval()
368
 
369
+ with dev_json.open() as f:
370
+ dev = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ if args.shuffle_dev:
373
+ rng = random.Random(args.shuffle_seed)
374
+ rng.shuffle(dev)
375
 
376
+ dev = dev[: args.num_samples]
377
+ total = len(dev)
 
 
378
 
379
+ gen_kwargs = dict(
380
+ max_new_tokens=160,
381
+ num_beams=4,
382
+ do_sample=False,
383
+ early_stopping=True,
384
+ pad_token_id=tokenizer.pad_token_id,
385
+ eos_token_id=tokenizer.eos_token_id,
386
+ )
387
 
388
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
 
389
 
390
+ em_correct = 0
391
+ ex_correct = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
394
+ for i, ex in enumerate(dev, start=1):
395
+ db_id = ex["db_id"]
396
+ question = ex["question"]
397
+ gold_query = ex["query"]
398
+ db_path = db_root / db_id / f"{db_id}.sqlite"
399
+
400
+ # Generate
401
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
402
+ input_ids = input_ids.unsqueeze(0).to(device)
403
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
404
+
405
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
406
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
407
+
408
+ # Write to files for official spider eval later
409
+ out_pred.write(f"{pred_sql}\n")
410
+ out_gold.write(f"{gold_query}\t{db_id}\n")
411
+
412
+ # --- LIVE TRACKING CHECKS ---
413
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
414
+ em_correct += 1
415
+ if check_execution(pred_sql, gold_query, db_path):
416
+ ex_correct += 1
417
+
418
+ # Print progress every 50 loops
419
+ if i % 10 == 0 or i == total:
420
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
421
+
422
+ print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
423
+
424
+ eval_script = project_root / "spider_eval" / "evaluation.py"
425
+
426
+ # 1. RUN EXACT MATCH EVAL
427
+ cmd_match = [
428
+ sys.executable, str(eval_script),
429
+ "--gold", str(temp_gold_path),
430
+ "--pred", str(pred_path),
431
+ "--etype", "match",
432
+ "--db", str(db_root),
433
+ "--table", str(table_json),
434
+ ]
435
+ proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
436
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
437
+
438
+ # 2. RUN EXECUTION EVAL
439
+ cmd_exec = [
440
+ sys.executable, str(eval_script),
441
+ "--gold", str(temp_gold_path),
442
+ "--pred", str(pred_path),
443
+ "--etype", "exec",
444
+ "--db", str(db_root),
445
+ "--table", str(table_json),
446
+ ]
447
+ proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
448
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
449
+
450
+ print("==========================================")
451
+ print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
452
+ print("==========================================")
453
+
454
+ if exact_acc is not None:
455
+ print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
456
+ else:
457
+ print("Exact Set Match Accuracy : Could not parse output")
458
+
459
+ if exec_acc is not None:
460
+ print(f"Execution Accuracy : {exec_acc*100:.2f}%")
461
+ else:
462
+ print("Execution Accuracy : Could not parse output")
463
+ print("==========================================\n")
464
 
465
  if __name__ == "__main__":
466
+ main()
src/evaluate_without_constraied.py DELETED
@@ -1,503 +0,0 @@
1
-
2
- # *********** code till task 3 ************
3
-
4
- # import json
5
- # import subprocess
6
- # import sys
7
- # import argparse
8
- # import random
9
- # import sqlite3
10
- # import time
11
- # import re
12
- # import os
13
- # from pathlib import Path
14
-
15
- # import torch
16
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
- # from peft import PeftModel
18
-
19
- # from prompting import encode_prompt
20
-
21
- # # -------------------------------
22
- # # NORMALIZATION
23
- # # -------------------------------
24
- # def normalize_sql(sql):
25
- # sql = sql.replace('"', "'")
26
- # sql = re.sub(r"\s+", " ", sql)
27
- # return sql.strip().lower().rstrip(";")
28
-
29
-
30
- # # -------------------------------
31
- # # 🔥 SAFE RESULT NORMALIZATION (FIX)
32
- # # -------------------------------
33
- # def normalize_result(res):
34
- # try:
35
- # return sorted([str(r) for r in res])
36
- # except:
37
- # return []
38
-
39
-
40
- # # -------------------------------
41
- # # EXECUTION CHECK (FIXED)
42
- # # -------------------------------
43
- # def check_execution(pred_sql, gold_sql, db_path):
44
- # try:
45
- # conn = sqlite3.connect(db_path)
46
- # conn.text_factory = lambda b: b.decode(errors='ignore')
47
-
48
- # start_time = time.monotonic()
49
-
50
- # def timeout_handler():
51
- # return 1 if (time.monotonic() - start_time) > 2.0 else 0
52
-
53
- # conn.set_progress_handler(timeout_handler, 10000)
54
-
55
- # cursor = conn.cursor()
56
-
57
- # cursor.execute(pred_sql)
58
- # pred_res = cursor.fetchall()
59
-
60
- # cursor.execute(gold_sql)
61
- # gold_res = cursor.fetchall()
62
-
63
- # conn.close()
64
-
65
- # # 🔥 FIXED COMPARISON
66
- # return normalize_result(pred_res) == normalize_result(gold_res)
67
-
68
- # except Exception:
69
- # return False
70
-
71
-
72
- # # -------------------------------
73
- # # SPIDER PARSER
74
- # # -------------------------------
75
- # def _parse_spider_accuracy(stdout: str, metric_type: str):
76
- # for line in stdout.splitlines():
77
- # if metric_type == "exec" and line.strip().startswith("execution"):
78
- # try:
79
- # return float(line.split()[-1])
80
- # except:
81
- # pass
82
- # elif metric_type == "match" and line.strip().startswith("exact"):
83
- # try:
84
- # return float(line.split()[-1])
85
- # except:
86
- # pass
87
- # return None
88
-
89
-
90
- # # -------------------------------
91
- # # MAIN
92
- # # -------------------------------
93
- # def main():
94
- # parser = argparse.ArgumentParser()
95
- # parser.add_argument("--adapter", type=str, required=True)
96
- # parser.add_argument("--num_samples", type=int, default= 500)
97
- # parser.add_argument("--shuffle_dev", action="store_true")
98
- # parser.add_argument("--shuffle_seed", type=int, default=42)
99
- # args = parser.parse_args()
100
-
101
- # project_root = Path(__file__).resolve().parents[1]
102
- # adapter_dir = project_root / args.adapter
103
-
104
- # db_root = project_root / "data" / "database"
105
- # table_json = project_root / "data" / "tables.json"
106
- # dev_json = project_root / "data" / "dev.json"
107
-
108
- # pred_path = project_root / "temp_predictions.txt"
109
- # temp_gold_path = project_root / "temp_gold.sql"
110
-
111
- # if not adapter_dir.exists():
112
- # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
113
-
114
- # device = "mps" if torch.backends.mps.is_available() else (
115
- # "cuda" if torch.cuda.is_available() else "cpu"
116
- # )
117
- # print(f"Using device: {device}")
118
-
119
- # BASE_MODEL = "Salesforce/codet5-base"
120
- # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
121
-
122
- # if tokenizer.pad_token is None:
123
- # tokenizer.pad_token = tokenizer.eos_token
124
-
125
- # print(f"\n📦 Loading Model: {args.adapter}")
126
-
127
- # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
128
-
129
- # adapter_for_peft = os.path.relpath(adapter_dir, project_root)
130
-
131
- # model = PeftModel.from_pretrained(
132
- # base,
133
- # adapter_for_peft,
134
- # local_files_only=True
135
- # ).to(device)
136
-
137
- # model = model.merge_and_unload()
138
- # model.eval()
139
-
140
- # # -------------------------------
141
- # # LOAD DATA
142
- # # -------------------------------
143
- # with dev_json.open() as f:
144
- # dev = json.load(f)
145
-
146
- # if args.shuffle_dev:
147
- # rng = random.Random(args.shuffle_seed)
148
- # rng.shuffle(dev)
149
-
150
- # dev = dev[: args.num_samples]
151
- # total = len(dev)
152
-
153
- # gen_kwargs = dict(
154
- # max_new_tokens=160,
155
- # num_beams=8,
156
- # length_penalty=0.8,
157
- # do_sample=False,
158
- # early_stopping=True,
159
- # pad_token_id=tokenizer.pad_token_id,
160
- # eos_token_id=tokenizer.eos_token_id,
161
- # )
162
-
163
- # print(f"\n🚀 Evaluating {total} samples...\n")
164
-
165
- # em_correct = 0
166
- # ex_correct = 0
167
-
168
- # with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
169
- # for i, ex in enumerate(dev, start=1):
170
-
171
- # db_id = ex["db_id"]
172
- # question = ex["question"]
173
- # gold_query = ex["query"]
174
- # db_path = db_root / db_id / f"{db_id}.sqlite"
175
-
176
- # # -------------------------------
177
- # # GENERATE SQL
178
- # # -------------------------------
179
- # input_ids = encode_prompt(
180
- # tokenizer,
181
- # question,
182
- # db_id,
183
- # device=device,
184
- # max_input_tokens=512
185
- # )
186
-
187
- # input_ids = input_ids.unsqueeze(0).to(device)
188
- # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
189
-
190
- # outputs = model.generate(
191
- # input_ids=input_ids,
192
- # attention_mask=attention_mask,
193
- # **gen_kwargs
194
- # )
195
-
196
- # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
197
-
198
- # # -------------------------------
199
- # # SAVE FOR SPIDER EVAL
200
- # # -------------------------------
201
- # out_pred.write(f"{pred_sql}\n")
202
- # out_gold.write(f"{gold_query}\t{db_id}\n")
203
-
204
- # # -------------------------------
205
- # # LIVE METRICS
206
- # # -------------------------------
207
- # if normalize_sql(pred_sql) == normalize_sql(gold_query):
208
- # em_correct += 1
209
-
210
- # if check_execution(pred_sql, gold_query, db_path):
211
- # ex_correct += 1
212
-
213
- # if i % 20 == 0 or i == total:
214
- # print(
215
- # f"Progress: {i}/{total} | "
216
- # f"EM: {(em_correct/i)*100:.2f}% | "
217
- # f"EX: {(ex_correct/i)*100:.2f}%"
218
- # )
219
-
220
- # print("\n🚀 Running Official Spider Evaluation...\n")
221
-
222
- # eval_script = project_root / "spider_eval" / "evaluation.py"
223
-
224
- # # EXACT MATCH
225
- # cmd_match = [
226
- # sys.executable, str(eval_script),
227
- # "--gold", str(temp_gold_path),
228
- # "--pred", str(pred_path),
229
- # "--etype", "match",
230
- # "--db", str(db_root),
231
- # "--table", str(table_json),
232
- # ]
233
-
234
- # proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
235
- # exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
236
-
237
- # # EXECUTION
238
- # cmd_exec = [
239
- # sys.executable, str(eval_script),
240
- # "--gold", str(temp_gold_path),
241
- # "--pred", str(pred_path),
242
- # "--etype", "exec",
243
- # "--db", str(db_root),
244
- # "--table", str(table_json),
245
- # ]
246
-
247
- # proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
248
- # exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
249
-
250
- # print("==========================================")
251
- # print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
252
- # print("==========================================")
253
-
254
- # print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
255
- # print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
256
-
257
- # print("==========================================\n")
258
-
259
-
260
- # if __name__ == "__main__":
261
- # main()
262
-
263
-
264
-
265
-
266
- # *********** for task 2 ****************************************
267
- import json
268
- import argparse
269
- import random
270
- import sqlite3
271
- import re
272
- import os
273
- from pathlib import Path
274
- from collections import defaultdict
275
-
276
- import torch
277
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
278
- from peft import PeftModel
279
-
280
- from prompting import encode_prompt
281
-
282
- # -------------------------------
283
- # NORMALIZATION
284
- # -------------------------------
285
- def normalize_sql(sql):
286
- sql = sql.replace('"', "'")
287
- sql = re.sub(r"\s+", " ", sql)
288
- return sql.strip().lower().rstrip(";")
289
-
290
- def normalize_result(res):
291
- try:
292
- return sorted([str(r) for r in res])
293
- except:
294
- return []
295
-
296
- # -------------------------------
297
- # STEP 1: EXECUTION
298
- # -------------------------------
299
- def execute_with_error(sql, db_path):
300
- try:
301
- conn = sqlite3.connect(db_path)
302
- cur = conn.cursor()
303
- cur.execute(sql)
304
- res = cur.fetchall()
305
- conn.close()
306
- return res, None
307
- except Exception as e:
308
- return None, str(e)
309
-
310
- # -------------------------------
311
- # STEP 2: ERROR CLASSIFICATION
312
- # -------------------------------
313
- def classify_error(sql, error_msg):
314
- if error_msg is None:
315
- return "correct"
316
-
317
- err = error_msg.lower()
318
- sql_l = sql.lower()
319
-
320
- if "syntax" in err:
321
- return "syntax_error"
322
- if "no such table" in err:
323
- return "wrong_table"
324
- if "no such column" in err:
325
- return "wrong_column"
326
- if "ambiguous" in err:
327
- return "missing_join"
328
- if "datatype mismatch" in err:
329
- return "type_error"
330
- if "where" not in sql_l and any(x in sql_l for x in ["=", ">", "<"]):
331
- return "missing_where"
332
-
333
- return "other"
334
-
335
- # -------------------------------
336
- # STEP 4: HINTS
337
- # -------------------------------
338
- def generate_hint(error_type):
339
- hints = {
340
- "missing_join": "Try using JOIN between related tables.",
341
- "wrong_column": "Check column names in schema.",
342
- "missing_where": "Add WHERE condition.",
343
- "syntax_error": "Fix SQL syntax.",
344
- "wrong_table": "Verify table names.",
345
- "type_error": "Check data types.",
346
- "other": "Review SQL logic."
347
- }
348
- return hints.get(error_type, "")
349
-
350
- # -------------------------------
351
- # STEP 2 EXTRA: LIGHT ATTRIBUTION
352
- # -------------------------------
353
- def extract_keywords(question):
354
- return [w for w in re.findall(r"\w+", question.lower()) if len(w) > 3]
355
-
356
- # -------------------------------
357
- # MAIN
358
- # -------------------------------
359
- def main():
360
- parser = argparse.ArgumentParser()
361
- parser.add_argument("--adapter", type=str, required=True)
362
- parser.add_argument("--num_samples", type=int, default=200)
363
- args = parser.parse_args()
364
-
365
- project_root = Path(__file__).resolve().parents[1]
366
- db_root = project_root / "data" / "database"
367
- dev_json = project_root / "data" / "dev.json"
368
-
369
- device = "mps" if torch.backends.mps.is_available() else "cpu"
370
-
371
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
372
- base = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
373
-
374
- model = PeftModel.from_pretrained(
375
- base,
376
- os.path.relpath(project_root / args.adapter, project_root),
377
- local_files_only=True
378
- ).to(device)
379
-
380
- model = model.merge_and_unload()
381
- model.eval()
382
-
383
- with open(dev_json) as f:
384
- dev = json.load(f)
385
-
386
- dev = dev[:args.num_samples]
387
-
388
- # STORAGE
389
- error_counter = defaultdict(int)
390
- error_examples = defaultdict(list)
391
- success_examples = []
392
- hint_examples = defaultdict(list)
393
- operation_counter = defaultdict(int)
394
- attribution_map = defaultdict(list)
395
-
396
- em, ex = 0, 0
397
-
398
- print(f"\n🚀 Evaluating {len(dev)} samples...\n")
399
-
400
- for i, sample in enumerate(dev, 1):
401
-
402
- db_id = sample["db_id"]
403
- q = sample["question"]
404
- gold = sample["query"]
405
- db_path = db_root / db_id / f"{db_id}.sqlite"
406
-
407
- input_ids = encode_prompt(tokenizer, q, db_id, device=device).unsqueeze(0)
408
-
409
- out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8)
410
- pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
411
-
412
- # operation analysis
413
- s = pred.lower()
414
- if "select" in s: operation_counter["SELECT"] += 1
415
- if "where" in s: operation_counter["WHERE"] += 1
416
- if "join" in s: operation_counter["JOIN"] += 1
417
- if "group by" in s: operation_counter["GROUP_BY"] += 1
418
- if "order by" in s: operation_counter["ORDER_BY"] += 1
419
-
420
- pred_res, err = execute_with_error(pred, db_path)
421
- gold_res, _ = execute_with_error(gold, db_path)
422
-
423
- error_type = classify_error(pred, err)
424
- error_counter[error_type] += 1
425
-
426
- # attribution
427
- if err:
428
- attribution_map[error_type].append(extract_keywords(q))
429
-
430
- # examples
431
- if len(error_examples[error_type]) < 3:
432
- error_examples[error_type].append(pred)
433
-
434
- # hints
435
- if error_type != "correct":
436
- hint = generate_hint(error_type)
437
- if len(hint_examples[error_type]) < 3:
438
- hint_examples[error_type].append((pred, hint))
439
-
440
- # metrics
441
- if normalize_sql(pred) == normalize_sql(gold):
442
- em += 1
443
-
444
- if pred_res and gold_res and normalize_result(pred_res) == normalize_result(gold_res):
445
- ex += 1
446
- if len(success_examples) < 5:
447
- success_examples.append(pred)
448
-
449
- if i % 20 == 0:
450
- print(f"[{i}] EM: {em/i:.2f} | EX: {ex/i:.2f}")
451
-
452
- # -------------------------------
453
- # OUTPUT
454
- # -------------------------------
455
- print("\n🎯 FINAL RESULTS")
456
- print(f"EM: {em/len(dev)*100:.2f}%")
457
- print(f"EX: {ex/len(dev)*100:.2f}%")
458
-
459
- print("\n🔥 ERROR SUMMARY")
460
- for k, v in error_counter.items():
461
- print(k, ":", v)
462
-
463
- print("\n🔥 ERROR EXAMPLES")
464
- for k in error_examples:
465
- print("\n", k)
466
- for e in error_examples[k]:
467
- print(" ", e)
468
-
469
- print("\n🔥 HINTS")
470
- for k in hint_examples:
471
- print("\n", k)
472
- for sql, h in hint_examples[k]:
473
- print(" ", sql)
474
- print(" →", h)
475
-
476
- print("\n🔥 ATTRIBUTION (KEYWORDS)")
477
- for k in attribution_map:
478
- print(k, ":", attribution_map[k][:3])
479
-
480
- print("\n🔥 SQL OPERATIONS")
481
- for k, v in operation_counter.items():
482
- print(k, ":", v)
483
-
484
- # -------------------------------
485
- # ADVERSARIAL
486
- # -------------------------------
487
- print("\n🔥 ADVERSARIAL TESTS")
488
-
489
- adv = [
490
- "Find most expensive product",
491
- "Top 3 students by marks",
492
- "Average salary per department"
493
- ]
494
-
495
- for q in adv:
496
- inp = encode_prompt(tokenizer, q, dev[0]["db_id"], device=device).unsqueeze(0)
497
- out = model.generate(input_ids=inp, max_new_tokens=120)
498
- print("\nQ:", q)
499
- print("SQL:", tokenizer.decode(out[0], skip_special_tokens=True))
500
-
501
-
502
- if __name__ == "__main__":
503
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/execution_reward copy.py DELETED
@@ -1,831 +0,0 @@
1
-
2
-
3
- # from __future__ import annotations
4
-
5
- # import hashlib
6
- # import os
7
- # import queue
8
- # import re
9
- # import sqlite3
10
- # import threading
11
- # import time
12
- # from concurrent.futures import ThreadPoolExecutor, as_completed
13
- # from dataclasses import dataclass
14
- # from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
15
-
16
- # # --- CACHE CONTROL ---
17
- # USE_CACHE = True
18
- # _REWARD_CACHE: Dict[str, float] = {}
19
-
20
- # def set_use_cache(enabled: bool):
21
- # """Dynamically toggle the reward cache for benchmarks."""
22
- # global USE_CACHE
23
- # USE_CACHE = enabled
24
-
25
- # def _normalize_sql(sql: str) -> str:
26
- # if not isinstance(sql, str):
27
- # return ""
28
- # s = sql.strip()
29
- # if s.startswith("```"):
30
- # s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
31
- # s = re.sub(r"\n?```$", "", s).strip()
32
- # if s.lower().startswith("sql:"):
33
- # s = s[4:].strip()
34
- # if ";" in s:
35
- # s = s.split(";", 1)[0].strip()
36
- # return s
37
-
38
- # def _connect_readonly(db_path: str) -> sqlite3.Connection:
39
- # uri = f"file:{os.path.abspath(db_path)}?mode=ro"
40
- # conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
41
- # conn.execute("PRAGMA query_only = ON;")
42
- # conn.execute("PRAGMA foreign_keys = ON;")
43
- # return conn
44
-
45
- # DEFAULT_QUERY_TIMEOUT_S = 2.0
46
-
47
- # def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
48
- # start = time.monotonic()
49
- # def _handler() -> int:
50
- # return 1 if (time.monotonic() - start) > timeout_s else 0
51
- # conn.set_progress_handler(_handler, 10_000)
52
-
53
- # def _list_tables(conn: sqlite3.Connection) -> List[str]:
54
- # try:
55
- # cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
56
- # return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
57
- # except sqlite3.Error:
58
- # return []
59
-
60
- # def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
61
- # s = sql.lower()
62
- # for t in table_names:
63
- # tl = t.lower()
64
- # if not tl:
65
- # continue
66
- # if re.search(rf"\b{re.escape(tl)}\b", s):
67
- # return True
68
- # return False
69
-
70
- # def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
71
- # try:
72
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
73
- # conn.execute(f"EXPLAIN QUERY PLAN {sql}")
74
- # return True
75
- # except sqlite3.Error:
76
- # return False
77
-
78
- # def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
79
- # try:
80
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
81
- # cur = conn.execute(sql)
82
- # rows = cur.fetchmany(max_rows)
83
- # norm_rows = [tuple(r) for r in rows]
84
- # return True, norm_rows, None
85
- # except sqlite3.Error as e:
86
- # return False, [], str(e)
87
-
88
- # _SQL_KEYWORDS_TO_IGNORE = {
89
- # "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
90
- # "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
91
- # "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
92
- # "when", "then", "else", "end", "asc", "desc"
93
- # }
94
-
95
- # _SQL_FUNCTIONS_TO_IGNORE = {
96
- # "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
97
- # "round", "date", "datetime", "strftime"
98
- # }
99
-
100
- # # --- LIGHTWEIGHT PARSING ---
101
- # def is_valid_select(sql: str):
102
- # sql = sql.strip().lower()
103
- # return sql.startswith("select") or sql.startswith("with")
104
-
105
- # def extract_tables(sql: str) -> List[str]:
106
- # sql = sql.lower()
107
- # if "join" not in sql:
108
- # tables = re.findall(r'from\s+(\w+)', sql)
109
- # return list(set(tables))
110
-
111
- # tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
112
- # joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
113
- # return list(set(tables + joins))
114
-
115
- # def extract_columns(sql: str) -> List[str]:
116
- # sql = sql.lower()
117
- # match = re.search(r'select\s+(.*?)\s+from', sql)
118
- # if not match:
119
- # return []
120
- # cols = match.group(1)
121
- # if cols.strip() == "*":
122
- # return ["*"]
123
- # return [c.strip() for c in cols.split(",")]
124
-
125
- # def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
126
- # tables = set()
127
- # columns = set()
128
- # for t in _list_tables(conn):
129
- # tl = t.lower()
130
- # if not tl:
131
- # continue
132
- # tables.add(tl)
133
- # try:
134
- # cur = conn.execute(f'PRAGMA table_info("{t}")')
135
- # for row in cur.fetchall():
136
- # if row and isinstance(row[1], str):
137
- # columns.add(row[1].lower())
138
- # except sqlite3.Error:
139
- # continue
140
- # return tables, columns
141
-
142
- # def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
143
- # return a == b
144
-
145
- # @dataclass
146
- # class RewardDebugStats:
147
- # total: int = 0
148
- # parsed_ok: int = 0
149
- # table_match: int = 0
150
- # column_match: int = 0
151
- # executed_ok: int = 0
152
- # exact_match: int = 0
153
-
154
- # _DEBUG = RewardDebugStats()
155
-
156
- # def reset_debug_metrics() -> None:
157
- # global _DEBUG
158
- # _DEBUG = RewardDebugStats()
159
-
160
- # def get_debug_metrics() -> dict:
161
- # denom = max(_DEBUG.total, 1)
162
- # return {
163
- # "valid_sql_rate": _DEBUG.parsed_ok / denom,
164
- # "table_match_rate": _DEBUG.table_match / denom,
165
- # "column_match_rate": _DEBUG.column_match / denom,
166
- # "execution_accuracy": _DEBUG.exact_match / denom,
167
- # }
168
-
169
- # EXECUTION_ERROR = "EXECUTION_ERROR"
170
-
171
- # _RESULT_CACHE_LOCK = threading.Lock()
172
- # _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
173
- # _RESULT_CACHE_MAX = 100_000
174
-
175
- # def clear_result_cache() -> None:
176
- # """Clear both DB query cache and reward cache."""
177
- # with _RESULT_CACHE_LOCK:
178
- # _RESULT_CACHE.clear()
179
- # _REWARD_CACHE.clear()
180
-
181
- # def _db_state_fingerprint(db_path: str) -> str:
182
- # try:
183
- # st = os.stat(db_path)
184
- # return f"{st.st_mtime_ns}:{st.st_size}"
185
- # except OSError:
186
- # return "missing"
187
-
188
- # def _result_cache_key(db_path: str, sql: str) -> str:
189
- # fp = _db_state_fingerprint(db_path)
190
- # payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
191
- # return hashlib.sha256(payload).hexdigest()
192
-
193
- # class _ConnectionPool:
194
- # def __init__(self, db_path: str, maxsize: int = 1) -> None:
195
- # self.db_path = db_path
196
- # self.pool = queue.LifoQueue(maxsize=maxsize)
197
- # self.lock = threading.Lock()
198
-
199
- # def acquire(self) -> sqlite3.Connection:
200
- # try:
201
- # return self.pool.get_nowait()
202
- # except queue.Empty:
203
- # with self.lock:
204
- # try:
205
- # return self.pool.get_nowait()
206
- # except queue.Empty:
207
- # return _connect_readonly(self.db_path)
208
-
209
- # def release(self, conn: sqlite3.Connection) -> None:
210
- # try:
211
- # self.pool.put_nowait(conn)
212
- # except queue.Full:
213
- # try:
214
- # conn.close()
215
- # except Exception:
216
- # pass
217
-
218
- # _POOL_LOCK = threading.Lock()
219
- # _POOLS: Dict[str, _ConnectionPool] = {}
220
-
221
- # def _get_pool(db_path: str) -> _ConnectionPool:
222
- # with _POOL_LOCK:
223
- # pool = _POOLS.get(db_path)
224
- # if pool is None:
225
- # pool = _ConnectionPool(db_path=db_path, maxsize=1)
226
- # _POOLS[db_path] = pool
227
- # return pool
228
-
229
- # class _PooledConnection:
230
- # def __init__(self, db_path: str) -> None:
231
- # self.db_path = db_path
232
- # self.pool = _get_pool(db_path)
233
- # self.conn: Optional[sqlite3.Connection] = None
234
-
235
- # def __enter__(self) -> sqlite3.Connection:
236
- # self.conn = self.pool.acquire()
237
- # return self.conn
238
-
239
- # def __exit__(self, exc_type, exc, tb) -> None:
240
- # if self.conn is not None:
241
- # self.pool.release(self.conn)
242
- # self.conn = None
243
-
244
- # def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
245
- # with _RESULT_CACHE_LOCK:
246
- # return _RESULT_CACHE.get(key)
247
-
248
- # def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
249
- # with _RESULT_CACHE_LOCK:
250
- # if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
251
- # _RESULT_CACHE.clear()
252
- # _RESULT_CACHE[key] = value
253
-
254
- # def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
255
- # try:
256
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
257
- # cur = conn.execute(sql)
258
- # rows = cur.fetchmany(max_rows)
259
- # return [tuple(r) for r in rows]
260
- # except Exception:
261
- # return EXECUTION_ERROR
262
-
263
- # def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
264
- # if not USE_CACHE:
265
- # with _PooledConnection(db_path) as conn:
266
- # return execute_sql(conn, sql, max_rows=max_rows)
267
-
268
- # key = _result_cache_key(db_path, sql)
269
- # cached = _cache_get(key)
270
- # if cached is not None:
271
- # return cached
272
- # with _PooledConnection(db_path) as conn:
273
- # res = execute_sql(conn, sql, max_rows=max_rows)
274
- # _cache_put(key, res)
275
- # return res
276
-
277
- # def execution_reward_timed(
278
- # pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
279
- # ) -> Tuple[float, Dict[str, float]]:
280
- # timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
281
- # t0 = time.perf_counter()
282
- # sql = _normalize_sql(pred_sql)
283
- # gold = _normalize_sql(gold_sql)
284
-
285
- # if not is_valid_select(sql):
286
- # timings["parse_s"] = time.perf_counter() - t0
287
- # return 0.0, timings
288
-
289
- # t1 = time.perf_counter()
290
- # timings["parse_s"] = t1 - t0
291
-
292
- # if measure_plan:
293
- # with _PooledConnection(db_path) as conn:
294
- # p0 = time.perf_counter()
295
- # _explain_query_plan(conn, sql)
296
- # _explain_query_plan(conn, gold)
297
- # timings["plan_s"] = time.perf_counter() - p0
298
-
299
- # e0 = time.perf_counter()
300
- # pred_res = execute_sql_cached(db_path, sql)
301
- # if pred_res == EXECUTION_ERROR:
302
- # timings["exec_s"] = time.perf_counter() - e0
303
- # return 0.0, timings
304
- # gold_res = execute_sql_cached(db_path, gold)
305
- # timings["exec_s"] = time.perf_counter() - e0
306
- # if gold_res == EXECUTION_ERROR:
307
- # return 0.0, timings
308
-
309
- # reward = -0.2
310
- # reward += 0.2
311
- # if _safe_results_equal(pred_res, gold_res):
312
- # return 1.0, timings
313
- # return max(-1.0, min(1.0, reward)), timings
314
-
315
- # def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
316
- # try:
317
- # sql = _normalize_sql(pred_sql)
318
- # gold = _normalize_sql(gold_sql)
319
-
320
- # if not is_valid_select(sql):
321
- # return -1.0
322
-
323
- # reward = -0.2
324
-
325
- # pred_tables = set(extract_tables(sql))
326
- # gold_tables = set(extract_tables(gold))
327
-
328
- # if pred_tables == gold_tables and len(gold_tables) > 0:
329
- # reward += 0.3
330
-
331
- # pred_cols = set(extract_columns(sql))
332
- # gold_cols = set(extract_columns(gold))
333
-
334
- # if gold_cols:
335
- # overlap = len(pred_cols & gold_cols) / len(gold_cols)
336
- # reward += 0.3 * overlap
337
-
338
- # pred_res = execute_sql_cached(db_path, sql)
339
- # if pred_res == EXECUTION_ERROR:
340
- # return 0.0
341
- # reward += 0.2
342
-
343
- # gold_res = execute_sql_cached(db_path, gold)
344
- # if gold_res == EXECUTION_ERROR:
345
- # return 0.0
346
- # if _safe_results_equal(pred_res, gold_res):
347
- # return 1.0
348
-
349
- # return max(-1.0, min(1.0, reward))
350
-
351
- # except Exception:
352
- # return 0.0
353
-
354
- # def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
355
- # if not USE_CACHE:
356
- # return execution_reward(pred_sql, db_path, gold_sql)
357
-
358
- # key = f"{db_path}|{pred_sql}|{gold_sql}"
359
- # if key not in _REWARD_CACHE:
360
- # _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
361
- # return _REWARD_CACHE[key]
362
-
363
- # def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
364
- # return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
365
-
366
- # def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
367
- # if not rollouts:
368
- # return []
369
-
370
- # unique_dbs = {db_path for _, db_path, _ in rollouts}
371
- # worker_count = max(1, min(max_workers, len(unique_dbs)))
372
- # results: List[Optional[float]] = [None] * len(rollouts)
373
-
374
- # with ThreadPoolExecutor(max_workers=worker_count) as executor:
375
- # futures = {
376
- # executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
377
- # for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
378
- # }
379
- # for fut in as_completed(futures):
380
- # idx = futures[fut]
381
- # try:
382
- # results[idx] = float(fut.result())
383
- # except Exception:
384
- # results[idx] = 0.0
385
-
386
- # return [r if r is not None else 0.0 for r in results]
387
-
388
- from __future__ import annotations
389
-
390
- import os
391
- import re
392
- import sqlite3
393
- import threading
394
- import time
395
- import json
396
- from concurrent.futures import ThreadPoolExecutor, as_completed
397
- from dataclasses import dataclass
398
- from typing import Dict, List
399
-
400
- from src.sql_validator import validate_sql_schema
401
-
402
- # =========================================================
403
- # 🔥 CONFIG FLAGS
404
- # =========================================================
405
- USE_SCHEMA_VALIDATION = True
406
- USE_CACHE = True
407
- DEFAULT_QUERY_TIMEOUT_S = 2.0
408
-
409
- EXECUTION_ERROR = "EXECUTION_ERROR"
410
-
411
- _REWARD_CACHE: Dict[str, float] = {}
412
-
413
- # =========================================================
414
- # 🔥 TASK 2: ERROR ANALYSIS + LOGGING
415
- # =========================================================
416
- ERROR_LOG_FILE = "results/error_logs.json"
417
-
418
-
419
- def classify_error(sql: str) -> str:
420
- sql = sql.lower()
421
-
422
- if "join" in sql and " on " not in sql:
423
- return "missing_join"
424
-
425
- if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
426
- return "wrong_where"
427
-
428
- if "null" in sql:
429
- return "null_handling"
430
-
431
- if "group by" in sql and "count" not in sql:
432
- return "wrong_groupby"
433
-
434
- return "other"
435
-
436
-
437
- def get_hint(error_type: str) -> str:
438
- hints = {
439
- "missing_join": "Add proper JOIN condition using ON.",
440
- "wrong_where": "Check WHERE clause conditions.",
441
- "null_handling": "Handle NULL values using IS NULL.",
442
- "wrong_groupby": "Use aggregation functions with GROUP BY.",
443
- "other": "Check SQL syntax and logic."
444
- }
445
- return hints.get(error_type, "Check query.")
446
-
447
-
448
- def log_error(question: str, sql: str, error: str, error_type: str):
449
- os.makedirs("results", exist_ok=True)
450
-
451
- entry = {
452
- "question": question,
453
- "sql": sql,
454
- "error": error,
455
- "error_type": error_type,
456
- "timestamp": time.time()
457
- }
458
-
459
- if os.path.exists(ERROR_LOG_FILE):
460
- with open(ERROR_LOG_FILE, "r") as f:
461
- logs = json.load(f)
462
- else:
463
- logs = []
464
-
465
- logs.append(entry)
466
-
467
- with open(ERROR_LOG_FILE, "w") as f:
468
- json.dump(logs, f, indent=2)
469
-
470
- # =========================================================
471
- # CACHE/VALIDATION TOGGLES (Task 1)
472
- # =========================================================
473
- def set_use_cache(enabled: bool) -> None:
474
- global USE_CACHE
475
- USE_CACHE = bool(enabled)
476
-
477
-
478
- def set_use_schema_validation(enabled: bool) -> None:
479
- global USE_SCHEMA_VALIDATION
480
- USE_SCHEMA_VALIDATION = bool(enabled)
481
-
482
-
483
- # =========================================================
484
- # SQL CLEANING
485
- # =========================================================
486
- def _normalize_sql(sql: str) -> str:
487
- if not isinstance(sql, str):
488
- return ""
489
- s = sql.strip()
490
-
491
- if s.startswith("```"):
492
- s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
493
- s = re.sub(r"\n?```$", "", s).strip()
494
-
495
- if s.lower().startswith("sql:"):
496
- s = s[4:].strip()
497
-
498
- if ";" in s:
499
- s = s.split(";", 1)[0].strip()
500
-
501
- return s
502
-
503
-
504
- # =========================================================
505
- # DB EXECUTION
506
- # =========================================================
507
- def _connect_readonly(db_path: str):
508
- uri = f"file:{os.path.abspath(db_path)}?mode=ro"
509
- conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
510
- conn.execute("PRAGMA query_only = ON;")
511
- conn.execute("PRAGMA foreign_keys = ON;")
512
- return conn
513
-
514
-
515
- def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
516
- start = time.monotonic()
517
-
518
- def handler():
519
- return 1 if (time.monotonic() - start) > timeout_s else 0
520
-
521
- conn.set_progress_handler(handler, 10_000)
522
-
523
-
524
- def execute_sql(conn, sql):
525
- try:
526
- _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
527
- cur = conn.execute(sql)
528
- return cur.fetchall()
529
- except Exception:
530
- return EXECUTION_ERROR
531
-
532
-
533
- _RESULT_CACHE = {}
534
- _RESULT_LOCK = threading.Lock()
535
-
536
-
537
- def execute_sql_cached(db_path, sql):
538
- key = f"{db_path}|{sql}"
539
-
540
- if USE_CACHE:
541
- with _RESULT_LOCK:
542
- if key in _RESULT_CACHE:
543
- return _RESULT_CACHE[key]
544
-
545
- conn = _connect_readonly(db_path)
546
- result = execute_sql(conn, sql)
547
- conn.close()
548
-
549
- if USE_CACHE:
550
- with _RESULT_LOCK:
551
- _RESULT_CACHE[key] = result
552
-
553
- return result
554
-
555
-
556
- def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
557
- """
558
- Like execute_sql_cached(), but reuses an existing connection.
559
- Intended for 1-thread-per-DB workloads (Task 1).
560
- """
561
- key = f"{db_path}|{sql}"
562
- if USE_CACHE:
563
- with _RESULT_LOCK:
564
- if key in _RESULT_CACHE:
565
- return _RESULT_CACHE[key]
566
-
567
- result = execute_sql(conn, sql)
568
-
569
- if USE_CACHE:
570
- with _RESULT_LOCK:
571
- _RESULT_CACHE[key] = result
572
-
573
- return result
574
-
575
-
576
- def clear_result_cache() -> None:
577
- global _RESULT_CACHE, _REWARD_CACHE
578
- with _RESULT_LOCK:
579
- _RESULT_CACHE.clear()
580
- _REWARD_CACHE.clear()
581
-
582
-
583
- # =========================================================
584
- # SQL PARSING
585
- # =========================================================
586
- def is_valid_select(sql):
587
- return sql.lower().startswith("select") or sql.lower().startswith("with")
588
-
589
-
590
- def extract_tables(sql):
591
- return re.findall(r'from\s+(\w+)', sql.lower())
592
-
593
-
594
- def extract_columns(sql):
595
- match = re.search(r'select\s+(.*?)\s+from', sql.lower())
596
- if not match:
597
- return []
598
- cols = match.group(1)
599
- return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
600
-
601
-
602
- def get_sql_operations(sql: str):
603
- sql = sql.lower()
604
- ops = []
605
-
606
- if "select" in sql: ops.append("SELECT")
607
- if "where" in sql: ops.append("WHERE")
608
- if "join" in sql: ops.append("JOIN")
609
- if "group by" in sql: ops.append("GROUP_BY")
610
- if "order by" in sql: ops.append("ORDER_BY")
611
-
612
- return ops
613
-
614
-
615
- def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
616
- try:
617
- _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
618
- conn.execute(f"EXPLAIN QUERY PLAN {sql}")
619
- return True
620
- except Exception:
621
- return False
622
-
623
-
624
- def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
625
- """
626
- Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
627
- Used by Task-1 benchmark to profile bottlenecks.
628
- """
629
- timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
630
- t0 = time.perf_counter()
631
-
632
- sql = _normalize_sql(pred_sql)
633
- gold = _normalize_sql(gold_sql)
634
-
635
- if not is_valid_select(sql):
636
- timings["parse_s"] = time.perf_counter() - t0
637
- return 0.0, timings
638
-
639
- t1 = time.perf_counter()
640
- timings["parse_s"] = t1 - t0
641
-
642
- conn = _connect_readonly(db_path)
643
- try:
644
- if measure_plan:
645
- p0 = time.perf_counter()
646
- _explain_query_plan(conn, sql)
647
- _explain_query_plan(conn, gold)
648
- timings["plan_s"] = time.perf_counter() - p0
649
-
650
- e0 = time.perf_counter()
651
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
652
- if pred_res == EXECUTION_ERROR:
653
- timings["exec_s"] = time.perf_counter() - e0
654
- return 0.0, timings
655
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
656
- timings["exec_s"] = time.perf_counter() - e0
657
- if gold_res == EXECUTION_ERROR:
658
- return 0.0, timings
659
-
660
- reward = -0.2 + 0.2
661
- if pred_res == gold_res:
662
- return 1.0, timings
663
- return max(-1.0, min(1.0, reward)), timings
664
- finally:
665
- try:
666
- conn.close()
667
- except Exception:
668
- pass
669
-
670
-
671
- # =========================================================
672
- # 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
673
- # =========================================================
674
- def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
675
- try:
676
- sql = _normalize_sql(pred_sql)
677
- gold = _normalize_sql(gold_sql)
678
-
679
- if not is_valid_select(sql):
680
- return -1.0
681
-
682
- reward = -0.2
683
-
684
- # =========================
685
- # SCHEMA VALIDATION (Task 3)
686
- # =========================
687
- if USE_SCHEMA_VALIDATION:
688
- valid, _ = validate_sql_schema(sql, db_path)
689
- if not valid:
690
- error_type = classify_error(sql)
691
- log_error("UNKNOWN", sql, "schema_invalid", error_type)
692
- return 0.1
693
-
694
- # =========================
695
- # EXECUTION
696
- # =========================
697
- pred_res = execute_sql_cached(db_path, sql)
698
-
699
- if pred_res == "EXECUTION_ERROR":
700
- error_type = classify_error(sql)
701
-
702
- log_error(
703
- question="UNKNOWN",
704
- sql=sql,
705
- error="execution_error",
706
- error_type=error_type
707
- )
708
-
709
- print(f"[ERROR] {error_type}")
710
- print(f"[HINT] {get_hint(error_type)}")
711
-
712
- return 0.1
713
-
714
- reward += 0.2
715
-
716
- gold_res = execute_sql_cached(db_path, gold)
717
-
718
- if gold_res == "EXECUTION_ERROR":
719
- return 0.1
720
-
721
- if pred_res == gold_res:
722
- return 1.0
723
-
724
- return max(-1.0, min(1.0, reward))
725
-
726
- except Exception as e:
727
- log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
728
- return 0.0
729
-
730
-
731
- # =========================================================
732
- # BATCH EXECUTION (Task 1)
733
- # =========================================================
734
- def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
735
- if not USE_CACHE:
736
- return float(execution_reward(pred_sql, db_path, gold_sql))
737
- key = f"{db_path}|{pred_sql}|{gold_sql}"
738
- if key in _REWARD_CACHE:
739
- return float(_REWARD_CACHE[key])
740
- r = float(execution_reward(pred_sql, db_path, gold_sql))
741
- _REWARD_CACHE[key] = r
742
- return r
743
-
744
-
745
- def execution_reward_batch_sequential(rollouts):
746
- return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
747
-
748
-
749
- def execution_reward_batch_parallel(rollouts, max_workers=10):
750
- results = [0.0] * len(rollouts)
751
-
752
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
753
- futures = {
754
- executor.submit(cached_execution_reward, p, d, g): i
755
- for i, (p, d, g) in enumerate(rollouts)
756
- }
757
-
758
- for fut in as_completed(futures):
759
- idx = futures[fut]
760
- try:
761
- results[idx] = fut.result()
762
- except Exception:
763
- results[idx] = 0.0
764
-
765
- return results
766
-
767
-
768
- def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
769
- """
770
- 1 thread per DB path. Reuses a single readonly connection per DB worker.
771
- Preserves input order.
772
- """
773
- if not rollouts:
774
- return []
775
-
776
- by_db = {}
777
- for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
778
- by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
779
-
780
- results = [0.0 for _ in range(len(rollouts))]
781
-
782
- def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
783
- try:
784
- sql = _normalize_sql(pred_sql)
785
- gold = _normalize_sql(gold_sql)
786
-
787
- if not is_valid_select(sql):
788
- return -1.0
789
-
790
- reward = -0.2
791
-
792
- if USE_SCHEMA_VALIDATION:
793
- valid, _ = validate_sql_schema(sql, db_path)
794
- if not valid:
795
- error_type = classify_error(sql)
796
- log_error("UNKNOWN", sql, "schema_invalid", error_type)
797
- return 0.1
798
-
799
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
800
- if pred_res == EXECUTION_ERROR:
801
- error_type = classify_error(sql)
802
- log_error("UNKNOWN", sql, "execution_error", error_type)
803
- return 0.1
804
-
805
- reward += 0.2
806
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
807
- if gold_res == EXECUTION_ERROR:
808
- return 0.1
809
- if pred_res == gold_res:
810
- return 1.0
811
- return max(-1.0, min(1.0, reward))
812
- except Exception:
813
- return 0.0
814
-
815
- def _worker(db_path: str, items):
816
- conn = _connect_readonly(db_path)
817
- try:
818
- for idx, pred, gold in items:
819
- results[idx] = _reward_with_conn(conn, pred, db_path, gold)
820
- finally:
821
- try:
822
- conn.close()
823
- except Exception:
824
- pass
825
-
826
- with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
827
- futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
828
- for fut in as_completed(futures):
829
- fut.result()
830
-
831
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/execution_reward.py CHANGED
@@ -1,510 +1,41 @@
1
-
2
-
3
- # from __future__ import annotations
4
-
5
- # import hashlib
6
- # import os
7
- # import queue
8
- # import re
9
- # import sqlite3
10
- # import threading
11
- # import time
12
- # from concurrent.futures import ThreadPoolExecutor, as_completed
13
- # from dataclasses import dataclass
14
- # from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
15
-
16
- # # --- CACHE CONTROL ---
17
- # USE_CACHE = True
18
- # _REWARD_CACHE: Dict[str, float] = {}
19
-
20
- # def set_use_cache(enabled: bool):
21
- # """Dynamically toggle the reward cache for benchmarks."""
22
- # global USE_CACHE
23
- # USE_CACHE = enabled
24
-
25
- # def _normalize_sql(sql: str) -> str:
26
- # if not isinstance(sql, str):
27
- # return ""
28
- # s = sql.strip()
29
- # if s.startswith("```"):
30
- # s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
31
- # s = re.sub(r"\n?```$", "", s).strip()
32
- # if s.lower().startswith("sql:"):
33
- # s = s[4:].strip()
34
- # if ";" in s:
35
- # s = s.split(";", 1)[0].strip()
36
- # return s
37
-
38
- # def _connect_readonly(db_path: str) -> sqlite3.Connection:
39
- # uri = f"file:{os.path.abspath(db_path)}?mode=ro"
40
- # conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
41
- # conn.execute("PRAGMA query_only = ON;")
42
- # conn.execute("PRAGMA foreign_keys = ON;")
43
- # return conn
44
-
45
- # DEFAULT_QUERY_TIMEOUT_S = 2.0
46
-
47
- # def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
48
- # start = time.monotonic()
49
- # def _handler() -> int:
50
- # return 1 if (time.monotonic() - start) > timeout_s else 0
51
- # conn.set_progress_handler(_handler, 10_000)
52
-
53
- # def _list_tables(conn: sqlite3.Connection) -> List[str]:
54
- # try:
55
- # cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
56
- # return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
57
- # except sqlite3.Error:
58
- # return []
59
-
60
- # def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
61
- # s = sql.lower()
62
- # for t in table_names:
63
- # tl = t.lower()
64
- # if not tl:
65
- # continue
66
- # if re.search(rf"\b{re.escape(tl)}\b", s):
67
- # return True
68
- # return False
69
-
70
- # def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
71
- # try:
72
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
73
- # conn.execute(f"EXPLAIN QUERY PLAN {sql}")
74
- # return True
75
- # except sqlite3.Error:
76
- # return False
77
-
78
- # def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
79
- # try:
80
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
81
- # cur = conn.execute(sql)
82
- # rows = cur.fetchmany(max_rows)
83
- # norm_rows = [tuple(r) for r in rows]
84
- # return True, norm_rows, None
85
- # except sqlite3.Error as e:
86
- # return False, [], str(e)
87
-
88
- # _SQL_KEYWORDS_TO_IGNORE = {
89
- # "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
90
- # "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
91
- # "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
92
- # "when", "then", "else", "end", "asc", "desc"
93
- # }
94
-
95
- # _SQL_FUNCTIONS_TO_IGNORE = {
96
- # "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
97
- # "round", "date", "datetime", "strftime"
98
- # }
99
-
100
- # # --- LIGHTWEIGHT PARSING ---
101
- # def is_valid_select(sql: str):
102
- # sql = sql.strip().lower()
103
- # return sql.startswith("select") or sql.startswith("with")
104
-
105
- # def extract_tables(sql: str) -> List[str]:
106
- # sql = sql.lower()
107
- # if "join" not in sql:
108
- # tables = re.findall(r'from\s+(\w+)', sql)
109
- # return list(set(tables))
110
-
111
- # tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
112
- # joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
113
- # return list(set(tables + joins))
114
-
115
- # def extract_columns(sql: str) -> List[str]:
116
- # sql = sql.lower()
117
- # match = re.search(r'select\s+(.*?)\s+from', sql)
118
- # if not match:
119
- # return []
120
- # cols = match.group(1)
121
- # if cols.strip() == "*":
122
- # return ["*"]
123
- # return [c.strip() for c in cols.split(",")]
124
-
125
- # def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
126
- # tables = set()
127
- # columns = set()
128
- # for t in _list_tables(conn):
129
- # tl = t.lower()
130
- # if not tl:
131
- # continue
132
- # tables.add(tl)
133
- # try:
134
- # cur = conn.execute(f'PRAGMA table_info("{t}")')
135
- # for row in cur.fetchall():
136
- # if row and isinstance(row[1], str):
137
- # columns.add(row[1].lower())
138
- # except sqlite3.Error:
139
- # continue
140
- # return tables, columns
141
-
142
- # def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
143
- # return a == b
144
-
145
- # @dataclass
146
- # class RewardDebugStats:
147
- # total: int = 0
148
- # parsed_ok: int = 0
149
- # table_match: int = 0
150
- # column_match: int = 0
151
- # executed_ok: int = 0
152
- # exact_match: int = 0
153
-
154
- # _DEBUG = RewardDebugStats()
155
-
156
- # def reset_debug_metrics() -> None:
157
- # global _DEBUG
158
- # _DEBUG = RewardDebugStats()
159
-
160
- # def get_debug_metrics() -> dict:
161
- # denom = max(_DEBUG.total, 1)
162
- # return {
163
- # "valid_sql_rate": _DEBUG.parsed_ok / denom,
164
- # "table_match_rate": _DEBUG.table_match / denom,
165
- # "column_match_rate": _DEBUG.column_match / denom,
166
- # "execution_accuracy": _DEBUG.exact_match / denom,
167
- # }
168
-
169
- # EXECUTION_ERROR = "EXECUTION_ERROR"
170
-
171
- # _RESULT_CACHE_LOCK = threading.Lock()
172
- # _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
173
- # _RESULT_CACHE_MAX = 100_000
174
-
175
- # def clear_result_cache() -> None:
176
- # """Clear both DB query cache and reward cache."""
177
- # with _RESULT_CACHE_LOCK:
178
- # _RESULT_CACHE.clear()
179
- # _REWARD_CACHE.clear()
180
-
181
- # def _db_state_fingerprint(db_path: str) -> str:
182
- # try:
183
- # st = os.stat(db_path)
184
- # return f"{st.st_mtime_ns}:{st.st_size}"
185
- # except OSError:
186
- # return "missing"
187
-
188
- # def _result_cache_key(db_path: str, sql: str) -> str:
189
- # fp = _db_state_fingerprint(db_path)
190
- # payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
191
- # return hashlib.sha256(payload).hexdigest()
192
-
193
- # class _ConnectionPool:
194
- # def __init__(self, db_path: str, maxsize: int = 1) -> None:
195
- # self.db_path = db_path
196
- # self.pool = queue.LifoQueue(maxsize=maxsize)
197
- # self.lock = threading.Lock()
198
-
199
- # def acquire(self) -> sqlite3.Connection:
200
- # try:
201
- # return self.pool.get_nowait()
202
- # except queue.Empty:
203
- # with self.lock:
204
- # try:
205
- # return self.pool.get_nowait()
206
- # except queue.Empty:
207
- # return _connect_readonly(self.db_path)
208
-
209
- # def release(self, conn: sqlite3.Connection) -> None:
210
- # try:
211
- # self.pool.put_nowait(conn)
212
- # except queue.Full:
213
- # try:
214
- # conn.close()
215
- # except Exception:
216
- # pass
217
-
218
- # _POOL_LOCK = threading.Lock()
219
- # _POOLS: Dict[str, _ConnectionPool] = {}
220
-
221
- # def _get_pool(db_path: str) -> _ConnectionPool:
222
- # with _POOL_LOCK:
223
- # pool = _POOLS.get(db_path)
224
- # if pool is None:
225
- # pool = _ConnectionPool(db_path=db_path, maxsize=1)
226
- # _POOLS[db_path] = pool
227
- # return pool
228
-
229
- # class _PooledConnection:
230
- # def __init__(self, db_path: str) -> None:
231
- # self.db_path = db_path
232
- # self.pool = _get_pool(db_path)
233
- # self.conn: Optional[sqlite3.Connection] = None
234
-
235
- # def __enter__(self) -> sqlite3.Connection:
236
- # self.conn = self.pool.acquire()
237
- # return self.conn
238
-
239
- # def __exit__(self, exc_type, exc, tb) -> None:
240
- # if self.conn is not None:
241
- # self.pool.release(self.conn)
242
- # self.conn = None
243
-
244
- # def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
245
- # with _RESULT_CACHE_LOCK:
246
- # return _RESULT_CACHE.get(key)
247
-
248
- # def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
249
- # with _RESULT_CACHE_LOCK:
250
- # if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
251
- # _RESULT_CACHE.clear()
252
- # _RESULT_CACHE[key] = value
253
-
254
- # def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
255
- # try:
256
- # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
257
- # cur = conn.execute(sql)
258
- # rows = cur.fetchmany(max_rows)
259
- # return [tuple(r) for r in rows]
260
- # except Exception:
261
- # return EXECUTION_ERROR
262
-
263
- # def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
264
- # if not USE_CACHE:
265
- # with _PooledConnection(db_path) as conn:
266
- # return execute_sql(conn, sql, max_rows=max_rows)
267
-
268
- # key = _result_cache_key(db_path, sql)
269
- # cached = _cache_get(key)
270
- # if cached is not None:
271
- # return cached
272
- # with _PooledConnection(db_path) as conn:
273
- # res = execute_sql(conn, sql, max_rows=max_rows)
274
- # _cache_put(key, res)
275
- # return res
276
-
277
- # def execution_reward_timed(
278
- # pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
279
- # ) -> Tuple[float, Dict[str, float]]:
280
- # timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
281
- # t0 = time.perf_counter()
282
- # sql = _normalize_sql(pred_sql)
283
- # gold = _normalize_sql(gold_sql)
284
-
285
- # if not is_valid_select(sql):
286
- # timings["parse_s"] = time.perf_counter() - t0
287
- # return 0.0, timings
288
-
289
- # t1 = time.perf_counter()
290
- # timings["parse_s"] = t1 - t0
291
-
292
- # if measure_plan:
293
- # with _PooledConnection(db_path) as conn:
294
- # p0 = time.perf_counter()
295
- # _explain_query_plan(conn, sql)
296
- # _explain_query_plan(conn, gold)
297
- # timings["plan_s"] = time.perf_counter() - p0
298
-
299
- # e0 = time.perf_counter()
300
- # pred_res = execute_sql_cached(db_path, sql)
301
- # if pred_res == EXECUTION_ERROR:
302
- # timings["exec_s"] = time.perf_counter() - e0
303
- # return 0.0, timings
304
- # gold_res = execute_sql_cached(db_path, gold)
305
- # timings["exec_s"] = time.perf_counter() - e0
306
- # if gold_res == EXECUTION_ERROR:
307
- # return 0.0, timings
308
-
309
- # reward = -0.2
310
- # reward += 0.2
311
- # if _safe_results_equal(pred_res, gold_res):
312
- # return 1.0, timings
313
- # return max(-1.0, min(1.0, reward)), timings
314
-
315
- # def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
316
- # try:
317
- # sql = _normalize_sql(pred_sql)
318
- # gold = _normalize_sql(gold_sql)
319
-
320
- # if not is_valid_select(sql):
321
- # return -1.0
322
-
323
- # reward = -0.2
324
-
325
- # pred_tables = set(extract_tables(sql))
326
- # gold_tables = set(extract_tables(gold))
327
-
328
- # if pred_tables == gold_tables and len(gold_tables) > 0:
329
- # reward += 0.3
330
-
331
- # pred_cols = set(extract_columns(sql))
332
- # gold_cols = set(extract_columns(gold))
333
-
334
- # if gold_cols:
335
- # overlap = len(pred_cols & gold_cols) / len(gold_cols)
336
- # reward += 0.3 * overlap
337
-
338
- # pred_res = execute_sql_cached(db_path, sql)
339
- # if pred_res == EXECUTION_ERROR:
340
- # return 0.0
341
- # reward += 0.2
342
-
343
- # gold_res = execute_sql_cached(db_path, gold)
344
- # if gold_res == EXECUTION_ERROR:
345
- # return 0.0
346
- # if _safe_results_equal(pred_res, gold_res):
347
- # return 1.0
348
-
349
- # return max(-1.0, min(1.0, reward))
350
-
351
- # except Exception:
352
- # return 0.0
353
-
354
- # def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
355
- # if not USE_CACHE:
356
- # return execution_reward(pred_sql, db_path, gold_sql)
357
-
358
- # key = f"{db_path}|{pred_sql}|{gold_sql}"
359
- # if key not in _REWARD_CACHE:
360
- # _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
361
- # return _REWARD_CACHE[key]
362
-
363
- # def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
364
- # return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
365
-
366
- # def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
367
- # if not rollouts:
368
- # return []
369
-
370
- # unique_dbs = {db_path for _, db_path, _ in rollouts}
371
- # worker_count = max(1, min(max_workers, len(unique_dbs)))
372
- # results: List[Optional[float]] = [None] * len(rollouts)
373
-
374
- # with ThreadPoolExecutor(max_workers=worker_count) as executor:
375
- # futures = {
376
- # executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
377
- # for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
378
- # }
379
- # for fut in as_completed(futures):
380
- # idx = futures[fut]
381
- # try:
382
- # results[idx] = float(fut.result())
383
- # except Exception:
384
- # results[idx] = 0.0
385
-
386
- # return [r if r is not None else 0.0 for r in results]
387
-
388
  from __future__ import annotations
389
 
390
  import os
391
  import re
392
  import sqlite3
393
- import threading
394
  import time
395
- import json
396
- from concurrent.futures import ThreadPoolExecutor, as_completed
397
  from dataclasses import dataclass
398
- from typing import Dict, List
399
-
400
- from src.sql_validator import validate_sql_schema
401
-
402
- # =========================================================
403
- # 🔥 CONFIG FLAGS
404
- # =========================================================
405
- USE_SCHEMA_VALIDATION = True
406
- USE_CACHE = True
407
- DEFAULT_QUERY_TIMEOUT_S = 2.0
408
-
409
- EXECUTION_ERROR = "EXECUTION_ERROR"
410
-
411
- _REWARD_CACHE: Dict[str, float] = {}
412
-
413
- # =========================================================
414
- # 🔥 TASK 2: ERROR ANALYSIS + LOGGING
415
- # =========================================================
416
- ERROR_LOG_FILE = "results/error_logs.json"
417
-
418
-
419
- def classify_error(sql: str) -> str:
420
- sql = sql.lower()
421
-
422
- if "join" in sql and " on " not in sql:
423
- return "missing_join"
424
-
425
- if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
426
- return "wrong_where"
427
-
428
- if "null" in sql:
429
- return "null_handling"
430
-
431
- if "group by" in sql and "count" not in sql:
432
- return "wrong_groupby"
433
 
434
- return "other"
 
 
 
 
 
 
 
435
 
436
 
437
- def get_hint(error_type: str) -> str:
438
- hints = {
439
- "missing_join": "Add proper JOIN condition using ON.",
440
- "wrong_where": "Check WHERE clause conditions.",
441
- "null_handling": "Handle NULL values using IS NULL.",
442
- "wrong_groupby": "Use aggregation functions with GROUP BY.",
443
- "other": "Check SQL syntax and logic."
444
- }
445
- return hints.get(error_type, "Check query.")
446
-
447
-
448
- def log_error(question: str, sql: str, error: str, error_type: str):
449
- os.makedirs("results", exist_ok=True)
450
-
451
- entry = {
452
- "question": question,
453
- "sql": sql,
454
- "error": error,
455
- "error_type": error_type,
456
- "timestamp": time.time()
457
- }
458
-
459
- if os.path.exists(ERROR_LOG_FILE):
460
- with open(ERROR_LOG_FILE, "r") as f:
461
- logs = json.load(f)
462
- else:
463
- logs = []
464
-
465
- logs.append(entry)
466
-
467
- with open(ERROR_LOG_FILE, "w") as f:
468
- json.dump(logs, f, indent=2)
469
-
470
- # =========================================================
471
- # CACHE/VALIDATION TOGGLES (Task 1)
472
- # =========================================================
473
- def set_use_cache(enabled: bool) -> None:
474
- global USE_CACHE
475
- USE_CACHE = bool(enabled)
476
-
477
-
478
- def set_use_schema_validation(enabled: bool) -> None:
479
- global USE_SCHEMA_VALIDATION
480
- USE_SCHEMA_VALIDATION = bool(enabled)
481
-
482
-
483
- # =========================================================
484
- # SQL CLEANING
485
- # =========================================================
486
  def _normalize_sql(sql: str) -> str:
487
  if not isinstance(sql, str):
488
  return ""
489
  s = sql.strip()
490
-
491
  if s.startswith("```"):
 
492
  s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
493
  s = re.sub(r"\n?```$", "", s).strip()
494
-
495
  if s.lower().startswith("sql:"):
496
  s = s[4:].strip()
497
-
498
  if ";" in s:
499
  s = s.split(";", 1)[0].strip()
500
-
501
  return s
502
 
503
 
504
- # =========================================================
505
- # DB EXECUTION
506
- # =========================================================
507
- def _connect_readonly(db_path: str):
508
  uri = f"file:{os.path.abspath(db_path)}?mode=ro"
509
  conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
510
  conn.execute("PRAGMA query_only = ON;")
@@ -512,320 +43,367 @@ def _connect_readonly(db_path: str):
512
  return conn
513
 
514
 
515
- def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
516
  start = time.monotonic()
517
 
518
- def handler():
519
  return 1 if (time.monotonic() - start) > timeout_s else 0
520
 
521
- conn.set_progress_handler(handler, 10_000)
 
522
 
523
 
524
- def execute_sql(conn, sql):
525
  try:
526
- _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
527
- cur = conn.execute(sql)
528
- return cur.fetchall()
529
- except Exception:
530
- return EXECUTION_ERROR
531
-
532
-
533
- _RESULT_CACHE = {}
534
- _RESULT_LOCK = threading.Lock()
535
-
536
-
537
- def execute_sql_cached(db_path, sql):
538
- key = f"{db_path}|{sql}"
539
 
540
- if USE_CACHE:
541
- with _RESULT_LOCK:
542
- if key in _RESULT_CACHE:
543
- return _RESULT_CACHE[key]
544
 
545
- conn = _connect_readonly(db_path)
546
- result = execute_sql(conn, sql)
547
- conn.close()
 
 
 
 
 
 
548
 
549
- if USE_CACHE:
550
- with _RESULT_LOCK:
551
- _RESULT_CACHE[key] = result
552
 
553
- return result
 
 
 
 
 
 
554
 
555
 
556
- def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  """
558
- Like execute_sql_cached(), but reuses an existing connection.
559
- Intended for 1-thread-per-DB workloads (Task 1).
560
  """
561
- key = f"{db_path}|{sql}"
562
- if USE_CACHE:
563
- with _RESULT_LOCK:
564
- if key in _RESULT_CACHE:
565
- return _RESULT_CACHE[key]
 
 
 
 
566
 
567
- result = execute_sql(conn, sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
- if USE_CACHE:
570
- with _RESULT_LOCK:
571
- _RESULT_CACHE[key] = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
- return result
574
 
 
 
 
575
 
576
- def clear_result_cache() -> None:
577
- global _RESULT_CACHE, _REWARD_CACHE
578
- with _RESULT_LOCK:
579
- _RESULT_CACHE.clear()
580
- _REWARD_CACHE.clear()
581
 
 
 
 
 
 
 
 
 
582
 
583
- # =========================================================
584
- # SQL PARSING
585
- # =========================================================
586
- def is_valid_select(sql):
587
- return sql.lower().startswith("select") or sql.lower().startswith("with")
588
 
 
589
 
590
- def extract_tables(sql):
591
- return re.findall(r'from\s+(\w+)', sql.lower())
592
 
 
 
 
593
 
594
- def extract_columns(sql):
595
- match = re.search(r'select\s+(.*?)\s+from', sql.lower())
596
- if not match:
597
- return []
598
- cols = match.group(1)
599
- return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
600
 
 
 
 
 
 
 
 
 
601
 
602
- def get_sql_operations(sql: str):
603
- sql = sql.lower()
604
- ops = []
605
-
606
- if "select" in sql: ops.append("SELECT")
607
- if "where" in sql: ops.append("WHERE")
608
- if "join" in sql: ops.append("JOIN")
609
- if "group by" in sql: ops.append("GROUP_BY")
610
- if "order by" in sql: ops.append("ORDER_BY")
611
 
612
- return ops
613
 
 
 
 
614
 
615
- def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
 
616
  try:
617
- _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
618
- conn.execute(f"EXPLAIN QUERY PLAN {sql}")
619
- return True
 
620
  except Exception:
621
- return False
622
 
623
 
624
- def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
625
  """
626
- Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
627
- Used by Task-1 benchmark to profile bottlenecks.
 
628
  """
629
- timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
630
- t0 = time.perf_counter()
631
-
632
- sql = _normalize_sql(pred_sql)
633
- gold = _normalize_sql(gold_sql)
634
-
635
- if not is_valid_select(sql):
636
- timings["parse_s"] = time.perf_counter() - t0
637
- return 0.0, timings
638
-
639
- t1 = time.perf_counter()
640
- timings["parse_s"] = t1 - t0
641
-
642
- conn = _connect_readonly(db_path)
643
  try:
644
- if measure_plan:
645
- p0 = time.perf_counter()
646
- _explain_query_plan(conn, sql)
647
- _explain_query_plan(conn, gold)
648
- timings["plan_s"] = time.perf_counter() - p0
649
-
650
- e0 = time.perf_counter()
651
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
652
- if pred_res == EXECUTION_ERROR:
653
- timings["exec_s"] = time.perf_counter() - e0
654
- return 0.0, timings
655
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
656
- timings["exec_s"] = time.perf_counter() - e0
657
- if gold_res == EXECUTION_ERROR:
658
- return 0.0, timings
659
-
660
- reward = -0.2 + 0.2
661
- if pred_res == gold_res:
662
- return 1.0, timings
663
- return max(-1.0, min(1.0, reward)), timings
664
- finally:
665
- try:
666
- conn.close()
667
- except Exception:
668
- pass
669
-
670
 
671
- # =========================================================
672
- # 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
673
- # =========================================================
674
  def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
675
  try:
676
  sql = _normalize_sql(pred_sql)
677
  gold = _normalize_sql(gold_sql)
678
 
679
- if not is_valid_select(sql):
680
  return -1.0
681
 
682
- reward = -0.2
683
-
684
- # =========================
685
- # SCHEMA VALIDATION (Task 3)
686
- # =========================
687
- if USE_SCHEMA_VALIDATION:
688
- valid, _ = validate_sql_schema(sql, db_path)
689
- if not valid:
690
- error_type = classify_error(sql)
691
- log_error("UNKNOWN", sql, "schema_invalid", error_type)
692
- return 0.1
693
-
694
- # =========================
695
- # EXECUTION
696
- # =========================
697
- pred_res = execute_sql_cached(db_path, sql)
698
-
699
- if pred_res == "EXECUTION_ERROR":
700
- error_type = classify_error(sql)
701
-
702
- log_error(
703
- question="UNKNOWN",
704
- sql=sql,
705
- error="execution_error",
706
- error_type=error_type
707
- )
708
-
709
- print(f"[ERROR] {error_type}")
710
- print(f"[HINT] {get_hint(error_type)}")
711
-
712
- return 0.1
713
-
714
- reward += 0.2
715
-
716
- gold_res = execute_sql_cached(db_path, gold)
717
-
718
- if gold_res == "EXECUTION_ERROR":
719
- return 0.1
720
-
721
- if pred_res == gold_res:
722
- return 1.0
723
-
724
- return max(-1.0, min(1.0, reward))
725
-
726
- except Exception as e:
727
- log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
728
- return 0.0
729
-
730
-
731
- # =========================================================
732
- # BATCH EXECUTION (Task 1)
733
- # =========================================================
734
- def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
735
- if not USE_CACHE:
736
- return float(execution_reward(pred_sql, db_path, gold_sql))
737
- key = f"{db_path}|{pred_sql}|{gold_sql}"
738
- if key in _REWARD_CACHE:
739
- return float(_REWARD_CACHE[key])
740
- r = float(execution_reward(pred_sql, db_path, gold_sql))
741
- _REWARD_CACHE[key] = r
742
- return r
743
-
744
-
745
- def execution_reward_batch_sequential(rollouts):
746
- return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
747
-
748
-
749
- def execution_reward_batch_parallel(rollouts, max_workers=10):
750
- results = [0.0] * len(rollouts)
751
-
752
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
753
- futures = {
754
- executor.submit(cached_execution_reward, p, d, g): i
755
- for i, (p, d, g) in enumerate(rollouts)
756
- }
757
 
758
- for fut in as_completed(futures):
759
- idx = futures[fut]
760
- try:
761
- results[idx] = fut.result()
762
- except Exception:
763
- results[idx] = 0.0
764
 
765
- return results
 
766
 
 
 
767
 
768
- def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
769
- """
770
- 1 thread per DB path. Reuses a single readonly connection per DB worker.
771
- Preserves input order.
772
- """
773
- if not rollouts:
774
- return []
775
 
776
- by_db = {}
777
- for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
778
- by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
779
 
780
- results = [0.0 for _ in range(len(rollouts))]
 
 
 
781
 
782
- def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
783
- try:
784
- sql = _normalize_sql(pred_sql)
785
- gold = _normalize_sql(gold_sql)
786
-
787
- if not is_valid_select(sql):
788
- return -1.0
789
-
790
- reward = -0.2
791
-
792
- if USE_SCHEMA_VALIDATION:
793
- valid, _ = validate_sql_schema(sql, db_path)
794
- if not valid:
795
- error_type = classify_error(sql)
796
- log_error("UNKNOWN", sql, "schema_invalid", error_type)
797
- return 0.1
798
-
799
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
800
- if pred_res == EXECUTION_ERROR:
801
- error_type = classify_error(sql)
802
- log_error("UNKNOWN", sql, "execution_error", error_type)
803
- return 0.1
804
-
805
- reward += 0.2
806
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
807
- if gold_res == EXECUTION_ERROR:
808
- return 0.1
809
- if pred_res == gold_res:
810
  return 1.0
811
- return max(-1.0, min(1.0, reward))
812
- except Exception:
813
- return 0.0
814
 
815
- def _worker(db_path: str, items):
816
- conn = _connect_readonly(db_path)
817
- try:
818
- for idx, pred, gold in items:
819
- results[idx] = _reward_with_conn(conn, pred, db_path, gold)
820
- finally:
821
- try:
822
- conn.close()
823
- except Exception:
824
- pass
825
-
826
- with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
827
- futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
828
- for fut in as_completed(futures):
829
- fut.result()
830
 
831
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
  import re
5
  import sqlite3
 
6
  import time
 
 
7
  from dataclasses import dataclass
8
+ from typing import List, Optional, Sequence, Set, Tuple, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ try:
11
+ import sqlparse
12
+ from sqlparse.sql import Function, Identifier, IdentifierList, Statement, Token, Where
13
+ from sqlparse.tokens import DML, Keyword, Name, Number, Punctuation, String, Whitespace
14
+ except Exception: # pragma: no cover
15
+ sqlparse = None # type: ignore[assignment]
16
+ Statement = object # type: ignore[misc,assignment]
17
+ Token = object # type: ignore[misc,assignment]
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def _normalize_sql(sql: str) -> str:
21
  if not isinstance(sql, str):
22
  return ""
23
  s = sql.strip()
 
24
  if s.startswith("```"):
25
+ # Strip markdown fences if present.
26
  s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
27
  s = re.sub(r"\n?```$", "", s).strip()
 
28
  if s.lower().startswith("sql:"):
29
  s = s[4:].strip()
30
+ # Keep only the first statement to avoid accidental multi-statement execution.
31
  if ";" in s:
32
  s = s.split(";", 1)[0].strip()
 
33
  return s
34
 
35
 
36
+ def _connect_readonly(db_path: str) -> sqlite3.Connection:
37
+ # Read-only prevents any accidental mutation during reward computation.
38
+ # Note: requires SQLite URI support (built-in).
 
39
  uri = f"file:{os.path.abspath(db_path)}?mode=ro"
40
  conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
41
  conn.execute("PRAGMA query_only = ON;")
 
43
  return conn
44
 
45
 
46
+ def _with_timeout(conn: sqlite3.Connection, timeout_s: float = 1.0) -> None:
47
  start = time.monotonic()
48
 
49
+ def _handler() -> int:
50
  return 1 if (time.monotonic() - start) > timeout_s else 0
51
 
52
+ # Call handler every N VM opcodes.
53
+ conn.set_progress_handler(_handler, 10_000)
54
 
55
 
56
+ def _list_tables(conn: sqlite3.Connection) -> List[str]:
57
  try:
58
+ cur = conn.execute(
59
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
60
+ )
61
+ return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
62
+ except sqlite3.Error:
63
+ return []
 
 
 
 
 
 
 
64
 
 
 
 
 
65
 
66
+ def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
67
+ s = sql.lower()
68
+ for t in table_names:
69
+ tl = t.lower()
70
+ if not tl:
71
+ continue
72
+ if re.search(rf"\b{re.escape(tl)}\b", s):
73
+ return True
74
+ return False
75
 
 
 
 
76
 
77
+ def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
78
+ try:
79
+ _with_timeout(conn, timeout_s=1.0)
80
+ conn.execute(f"EXPLAIN QUERY PLAN {sql}")
81
+ return True
82
+ except sqlite3.Error:
83
+ return False
84
 
85
 
86
+ def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
87
+ try:
88
+ _with_timeout(conn, timeout_s=1.0)
89
+ cur = conn.execute(sql)
90
+ rows = cur.fetchmany(max_rows)
91
+ # Normalize to plain tuples for deterministic comparison.
92
+ norm_rows = [tuple(r) for r in rows]
93
+ return True, norm_rows, None
94
+ except sqlite3.Error as e:
95
+ return False, [], str(e)
96
+
97
+
98
+ _SQL_KEYWORDS_TO_IGNORE = {
99
+ "select",
100
+ "from",
101
+ "where",
102
+ "join",
103
+ "inner",
104
+ "left",
105
+ "right",
106
+ "full",
107
+ "outer",
108
+ "on",
109
+ "group",
110
+ "by",
111
+ "order",
112
+ "limit",
113
+ "having",
114
+ "distinct",
115
+ "union",
116
+ "intersect",
117
+ "except",
118
+ "as",
119
+ "and",
120
+ "or",
121
+ "not",
122
+ "in",
123
+ "is",
124
+ "null",
125
+ "like",
126
+ "between",
127
+ "case",
128
+ "when",
129
+ "then",
130
+ "else",
131
+ "end",
132
+ "asc",
133
+ "desc",
134
+ }
135
+
136
+ _SQL_FUNCTIONS_TO_IGNORE = {
137
+ "count",
138
+ "avg",
139
+ "min",
140
+ "max",
141
+ "sum",
142
+ "lower",
143
+ "upper",
144
+ "substr",
145
+ "coalesce",
146
+ "round",
147
+ "date",
148
+ "datetime",
149
+ "strftime",
150
+ }
151
+
152
+
153
+ def extract_tables(sql: str) -> Set[str]:
154
  """
155
+ Best-effort table extraction from SQL using sqlparse.
156
+ Returns lowercase table names (unqualified).
157
  """
158
+ sql = _normalize_sql(sql)
159
+ if not sql:
160
+ return set()
161
+ if sqlparse is None:
162
+ # Fallback: naive regex for FROM/JOIN.
163
+ found = set()
164
+ for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I):
165
+ found.add(m.group(2).lower())
166
+ return found
167
 
168
+ try:
169
+ statements = sqlparse.parse(sql)
170
+ except Exception:
171
+ return set()
172
+
173
+ tables: Set[str] = set()
174
+
175
+ def _add_identifier_as_table(ident: Identifier) -> None:
176
+ # Prefer real name over alias; strip any schema prefix.
177
+ name = ident.get_real_name() or ident.get_name()
178
+ if not name:
179
+ return
180
+ tables.add(name.lower())
181
+
182
+ for st in statements:
183
+ if not isinstance(st, Statement):
184
+ continue
185
+ seen_from = False
186
+ for tok in st.flatten():
187
+ if tok.ttype in Whitespace:
188
+ continue
189
+ if tok.ttype is Keyword and tok.value.upper() in {"FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"}:
190
+ seen_from = True
191
+ continue
192
+ if not seen_from:
193
+ continue
194
+
195
+ if isinstance(tok, Identifier):
196
+ _add_identifier_as_table(tok)
197
+ seen_from = False
198
+ elif tok.ttype is Name:
199
+ tables.add(tok.value.lower())
200
+ seen_from = False
201
+ elif tok.ttype is Keyword and tok.value.upper() in {"WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"}:
202
+ seen_from = False
203
+
204
+ return tables
205
+
206
+
207
+ def extract_columns(sql: str) -> Set[str]:
208
+ """
209
+ Best-effort column extraction from SQL using sqlparse.
210
+ Returns lowercase column names (unqualified).
211
+ """
212
+ sql = _normalize_sql(sql)
213
+ if not sql:
214
+ return set()
215
+ if sqlparse is None:
216
+ # Fallback: naive dotted identifiers and bare names after SELECT/WHERE/etc.
217
+ cols = set()
218
+ for m in re.finditer(r"\b([a-zA-Z_][\w$]*)\b", sql):
219
+ w = m.group(1).lower()
220
+ if w in _SQL_KEYWORDS_TO_IGNORE or w in _SQL_FUNCTIONS_TO_IGNORE:
221
+ continue
222
+ cols.add(w)
223
+ return cols
224
 
225
+ try:
226
+ statements = sqlparse.parse(sql)
227
+ except Exception:
228
+ return set()
229
+
230
+ cols: Set[str] = set()
231
+
232
+ def _maybe_add_col(name: Optional[str]) -> None:
233
+ if not name:
234
+ return
235
+ n = name.strip().strip('"').strip("'").lower()
236
+ if not n or n == "*":
237
+ return
238
+ if n in _SQL_KEYWORDS_TO_IGNORE or n in _SQL_FUNCTIONS_TO_IGNORE:
239
+ return
240
+ cols.add(n)
241
+
242
+ def _handle_identifier(ident: Identifier) -> None:
243
+ # If qualified (t.col), keep only col for overlap/hallucination checks.
244
+ _maybe_add_col(ident.get_real_name() or ident.get_name())
245
+
246
+ for st in statements:
247
+ if not isinstance(st, Statement):
248
+ continue
249
+ for tok in st.flatten():
250
+ # Skip whitespace/punctuation/string literals/numbers.
251
+ if getattr(tok, "ttype", None) in (Whitespace, Punctuation, String, Number):
252
+ continue
253
+
254
+ if isinstance(tok, Function):
255
+ fname = tok.get_name()
256
+ if fname:
257
+ # Don't treat function name as a column.
258
+ pass
259
+ continue
260
+
261
+ if isinstance(tok, IdentifierList):
262
+ for ident in tok.get_identifiers():
263
+ if isinstance(ident, Identifier):
264
+ _handle_identifier(ident)
265
+ continue
266
+
267
+ if isinstance(tok, Identifier):
268
+ _handle_identifier(tok)
269
+ continue
270
+
271
+ if getattr(tok, "ttype", None) is Name:
272
+ _maybe_add_col(tok.value)
273
+
274
+ return cols
275
+
276
+
277
+ def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
278
+ """
279
+ Return (tables, columns) sets from SQLite schema; all lowercased.
280
+ Columns are returned as a global set (unqualified).
281
+ """
282
+ tables = set()
283
+ columns = set()
284
+ for t in _list_tables(conn):
285
+ tl = t.lower()
286
+ if not tl:
287
+ continue
288
+ tables.add(tl)
289
+ try:
290
+ cur = conn.execute(f'PRAGMA table_info("{t}")')
291
+ for row in cur.fetchall():
292
+ if row and isinstance(row[1], str):
293
+ columns.add(row[1].lower())
294
+ except sqlite3.Error:
295
+ continue
296
+ return tables, columns
297
 
 
298
 
299
+ def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
300
+ # Deterministic comparison: compare exact row tuples in order.
301
+ return a == b
302
 
 
 
 
 
 
303
 
304
+ @dataclass
305
+ class RewardDebugStats:
306
+ total: int = 0
307
+ parsed_ok: int = 0
308
+ table_match: int = 0
309
+ column_match: int = 0
310
+ executed_ok: int = 0
311
+ exact_match: int = 0
312
 
 
 
 
 
 
313
 
314
+ _DEBUG = RewardDebugStats()
315
 
 
 
316
 
317
+ def reset_debug_metrics() -> None:
318
+ global _DEBUG
319
+ _DEBUG = RewardDebugStats()
320
 
 
 
 
 
 
 
321
 
322
+ def get_debug_metrics() -> dict:
323
+ denom = max(_DEBUG.total, 1)
324
+ return {
325
+ "valid_sql_rate": _DEBUG.parsed_ok / denom,
326
+ "table_match_rate": _DEBUG.table_match / denom,
327
+ "column_match_rate": _DEBUG.column_match / denom,
328
+ "execution_accuracy": _DEBUG.exact_match / denom,
329
+ }
330
 
331
+ EXECUTION_ERROR = "EXECUTION_ERROR"
 
 
 
 
 
 
 
 
332
 
 
333
 
334
+ def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
335
+ """
336
+ Execute SQL safely.
337
 
338
+ If sqlite raises ANY exception, return EXECUTION_ERROR (NOT empty list).
339
+ """
340
  try:
341
+ _with_timeout(conn, timeout_s=1.0)
342
+ cur = conn.execute(sql)
343
+ rows = cur.fetchmany(max_rows)
344
+ return [tuple(r) for r in rows]
345
  except Exception:
346
+ return EXECUTION_ERROR
347
 
348
 
349
+ def _sqlparse_valid_select(sql: str) -> bool:
350
  """
351
+ Parse validation using sqlparse:
352
+ - parse() non-empty
353
+ - contains a SELECT statement
354
  """
355
+ if sqlparse is None:
356
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
357
  try:
358
+ stmts = sqlparse.parse(sql)
359
+ if not stmts:
360
+ return False
361
+ for st in stmts:
362
+ try:
363
+ if hasattr(st, "get_type") and st.get_type() == "SELECT":
364
+ return True
365
+ except Exception:
366
+ continue
367
+ return False
368
+ except Exception:
369
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
 
 
 
371
  def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
372
  try:
373
  sql = _normalize_sql(pred_sql)
374
  gold = _normalize_sql(gold_sql)
375
 
376
+ if not sql or "SELECT" not in sql.upper():
377
  return -1.0
378
 
379
+ if not _sqlparse_valid_select(sql):
380
+ return -1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
+ reward = -0.2 # valid SQL baseline
 
 
 
 
 
383
 
384
+ pred_tables = extract_tables(sql)
385
+ gold_tables = extract_tables(gold)
386
 
387
+ if pred_tables == gold_tables and len(gold_tables) > 0:
388
+ reward += 0.3
389
 
390
+ pred_cols = extract_columns(sql)
391
+ gold_cols = extract_columns(gold)
 
 
 
 
 
392
 
393
+ if gold_cols:
394
+ overlap = len(pred_cols & gold_cols) / len(gold_cols)
395
+ reward += 0.3 * overlap
396
 
397
+ with _connect_readonly(db_path) as conn:
398
+ pred_res = execute_sql(conn, sql)
399
+ if pred_res != EXECUTION_ERROR:
400
+ reward += 0.2
401
 
402
+ gold_res = execute_sql(conn, gold)
403
+ if pred_res != EXECUTION_ERROR and _safe_results_equal(pred_res, gold_res):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  return 1.0
 
 
 
405
 
406
+ return max(-1.0, min(1.0, reward))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+ except Exception:
409
+ return -1.0
src/execution_reward_soft.py DELETED
@@ -1,211 +0,0 @@
1
- import random
2
- import threading
3
- from collections import Counter
4
- from concurrent.futures import ThreadPoolExecutor, as_completed
5
- from src.execution_reward import (
6
- _normalize_sql,
7
- is_valid_select,
8
- execute_sql_cached,
9
- execute_sql_cached_conn,
10
- EXECUTION_ERROR,
11
- validate_sql_schema,
12
- USE_SCHEMA_VALIDATION,
13
- _connect_readonly,
14
- )
15
-
16
- # =========================================================
17
- # 🔥 SOFT REWARD CORE
18
- # =========================================================
19
- def compute_soft_reward(pred_res, gold_res, sample_k=10):
20
- try:
21
- # =================================================
22
- # 1. EDGE CASES
23
- # =================================================
24
- if not gold_res:
25
- return 1.0 if not pred_res else 0.3
26
-
27
- if not pred_res:
28
- return -0.05
29
-
30
- # =================================================
31
- # 2. SAFE HASHING
32
- # =================================================
33
- def make_hashable(row):
34
- return tuple(str(item) for item in row)
35
-
36
- pred_counter = Counter(make_hashable(r) for r in pred_res)
37
-
38
- # =================================================
39
- # 3. SAMPLING
40
- # =================================================
41
- k = min(sample_k, len(gold_res))
42
- sample = random.sample(gold_res, k)
43
-
44
- # =================================================
45
- # 4. MATCH COUNT
46
- # =================================================
47
- match = 0
48
- for row in sample:
49
- key = make_hashable(row)
50
- if pred_counter.get(key, 0) > 0:
51
- pred_counter[key] -= 1
52
- match += 1
53
-
54
- score = match / max(len(sample), 1)
55
-
56
- # =================================================
57
- # 5. 🔥 ANTI-CHEAT LENGTH PENALTY
58
- # =================================================
59
- len_ratio = len(pred_res) / max(len(gold_res), 1)
60
-
61
- if len_ratio > 1.5:
62
- score = score / (len_ratio ** 0.5) # 🔥 smoother penalty
63
-
64
- # =================================================
65
- # 6. CLAMP SCORE (IMPORTANT FOR STABILITY)
66
- # =================================================
67
- score = max(0.0, min(1.0, score))
68
-
69
- # =================================================
70
- # 7. FINAL REWARD
71
- # =================================================
72
- return 0.3 + 0.7 * score
73
-
74
- except Exception:
75
- return -0.05
76
-
77
-
78
- # =========================================================
79
- # 🔥 MAIN EXECUTION REWARD
80
- # =========================================================
81
- _TLS = threading.local()
82
-
83
-
84
- def _get_thread_conn(db_path: str):
85
- conns = getattr(_TLS, "conns", None)
86
- if conns is None:
87
- conns = {}
88
- _TLS.conns = conns
89
- conn = conns.get(db_path)
90
- if conn is None:
91
- conn = _connect_readonly(db_path)
92
- conns[db_path] = conn
93
- return conn
94
-
95
-
96
- def execution_reward_soft_pooled(pred_sql, db_path, gold_sql, *, sample_k: int = 10):
97
- """
98
- Soft execution reward, but reuses a per-thread read-only SQLite connection.
99
- This avoids connect/close overhead in RL loops.
100
- """
101
- try:
102
- sql = _normalize_sql(pred_sql)
103
- gold = _normalize_sql(gold_sql)
104
-
105
- if not is_valid_select(sql):
106
- return -0.05
107
-
108
- if USE_SCHEMA_VALIDATION:
109
- ok, _ = validate_sql_schema(sql, db_path)
110
- if not ok:
111
- return -0.05
112
-
113
- conn = _get_thread_conn(db_path)
114
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
115
- if pred_res == EXECUTION_ERROR:
116
- return -0.05
117
-
118
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
119
- if gold_res == EXECUTION_ERROR:
120
- return -0.05
121
-
122
- return compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k))
123
- except Exception:
124
- return -0.05
125
-
126
-
127
- def execution_reward_soft(pred_sql, db_path, gold_sql):
128
- try:
129
- sql = _normalize_sql(pred_sql)
130
- gold = _normalize_sql(gold_sql)
131
-
132
- # =================================================
133
- # BASIC VALIDATION
134
- # =================================================
135
- if not is_valid_select(sql):
136
- return -0.05
137
-
138
- if USE_SCHEMA_VALIDATION:
139
- ok, _ = validate_sql_schema(sql, db_path)
140
- if not ok:
141
- return -0.05
142
-
143
- # =================================================
144
- # EXECUTION
145
- # =================================================
146
- pred_res = execute_sql_cached(db_path, sql)
147
- if pred_res == EXECUTION_ERROR:
148
- return -0.05
149
-
150
- gold_res = execute_sql_cached(db_path, gold)
151
- if gold_res == EXECUTION_ERROR:
152
- return -0.05
153
-
154
- return compute_soft_reward(pred_res, gold_res)
155
-
156
- except Exception:
157
- return -0.05
158
-
159
-
160
- def execution_reward_soft_batch_parallel_by_db(rollouts, *, max_workers: int = 20, sample_k: int = 10):
161
- """
162
- rollouts: Sequence[(pred_sql, db_path, gold_sql)]
163
- Executes with 1-thread-per-DB grouping for better connection reuse.
164
- Returns rewards in the same order as input.
165
- """
166
- if not rollouts:
167
- return []
168
-
169
- # Group by DB so each worker can hold a single connection and reuse it.
170
- by_db = {}
171
- for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
172
- by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
173
-
174
- out = [0.0 for _ in range(len(rollouts))]
175
-
176
- def _worker(db_path: str, items):
177
- conn = _connect_readonly(db_path)
178
- try:
179
- for idx, pred_sql, gold_sql in items:
180
- # Do NOT use the global thread-local here; this worker owns the connection.
181
- try:
182
- sql = _normalize_sql(pred_sql)
183
- gold = _normalize_sql(gold_sql)
184
- if not is_valid_select(sql):
185
- out[idx] = -0.05
186
- continue
187
- if USE_SCHEMA_VALIDATION:
188
- ok, _ = validate_sql_schema(sql, db_path)
189
- if not ok:
190
- out[idx] = -0.05
191
- continue
192
- pred_res = execute_sql_cached_conn(conn, db_path, sql)
193
- if pred_res == EXECUTION_ERROR:
194
- out[idx] = -0.05
195
- continue
196
- gold_res = execute_sql_cached_conn(conn, db_path, gold)
197
- if gold_res == EXECUTION_ERROR:
198
- out[idx] = -0.05
199
- continue
200
- out[idx] = float(compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k)))
201
- except Exception:
202
- out[idx] = -0.05
203
- finally:
204
- conn.close()
205
-
206
- with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
207
- futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
208
- for fut in as_completed(futures):
209
- fut.result()
210
-
211
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/load_lora_model.py CHANGED
@@ -1,84 +1,21 @@
1
- # import torch
2
- # from transformers import T5ForConditionalGeneration, T5Tokenizer
3
- # from peft import LoraConfig, get_peft_model, TaskType
4
-
5
- # device = "mps" if torch.backends.mps.is_available() else "cpu"
6
-
7
- # MODEL_PATH = "../outputs/model" # your supervised trained model
8
-
9
- # print("Loading base model...")
10
- # model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
11
-
12
- # tokenizer = T5Tokenizer.from_pretrained("t5-small")
13
-
14
- # # ---------------- LoRA CONFIG ----------------
15
- # lora_config = LoraConfig(
16
- # r=8, # rank (small brain attachment)
17
- # lora_alpha=16,
18
- # target_modules=["q", "v"], # attention matrices only
19
- # lora_dropout=0.05,
20
- # bias="none",
21
- # task_type=TaskType.SEQ_2_SEQ_LM
22
- # )
23
-
24
- # print("Attaching LoRA adapters...")
25
- # model = get_peft_model(model, lora_config)
26
-
27
- # model.print_trainable_parameters()
28
-
29
- # print("READY ✔ LoRA model loaded")
30
-
31
- # ****************** task 5 @#$%^&*I(O)(*&^%$#$%^&*(*&^%$#$%^&*^%$#%^)
32
- # )
33
- #
34
- #
35
  import torch
36
  from transformers import T5ForConditionalGeneration, T5Tokenizer
37
  from peft import LoraConfig, get_peft_model, TaskType
38
 
39
- # ---------------- DEVICE SETUP ----------------
40
  device = "mps" if torch.backends.mps.is_available() else "cpu"
41
 
42
- MODEL_PATH = "../outputs/model"
43
-
44
- # ---------------- LOAD TOKENIZER ----------------
45
- tokenizer = T5Tokenizer.from_pretrained("t5-small")
46
-
47
- # ---------------- LOAD MODEL WITH QUANTIZATION ----------------
48
- def load_model(quantization=None):
49
- print(f"Loading model with quantization = {quantization}")
50
-
51
- if quantization == "int8":
52
- model = T5ForConditionalGeneration.from_pretrained(
53
- MODEL_PATH,
54
- load_in_8bit=True,
55
- device_map="auto"
56
- )
57
-
58
- elif quantization == "int4":
59
- model = T5ForConditionalGeneration.from_pretrained(
60
- MODEL_PATH,
61
- load_in_4bit=True,
62
- device_map="auto"
63
- )
64
-
65
- else: # fp32
66
- model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
67
 
68
- return model
69
-
70
-
71
- # 👉 CHANGE THIS VALUE TO TEST
72
- QUANTIZATION = "int8" # options: None, "int8", "int4"
73
-
74
- model = load_model(QUANTIZATION)
75
 
 
76
 
77
  # ---------------- LoRA CONFIG ----------------
78
  lora_config = LoraConfig(
79
- r=8,
80
  lora_alpha=16,
81
- target_modules=["q", "v"],
82
  lora_dropout=0.05,
83
  bias="none",
84
  task_type=TaskType.SEQ_2_SEQ_LM
@@ -89,4 +26,5 @@ model = get_peft_model(model, lora_config)
89
 
90
  model.print_trainable_parameters()
91
 
92
- print("READY ✔ LoRA + Quantized model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
  from peft import LoraConfig, get_peft_model, TaskType
4
 
 
5
  device = "mps" if torch.backends.mps.is_available() else "cpu"
6
 
7
+ MODEL_PATH = "../outputs/model" # your supervised trained model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ print("Loading base model...")
10
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
 
 
 
 
 
11
 
12
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
13
 
14
  # ---------------- LoRA CONFIG ----------------
15
  lora_config = LoraConfig(
16
+ r=8, # rank (small brain attachment)
17
  lora_alpha=16,
18
+ target_modules=["q", "v"], # attention matrices only
19
  lora_dropout=0.05,
20
  bias="none",
21
  task_type=TaskType.SEQ_2_SEQ_LM
 
26
 
27
  model.print_trainable_parameters()
28
 
29
+ print("READY ✔ LoRA model loaded")
30
+
src/quantization_utils.py DELETED
@@ -1,222 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- import time
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
- from typing import Any, Dict, Optional, Tuple
9
-
10
- import torch
11
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
-
13
- try:
14
- from transformers import BitsAndBytesConfig # type: ignore
15
- except Exception: # pragma: no cover
16
- BitsAndBytesConfig = None # type: ignore
17
-
18
- try:
19
- from peft import PeftModel
20
- except Exception as e: # pragma: no cover
21
- PeftModel = None # type: ignore
22
-
23
-
24
- @dataclass(frozen=True)
25
- class QuantArtifact:
26
- out_dir: Path
27
- mode: str # fp32 | int8_dynamic | int8_decoder_dynamic | int8_bnb | int4_bnb
28
- base_model: str
29
- adapter_path: Optional[str]
30
- created_at_s: float
31
-
32
-
33
- def _bool_env(name: str, default: str = "0") -> bool:
34
- return os.environ.get(name, default).strip() in {"1", "true", "True", "yes", "Y"}
35
-
36
-
37
- def estimate_model_bytes(model: torch.nn.Module) -> int:
38
- total = 0
39
- for p in model.parameters():
40
- total += p.numel() * p.element_size()
41
- for b in model.buffers():
42
- total += b.numel() * b.element_size()
43
- return int(total)
44
-
45
-
46
- def _load_tokenizer(base_model: str, *, local_only: bool) -> Any:
47
- tok = AutoTokenizer.from_pretrained(base_model, local_files_only=local_only)
48
- if tok.pad_token_id is None and getattr(tok, "eos_token_id", None) is not None:
49
- tok.pad_token = tok.eos_token
50
- return tok
51
-
52
-
53
- def load_fp32_model(
54
- base_model: str,
55
- *,
56
- adapter_path: Optional[str] = None,
57
- device: str = "cpu",
58
- local_only: bool = True,
59
- torch_dtype: torch.dtype = torch.float32,
60
- merge_lora: bool = True,
61
- ) -> Tuple[Any, torch.nn.Module]:
62
- tok = _load_tokenizer(base_model, local_only=local_only)
63
- model = AutoModelForSeq2SeqLM.from_pretrained(
64
- base_model,
65
- local_files_only=local_only,
66
- torch_dtype=torch_dtype,
67
- ).to(device)
68
-
69
- if adapter_path:
70
- if PeftModel is None:
71
- raise RuntimeError("peft is required to load adapters.")
72
- model = PeftModel.from_pretrained(model, adapter_path).to(device)
73
- if merge_lora and hasattr(model, "merge_and_unload"):
74
- model = model.merge_and_unload()
75
- model = model.to(device)
76
-
77
- model.eval()
78
- return tok, model
79
-
80
-
81
- def quantize_dynamic_int8(model: torch.nn.Module) -> torch.nn.Module:
82
- # CPU-only; quantized kernels run on CPU.
83
- # Ensure a quantization engine is selected (PyTorch may default to "none" on macOS).
84
- try:
85
- supported = list(getattr(torch.backends.quantized, "supported_engines", []))
86
- current = getattr(torch.backends.quantized, "engine", "none")
87
- if current in {"none", None, ""}:
88
- if "fbgemm" in supported:
89
- torch.backends.quantized.engine = "fbgemm"
90
- elif "qnnpack" in supported:
91
- torch.backends.quantized.engine = "qnnpack"
92
- except Exception: # pragma: no cover
93
- pass
94
- return torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
95
-
96
-
97
- def quantize_dynamic_int8_decoder_only(model: Any) -> Any:
98
- """
99
- Mixed-precision (Task 5): encoder fp32, decoder int8 dynamic quantized.
100
- """
101
- if not hasattr(model, "decoder"):
102
- raise ValueError("Model has no decoder attribute.")
103
- try:
104
- supported = list(getattr(torch.backends.quantized, "supported_engines", []))
105
- current = getattr(torch.backends.quantized, "engine", "none")
106
- if current in {"none", None, ""}:
107
- if "fbgemm" in supported:
108
- torch.backends.quantized.engine = "fbgemm"
109
- elif "qnnpack" in supported:
110
- torch.backends.quantized.engine = "qnnpack"
111
- except Exception: # pragma: no cover
112
- pass
113
- model.decoder = torch.quantization.quantize_dynamic(model.decoder, {torch.nn.Linear}, dtype=torch.qint8)
114
- return model
115
-
116
-
117
- def load_bnb_quantized_model(
118
- base_model: str,
119
- *,
120
- adapter_path: Optional[str],
121
- device: str,
122
- local_only: bool,
123
- load_in_8bit: bool = False,
124
- load_in_4bit: bool = False,
125
- ) -> Tuple[Any, torch.nn.Module]:
126
- """
127
- bitsandbytes int8/int4 (requires bitsandbytes + CUDA). Not supported on CPU/MPS.
128
- """
129
- if BitsAndBytesConfig is None:
130
- raise RuntimeError("transformers BitsAndBytesConfig not available; upgrade transformers or install extras.")
131
- if device != "cuda":
132
- raise RuntimeError("bitsandbytes quantization requires CUDA (device=cuda).")
133
- if not (load_in_8bit or load_in_4bit):
134
- raise ValueError("Specify load_in_8bit or load_in_4bit.")
135
-
136
- tok = _load_tokenizer(base_model, local_only=local_only)
137
- qconf = BitsAndBytesConfig(load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
138
- model = AutoModelForSeq2SeqLM.from_pretrained(
139
- base_model,
140
- local_files_only=local_only,
141
- quantization_config=qconf,
142
- device_map="auto",
143
- )
144
- if adapter_path:
145
- if PeftModel is None:
146
- raise RuntimeError("peft is required to load adapters.")
147
- model = PeftModel.from_pretrained(model, adapter_path)
148
- model.eval()
149
- return tok, model
150
-
151
-
152
- def save_quant_artifact(
153
- out_dir: str | Path,
154
- *,
155
- mode: str,
156
- base_model: str,
157
- adapter_path: Optional[str],
158
- tokenizer: Any,
159
- model: torch.nn.Module,
160
- ) -> QuantArtifact:
161
- out = Path(out_dir)
162
- out.mkdir(parents=True, exist_ok=True)
163
- (out / "tokenizer").mkdir(exist_ok=True)
164
-
165
- tokenizer.save_pretrained(out / "tokenizer")
166
- torch.save(model.state_dict(), out / "model.pt")
167
-
168
- meta: Dict[str, Any] = {
169
- "mode": mode,
170
- "base_model": base_model,
171
- "adapter_path": adapter_path,
172
- "created_at_s": time.time(),
173
- "estimated_model_bytes": estimate_model_bytes(model),
174
- }
175
- (out / "meta.json").write_text(json.dumps(meta, indent=2))
176
-
177
- return QuantArtifact(
178
- out_dir=out,
179
- mode=mode,
180
- base_model=base_model,
181
- adapter_path=adapter_path,
182
- created_at_s=float(meta["created_at_s"]),
183
- )
184
-
185
-
186
- def load_quant_artifact(
187
- artifact_dir: str | Path,
188
- *,
189
- device: str = "cpu",
190
- local_only: bool = True,
191
- ) -> Tuple[Any, torch.nn.Module, Dict[str, Any]]:
192
- """
193
- Loads a previously exported quant artifact.
194
- For dynamic quant modes, we reconstruct the architecture, apply the same quantization,
195
- then load the saved state_dict.
196
- """
197
- adir = Path(artifact_dir)
198
- meta = json.loads((adir / "meta.json").read_text())
199
- mode = meta["mode"]
200
- base_model = meta["base_model"]
201
-
202
- tok = AutoTokenizer.from_pretrained(adir / "tokenizer", local_files_only=True)
203
- if tok.pad_token_id is None and getattr(tok, "eos_token_id", None) is not None:
204
- tok.pad_token = tok.eos_token
205
-
206
- model = AutoModelForSeq2SeqLM.from_pretrained(base_model, local_files_only=local_only).to(device)
207
- model.eval()
208
-
209
- if mode == "int8_dynamic":
210
- model = quantize_dynamic_int8(model)
211
- elif mode == "int8_decoder_dynamic":
212
- model = quantize_dynamic_int8_decoder_only(model)
213
- elif mode in {"fp32"}:
214
- pass
215
- else:
216
- raise RuntimeError(f"Unsupported artifact mode for local loading: {mode}")
217
-
218
- state = torch.load(adir / "model.pt", map_location=device)
219
- model.load_state_dict(state, strict=False)
220
- model.to(device)
221
- model.eval()
222
- return tok, model, meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/quantized_text2sql_engine.py DELETED
@@ -1,243 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import sqlite3
5
- import threading
6
- import time
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
- from collections import OrderedDict
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Sequence, Tuple
11
-
12
- import torch
13
-
14
- from src.quantization_utils import load_quant_artifact
15
- from src.schema_encoder import SchemaEncoder
16
- from src.sql_validator import validate_sql_schema
17
-
18
- # ==========================================
19
- # RELATIVE PATH RESOLUTION (GLOBAL)
20
- # ==========================================
21
- PROJECT_ROOT = Path(__file__).resolve().parent.parent
22
-
23
- if (PROJECT_ROOT / "data" / "database").exists():
24
- DB_ROOT = PROJECT_ROOT / "data" / "database"
25
- else:
26
- DB_ROOT = PROJECT_ROOT / "final_databases"
27
-
28
-
29
- class QuantizedText2SQLEngine:
30
- def __init__(
31
- self,
32
- artifact_dir: str,
33
- *,
34
- device: str = "cpu",
35
- use_constrained: bool = False,
36
- exec_workers: int | None = None,
37
- default_timeout_s: float = 2.0,
38
- use_cache: bool = True,
39
- cache_max_entries: int = 50_000,
40
- ):
41
- self.device = device
42
- self.use_constrained = bool(use_constrained)
43
- self.tokenizer, self.model, self.meta = load_quant_artifact(artifact_dir, device=device, local_only=True)
44
- self.schema_encoder = SchemaEncoder(DB_ROOT)
45
-
46
- if exec_workers is None:
47
- exec_workers = int(os.environ.get("SQL_EXEC_WORKERS", "8"))
48
-
49
- self.exec_pool = ThreadPoolExecutor(max_workers=int(exec_workers))
50
- self.default_timeout_s = float(default_timeout_s)
51
- self.use_cache = bool(use_cache)
52
- self.cache_max_entries = int(cache_max_entries)
53
- self._cache: "OrderedDict[tuple[str, str], tuple[list, list]]" = OrderedDict()
54
- self._cache_lock = threading.Lock()
55
- self._stats_lock = threading.Lock()
56
- self._exec_cache_hits = 0
57
- self._exec_cache_misses = 0
58
- self._exec_calls = 0
59
- self._tls = threading.local()
60
-
61
- def _get_db_path(self, db_id: str) -> str:
62
- """Smart resolver for flat vs nested database folders"""
63
- path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
64
- path2 = DB_ROOT / f"{db_id}.sqlite"
65
- return str(path1) if path1.exists() else str(path2)
66
-
67
- def build_prompt(self, question: str, db_id: str) -> str:
68
- schema = self.schema_encoder.structured_schema(db_id)
69
- return (
70
- "You are a SQLite expert.\n\n"
71
- f"Database: {db_id}\n\n"
72
- "Schema:\n"
73
- f"{schema}\n\n"
74
- "Question:\n"
75
- f"{question}\n\n"
76
- "SQL:"
77
- )
78
-
79
- def generate_sql_batch(
80
- self,
81
- pairs: Sequence[Tuple[str, str]],
82
- *,
83
- max_new_tokens: int = 120,
84
- num_beams: int = 8,
85
- repetition_penalty: float = 1.2,
86
- ) -> List[str]:
87
- prompts = [self.build_prompt(q, db_id) for q, db_id in pairs]
88
-
89
- if self.use_constrained:
90
- from transformers.generation.logits_process import LogitsProcessorList
91
- from src.constrained_decoding import SchemaConstrainedLogitsProcessor
92
-
93
- sqls: List[str] = []
94
- for (q, db_id), prompt in zip(pairs, prompts):
95
- db_path = self._get_db_path(db_id)
96
- enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
97
- proc = LogitsProcessorList([SchemaConstrainedLogitsProcessor(self.tokenizer, db_path)])
98
-
99
- out = self.model.generate(
100
- **enc,
101
- max_new_tokens=int(max_new_tokens),
102
- num_beams=int(num_beams),
103
- repetition_penalty=float(repetition_penalty),
104
- logits_processor=proc,
105
- )
106
- sqls.append(self.tokenizer.decode(out[0], skip_special_tokens=True).strip())
107
- return sqls
108
-
109
- enc = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
110
- out = self.model.generate(
111
- **enc,
112
- max_new_tokens=int(max_new_tokens),
113
- num_beams=int(num_beams),
114
- repetition_penalty=float(repetition_penalty),
115
- )
116
- return [self.tokenizer.decode(x, skip_special_tokens=True).strip() for x in out]
117
-
118
- def _get_thread_conn(self, db_path: str) -> sqlite3.Connection:
119
- conns = getattr(self._tls, "conns", None)
120
- if conns is None:
121
- conns = {}
122
- self._tls.conns = conns
123
- conn = conns.get(db_path)
124
- if conn is None:
125
- conn = sqlite3.connect(db_path)
126
- conn.text_factory = lambda b: b.decode(errors="ignore")
127
- conns[db_path] = conn
128
- return conn
129
-
130
- def _cache_get(self, key: tuple[str, str]) -> tuple[list, list] | None:
131
- if not self.use_cache: return None
132
- with self._cache_lock:
133
- hit = self._cache.get(key)
134
- if hit is None: return None
135
- self._cache.move_to_end(key)
136
- return hit
137
-
138
- def _cache_put(self, key: tuple[str, str], value: tuple[list, list]) -> None:
139
- if not self.use_cache: return
140
- with self._cache_lock:
141
- self._cache[key] = value
142
- self._cache.move_to_end(key)
143
- while len(self._cache) > self.cache_max_entries:
144
- self._cache.popitem(last=False)
145
-
146
- def _execute_one(self, sql: str, db_path: str, timeout_s: float | None = None):
147
- timeout_s = float(self.default_timeout_s if timeout_s is None else timeout_s)
148
- key = (db_path, sql)
149
- cached = self._cache_get(key)
150
-
151
- with self._stats_lock: self._exec_calls += 1
152
-
153
- if cached is not None:
154
- with self._stats_lock: self._exec_cache_hits += 1
155
- return cached
156
-
157
- with self._stats_lock: self._exec_cache_misses += 1
158
-
159
- conn = self._get_thread_conn(db_path)
160
- start_t = time.monotonic()
161
-
162
- def handler():
163
- return 1 if (time.monotonic() - start_t) > timeout_s else 0
164
-
165
- conn.set_progress_handler(handler, 10_000)
166
- cur = conn.cursor()
167
- cur.execute(sql)
168
- rows = cur.fetchall()
169
- cols = [d[0] for d in cur.description] if cur.description else []
170
- out = (rows, cols)
171
- self._cache_put(key, out)
172
- return out
173
-
174
- def stats(self) -> Dict[str, Any]:
175
- with self._stats_lock:
176
- calls, hits, misses = int(self._exec_calls), int(self._exec_cache_hits), int(self._exec_cache_misses)
177
-
178
- hit_rate = (hits / calls) if calls else 0.0
179
- return {
180
- "exec_calls": calls, "exec_cache_hits": hits, "exec_cache_misses": misses,
181
- "exec_cache_hit_rate": float(hit_rate), "use_cache": bool(self.use_cache),
182
- "exec_workers": int(getattr(self.exec_pool, "_max_workers", 0) or 0),
183
- }
184
-
185
- def reset_stats(self) -> None:
186
- with self._stats_lock:
187
- self._exec_calls = self._exec_cache_hits = self._exec_cache_misses = 0
188
-
189
- def execute_sql(self, sql: str, db_id: str, *, timeout_s: float | None = None, validate_schema: bool = True):
190
- db_path = self._get_db_path(db_id)
191
- if validate_schema:
192
- try: ok, _ = validate_sql_schema(sql, db_path)
193
- except Exception: ok = False
194
- if not ok: raise ValueError("Invalid schema")
195
- return self._execute_one(sql, db_path, timeout_s=timeout_s)
196
-
197
- def ask(
198
- self,
199
- question: str,
200
- db_id: str,
201
- *,
202
- max_new_tokens: int = 120,
203
- num_beams: int = 8,
204
- repetition_penalty: float = 1.2,
205
- timeout_s: float | None = None,
206
- ) -> Dict[str, Any]:
207
- sql = self.generate_sql_batch(
208
- [(question, db_id)],
209
- max_new_tokens=max_new_tokens,
210
- num_beams=num_beams,
211
- repetition_penalty=repetition_penalty,
212
- )[0]
213
-
214
- db_path = self._get_db_path(db_id)
215
-
216
- try: ok, _ = validate_sql_schema(sql, db_path)
217
- except Exception: ok = False
218
-
219
- if not ok: return {"sql": sql, "rows": [], "columns": [], "error": "Invalid schema"}
220
-
221
- try:
222
- rows, cols = self._execute_one(sql, db_path, timeout_s=timeout_s)
223
- return {"sql": sql, "rows": rows, "columns": cols, "error": None}
224
- except Exception as e:
225
- return {"sql": sql, "rows": [], "columns": [], "error": str(e)}
226
-
227
- def ask_batch_execute(self, pairs: Sequence[Tuple[str, str]]) -> List[Dict[str, Any]]:
228
- sqls = self.generate_sql_batch(pairs)
229
- results: List[Dict[str, Any]] = []
230
- futures = {}
231
- for (q, db_id), sql in zip(pairs, sqls):
232
- db_path = self._get_db_path(db_id)
233
- futures[self.exec_pool.submit(self._execute_one, sql, db_path)] = (sql, db_id)
234
-
235
- for fut in as_completed(futures):
236
- sql, db_id = futures[fut]
237
- try:
238
- rows, cols = fut.result()
239
- results.append({"db_id": db_id, "sql": sql, "rows": rows, "columns": cols, "error": None})
240
- except Exception as e:
241
- results.append({"db_id": db_id, "sql": sql, "rows": [], "columns": [], "error": str(e)})
242
-
243
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/schema_encoder.py CHANGED
@@ -1,38 +1,54 @@
1
  import sqlite3
2
- from pathlib import Path
3
 
4
  class SchemaEncoder:
 
5
  def __init__(self, db_root):
6
- self.db_root = Path(db_root)
7
-
8
- def _get_db_path(self, db_id: str) -> Path:
9
- # Check standard Spider format (subfolder)
10
- path1 = self.db_root / db_id / f"{db_id}.sqlite"
11
- # Check flat format (no subfolder)
12
- path2 = self.db_root / f"{db_id}.sqlite"
13
-
14
- if path1.exists():
15
- return path1
16
- if path2.exists():
17
- return path2
18
-
19
- raise FileNotFoundError(f"unable to open database file. Looked in:\n1. {path1}\n2. {path2}")
20
-
21
- def structured_schema(self, db_id: str) -> str:
22
- db_path = self._get_db_path(db_id)
23
-
24
- conn = sqlite3.connect(str(db_path))
25
- cur = conn.cursor()
26
-
27
- # Get all tables
28
- cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
29
- tables = [r[0] for r in cur.fetchall() if r[0] != "sqlite_sequence"]
30
-
31
- schema_str = ""
32
- for table in tables:
33
- cur.execute(f"PRAGMA table_info(`{table}`);")
34
- cols = [c[1] for c in cur.fetchall()]
35
- schema_str += f"{table} ({', '.join(cols)})\n"
36
-
37
  conn.close()
38
- return schema_str.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import sqlite3
2
+
3
 
4
  class SchemaEncoder:
5
+
6
  def __init__(self, db_root):
7
+ self.db_root = db_root
8
+
9
+ def get_tables_and_columns(self, db_id):
10
+
11
+ # FIXED PATH
12
+ db_path = self.db_root / f"{db_id}.sqlite"
13
+
14
+ conn = sqlite3.connect(db_path)
15
+ cursor = conn.cursor()
16
+
17
+ tables = cursor.execute(
18
+ "SELECT name FROM sqlite_master WHERE type='table';"
19
+ ).fetchall()
20
+
21
+ schema = {}
22
+
23
+ for (table,) in tables:
24
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
25
+ col_names = [c[1] for c in cols]
26
+ schema[table] = col_names
27
+
 
 
 
 
 
 
 
 
 
 
28
  conn.close()
29
+ return schema
30
+
31
+ # -----------------------------------
32
+ # Strategy 1: Structured
33
+ # -----------------------------------
34
+ def structured_schema(self, db_id):
35
+ schema = self.get_tables_and_columns(db_id)
36
+
37
+ lines = []
38
+ for table, cols in schema.items():
39
+ lines.append(f"{table}({', '.join(cols)})")
40
+
41
+ return "\n".join(lines)
42
+
43
+ # -----------------------------------
44
+ # Strategy 2: Natural Language
45
+ # -----------------------------------
46
+ def natural_language_schema(self, db_id):
47
+ schema = self.get_tables_and_columns(db_id)
48
+
49
+ lines = []
50
+ for table, cols in schema.items():
51
+ col_text = ", ".join(cols)
52
+ lines.append(f"The table '{table}' contains the columns: {col_text}.")
53
+
54
+ return "\n".join(lines)
src/schema_utils.py DELETED
@@ -1,222 +0,0 @@
1
- # import os
2
- # import sqlite3
3
- # import threading
4
- # from typing import Dict, List, Set, Tuple
5
-
6
- # def get_schema(db_path):
7
- # schema_map = get_table_to_columns(db_path)
8
- # schema_text = ""
9
- # for table, col_names in schema_map.items():
10
- # schema_text += f"{table}({', '.join(col_names)})\n"
11
- # return schema_text
12
-
13
- # _SCHEMA_LOCK = threading.Lock()
14
- # _SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
15
-
16
- # def _db_state_fingerprint(db_path: str) -> str:
17
- # try:
18
- # st = os.stat(db_path)
19
- # return f"{st.st_mtime_ns}:{st.st_size}"
20
- # except OSError:
21
- # return "missing"
22
-
23
- # def _connect_readonly(db_path: str) -> sqlite3.Connection:
24
- # uri = f"file:{os.path.abspath(db_path)}?mode=ro"
25
- # conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
26
- # conn.execute("PRAGMA query_only = ON;")
27
- # conn.execute("PRAGMA foreign_keys = ON;")
28
- # return conn
29
-
30
- # def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
31
- # """
32
- # Return mapping of table -> column names for the SQLite DB at db_path.
33
- # Tables and columns are returned lowercased.
34
- # """
35
- # fp = _db_state_fingerprint(db_path)
36
- # with _SCHEMA_LOCK:
37
- # cached = _SCHEMA_CACHE.get(db_path)
38
- # if cached is not None and cached[0] == fp:
39
- # return cached[1]
40
-
41
- # schema: Dict[str, List[str]] = {}
42
- # with _connect_readonly(db_path) as conn:
43
- # cur = conn.execute(
44
- # "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
45
- # )
46
- # tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
47
- # for table in tables:
48
- # table_l = table.lower()
49
- # try:
50
- # cur = conn.execute(f'PRAGMA table_info("{table}")')
51
- # cols = [row[1].lower() for row in cur.fetchall() if row and isinstance(row[1], str)]
52
- # schema[table_l] = cols
53
- # except sqlite3.Error:
54
- # schema[table_l] = []
55
-
56
- # with _SCHEMA_LOCK:
57
- # _SCHEMA_CACHE[db_path] = (fp, schema)
58
- # return schema
59
-
60
- # def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
61
- # schema = get_table_to_columns(db_path)
62
- # tables = set(schema.keys())
63
- # columns: Set[str] = set()
64
- # for cols in schema.values():
65
- # columns.update(cols)
66
- # return tables, columns
67
-
68
-
69
- import os
70
- import sqlite3
71
- import threading
72
- from typing import Dict, List, Set, Tuple
73
-
74
-
75
- # ===============================
76
- # 🔥 SCHEMA TEXT (for prompting)
77
- # ===============================
78
- def get_schema(db_path: str) -> str:
79
- schema_map = get_table_to_columns(db_path)
80
- schema_text = ""
81
-
82
- for table, col_names in schema_map.items():
83
- schema_text += f"{table}({', '.join(col_names)})\n"
84
-
85
- return schema_text
86
-
87
-
88
- # ===============================
89
- # 🔥 CACHE + LOCK
90
- # ===============================
91
- _SCHEMA_LOCK = threading.Lock()
92
- _SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
93
-
94
-
95
- def _db_state_fingerprint(db_path: str) -> str:
96
- try:
97
- st = os.stat(db_path)
98
- return f"{st.st_mtime_ns}:{st.st_size}"
99
- except OSError:
100
- return "missing"
101
-
102
-
103
- def _connect_readonly(db_path: str) -> sqlite3.Connection:
104
- uri = f"file:{os.path.abspath(db_path)}?mode=ro"
105
- conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
106
- conn.execute("PRAGMA query_only = ON;")
107
- conn.execute("PRAGMA foreign_keys = ON;")
108
- return conn
109
-
110
-
111
- # ===============================
112
- # 🔥 CORE: TABLE → COLUMNS
113
- # ===============================
114
- def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
115
- """
116
- Return mapping of table -> column names (ONLY names, no types).
117
- """
118
- fp = _db_state_fingerprint(db_path)
119
-
120
- with _SCHEMA_LOCK:
121
- cached = _SCHEMA_CACHE.get(db_path)
122
- if cached is not None and cached[0] == fp:
123
- return cached[1]
124
-
125
- schema: Dict[str, List[str]] = {}
126
-
127
- with _connect_readonly(db_path) as conn:
128
- cur = conn.execute(
129
- "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
130
- )
131
-
132
- tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
133
-
134
- for table in tables:
135
- table_l = table.lower()
136
-
137
- try:
138
- cur = conn.execute(f'PRAGMA table_info("{table}")')
139
-
140
- cols = []
141
- for row in cur.fetchall():
142
- col_name = row[1].lower()
143
- cols.append(col_name)
144
-
145
- schema[table_l] = list(set(cols)) # remove duplicates
146
-
147
- except sqlite3.Error:
148
- schema[table_l] = []
149
-
150
- with _SCHEMA_LOCK:
151
- _SCHEMA_CACHE[db_path] = (fp, schema)
152
-
153
- return schema
154
-
155
-
156
- # ===============================
157
- # 🔥 TABLE + COLUMN SETS
158
- # ===============================
159
- def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
160
- schema = get_table_to_columns(db_path)
161
-
162
- tables = set(schema.keys())
163
- columns: Set[str] = set()
164
-
165
- for cols in schema.values():
166
- columns.update(cols)
167
-
168
- return tables, columns
169
-
170
-
171
- # ===============================
172
- # 🔥 FOREIGN KEYS (IMPORTANT)
173
- # ===============================
174
- def get_foreign_keys(db_path: str) -> List[Tuple[str, str, str, str]]:
175
- """
176
- Returns list of foreign key relations:
177
- (table, column, ref_table, ref_column)
178
- """
179
- fks = []
180
-
181
- with _connect_readonly(db_path) as conn:
182
- cur = conn.execute(
183
- "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
184
- )
185
- tables = [r[0] for r in cur.fetchall()]
186
-
187
- for table in tables:
188
- try:
189
- cur = conn.execute(f'PRAGMA foreign_key_list("{table}")')
190
-
191
- for row in cur.fetchall():
192
- fks.append((
193
- table.lower(),
194
- row[3].lower(), # column
195
- row[2].lower(), # referenced table
196
- row[4].lower() # referenced column
197
- ))
198
-
199
- except sqlite3.Error:
200
- continue
201
-
202
- return fks
203
-
204
-
205
- # ===============================
206
- # 🔥 FINAL: CONSTRAINT GRAPH
207
- # ===============================
208
- def get_constraint_graph(db_path: str):
209
- """
210
- Build schema constraint graph:
211
- - tables
212
- - columns
213
- - foreign key relations
214
- """
215
- tables, columns = get_db_tables_and_columns(db_path)
216
- fks = get_foreign_keys(db_path)
217
-
218
- return {
219
- "tables": tables,
220
- "columns": columns,
221
- "foreign_keys": fks
222
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sql_validator.py CHANGED
@@ -1,209 +1,6 @@
1
- # import re
2
- # from pathlib import Path
3
- # from typing import Optional, Set, Tuple
4
-
5
- # from schema_utils import get_db_tables_and_columns, get_table_to_columns
6
-
7
- # class SQLValidator:
8
-
9
- # def __init__(self, db_root):
10
- # self.db_root = Path(db_root)
11
-
12
- # # ---------------------------
13
- # # Load schema
14
- # # ---------------------------
15
- # def load_schema(self, db_id):
16
- # db_path = self.db_root / db_id / f"{db_id}.sqlite"
17
- # return get_table_to_columns(str(db_path))
18
-
19
-
20
- # # ---------------------------
21
- # # Basic syntax check
22
- # # ---------------------------
23
- # def basic_structure_valid(self, sql):
24
- # s = sql.lower()
25
-
26
- # if "select" not in s or "from" not in s:
27
- # return False, "Missing SELECT or FROM"
28
-
29
- # if len(s.split()) < 4:
30
- # return False, "Too short to be SQL"
31
-
32
- # return True, None
33
-
34
-
35
- # # ---------------------------
36
- # # Extract identifiers
37
- # # ---------------------------
38
- # def extract_identifiers(self, sql):
39
- # tokens = re.findall(r"[A-Za-z_]+", sql.lower())
40
- # return set(tokens)
41
-
42
-
43
- # # ---------------------------
44
- # # Table validation
45
- # # ---------------------------
46
- # def validate_tables(self, sql, schema):
47
- # words = self.extract_identifiers(sql)
48
- # tables = set(schema.keys())
49
-
50
- # used_tables = [w for w in words if w in tables]
51
-
52
- # if not used_tables:
53
- # return False, "No valid table used"
54
-
55
- # return True, None
56
-
57
-
58
- # # ---------------------------
59
- # # Column validation
60
- # # ---------------------------
61
- # def validate_columns(self, sql, schema):
62
- # words = self.extract_identifiers(sql)
63
-
64
- # valid_columns = set()
65
- # for cols in schema.values():
66
- # valid_columns.update(cols)
67
-
68
- # # ignore SQL keywords
69
- # keywords = {
70
- # "select","from","where","join","on","group","by",
71
- # "order","limit","count","sum","avg","min","max",
72
- # "and","or","in","like","distinct","asc","desc"
73
- # }
74
-
75
- # invalid = []
76
- # for w in words:
77
- # if w not in valid_columns and w not in schema and w not in keywords:
78
- # if not w.isdigit():
79
- # invalid.append(w)
80
-
81
- # # allow small hallucinations but block many
82
- # if len(invalid) > 3:
83
- # return False, f"Too many unknown identifiers: {invalid[:5]}"
84
-
85
- # return True, None
86
-
87
-
88
- # # ---------------------------
89
- # # Dangerous query protection
90
- # # ---------------------------
91
- # def block_dangerous(self, sql):
92
- # bad = ["drop", "delete", "update", "insert", "alter"]
93
-
94
- # s = sql.lower()
95
- # for b in bad:
96
- # if b in s:
97
- # return False, f"Dangerous keyword detected: {b}"
98
-
99
- # return True, None
100
-
101
-
102
- # # ---------------------------
103
- # # Main validation
104
- # # ---------------------------
105
- # def validate(self, sql, db_id):
106
-
107
- # schema = self.load_schema(db_id)
108
-
109
- # checks = [
110
- # self.block_dangerous(sql),
111
- # self.basic_structure_valid(sql),
112
- # self.validate_tables(sql, schema),
113
- # self.validate_columns(sql, schema),
114
- # ]
115
-
116
- # for ok, msg in checks:
117
- # if not ok:
118
- # return False, msg
119
-
120
- # return True, None
121
-
122
-
123
- # _VALIDATION_CACHE = {}
124
- # _VALIDATION_CACHE_MAX = 100_000
125
-
126
-
127
- # def _db_state_fingerprint(db_path: str) -> str:
128
- # try:
129
- # st = Path(db_path).stat()
130
- # return f"{st.st_mtime_ns}:{st.st_size}"
131
- # except OSError:
132
- # return "missing"
133
-
134
-
135
- # def _extract_referenced_tables(sql: str) -> Set[str]:
136
- # # Best-effort: FROM/JOIN targets (unquoted identifiers).
137
- # tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I)
138
- # return {t[1].lower() for t in tokens if t and len(t) > 1}
139
-
140
-
141
- # def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]:
142
- # """
143
- # Strict schema validation for reward computation.
144
- # - References must resolve to real tables/columns in the target DB.
145
- # - Returns (ok, message). On failure, message is a short reason.
146
- # """
147
- # fp = _db_state_fingerprint(db_path)
148
- # key = f"{fp}|{sql}"
149
- # cached = _VALIDATION_CACHE.get(key)
150
- # if cached is not None:
151
- # return cached
152
-
153
- # valid_tables, valid_columns = get_db_tables_and_columns(db_path)
154
-
155
- # referenced_tables = _extract_referenced_tables(sql)
156
- # unknown_tables = sorted(t for t in referenced_tables if t not in valid_tables)
157
- # if unknown_tables:
158
- # out = (False, f"Unknown table(s): {unknown_tables[:5]}")
159
- # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
160
- # _VALIDATION_CACHE.clear()
161
- # _VALIDATION_CACHE[key] = out
162
- # return out
163
-
164
- # # Column-level correctness is hard to do reliably with regex alone; rely on SQLite compilation.
165
- # # This does not execute the query, but will fail for unknown tables/columns.
166
- # try:
167
- # import sqlite3 # local import to keep module lightweight
168
-
169
- # uri = f"file:{Path(db_path).resolve()}?mode=ro"
170
- # conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
171
- # try:
172
- # conn.execute("PRAGMA query_only = ON;")
173
- # conn.execute("PRAGMA foreign_keys = ON;")
174
- # conn.execute(f"EXPLAIN QUERY PLAN {sql}")
175
- # finally:
176
- # conn.close()
177
- # except Exception as e:
178
- # msg = str(e).lower()
179
- # if "no such table" in msg:
180
- # out = (False, "Unknown table")
181
- # elif "no such column" in msg:
182
- # out = (False, "Unknown column")
183
- # else:
184
- # out = (False, "Schema validation failed")
185
-
186
- # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
187
- # _VALIDATION_CACHE.clear()
188
- # _VALIDATION_CACHE[key] = out
189
- # return out
190
-
191
- # out = (True, None)
192
- # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
193
- # _VALIDATION_CACHE.clear()
194
- # _VALIDATION_CACHE[key] = out
195
- # return out
196
-
197
-
198
-
199
-
200
-
201
  import re
202
  from pathlib import Path
203
- from typing import Optional, Set, Tuple, Dict, List
204
-
205
- from src.schema_utils import get_db_tables_and_columns, get_table_to_columns, get_constraint_graph
206
-
207
 
208
  class SQLValidator:
209
 
@@ -215,7 +12,23 @@ class SQLValidator:
215
  # ---------------------------
216
  def load_schema(self, db_id):
217
  db_path = self.db_root / db_id / f"{db_id}.sqlite"
218
- return get_table_to_columns(str(db_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  # ---------------------------
221
  # Basic syntax check
@@ -231,13 +44,15 @@ class SQLValidator:
231
 
232
  return True, None
233
 
 
234
  # ---------------------------
235
  # Extract identifiers
236
  # ---------------------------
237
  def extract_identifiers(self, sql):
238
- tokens = re.findall(r"[A-Za-z_][A-Za-z0-9_]*", sql.lower())
239
  return set(tokens)
240
 
 
241
  # ---------------------------
242
  # Table validation
243
  # ---------------------------
@@ -252,6 +67,7 @@ class SQLValidator:
252
 
253
  return True, None
254
 
 
255
  # ---------------------------
256
  # Column validation
257
  # ---------------------------
@@ -262,29 +78,26 @@ class SQLValidator:
262
  for cols in schema.values():
263
  valid_columns.update(cols)
264
 
 
265
  keywords = {
266
  "select","from","where","join","on","group","by",
267
  "order","limit","count","sum","avg","min","max",
268
- "and","or","in","like","distinct","asc","desc",
269
- "having","as","inner","left","right","outer"
270
  }
271
 
272
  invalid = []
273
  for w in words:
274
- if (
275
- w not in valid_columns
276
- and w not in schema
277
- and w not in keywords
278
- and not w.isdigit()
279
- ):
280
- invalid.append(w)
281
-
282
- # stricter than before
283
- if len(invalid) > 2:
284
- return False, f"Unknown identifiers: {invalid[:5]}"
285
 
286
  return True, None
287
 
 
288
  # ---------------------------
289
  # Dangerous query protection
290
  # ---------------------------
@@ -298,18 +111,6 @@ class SQLValidator:
298
 
299
  return True, None
300
 
301
- # ---------------------------
302
- # FK-aware JOIN validation (NEW 🔥)
303
- # ---------------------------
304
- def validate_joins(self, db_id):
305
- db_path = self.db_root / db_id / f"{db_id}.sqlite"
306
- graph = get_constraint_graph(str(db_path))
307
-
308
- # not strict enforcement, just check FK existence
309
- if len(graph["foreign_keys"]) == 0:
310
- return True, None
311
-
312
- return True, None # placeholder (safe for now)
313
 
314
  # ---------------------------
315
  # Main validation
@@ -330,86 +131,3 @@ class SQLValidator:
330
  return False, msg
331
 
332
  return True, None
333
-
334
-
335
- # ===============================
336
- # 🔥 FAST SCHEMA VALIDATION (REWARD)
337
- # ===============================
338
- _VALIDATION_CACHE = {}
339
- _VALIDATION_CACHE_MAX = 100_000
340
-
341
-
342
- def _db_state_fingerprint(db_path: str) -> str:
343
- try:
344
- st = Path(db_path).stat()
345
- return f"{st.st_mtime_ns}:{st.st_size}"
346
- except OSError:
347
- return "missing"
348
-
349
-
350
- def _extract_referenced_tables(sql: str) -> Set[str]:
351
- tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I)
352
- return {t[1].lower() for t in tokens if t and len(t) > 1}
353
-
354
-
355
- def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]:
356
- """
357
- STRICT schema validation (Task 3 core)
358
- """
359
-
360
- fp = _db_state_fingerprint(db_path)
361
- key = f"{fp}|{sql}"
362
-
363
- cached = _VALIDATION_CACHE.get(key)
364
- if cached is not None:
365
- return cached
366
-
367
- valid_tables, valid_columns = get_db_tables_and_columns(db_path)
368
-
369
- # ---------------------------
370
- # Table validation
371
- # ---------------------------
372
- referenced_tables = _extract_referenced_tables(sql)
373
-
374
- unknown_tables = [t for t in referenced_tables if t not in valid_tables]
375
-
376
- if unknown_tables:
377
- out = (False, f"Unknown table(s): {unknown_tables[:3]}")
378
- _VALIDATION_CACHE[key] = out
379
- return out
380
-
381
- # ---------------------------
382
- # Column validation via SQLite planner
383
- # ---------------------------
384
- try:
385
- import sqlite3
386
-
387
- uri = f"file:{Path(db_path).resolve()}?mode=ro"
388
- conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
389
-
390
- try:
391
- conn.execute("PRAGMA query_only = ON;")
392
- conn.execute("PRAGMA foreign_keys = ON;")
393
-
394
- # 🔥 Key idea: no execution, only planning
395
- conn.execute(f"EXPLAIN QUERY PLAN {sql}")
396
-
397
- finally:
398
- conn.close()
399
-
400
- except Exception as e:
401
- msg = str(e).lower()
402
-
403
- if "no such table" in msg:
404
- out = (False, "Unknown table")
405
- elif "no such column" in msg:
406
- out = (False, "Unknown column")
407
- else:
408
- out = (False, "Invalid SQL")
409
-
410
- _VALIDATION_CACHE[key] = out
411
- return out
412
-
413
- out = (True, None)
414
- _VALIDATION_CACHE[key] = out
415
- return out
 
1
+ import sqlite3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import re
3
  from pathlib import Path
 
 
 
 
4
 
5
  class SQLValidator:
6
 
 
12
  # ---------------------------
13
  def load_schema(self, db_id):
14
  db_path = self.db_root / db_id / f"{db_id}.sqlite"
15
+
16
+ conn = sqlite3.connect(db_path)
17
+ cursor = conn.cursor()
18
+
19
+ tables = cursor.execute(
20
+ "SELECT name FROM sqlite_master WHERE type='table';"
21
+ ).fetchall()
22
+
23
+ schema = {}
24
+
25
+ for (table,) in tables:
26
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
27
+ schema[table.lower()] = [c[1].lower() for c in cols]
28
+
29
+ conn.close()
30
+ return schema
31
+
32
 
33
  # ---------------------------
34
  # Basic syntax check
 
44
 
45
  return True, None
46
 
47
+
48
  # ---------------------------
49
  # Extract identifiers
50
  # ---------------------------
51
  def extract_identifiers(self, sql):
52
+ tokens = re.findall(r"[A-Za-z_]+", sql.lower())
53
  return set(tokens)
54
 
55
+
56
  # ---------------------------
57
  # Table validation
58
  # ---------------------------
 
67
 
68
  return True, None
69
 
70
+
71
  # ---------------------------
72
  # Column validation
73
  # ---------------------------
 
78
  for cols in schema.values():
79
  valid_columns.update(cols)
80
 
81
+ # ignore SQL keywords
82
  keywords = {
83
  "select","from","where","join","on","group","by",
84
  "order","limit","count","sum","avg","min","max",
85
+ "and","or","in","like","distinct","asc","desc"
 
86
  }
87
 
88
  invalid = []
89
  for w in words:
90
+ if w not in valid_columns and w not in schema and w not in keywords:
91
+ if not w.isdigit():
92
+ invalid.append(w)
93
+
94
+ # allow small hallucinations but block many
95
+ if len(invalid) > 3:
96
+ return False, f"Too many unknown identifiers: {invalid[:5]}"
 
 
 
 
97
 
98
  return True, None
99
 
100
+
101
  # ---------------------------
102
  # Dangerous query protection
103
  # ---------------------------
 
111
 
112
  return True, None
113
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # ---------------------------
116
  # Main validation
 
131
  return False, msg
132
 
133
  return True, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/text2sql_engine.py CHANGED
@@ -1,223 +1,3 @@
1
- # import sqlite3
2
- # import torch
3
- # import re
4
- # import time
5
- # from pathlib import Path
6
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
- # from peft import PeftModel
8
- # from src.sql_validator import SQLValidator
9
- # from src.schema_encoder import SchemaEncoder
10
-
11
- # PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
-
13
- # # ================================
14
- # # DATABASE PATH AUTO DETECTION
15
- # # ================================
16
- # if (PROJECT_ROOT / "data/database").exists():
17
- # DB_ROOT = PROJECT_ROOT / "data/database"
18
- # else:
19
- # DB_ROOT = PROJECT_ROOT / "final_databases"
20
-
21
-
22
- # def normalize_question(q: str):
23
- # q = q.lower().strip()
24
- # q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
25
- # q = re.sub(r"\s+", " ", q)
26
- # return q
27
-
28
-
29
- # def semantic_fix(question, sql):
30
- # q = question.lower().strip()
31
- # s = sql.lower()
32
-
33
- # num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q)
34
-
35
- # if num_match and "limit" not in s and "count(" not in s:
36
- # limit_val = num_match.group(1)
37
- # sql = sql.rstrip(";")
38
- # sql = f"{sql.strip()} LIMIT {limit_val}"
39
-
40
- # return sql
41
-
42
-
43
- # class Text2SQLEngine:
44
- # def __init__(self,
45
- # adapter_path=None,
46
- # base_model_name="Salesforce/codet5-base",
47
- # use_lora=True):
48
-
49
- # self.device = "mps" if torch.backends.mps.is_available() else (
50
- # "cuda" if torch.cuda.is_available() else "cpu"
51
- # )
52
-
53
- # self.validator = SQLValidator(DB_ROOT)
54
- # self.schema_encoder = SchemaEncoder(DB_ROOT)
55
-
56
- # self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
57
-
58
- # print("Loading base model...")
59
- # base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
60
-
61
- # if not use_lora:
62
- # self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
63
- # self.model = base.to(self.device)
64
- # self.model.eval()
65
- # return
66
-
67
- # if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists():
68
- # adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model"
69
- # else:
70
- # adapter_path = PROJECT_ROOT / "best_rlhf_model"
71
-
72
- # adapter_path = adapter_path.resolve()
73
-
74
- # print("Loading tokenizer and LoRA adapter...")
75
-
76
- # try:
77
- # self.tokenizer = AutoTokenizer.from_pretrained(
78
- # str(adapter_path),
79
- # local_files_only=True
80
- # )
81
- # except Exception:
82
- # self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
83
-
84
- # self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
85
- # self.model.eval()
86
-
87
- # print("✅ RLHF model ready\n")
88
-
89
- # def build_prompt(self, question, schema):
90
- # return f"""You are an expert SQL generator.
91
- # Database schema:
92
- # {schema}
93
- # Generate a valid SQLite query for the question.
94
- # Question:
95
- # {question}
96
- # SQL:
97
- # """
98
-
99
- # def get_schema(self, db_id):
100
- # return self.schema_encoder.structured_schema(db_id)
101
-
102
- # def extract_sql(self, text: str):
103
-
104
- # text = text.strip()
105
-
106
- # if "SQL:" in text:
107
- # text = text.split("SQL:")[-1]
108
-
109
- # match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
110
-
111
- # if match:
112
- # text = match.group(0)
113
-
114
- # return text.split(";")[0].strip()
115
-
116
- # def clean_sql(self, sql: str):
117
-
118
- # sql = sql.replace('"', "'")
119
- # sql = re.sub(r"\s+", " ", sql)
120
-
121
- # return sql.strip()
122
-
123
- # def generate_sql(self, prompt):
124
-
125
- # inputs = self.tokenizer(
126
- # prompt,
127
- # return_tensors="pt",
128
- # truncation=True,
129
- # max_length=512
130
- # ).to(self.device)
131
-
132
- # with torch.no_grad():
133
-
134
- # outputs = self.model.generate(
135
- # **inputs,
136
- # max_new_tokens=128,
137
- # num_beams=5,
138
- # early_stopping=True
139
- # )
140
-
141
- # decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
142
-
143
- # return self.clean_sql(self.extract_sql(decoded))
144
-
145
- # def execute_sql(self, question, sql, db_id):
146
-
147
- # if re.search(self.dml_keywords, sql, re.IGNORECASE):
148
- # return sql, [], [], "❌ Security Alert"
149
-
150
- # # FIXED DATABASE PATH
151
- # db_path = DB_ROOT / f"{db_id}.sqlite"
152
-
153
- # sql = self.clean_sql(sql)
154
- # sql = semantic_fix(question, sql)
155
-
156
- # try:
157
-
158
- # conn = sqlite3.connect(db_path)
159
-
160
- # cursor = conn.cursor()
161
-
162
- # cursor.execute(sql)
163
-
164
- # rows = cursor.fetchall()
165
-
166
- # columns = [d[0] for d in cursor.description] if cursor.description else []
167
-
168
- # conn.close()
169
-
170
- # return sql, columns, rows, None
171
-
172
- # except Exception as e:
173
-
174
- # return sql, [], [], str(e)
175
-
176
- # def ask(self, question, db_id):
177
-
178
- # question = normalize_question(question)
179
-
180
- # if re.search(self.dml_keywords, question, re.IGNORECASE):
181
-
182
- # return {
183
- # "question": question,
184
- # "sql": "-- BLOCKED",
185
- # "columns": [],
186
- # "rows": [],
187
- # "error": "Malicious prompt"
188
- # }
189
-
190
- # schema = self.get_schema(db_id)
191
-
192
- # prompt = self.build_prompt(question, schema)
193
-
194
- # raw_sql = self.generate_sql(prompt)
195
-
196
- # final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
197
-
198
- # return {
199
- # "question": question,
200
- # "sql": final_sql,
201
- # "columns": cols,
202
- # "rows": rows,
203
- # "error": error
204
- # }
205
-
206
-
207
- # _engine = None
208
-
209
-
210
- # def get_engine():
211
-
212
- # global _engine
213
-
214
- # if _engine is None:
215
- # _engine = Text2SQLEngine()
216
-
217
- # return _engine
218
-
219
-
220
-
221
  import sqlite3
222
  import torch
223
  import re
@@ -226,7 +6,7 @@ from pathlib import Path
226
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
227
  from peft import PeftModel
228
  from src.sql_validator import SQLValidator
229
- from src.schema_encoder import SchemaEncoder, build_schema_graph # Added build_schema_graph
230
 
231
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
232
 
@@ -239,92 +19,6 @@ else:
239
  DB_ROOT = PROJECT_ROOT / "final_databases"
240
 
241
 
242
- # ==========================================
243
- # INPUT VALIDATION & RELEVANCE (From Code 1)
244
- # ==========================================
245
- def is_valid_question(q: str):
246
- """Extremely relaxed valid question checker. As long as there is 1 word, it passes."""
247
- words = re.findall(r"[a-zA-Z0-9]+", q)
248
- return len(words) >= 1
249
-
250
- def is_relevant_to_db(question: str, schema_graph: dict):
251
- """
252
- Lexical heuristic to block completely out-of-domain questions
253
- while allowing valid plurals.
254
- """
255
- q_words = set(re.findall(r'\b[a-z]{3,}\b', question.lower()))
256
- stop_words = {"show", "list", "all", "and", "the", "get", "find", "how", "many", "what", "where", "which", "who", "give", "display", "count", "from", "for", "with", "that", "have", "has", "are", "there"}
257
- q_words = q_words - stop_words
258
-
259
- if not q_words:
260
- return True
261
-
262
- schema_words = set()
263
- for table, cols in schema_graph.items():
264
- schema_words.update(re.findall(r'\b[a-z]{3,}\b', table.lower()))
265
- for col in cols:
266
- schema_words.update(re.findall(r'\b[a-z]{3,}\b', col.lower()))
267
-
268
- synonyms = {
269
- "customer": ["client", "buyer", "shopper", "person", "people", "user"],
270
- "employee": ["staff", "worker", "boss", "manager", "person", "people"],
271
- "track": ["song", "music", "audio", "tune"],
272
- "album": ["record", "cd", "music"],
273
- "artist": ["singer", "band", "musician", "creator"],
274
- "invoice": ["bill", "receipt", "purchase", "sale", "order", "buy", "bought", "cost"],
275
- "city": ["town", "location", "place"],
276
- "country": ["nation", "location", "place"],
277
- "flight": ["plane", "airline", "trip", "fly", "airport"],
278
- "student": ["pupil", "learner", "kid", "child"],
279
- "club": ["group", "organization", "team"],
280
- "course": ["class", "subject"],
281
- "cinema": ["movie", "film", "theater", "screen"]
282
- }
283
-
284
- extended_schema_words = set(schema_words)
285
- for db_word in schema_words:
286
- if db_word in synonyms:
287
- extended_schema_words.update(synonyms[db_word])
288
-
289
- extended_schema_words.update({"id", "name", "total", "sum", "average", "avg", "min", "max", "number", "amount", "record", "data", "info", "information", "detail", "first", "last", "most", "least", "cheapest", "expensive", "best"})
290
-
291
- # Check if the word OR its singular form is in the schema
292
- for qw in q_words:
293
- qw_singular = qw[:-1] if qw.endswith('s') else qw
294
- if qw in extended_schema_words or qw_singular in extended_schema_words:
295
- return True
296
-
297
- return False
298
-
299
- # ==========================================
300
- # SCHEMA CONSTRAINTS (From Code 1)
301
- # ==========================================
302
- def apply_schema_constraints(sql, schema_graph):
303
- sql = sql.lower()
304
-
305
- used_tables = [t[1] for t in re.findall(r'(from|join)\s+(\w+)', sql)]
306
- for t in used_tables:
307
- if t not in schema_graph:
308
- return None
309
-
310
- valid_columns = set()
311
- for cols in schema_graph.values():
312
- valid_columns.update(cols)
313
-
314
- col_blocks = re.findall(r'select\s+(.*?)\s+from', sql)
315
- for block in col_blocks:
316
- for c in block.split(","):
317
- c = c.strip().split()[-1]
318
- if "." in c:
319
- c = c.split(".")[-1]
320
-
321
- if c != "*" and "(" not in c and c != "":
322
- if c not in valid_columns:
323
- return None
324
-
325
- return sql
326
-
327
-
328
  def normalize_question(q: str):
329
  q = q.lower().strip()
330
  q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
@@ -350,8 +44,7 @@ class Text2SQLEngine:
350
  def __init__(self,
351
  adapter_path=None,
352
  base_model_name="Salesforce/codet5-base",
353
- use_lora=True,
354
- use_constrained_decoding=True): # Added constrained decoding flag
355
 
356
  self.device = "mps" if torch.backends.mps.is_available() else (
357
  "cuda" if torch.cuda.is_available() else "cpu"
@@ -359,7 +52,6 @@ class Text2SQLEngine:
359
 
360
  self.validator = SQLValidator(DB_ROOT)
361
  self.schema_encoder = SchemaEncoder(DB_ROOT)
362
- self.use_constrained_decoding = use_constrained_decoding
363
 
364
  self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
365
 
@@ -408,20 +100,28 @@ SQL:
408
  return self.schema_encoder.structured_schema(db_id)
409
 
410
  def extract_sql(self, text: str):
 
411
  text = text.strip()
 
412
  if "SQL:" in text:
413
  text = text.split("SQL:")[-1]
 
414
  match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
 
415
  if match:
416
  text = match.group(0)
 
417
  return text.split(";")[0].strip()
418
 
419
  def clean_sql(self, sql: str):
 
420
  sql = sql.replace('"', "'")
421
  sql = re.sub(r"\s+", " ", sql)
 
422
  return sql.strip()
423
 
424
  def generate_sql(self, prompt):
 
425
  inputs = self.tokenizer(
426
  prompt,
427
  return_tensors="pt",
@@ -430,6 +130,7 @@ SQL:
430
  ).to(self.device)
431
 
432
  with torch.no_grad():
 
433
  outputs = self.model.generate(
434
  **inputs,
435
  max_new_tokens=128,
@@ -438,85 +139,64 @@ SQL:
438
  )
439
 
440
  decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
441
  return self.clean_sql(self.extract_sql(decoded))
442
 
443
  def execute_sql(self, question, sql, db_id):
 
444
  if re.search(self.dml_keywords, sql, re.IGNORECASE):
445
  return sql, [], [], "❌ Security Alert"
446
 
447
- # FIXED DATABASE PATH (From Code 2)
448
  db_path = DB_ROOT / f"{db_id}.sqlite"
449
 
450
  sql = self.clean_sql(sql)
451
  sql = semantic_fix(question, sql)
452
 
453
  try:
 
454
  conn = sqlite3.connect(db_path)
 
455
  cursor = conn.cursor()
 
456
  cursor.execute(sql)
 
457
  rows = cursor.fetchall()
 
458
  columns = [d[0] for d in cursor.description] if cursor.description else []
 
459
  conn.close()
 
460
  return sql, columns, rows, None
 
461
  except Exception as e:
 
462
  return sql, [], [], str(e)
463
 
464
  def ask(self, question, db_id):
465
- # 1. Normalize
466
- question_norm = normalize_question(question)
467
- question_context = f"Database question: {question_norm}"
468
 
469
- # 2. Block dangerous inputs
470
- if re.search(self.dml_keywords, question_context, re.IGNORECASE):
 
 
471
  return {
472
- "question": question_norm,
473
  "sql": "-- BLOCKED",
474
  "columns": [],
475
  "rows": [],
476
- "error": "Malicious prompt"
477
  }
478
 
479
- # 3. Check basic validity of question
480
- if not is_valid_question(question_context):
481
- return {"sql": "", "error": "❌ Invalid input. Please type words."}
482
-
483
  schema = self.get_schema(db_id)
484
- schema_graph = build_schema_graph(schema)
485
 
486
- # 4. LEXICAL RELEVANCE GUARDRAIL
487
- if not is_relevant_to_db(question_norm, schema_graph):
488
- return {"sql": "", "error": "❌ Question is completely out of domain for the selected database."}
489
 
490
- # 5. INITIAL GENERATION
491
- prompt = self.build_prompt(question_context, schema)
492
  raw_sql = self.generate_sql(prompt)
493
 
494
- # 6. STRONGER CONSTRAINT LOGIC
495
- if self.use_constrained_decoding:
496
- filtered_sql = apply_schema_constraints(raw_sql, schema_graph)
497
-
498
- if filtered_sql is None:
499
- constraint_prompt = f"""Use ONLY valid schema.
500
- Database schema:
501
- {schema}
502
- Generate a valid SQLite query for the question.
503
- Question:
504
- {question_context}
505
- SQL:
506
- """
507
- sql_retry = self.generate_sql(constraint_prompt)
508
- filtered_sql = apply_schema_constraints(sql_retry, schema_graph)
509
-
510
- if filtered_sql:
511
- raw_sql = filtered_sql
512
- else:
513
- raw_sql = sql_retry
514
-
515
- # 7. EXECUTION
516
- final_sql, cols, rows, error = self.execute_sql(question_norm, raw_sql, db_id)
517
 
518
  return {
519
- "question": question_norm,
520
  "sql": final_sql,
521
  "columns": cols,
522
  "rows": rows,
@@ -526,338 +206,12 @@ SQL:
526
 
527
  _engine = None
528
 
529
- def get_engine(use_constrained=True): # Added parameter to control constraints
 
 
530
  global _engine
531
 
532
  if _engine is None:
533
- _engine = Text2SQLEngine(use_constrained_decoding=use_constrained)
534
 
535
  return _engine
536
-
537
- # import sqlite3
538
- # import torch
539
- # import re
540
- # import os
541
- # from pathlib import Path
542
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
543
- # from peft import PeftModel
544
- # from src.sql_validator import SQLValidator
545
- # from src.schema_encoder import SchemaEncoder # Removed build_schema_graph import
546
-
547
- # PROJECT_ROOT = Path(__file__).resolve().parents[1]
548
-
549
- # # ================================
550
- # # DATABASE PATH AUTO DETECTION
551
- # # ================================
552
- # if (PROJECT_ROOT / "data/database").exists():
553
- # DB_ROOT = PROJECT_ROOT / "data/database"
554
- # else:
555
- # DB_ROOT = PROJECT_ROOT / "final_databases"
556
-
557
-
558
- # # ==========================================
559
- # # SCHEMA PARSING
560
- # # ==========================================
561
- # def build_schema_graph(schema_text):
562
- # """
563
- # Parses a structured schema text string into a dictionary graph.
564
- # Matches formats like: table_name(col1, col2, col3)
565
- # """
566
- # tables = {}
567
- # for match in re.findall(r'(\w+)\s*\((.*?)\)', schema_text):
568
- # table = match[0]
569
- # cols = [c.strip().split()[0] for c in match[1].split(",")]
570
- # tables[table] = cols
571
- # return tables
572
-
573
-
574
- # # ==========================================
575
- # # INPUT VALIDATION & RELEVANCE
576
- # # ==========================================
577
- # def is_valid_question(q: str):
578
- # q = q.strip().lower()
579
-
580
- # if len(q) < 3:
581
- # return False
582
-
583
- # words = re.findall(r"[a-zA-Z]+", q)
584
- # if len(words) < 1:
585
- # return False
586
-
587
- # return True
588
-
589
-
590
- # def is_relevant_to_db(question: str, schema_graph: dict):
591
- # q_words = set(re.findall(r'\b[a-z]{3,}\b', question.lower()))
592
- # stop_words = {"show", "list", "all", "and", "the", "get", "find", "how", "many", "what", "where", "which", "who", "give", "display", "count", "from", "for", "with", "that", "have", "has", "are", "there"}
593
- # q_words = q_words - stop_words
594
-
595
- # if not q_words:
596
- # return True
597
-
598
- # schema_words = set()
599
- # for table, cols in schema_graph.items():
600
- # schema_words.update(re.findall(r'\b[a-z]{3,}\b', table.lower()))
601
- # for col in cols:
602
- # schema_words.update(re.findall(r'\b[a-z]{3,}\b', col.lower()))
603
-
604
- # synonyms = {
605
- # "customer": ["client", "buyer", "shopper", "person", "people", "user"],
606
- # "employee": ["staff", "worker", "boss", "manager", "person", "people"],
607
- # "track": ["song", "music", "audio", "tune"],
608
- # "album": ["record", "cd", "music"],
609
- # "artist": ["singer", "band", "musician", "creator"],
610
- # "invoice": ["bill", "receipt", "purchase", "sale", "order", "buy", "bought", "cost"],
611
- # "city": ["town", "location", "place"],
612
- # "country": ["nation", "location", "place"],
613
- # "flight": ["plane", "airline", "trip", "fly", "airport"],
614
- # "student": ["pupil", "learner", "kid", "child"],
615
- # "club": ["group", "organization", "team"],
616
- # "course": ["class", "subject"],
617
- # "cinema": ["movie", "film", "theater", "screen"]
618
- # }
619
-
620
- # extended_schema_words = set(schema_words)
621
- # for db_word in schema_words:
622
- # if db_word in synonyms:
623
- # extended_schema_words.update(synonyms[db_word])
624
-
625
- # extended_schema_words.update({"id", "name", "total", "sum", "average", "avg", "min", "max", "number", "amount", "record", "data", "info", "information", "detail", "first", "last", "most", "least", "cheapest", "expensive", "best"})
626
-
627
- # for qw in q_words:
628
- # qw_singular = qw[:-1] if qw.endswith('s') else qw
629
- # if qw in extended_schema_words or qw_singular in extended_schema_words:
630
- # return True
631
-
632
- # return False
633
-
634
- # def normalize_question(q: str):
635
- # return re.sub(r"\s+", " ", q.lower().strip())
636
-
637
- # def semantic_fix(question, sql):
638
- # q = question.lower()
639
- # num_match = re.search(r'\b(?:show|list|top|get)\s+(\d+)\b', q)
640
-
641
- # if num_match and "limit" not in sql.lower():
642
- # sql = f"{sql} LIMIT {num_match.group(1)}"
643
-
644
- # return sql
645
-
646
- # # ==========================================
647
- # # SCHEMA CONSTRAINTS (SIMULATED LOGIT BLOCKING)
648
- # # ==========================================
649
- # def apply_schema_constraints(sql, schema_graph):
650
- # sql = sql.lower()
651
-
652
- # used_tables = [t[1] for t in re.findall(r'(from|join)\s+(\w+)', sql)]
653
- # for t in used_tables:
654
- # if t not in schema_graph:
655
- # return None
656
-
657
- # valid_columns = set()
658
- # for cols in schema_graph.values():
659
- # valid_columns.update(cols)
660
-
661
- # col_blocks = re.findall(r'select\s+(.*?)\s+from', sql)
662
- # for block in col_blocks:
663
- # for c in block.split(","):
664
- # c = c.strip().split()[-1]
665
- # if "." in c:
666
- # c = c.split(".")[-1]
667
-
668
- # if c != "*" and "(" not in c and c != "":
669
- # if c not in valid_columns:
670
- # return None
671
-
672
- # return sql
673
-
674
- # # ==========================================
675
- # # ENGINE
676
- # # ==========================================
677
- # class Text2SQLEngine:
678
-
679
- # def __init__(self,
680
- # adapter_path="checkpoints/best_rlhf_model_2",
681
- # base_model_name="Salesforce/codet5-base",
682
- # use_lora=True,
683
- # use_constrained_decoding=False):
684
-
685
- # self.device = "mps" if torch.backends.mps.is_available() else (
686
- # "cuda" if torch.cuda.is_available() else "cpu"
687
- # )
688
-
689
- # self.validator = SQLValidator(DB_ROOT)
690
- # self.schema_encoder = SchemaEncoder(DB_ROOT)
691
-
692
- # self.use_constrained_decoding = use_constrained_decoding
693
- # self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate|create)\b'
694
-
695
- # print(f"\n📦 Loading model on {self.device}...")
696
-
697
- # base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
698
-
699
- # # Override the redundant special tokens to prevent the tokenizer crash
700
- # self.tokenizer = AutoTokenizer.from_pretrained(
701
- # base_model_name,
702
- # use_fast=False,
703
- # additional_special_tokens=[]
704
- # )
705
-
706
- # # 🔥 FIXED LOADA ADAPTER PATH LOGIC
707
- # if use_lora:
708
- # if adapter_path and (PROJECT_ROOT / adapter_path).exists():
709
- # adapter_path = PROJECT_ROOT / adapter_path
710
- # elif (PROJECT_ROOT / "checkpoints/best_rlhf_model_2").exists():
711
- # adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model_2"
712
- # else:
713
- # adapter_path = PROJECT_ROOT / "best_rlhf_model_2"
714
-
715
- # adapter_path = adapter_path.resolve()
716
-
717
- # if adapter_path.exists():
718
- # try:
719
- # self.model = PeftModel.from_pretrained(
720
- # base,
721
- # str(adapter_path),
722
- # local_files_only=True
723
- # ).to(self.device)
724
- # print(f"✅ LoRA loaded from {adapter_path}")
725
- # except Exception as e:
726
- # print(f"⚠️ LoRA load failed: {e}")
727
- # self.model = base.to(self.device)
728
- # else:
729
- # print(f"⚠️ Adapter not found at {adapter_path}, using base model")
730
- # self.model = base.to(self.device)
731
- # else:
732
- # self.model = base.to(self.device)
733
-
734
- # self.model.eval()
735
-
736
- # def build_prompt(self, question, schema):
737
- # return f"""
738
- # You are an expert SQL generator.
739
-
740
- # IMPORTANT:
741
- # - Use correct tables and columns
742
- # - Use JOINs when needed
743
-
744
- # Schema:
745
- # {schema}
746
-
747
- # Question:
748
- # {question}
749
-
750
- # SQL:
751
- # """
752
-
753
- # def get_schema(self, db_id):
754
- # return self.schema_encoder.structured_schema(db_id)
755
-
756
- # def extract_sql(self, text):
757
- # match = re.search(r"(select|with)[\s\S]*", text, re.IGNORECASE)
758
- # return match.group(0).split(";")[0].strip() if match else ""
759
-
760
- # def clean_sql(self, sql):
761
- # return re.sub(r"\s+", " ", sql.replace('"', "'")).strip()
762
-
763
- # def generate_sql(self, prompt):
764
- # inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
765
-
766
- # with torch.no_grad():
767
- # outputs = self.model.generate(
768
- # **inputs,
769
- # max_new_tokens=128,
770
- # num_beams=8,
771
- # length_penalty=0.8,
772
- # early_stopping=True
773
- # )
774
-
775
- # decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
776
- # return self.clean_sql(self.extract_sql(decoded))
777
-
778
- # def execute_sql(self, question, sql, db_id):
779
-
780
- # if re.search(self.dml_keywords, sql, re.IGNORECASE):
781
- # return "", [], [], "❌ Blocked malicious SQL (Contains INSERT/UPDATE/DELETE/DROP)"
782
-
783
- # # 🔥 FIXED DATABASE PATH
784
- # db_path = DB_ROOT / f"{db_id}.sqlite"
785
- # sql = semantic_fix(question, sql)
786
-
787
- # try:
788
- # conn = sqlite3.connect(db_path)
789
- # cursor = conn.cursor()
790
- # cursor.execute(sql)
791
-
792
- # rows = cursor.fetchall()
793
- # columns = [d[0] for d in cursor.description] if cursor.description else []
794
-
795
- # conn.close()
796
- # return sql, columns, rows, None
797
-
798
- # except Exception as e:
799
- # return sql, [], [], str(e)
800
-
801
- # def ask(self, question, db_id):
802
-
803
- # question = normalize_question(question)
804
- # question_context = f"Database question: {question}"
805
-
806
- # if re.search(self.dml_keywords, question_context, re.IGNORECASE):
807
- # return {"sql": "", "error": "❌ Blocked dangerous query from input text."}
808
-
809
- # if not is_valid_question(question_context):
810
- # return {"sql": "", "error": "❌ Invalid input. Please type words."}
811
-
812
- # schema = self.get_schema(db_id)
813
- # schema_graph = build_schema_graph(schema)
814
-
815
- # if not is_relevant_to_db(question, schema_graph):
816
- # return {"sql": "", "error": "❌ Question is completely out of domain for the selected database."}
817
-
818
- # sql = self.generate_sql(self.build_prompt(question_context, schema))
819
-
820
- # if self.use_constrained_decoding:
821
- # filtered_sql = apply_schema_constraints(sql, schema_graph)
822
-
823
- # if filtered_sql is None:
824
- # constraint_prompt = f"""
825
- # Use ONLY valid schema.
826
- # Schema:
827
- # {schema}
828
-
829
- # Question:
830
- # {question_context}
831
-
832
- # SQL:
833
- # """
834
- # sql_retry = self.generate_sql(constraint_prompt)
835
- # filtered_sql = apply_schema_constraints(sql_retry, schema_graph)
836
-
837
- # if filtered_sql:
838
- # sql = filtered_sql
839
- # else:
840
- # sql = sql_retry
841
-
842
- # final_sql, cols, rows, error = self.execute_sql(question_context, sql, db_id)
843
-
844
- # return {
845
- # "question": question_context,
846
- # "sql": final_sql,
847
- # "columns": cols,
848
- # "rows": rows,
849
- # "error": error
850
- # }
851
-
852
- # def get_engine(
853
- # adapter_path="checkpoints/best_rlhf_model_2",
854
- # base_model_name="Salesforce/codet5-base",
855
- # use_lora=True,
856
- # use_constrained=True
857
- # ):
858
- # return Text2SQLEngine(
859
- # adapter_path=adapter_path,
860
- # base_model_name=base_model_name,
861
- # use_lora=use_lora,
862
- # use_constrained_decoding=use_constrained
863
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import sqlite3
2
  import torch
3
  import re
 
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  from peft import PeftModel
8
  from src.sql_validator import SQLValidator
9
+ from src.schema_encoder import SchemaEncoder
10
 
11
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
 
 
19
  DB_ROOT = PROJECT_ROOT / "final_databases"
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def normalize_question(q: str):
23
  q = q.lower().strip()
24
  q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
 
44
  def __init__(self,
45
  adapter_path=None,
46
  base_model_name="Salesforce/codet5-base",
47
+ use_lora=True):
 
48
 
49
  self.device = "mps" if torch.backends.mps.is_available() else (
50
  "cuda" if torch.cuda.is_available() else "cpu"
 
52
 
53
  self.validator = SQLValidator(DB_ROOT)
54
  self.schema_encoder = SchemaEncoder(DB_ROOT)
 
55
 
56
  self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
57
 
 
100
  return self.schema_encoder.structured_schema(db_id)
101
 
102
  def extract_sql(self, text: str):
103
+
104
  text = text.strip()
105
+
106
  if "SQL:" in text:
107
  text = text.split("SQL:")[-1]
108
+
109
  match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
110
+
111
  if match:
112
  text = match.group(0)
113
+
114
  return text.split(";")[0].strip()
115
 
116
  def clean_sql(self, sql: str):
117
+
118
  sql = sql.replace('"', "'")
119
  sql = re.sub(r"\s+", " ", sql)
120
+
121
  return sql.strip()
122
 
123
  def generate_sql(self, prompt):
124
+
125
  inputs = self.tokenizer(
126
  prompt,
127
  return_tensors="pt",
 
130
  ).to(self.device)
131
 
132
  with torch.no_grad():
133
+
134
  outputs = self.model.generate(
135
  **inputs,
136
  max_new_tokens=128,
 
139
  )
140
 
141
  decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
142
+
143
  return self.clean_sql(self.extract_sql(decoded))
144
 
145
  def execute_sql(self, question, sql, db_id):
146
+
147
  if re.search(self.dml_keywords, sql, re.IGNORECASE):
148
  return sql, [], [], "❌ Security Alert"
149
 
150
+ # FIXED DATABASE PATH
151
  db_path = DB_ROOT / f"{db_id}.sqlite"
152
 
153
  sql = self.clean_sql(sql)
154
  sql = semantic_fix(question, sql)
155
 
156
  try:
157
+
158
  conn = sqlite3.connect(db_path)
159
+
160
  cursor = conn.cursor()
161
+
162
  cursor.execute(sql)
163
+
164
  rows = cursor.fetchall()
165
+
166
  columns = [d[0] for d in cursor.description] if cursor.description else []
167
+
168
  conn.close()
169
+
170
  return sql, columns, rows, None
171
+
172
  except Exception as e:
173
+
174
  return sql, [], [], str(e)
175
 
176
  def ask(self, question, db_id):
 
 
 
177
 
178
+ question = normalize_question(question)
179
+
180
+ if re.search(self.dml_keywords, question, re.IGNORECASE):
181
+
182
  return {
183
+ "question": question,
184
  "sql": "-- BLOCKED",
185
  "columns": [],
186
  "rows": [],
187
+ "error": "Malicious prompt"
188
  }
189
 
 
 
 
 
190
  schema = self.get_schema(db_id)
 
191
 
192
+ prompt = self.build_prompt(question, schema)
 
 
193
 
 
 
194
  raw_sql = self.generate_sql(prompt)
195
 
196
+ final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  return {
199
+ "question": question,
200
  "sql": final_sql,
201
  "columns": cols,
202
  "rows": rows,
 
206
 
207
  _engine = None
208
 
209
+
210
+ def get_engine():
211
+
212
  global _engine
213
 
214
  if _engine is None:
215
+ _engine = Text2SQLEngine()
216
 
217
  return _engine