tjhalanigrid commited on
Commit
b5ae35c
·
1 Parent(s): 2c420d1

fix tokenizer folder structure for custom backend

Browse files
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Text2sql Demo
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: streamlit
 
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
  python_version: 3.10.13
11
- short_description: 'to show the streamlit interface'
12
  ---
 
1
  ---
2
  title: Text2sql Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  python_version: 3.10.13
12
+ short_description: 'Text to SQL with RLHF'
13
  ---
app.py CHANGED
@@ -1,4 +1,9 @@
1
- import streamlit as st
 
 
 
 
 
2
  import pandas as pd
3
  import re
4
  import time
@@ -10,11 +15,7 @@ import subprocess
10
  import base64
11
  import io
12
  from pathlib import Path
13
-
14
- # ==========================================
15
- # PAGE CONFIG
16
- # ==========================================
17
- st.set_page_config(page_title="Text-to-SQL RLHF", layout="wide")
18
 
19
  # ==========================================
20
  # RELATIVE PATH RESOLUTION (GLOBAL)
@@ -39,45 +40,52 @@ def get_db_path(db_id: str) -> str:
39
  # ==========================================
40
  if not torch.cuda.is_available():
41
  class MockCUDAEvent:
42
- def __init__(self, enable_timing=False, blocking=False, interprocess=False): self.t = 0.0
43
- def record(self, stream=None): self.t = time.perf_counter()
44
- def elapsed_time(self, end_event): return (end_event.t - self.t) * 1000.0
 
 
 
 
45
  torch.cuda.Event = MockCUDAEvent
46
  if not hasattr(torch.cuda, 'synchronize'):
47
  torch.cuda.synchronize = lambda: None
48
 
49
  # ==========================================
50
- # IMPORTS & GLOBAL STATE
51
  # ==========================================
52
  from src.quantized_text2sql_engine import QuantizedText2SQLEngine
53
  from src.schema_encoder import SchemaEncoder
54
 
55
  DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
56
 
57
- # Use st.session_state to persist logs across UI reruns safely
58
- if '_QUERY_LOG' not in st.session_state:
59
- st.session_state._QUERY_LOG = []
60
- st.session_state._PERF_LOG = []
61
- st.session_state._SUCCESS_LOG = []
62
- st.session_state._OP_STATS = {
63
- "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
64
- "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
65
- }
66
 
67
- # 🚨 LAZY LOADING: Streamlit caches the engine so it only loads ONCE, and ONLY when called.
68
- @st.cache_resource(show_spinner=False)
69
- def load_engine_and_schema():
70
- engine = None
71
- try:
72
- engine = QuantizedText2SQLEngine(DEFAULT_QUANT_ARTIFACT, device="cpu", use_constrained=False, exec_workers=8, use_cache=True)
73
- except Exception as e:
74
- print(f"Failed to load engine: {e}")
75
- encoder = SchemaEncoder(DB_ROOT)
76
- return engine, encoder
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # ==========================================
79
- # HELPER FUNCTIONS
80
- # ==========================================
81
  SAMPLES = [
82
  ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
83
  ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
@@ -87,7 +95,6 @@ SAMPLES = [
87
  ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
88
  ]
89
  SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
90
- DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
91
 
92
  def explain_sql(sql):
93
  if not sql: return ""
@@ -138,7 +145,8 @@ def get_hint(error_type):
138
  }
139
  return hints.get(error_type, "Review query.")
140
 
141
- def is_relevant_to_schema(question, db_id, schema_encoder):
 
142
  try: raw_schema = schema_encoder.structured_schema(db_id).lower()
143
  except: return True
144
  schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
@@ -151,319 +159,401 @@ def is_relevant_to_schema(question, db_id, schema_encoder):
151
  if word in schema_words or singular_word in schema_words: return True
152
  return False
153
 
154
- def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
155
- st.session_state._QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- def _perf_log(payload: dict) -> None:
158
- st.session_state._PERF_LOG.append(payload)
159
- if len(st.session_state._PERF_LOG) > 1000: del st.session_state._PERF_LOG[:200]
160
 
161
- # ==========================================
162
- # MAIN UI
163
- # ==========================================
164
- st.markdown("""
165
- <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
166
- <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
167
- <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
168
- </div>
169
- """, unsafe_allow_html=True)
 
 
 
 
 
 
170
 
171
- tab1, tab2 = st.tabs(["Inference", "Diagnostics"])
 
 
172
 
173
- with tab1:
174
- col1, col2 = st.columns([1, 2])
 
175
 
176
- with col1:
177
- st.markdown("### 1. Configuration & Input")
178
- method = st.radio("How do you want to ask?", ["💡 Pick a Sample", "✍️ Type my own"])
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- if method == "💡 Pick a Sample":
181
- sample_q = st.selectbox("Select a Sample Question", SAMPLE_QUESTIONS)
182
- db_id = next((db for q, db in SAMPLES if q == sample_q), "chinook_1")
183
- st.text_input("Database", value=db_id, disabled=True)
184
- custom_q = ""
185
- else:
186
- db_id = st.selectbox("Select Database", DBS, index=DBS.index("chinook_1"))
187
- sample_q = ""
188
- custom_q = st.text_area("Ask your Custom Question", placeholder="Type your own question here...", height=100)
189
-
190
- # Schema Viewer
191
- _, schema_encoder = load_engine_and_schema() # Encoder loads instantly
192
- st.markdown("#### 📋 Database Structure")
193
- st.caption("Use these exact names! Table names are **Dark**, Column names are Light.")
194
- with st.container(height=250):
195
- try:
196
- st.code(schema_encoder.structured_schema(db_id), language="sql")
197
- except Exception as e:
198
- st.error(f"Error loading schema: {e}")
199
-
200
- run_btn = st.button("🚀 Generate & Run SQL", type="primary", use_container_width=True)
201
-
202
- with col2:
203
- st.markdown("### 2. Execution Results")
204
- sql_placeholder = st.empty()
205
- df_placeholder = st.empty()
206
- exp_placeholder = st.empty()
207
-
208
- if run_btn:
209
- raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
210
-
211
- if not raw_question or str(raw_question).strip() == "":
212
- exp_placeholder.warning("⚠️ Please enter a question.")
213
- st.stop()
214
-
215
- typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
216
- question = str(raw_question)
217
- for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
218
- q_lower = question.strip().lower()
219
-
220
- if len(q_lower.split()) < 2:
221
- _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
222
- exp_placeholder.warning("⚠️ Please enter a clear, meaningful natural language question (more than one word).")
223
- st.stop()
224
-
225
- if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
226
- _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
227
- exp_placeholder.error("🛑 Security Alert: Modifying or deleting data is strictly prohibited.")
228
- st.stop()
229
-
230
- if not is_relevant_to_schema(question, db_id, schema_encoder):
231
- _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
232
- exp_placeholder.error(f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema.")
233
- st.stop()
234
-
235
- start_time = time.time()
236
- t0 = time.perf_counter()
237
-
238
- # LAZY LOAD TRIGGER: We only spin up the engine here!
239
- with st.spinner("Booting AI Engine & Generating SQL..."):
240
- quant_engine, _ = load_engine_and_schema()
241
- if quant_engine is None:
242
- exp_placeholder.error("❌ CRITICAL BACKEND CRASH: Quantized engine is not available. Ensure 'int8_dynamic' folder is uploaded.")
243
- st.stop()
244
-
245
- try:
246
- result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
247
- except TypeError:
248
- result = quant_engine.ask(question, str(db_id))
249
- except Exception as e:
250
- _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
251
- exp_placeholder.error(f"❌ CRITICAL BACKEND CRASH:\n{str(e)}")
252
- st.stop()
253
-
254
- final_sql = str(result.get("sql", ""))
255
- model_sql = final_sql
256
-
257
- # Semantic limit cleaner
258
- num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
259
- if not num_match and q_lower.startswith(("show", "list", "get")):
260
- num_match = re.search(r'\b(\d+)\b', q_lower)
261
-
262
- if num_match and final_sql:
263
- limit_val = num_match.group(1)
264
- final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
265
- final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
266
- final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
267
- final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
268
-
269
- agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
270
- if not any(k in q_lower for k in agg_kws):
271
- final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
272
- final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
273
- final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
274
- final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
275
-
276
- if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
277
- final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
278
-
279
- if "limit" not in final_sql.lower():
280
- final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
281
-
282
- sql_placeholder.code(final_sql, language="sql")
283
-
284
- # Execution
285
- from src.sql_validator import validate_sql_schema
286
- db_path = get_db_path(str(db_id))
287
 
288
- try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
289
- except Exception: strict_valid = False
 
 
 
 
 
 
 
 
 
 
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  error_msg = None
292
- rows, cols = [], []
293
- sqlite_success = False
294
-
295
- with st.spinner("Executing query..."):
296
- try:
297
- rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
298
- sqlite_success = True
299
- except Exception as e:
300
- error_msg = str(e)
301
- sqlite_success = False
302
-
303
- if not sqlite_success and model_sql and model_sql != final_sql:
304
- try:
305
- alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
306
- final_sql = model_sql
307
- sql_placeholder.code(final_sql, language="sql")
308
- rows, cols = alt_rows, alt_cols
309
- sqlite_success = True
310
- error_msg = None
311
- except Exception: pass
312
-
313
- valid = sqlite_success
314
-
315
- if error_msg or not valid:
316
- et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
317
- _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
318
-
319
- latency = round(time.time() - start_time, 3)
320
- t1 = time.perf_counter()
321
- engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
322
-
323
- perf = {
324
- "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
325
- "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
326
- "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
327
- }
328
- _perf_log(perf)
329
-
330
- window = st.session_state._PERF_LOG[-50:]
331
- avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
332
- constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
333
-
334
- perf_block = (
335
- f"\n\n---\n**Performance (task impact)**\n"
336
- f"- Total latency (ms): {perf['latency_total_ms']}\n"
337
- f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
338
- f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
339
- f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
340
- f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
341
- )
342
-
343
- ops = sql_ops(final_sql)
344
-
345
- if error_msg or not valid:
346
- for op in ops:
347
- if op in st.session_state._OP_STATS: st.session_state._OP_STATS[op]["fail"] += 1
348
-
349
- error_type = classify_error(final_sql, str(error_msg or ""))
350
- explanation = f"❌ Error Details:\n\n{error_msg}\n\nError Type: {error_type}\nHint: {get_hint(error_type)}{perf_block}"
351
- exp_placeholder.error(explanation)
352
- else:
353
- safe_cols = cols if cols else ["Result"]
354
- df_placeholder.dataframe(pd.DataFrame(rows, columns=safe_cols), use_container_width=True)
355
-
356
- for op in ops:
357
- if op in st.session_state._OP_STATS: st.session_state._OP_STATS[op]["ok"] += 1
358
- st.session_state._SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
359
-
360
- explanation = f"✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
361
-
362
- limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
363
- if limit_match and len(rows) < int(limit_match.group(1)):
364
- explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
365
-
366
- exp_placeholder.info(explanation)
367
-
368
-
369
- with tab2:
370
- st.markdown("## Diagnostics & Telemetry")
371
 
372
- with st.expander("Task 1: Parallel Reward Benchmark"):
373
- st.markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
374
- t1_n = st.number_input("Rollouts (n)", value=20, step=1)
375
- t1_workers = st.number_input("Max workers", value=10, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
- if st.button("Run Task 1 benchmark"):
378
- output_container = st.empty()
379
- env = os.environ.copy()
380
- env["PYTHONPATH"] = str(PROJECT_ROOT) + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
381
- env.setdefault("MPLBACKEND", "Agg")
382
- env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
383
- os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
384
-
385
- cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(t1_n)), "--max-workers", str(int(t1_workers)), "--skip-profile"]
386
-
387
- try:
388
- proc = subprocess.Popen(cmd, cwd=str(PROJECT_ROOT), env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
389
- lines = []
390
- last_update = time.time()
391
- for line in proc.stdout:
392
- lines.append(line)
393
- if time.time() - last_update > 0.5:
394
- output_container.text("".join(lines[-50:]))
395
- last_update = time.time()
396
- proc.wait()
397
- output_container.text("".join(lines))
398
-
399
- plot_path = PROJECT_ROOT / "results" / "task1_plot.png"
400
- if plot_path.exists():
401
- st.image(str(plot_path))
402
- else:
403
- st.write("*No plot generated.*")
404
- except Exception as e:
405
- output_container.error(f"Failed to run benchmark: {e}")
406
-
407
- with st.expander("Task 2: Error Dashboard", expanded=True):
408
- st.markdown("*(Live telemetry tracking the most common SQL failures.)*")
409
- if st.button("Refresh Dashboard"):
410
- st.rerun()
411
-
412
- counts = {}
413
- for r in st.session_state._QUERY_LOG[-1000:]:
414
- k = r.get("error_type") or "other"
415
- counts[k] = counts.get(k, 0) + 1
416
-
417
- if not counts:
418
- st.write("No errors logged yet.")
419
- else:
420
- rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
421
- st.dataframe(pd.DataFrame(rows), use_container_width=True)
422
-
423
- recent = []
424
- for r in st.session_state._QUERY_LOG[-100:]:
425
- ts = r.get("t")
426
- ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
427
- recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
428
- st.dataframe(pd.DataFrame(recent), use_container_width=True)
429
-
430
- choices = [str(x["error_type"]) for x in rows]
431
- if choices:
432
- sel_type = st.selectbox("View Examples for Error Type", choices)
433
- matches = [r for r in reversed(st.session_state._QUERY_LOG) if (r.get("error_type") or "") == str(sel_type)][:3]
434
-
435
- st.write(f"**Hint:** {get_hint(sel_type)}")
436
- for i, r in enumerate(matches, 1):
437
- st.markdown(f"**Example {i}**\n* **DB:** {r.get('db_id','')}\n* **Q:** {r.get('question','')}\n* **SQL:** `{r.get('sql','')}`\n* **Msg:** {r.get('error_msg','')}")
438
-
439
- with st.expander("Task 2: Clause Telemetry"):
440
- st.markdown("*(Analyzes which specific SQL clauses are most prone to errors.)*")
441
- if st.button("Refresh SQL-op stats"):
442
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- rows = []
445
- for op, d in st.session_state._OP_STATS.items():
446
- ok, fail = int(d.get("ok", 0)), int(d.get("fail", 0))
447
- total = ok + fail
448
- rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
449
-
450
- st.dataframe(pd.DataFrame(rows), use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
- try:
453
- import matplotlib.pyplot as plt
454
- labels = list(st.session_state._OP_STATS.keys())
455
- oks = [int(st.session_state._OP_STATS[k]["ok"]) for k in labels]
456
- fails = [int(st.session_state._OP_STATS[k]["fail"]) for k in labels]
457
-
458
- fig, ax = plt.subplots(figsize=(9, 3.5))
459
- x = list(range(len(labels)))
460
- ax.bar(x, oks, label="ok", color="#16a34a")
461
- ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
462
- ax.set_xticks(x)
463
- ax.set_xticklabels(labels, rotation=30, ha="right")
464
- ax.set_title("Success/Failure by SQL operation")
465
- ax.legend()
466
- fig.tight_layout()
467
- st.pyplot(fig)
468
- except Exception as e:
469
- st.error(f"Plot error: {e}")
 
1
+ """
2
+ GRADIO DEMO UI - LAZY LOADING EDITION
3
+ NL → SQL → Result Table
4
+ """
5
+
6
+ import gradio as gr
7
  import pandas as pd
8
  import re
9
  import time
 
15
  import base64
16
  import io
17
  from pathlib import Path
18
+ from typing import Iterator
 
 
 
 
19
 
20
  # ==========================================
21
  # RELATIVE PATH RESOLUTION (GLOBAL)
 
40
  # ==========================================
41
  if not torch.cuda.is_available():
42
  class MockCUDAEvent:
43
+ def __init__(self, enable_timing=False, blocking=False, interprocess=False):
44
+ self.t = 0.0
45
+ def record(self, stream=None):
46
+ self.t = time.perf_counter()
47
+ def elapsed_time(self, end_event):
48
+ return (end_event.t - self.t) * 1000.0
49
+
50
  torch.cuda.Event = MockCUDAEvent
51
  if not hasattr(torch.cuda, 'synchronize'):
52
  torch.cuda.synchronize = lambda: None
53
 
54
  # ==========================================
55
+ # IMPORTS & ENGINE SETUP
56
  # ==========================================
57
  from src.quantized_text2sql_engine import QuantizedText2SQLEngine
58
  from src.schema_encoder import SchemaEncoder
59
 
60
  DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
61
 
62
+ _ENGINE_CACHE = {}
63
+ _QUERY_LOG = []
64
+ _PERF_LOG = []
65
+ _SUCCESS_LOG = []
 
 
 
 
 
66
 
67
+ _OP_STATS = {
68
+ "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
69
+ "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
70
+ }
71
+
72
+ def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
73
+ key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
74
+ if key not in _ENGINE_CACHE:
75
+ try:
76
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
77
+ except TypeError:
78
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
79
+ return _ENGINE_CACHE[key]
80
+
81
+ # 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
82
+ quant_engine = None
83
+ try:
84
+ schema_encoder = SchemaEncoder(DB_ROOT)
85
+ except Exception as e:
86
+ print(f"Warning: SchemaEncoder failed to load: {e}")
87
+ schema_encoder = None
88
 
 
 
 
89
  SAMPLES = [
90
  ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
91
  ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
 
95
  ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
96
  ]
97
  SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
 
98
 
99
  def explain_sql(sql):
100
  if not sql: return ""
 
145
  }
146
  return hints.get(error_type, "Review query.")
147
 
148
+ def is_relevant_to_schema(question, db_id):
149
+ if schema_encoder is None: return True
150
  try: raw_schema = schema_encoder.structured_schema(db_id).lower()
151
  except: return True
152
  schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
 
159
  if word in schema_words or singular_word in schema_words: return True
160
  return False
161
 
162
+ def run_query(method, sample_q, custom_q, db_id):
163
+ global quant_engine
164
+
165
+ # 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
166
+ if quant_engine is None:
167
+ print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
168
+ try:
169
+ quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
170
+ if quant_engine is None:
171
+ return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
172
+ except Exception as e:
173
+ return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
174
+
175
+ def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
176
+ _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
177
 
178
+ def _perf_log(payload: dict) -> None:
179
+ _PERF_LOG.append(payload)
180
+ if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
181
 
182
+ raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
183
+
184
+ if not raw_question or str(raw_question).strip() == "":
185
+ return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
186
+ if not db_id or str(db_id).strip() == "":
187
+ return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
188
+
189
+ typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
190
+ question = str(raw_question)
191
+ for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
192
+ q_lower = question.strip().lower()
193
+
194
+ if len(q_lower.split()) < 2:
195
+ _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
196
+ return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
197
 
198
+ if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
199
+ _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
200
+ return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
201
 
202
+ if not is_relevant_to_schema(question, db_id):
203
+ _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
204
+ return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
205
 
206
+ start_time = time.time()
207
+ t0 = time.perf_counter()
208
+ ui_warnings = ""
209
+
210
+ try:
211
+ try:
212
+ result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
213
+ except TypeError:
214
+ result = quant_engine.ask(question, str(db_id))
215
+ except Exception as e:
216
+ _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
217
+ return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
218
+
219
+ final_sql = str(result.get("sql", ""))
220
+ model_sql = final_sql
221
 
222
+ num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
223
+ if not num_match and q_lower.startswith(("show", "list", "get")):
224
+ num_match = re.search(r'\b(\d+)\b', q_lower)
225
+
226
+ if num_match and final_sql:
227
+ limit_val = num_match.group(1)
228
+ final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
229
+ final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
230
+ final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
231
+ final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
232
+
233
+ agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
234
+ if not any(k in q_lower for k in agg_kws):
235
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
236
+ final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
237
+ final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
238
+ final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
239
+
240
+ if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
241
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ if "limit" not in final_sql.lower():
244
+ final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
245
+
246
+ # Execution
247
+ from src.sql_validator import validate_sql_schema
248
+ db_path = get_db_path(str(db_id))
249
+
250
+ try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
251
+ except Exception: strict_valid = False
252
+
253
+ error_msg = None
254
+ rows, cols = [], []
255
+ sqlite_success = False
256
 
257
+ try:
258
+ rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
259
+ sqlite_success = True
260
+ except Exception as e:
261
+ error_msg = str(e)
262
+ sqlite_success = False
263
+
264
+ if not sqlite_success and model_sql and model_sql != final_sql:
265
+ try:
266
+ alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
267
+ final_sql = model_sql
268
+ rows, cols = alt_rows, alt_cols
269
+ sqlite_success = True
270
  error_msg = None
271
+ except Exception: pass
272
+
273
+ valid = sqlite_success
274
+
275
+ if error_msg or not valid:
276
+ et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
277
+ _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
278
+
279
+ latency = round(time.time() - start_time, 3)
280
+ t1 = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
283
+
284
+ perf = {
285
+ "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
286
+ "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
287
+ "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
288
+ }
289
+ _perf_log(perf)
290
+
291
+ window = _PERF_LOG[-50:]
292
+ avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
293
+ constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
294
+
295
+ perf_block = (
296
+ "\n\n---\nPerformance (task impact)\n"
297
+ f"- Total latency (ms): {perf['latency_total_ms']}\n"
298
+ f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
299
+ f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
300
+ f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
301
+ f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
302
+ )
303
+
304
+ if error_msg or not valid:
305
+ display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
306
+ explanation = f"{ui_warnings}❌ Error Details:\n\n"
307
+ if error_msg: explanation += f"{error_msg}\n\n"
308
 
309
+ error_type = classify_error(final_sql, str(error_msg or ""))
310
+ explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
311
+ explanation += perf_block
312
+ ops = sql_ops(final_sql)
313
+ for op in ops:
314
+ if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
315
+ return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
316
+
317
+ safe_cols = cols if cols else ["Result"]
318
+ explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
319
+
320
+ ops = sql_ops(final_sql)
321
+ for op in ops:
322
+ if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
323
+ _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
324
+
325
+ limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
326
+ if limit_match and len(rows) < int(limit_match.group(1)):
327
+ explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
328
+
329
+ return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
330
+
331
+ def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
332
+ project_root = str(PROJECT_ROOT)
333
+ env = os.environ.copy()
334
+ env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
335
+ env.setdefault("MPLBACKEND", "Agg")
336
+ env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
337
+ try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
338
+ except Exception: pass
339
+
340
+ cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
341
+ proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
342
+ last_yield = time.perf_counter()
343
+ lines: list[str] = []
344
+ yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
345
+
346
+ assert proc.stdout is not None
347
+ for line in proc.stdout:
348
+ lines.append(line)
349
+ now = time.perf_counter()
350
+ if now - last_yield >= 0.5:
351
+ last_yield = now
352
+ yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
353
+
354
+ proc.wait()
355
+ out = "".join(lines).strip()
356
+
357
+ plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
358
+ if os.path.exists(plot_path):
359
+ try:
360
+ b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
361
+ yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
362
+ return
363
+ except Exception:
364
+ yield out, f"<pre>{plot_path}</pre>"
365
+ return
366
+
367
+ yield out, "<i>No plot generated</i>"
368
+
369
+ def task2_dashboard_structured():
370
+ if not _QUERY_LOG:
371
+ empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
372
+ empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
373
+ return empty_counts, empty_recent, gr.update(choices=[], value=None)
374
+
375
+ counts = {}
376
+ for r in _QUERY_LOG[-1000:]:
377
+ k = r.get("error_type") or "other"
378
+ counts[k] = counts.get(k, 0) + 1
379
+ rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
380
+ counts_df = pd.DataFrame(rows)
381
+
382
+ recent = []
383
+ for r in _QUERY_LOG[-100:]:
384
+ ts = r.get("t")
385
+ try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
386
+ except Exception: ts_s = ""
387
+ recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
388
+ recent_df = pd.DataFrame(recent)
389
+
390
+ choices = [str(x["error_type"]) for x in rows]
391
+ default = choices[0] if choices else None
392
+ return counts_df, recent_df, gr.update(choices=choices, value=default)
393
+
394
+ def task2_error_examples(error_type: str) -> str:
395
+ if not error_type: return ""
396
+ hint = get_hint(error_type)
397
+ matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
398
+ if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
399
+ out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
400
+ for i, r in enumerate(matches, 1):
401
+ out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
402
+ return "\n".join(out).strip()
403
+
404
+ def _plot_op_stats_html() -> str:
405
+ try:
406
+ import matplotlib.pyplot as plt
407
+ labels = list(_OP_STATS.keys())
408
+ oks = [int(_OP_STATS[k]["ok"]) for k in labels]
409
+ fails = [int(_OP_STATS[k]["fail"]) for k in labels]
410
+
411
+ fig, ax = plt.subplots(figsize=(9, 3.5))
412
+ x = list(range(len(labels)))
413
+ ax.bar(x, oks, label="ok", color="#16a34a")
414
+ ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
415
+ ax.set_xticks(x)
416
+ ax.set_xticklabels(labels, rotation=30, ha="right")
417
+ ax.set_title("Success/Failure by SQL operation")
418
+ ax.legend()
419
+ fig.tight_layout()
420
+
421
+ buf = io.BytesIO()
422
+ fig.savefig(buf, format="png", dpi=160)
423
+ plt.close(fig)
424
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
425
+ return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
426
+ except Exception as e: return f"<pre>Plot error: {e}</pre>"
427
+
428
+ def task2_ops_table():
429
+ rows = []
430
+ for op, d in _OP_STATS.items():
431
+ ok = int(d.get("ok", 0))
432
+ fail = int(d.get("fail", 0))
433
+ total = ok + fail
434
+ rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
435
+ return pd.DataFrame(rows), _plot_op_stats_html()
436
+
437
+ def toggle_input_method(method, current_sample):
438
+ if method == "💡 Pick a Sample":
439
+ db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
440
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
441
+ return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
442
+
443
+ def load_sample(selected_question):
444
+ if not selected_question: return gr.update()
445
+ return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
446
+
447
+ def clear_inputs():
448
+ return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
449
+
450
+ def update_schema(db_id):
451
+ if not db_id or schema_encoder is None: return ""
452
+ try:
453
+ raw_schema = schema_encoder.structured_schema(db_id)
454
+ html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
455
+ for line in raw_schema.strip().split('\n'):
456
+ line = line.strip()
457
+ if not line: continue
458
+ match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
459
+ if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
460
+ else: html_output += f"<div style='color: #475569;'>{line}</div>"
461
+ html_output += "</div>"
462
+ return html_output
463
+ except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
464
+
465
+ # =========================
466
+ # UI LAYOUT
467
+ # =========================
468
+ with gr.Blocks(title="Text-to-SQL RLHF") as demo:
469
+ gr.HTML("""
470
+ <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
471
+ <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
472
+ <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
473
+ </div>
474
+ """)
475
+
476
+ DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
477
+
478
+ with gr.Tabs():
479
+ with gr.Tab("Inference"):
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ gr.Markdown("### 1. Configuration & Input")
483
+ input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
484
+ sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
485
+ type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
486
+ gr.Markdown("---")
487
+ db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
488
+ custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
489
+
490
+ gr.Markdown("#### 📋 Database Structure")
491
+ gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
492
+ schema_display = gr.HTML(value=update_schema("chinook_1"))
493
+
494
+ with gr.Row():
495
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
496
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
497
+
498
+ with gr.Column(scale=2):
499
+ gr.Markdown("### 2. Execution Results")
500
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
501
+ result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
502
+ explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
503
+
504
+ with gr.Tab("Diagnostics"):
505
+ gr.Markdown("## Diagnostics & Telemetry")
506
 
507
+ with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
508
+ gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
509
+ t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
510
+ t1_workers = gr.Number(value=10, precision=0, label="Max workers")
511
+ t1_run = gr.Button("Run Task 1 benchmark")
512
+ t1_out = gr.Textbox(label="Output", lines=12)
513
+ t1_plot = gr.HTML(label="Plot (if generated)")
514
+ t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
515
+
516
+ with gr.Accordion("Task 2: Error Dashboard", open=True):
517
+ gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
518
+ t2_refresh = gr.Button("Refresh dashboard")
519
+ t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
520
+ t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
521
+ t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
522
+ t2_examples = gr.Textbox(label="Examples + hint", lines=10)
523
+
524
+ t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
525
+ t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
526
+
527
+ with gr.Accordion("Task 2: Clause Telemetry", open=False):
528
+ gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
529
+ t2_ops_refresh = gr.Button("Refresh SQL-op stats")
530
+ t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
531
+ t2_ops_plot = gr.HTML(label="Op plot")
532
+ t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
533
+
534
+ # EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
535
+ input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
536
+ sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
537
+ db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
538
+
539
+ run_btn.click(
540
+ fn=run_query,
541
+ inputs=[input_method, sample_dropdown, custom_question, db_id],
542
+ outputs=[final_sql, result_table, explanation]
543
+ ).then(
544
+ fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
545
+ ).then(
546
+ fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
547
+ )
548
+
549
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
550
 
551
+ if __name__ == "__main__":
552
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
553
+ server_port = 7860
554
+
555
+ print(f"Starting Gradio UI on {server_name}:{server_port}...", flush=True)
556
+ try:
557
+ demo.launch(server_name=server_name, server_port=server_port, ssr_mode=False)
558
+ except TypeError:
559
+ demo.launch(server_name=server_name, server_port=server_port)
 
 
 
 
 
 
 
 
 
int8_dynamic/{merges.txt → tokenizer/merges.txt} RENAMED
File without changes
int8_dynamic/{special_tokens_map.json → tokenizer/special_tokens_map.json} RENAMED
File without changes
int8_dynamic/{tokenizer.json → tokenizer/tokenizer.json} RENAMED
File without changes
int8_dynamic/{tokenizer_config.json → tokenizer/tokenizer_config.json} RENAMED
File without changes
int8_dynamic/{vocab.json → tokenizer/vocab.json} RENAMED
File without changes
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  gradio==5.8.0
2
- streamlit
3
  pandas
4
  sqlparse
5
  transformers
@@ -8,5 +7,4 @@ peft
8
  trl
9
  sentencepiece
10
  matplotlib
11
- huggingface_hub
12
-
 
1
  gradio==5.8.0
 
2
  pandas
3
  sqlparse
4
  transformers
 
7
  trl
8
  sentencepiece
9
  matplotlib
10
+ huggingface_hub