tjhalanigrid commited on
Commit
30e149a
·
0 Parent(s):

clean repo without LFS and binaries

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +12 -0
  3. README.md +14 -0
  4. app.py +651 -0
  5. final_databases/academic.sqlite +3 -0
  6. final_databases/activity_1.sqlite +3 -0
  7. final_databases/aircraft.sqlite +3 -0
  8. final_databases/allergy_1.sqlite +3 -0
  9. final_databases/apartment_rentals.sqlite +3 -0
  10. final_databases/architecture.sqlite +3 -0
  11. final_databases/assets_maintenance.sqlite +3 -0
  12. final_databases/battle_death.sqlite +3 -0
  13. final_databases/behavior_monitoring.sqlite +3 -0
  14. final_databases/bike_1.sqlite +3 -0
  15. final_databases/body_builder.sqlite +3 -0
  16. final_databases/book_2.sqlite +3 -0
  17. final_databases/browser_web.sqlite +3 -0
  18. final_databases/candidate_poll.sqlite +3 -0
  19. final_databases/car_1.sqlite +3 -0
  20. final_databases/chinook_1.sqlite +3 -0
  21. final_databases/cinema.sqlite +3 -0
  22. final_databases/city_record.sqlite +3 -0
  23. final_databases/climbing.sqlite +3 -0
  24. final_databases/club_1.sqlite +3 -0
  25. final_databases/coffee_shop.sqlite +3 -0
  26. final_databases/college_1.sqlite +3 -0
  27. final_databases/college_2.sqlite +3 -0
  28. final_databases/college_3.sqlite +3 -0
  29. final_databases/company_1.sqlite +3 -0
  30. final_databases/company_employee.sqlite +3 -0
  31. final_databases/company_office.sqlite +3 -0
  32. final_databases/concert_singer.sqlite +3 -0
  33. final_databases/county_public_safety.sqlite +3 -0
  34. final_databases/course_teach.sqlite +3 -0
  35. final_databases/cre_Doc_Control_Systems.sqlite +3 -0
  36. final_databases/cre_Doc_Template_Mgt.sqlite +3 -0
  37. final_databases/cre_Doc_Tracking_DB.sqlite +3 -0
  38. final_databases/cre_Docs_and_Epenses.sqlite +3 -0
  39. final_databases/cre_Drama_Workshop_Groups.sqlite +3 -0
  40. final_databases/cre_Theme_park.sqlite +3 -0
  41. final_databases/csu_1.sqlite +3 -0
  42. final_databases/culture_company.sqlite +3 -0
  43. final_databases/customer_complaints.sqlite +3 -0
  44. final_databases/customer_deliveries.sqlite +3 -0
  45. final_databases/customers_and_addresses.sqlite +3 -0
  46. final_databases/customers_and_invoices.sqlite +3 -0
  47. final_databases/customers_and_products_contacts.sqlite +3 -0
  48. final_databases/customers_campaigns_ecommerce.sqlite +3 -0
  49. final_databases/customers_card_transactions.sqlite +3 -0
  50. final_databases/debate.sqlite +3 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.sqlite filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .DS_Store
4
+ # checkpoints/milestone_before_more_dbs
5
+ checkpoints/best_rlhf_codet5_soft
6
+ checkpoints/best_rlhf_model
7
+ results/
8
+ *.png
9
+ *.pt
10
+ *.bin
11
+ *.safetensors
12
+ *.zip
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text2sql Demo
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 6.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: 'to show the gradio interface '
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRADIO DEMO UI
3
+ NL → SQL → Result Table
4
+ """
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import re
9
+ import time
10
+ import os
11
+ import torch
12
+ import sys
13
+ import json
14
+ import subprocess
15
+ import base64
16
+ from pathlib import Path
17
+ from typing import Iterator
18
+ import io
19
+
20
+ import zipfile
21
+
22
+ MODEL_DIR = "int8_dynamic"
23
+
24
+ if not os.path.exists(MODEL_DIR):
25
+ if os.path.exists("int8_model.zip"):
26
+ print("Extracting model...")
27
+ with zipfile.ZipFile("int8_model.zip", 'r') as zip_ref:
28
+ zip_ref.extractall(".")
29
+ else:
30
+ raise FileNotFoundError("Model zip not found!")
31
+ # ==========================================
32
+ # 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
33
+ # ==========================================
34
+ if not torch.cuda.is_available():
35
+ class MockCUDAEvent:
36
+ def __init__(self, enable_timing=False, blocking=False, interprocess=False):
37
+ self.t = 0.0
38
+ def record(self, stream=None):
39
+ self.t = time.perf_counter()
40
+ def elapsed_time(self, end_event):
41
+ return (end_event.t - self.t) * 1000.0
42
+
43
+ torch.cuda.Event = MockCUDAEvent
44
+ if not hasattr(torch.cuda, 'synchronize'):
45
+ torch.cuda.synchronize = lambda: None
46
+
47
+ # ==========================================
48
+ # RELATIVE PATH RESOLUTION (GLOBAL)
49
+ # ==========================================
50
+ try:
51
+ PROJECT_ROOT = Path(__file__).resolve().parent
52
+ except NameError:
53
+ PROJECT_ROOT = Path(".").resolve()
54
+
55
+ if (PROJECT_ROOT / "data" / "database").exists():
56
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
57
+ else:
58
+ DB_ROOT = PROJECT_ROOT / "final_databases"
59
+
60
+ def get_db_path(db_id: str) -> str:
61
+ path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
62
+ path2 = DB_ROOT / f"{db_id}.sqlite"
63
+ return str(path1) if path1.exists() else str(path2)
64
+
65
+ # ==========================================
66
+ # IMPORTS & ENGINE SETUP
67
+ # ==========================================
68
+ from src.quantized_text2sql_engine import QuantizedText2SQLEngine
69
+ from src.schema_encoder import SchemaEncoder
70
+
71
+ fallback_adapter = str(PROJECT_ROOT / "best_rlhf_model_2")
72
+ if not os.path.exists(fallback_adapter):
73
+ fallback_adapter = str(PROJECT_ROOT / "sft_adapter_codet5")
74
+
75
+ adapter_path = os.environ.get("TEXT2SQL_ADAPTER_PATH", fallback_adapter)
76
+ base_model_name = os.environ.get("TEXT2SQL_BASE_MODEL", "Salesforce/codet5-base")
77
+ use_lora_env = os.environ.get("TEXT2SQL_USE_LORA", "true").strip().lower()
78
+ use_lora = use_lora_env not in {"0", "false", "no"}
79
+
80
+ DEFAULT_QUANT_ARTIFACT = os.environ.get("TEXT2SQL_QUANT_ARTIFACT", str(PROJECT_ROOT / "int8_dynamic")).strip()
81
+ if not DEFAULT_QUANT_ARTIFACT:
82
+ DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
83
+
84
+ _ENGINE_CACHE = {}
85
+ _QUERY_LOG = []
86
+ _PERF_LOG = []
87
+ _SUCCESS_LOG = []
88
+
89
+ _OP_STATS = {
90
+ "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
91
+ "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
92
+ }
93
+
94
+ def get_quant_engine(artifact_dir: str, use_constrained: bool, exec_workers: int = 8, use_cache: bool = True):
95
+ key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
96
+ if key not in _ENGINE_CACHE:
97
+ try:
98
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
99
+ except TypeError:
100
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
101
+ return _ENGINE_CACHE[key]
102
+
103
+ try:
104
+ quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
105
+ except Exception:
106
+ quant_engine = None
107
+
108
+ schema_encoder = SchemaEncoder(DB_ROOT)
109
+
110
+ SAMPLES = [
111
+ ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
112
+ ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
113
+ ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
114
+ ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
115
+ ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
116
+ ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
117
+ ]
118
+ SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
119
+
120
+ def explain_sql(sql):
121
+ if not sql: return ""
122
+ explanation = "This SQL query retrieves information from the database."
123
+ sql_lower = sql.lower()
124
+ if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
125
+ if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
126
+ if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
127
+ if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
128
+ if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
129
+ return explanation
130
+
131
+ def sql_ops(sql: str) -> list[str]:
132
+ s = (sql or "").lower()
133
+ ops = ["SELECT"]
134
+ if " where " in f" {s} ": ops.append("WHERE")
135
+ if " join " in f" {s} ": ops.append("JOIN")
136
+ if " group by " in f" {s} ": ops.append("GROUP_BY")
137
+ if " order by " in f" {s} ": ops.append("ORDER_BY")
138
+ if " having " in f" {s} ": ops.append("HAVING")
139
+ if " limit " in f" {s} ": ops.append("LIMIT")
140
+ return ops
141
+
142
+ def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
143
+ s = (sql or "").lower()
144
+ m = (error_msg or "").lower()
145
+ if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
146
+ if not s.strip().startswith(("select", "with")): return "syntax_error"
147
+ if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
148
+ if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
149
+ if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
150
+ if "no such table" in m: return "missing_table"
151
+ if "no such column" in m: return "missing_column"
152
+ if "ambiguous column name" in m: return "ambiguous_column"
153
+ if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
154
+ if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
155
+ if "syntax error" in m: return "syntax_error"
156
+ if "near" in m and "syntax error" in m: return "syntax_error"
157
+ if "runtime" in m or "constraint failed" in m: return "runtime_error"
158
+ return "other"
159
+
160
+ def get_hint(error_type):
161
+ hints = {
162
+ "missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
163
+ "wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
164
+ "missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
165
+ "ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
166
+ }
167
+ return hints.get(error_type, "Review query.")
168
+
169
+ def is_relevant_to_schema(question, db_id):
170
+ try: raw_schema = schema_encoder.structured_schema(db_id).lower()
171
+ except: return True
172
+ schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
173
+ q_words = re.findall(r'[a-z0-9_]+', question.lower())
174
+ stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
175
+ meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
176
+ if not meaningful_q_words: return True
177
+ for word in meaningful_q_words:
178
+ singular_word = word[:-1] if word.endswith('s') else word
179
+ if word in schema_words or singular_word in schema_words: return True
180
+ return False
181
+
182
+ def run_query(method, sample_q, custom_q, db_id):
183
+ # HARDCODED INFERENCE DEFAULTS
184
+ quant_artifact_dir = DEFAULT_QUANT_ARTIFACT
185
+ use_constrained_decoding = False
186
+ gen_beams = 4
187
+ gen_max_new_tokens = 120
188
+ exec_timeout_s = 2.0
189
+ exec_workers = 8
190
+ exec_cache_on = True
191
+
192
+ def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
193
+ _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
194
+
195
+ def _perf_log(payload: dict) -> None:
196
+ _PERF_LOG.append(payload)
197
+ if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
198
+
199
+ raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
200
+
201
+ # 1. EMPTY CHECK
202
+ if not raw_question or str(raw_question).strip() == "":
203
+ return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
204
+ if not db_id or str(db_id).strip() == "":
205
+ return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
206
+
207
+ typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
208
+ question = str(raw_question)
209
+ for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
210
+ q_lower = question.strip().lower()
211
+
212
+ # 🔥 FIX 1: STRICTER GIBBERISH FILTER
213
+ # Blocks single-word nonsensical inputs like "wdasefgbn"
214
+ if len(q_lower.split()) < 2:
215
+ _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
216
+ return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
217
+
218
+ # 3. DML (DELETE/UPDATE) BLOCKER
219
+ if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
220
+ _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
221
+ return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
222
+
223
+ # 4. OUT OF DOMAIN
224
+ if not is_relevant_to_schema(question, db_id):
225
+ _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
226
+ return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
227
+
228
+ start_time = time.time()
229
+ t0 = time.perf_counter()
230
+ ui_warnings = ""
231
+
232
+ try:
233
+ engine = quant_engine
234
+ if quant_artifact_dir and str(quant_artifact_dir).strip():
235
+ engine = get_quant_engine(str(quant_artifact_dir).strip(), bool(use_constrained_decoding), exec_workers=int(exec_workers), use_cache=bool(exec_cache_on))
236
+ if engine is None: raise RuntimeError("Quantized engine is not available.")
237
+ try:
238
+ result = engine.ask(question, str(db_id), num_beams=int(gen_beams), max_new_tokens=int(gen_max_new_tokens), timeout_s=float(exec_timeout_s))
239
+ except TypeError:
240
+ result = engine.ask(question, str(db_id))
241
+
242
+ except Exception as e:
243
+ _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
244
+ return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
245
+
246
+ final_sql = str(result.get("sql", ""))
247
+ model_sql = final_sql
248
+
249
+ # 🔥 FIX 2: ADVANCED SEMANTIC LIMIT CLEANER
250
+ num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
251
+ if not num_match and q_lower.startswith(("show", "list", "get")):
252
+ num_match = re.search(r'\b(\d+)\b', q_lower)
253
+
254
+ if num_match and final_sql:
255
+ limit_val = num_match.group(1)
256
+
257
+ # 1. Strip hallucinated count(*) = X in WHERE clauses
258
+ final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
259
+
260
+ # 2. Strip hallucinated WHERE column = '5' (Fixes "show firstname of 5 employees")
261
+ final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
262
+ final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql) # Cleanup empty where
263
+ final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql) # Cleanup dangling where before other clauses
264
+
265
+ # 3. Strip unwarranted complex Groupings
266
+ agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
267
+ if not any(k in q_lower for k in agg_kws):
268
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
269
+ final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
270
+ final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
271
+ final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
272
+
273
+ if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
274
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
275
+
276
+ # 4. Append limits safely
277
+ if "limit" not in final_sql.lower():
278
+ final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
279
+
280
+ # =====================================================================
281
+ # 🔥 ROBUST SQLITE EXECUTION
282
+ # =====================================================================
283
+ from src.sql_validator import validate_sql_schema
284
+ db_path = get_db_path(str(db_id))
285
+
286
+ try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
287
+ except Exception: strict_valid = False
288
+
289
+ error_msg = None
290
+ rows, cols = [], []
291
+ sqlite_success = False
292
+
293
+ if final_sql and engine is not None:
294
+ try:
295
+ rows, cols = engine._execute_one(final_sql, db_path, timeout_s=float(exec_timeout_s))
296
+ sqlite_success = True
297
+ except Exception as e:
298
+ error_msg = str(e)
299
+ sqlite_success = False
300
+
301
+ if not sqlite_success and model_sql and model_sql != final_sql and engine is not None:
302
+ try:
303
+ alt_rows, alt_cols = engine._execute_one(model_sql, db_path, timeout_s=float(exec_timeout_s))
304
+ final_sql = model_sql
305
+ rows, cols = alt_rows, alt_cols
306
+ sqlite_success = True
307
+ error_msg = None
308
+ except Exception: pass
309
+
310
+ valid = sqlite_success
311
+
312
+ if error_msg or not valid:
313
+ et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
314
+ _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
315
+
316
+ latency = round(time.time() - start_time, 3)
317
+ t1 = time.perf_counter()
318
+
319
+ engine_stats_after = engine.stats() if hasattr(engine, 'stats') else {}
320
+
321
+ perf = {
322
+ "db_id": str(db_id), "use_constrained_decoding": bool(use_constrained_decoding), "num_beams": int(gen_beams),
323
+ "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
324
+ "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
325
+ }
326
+ _perf_log(perf)
327
+
328
+ window = _PERF_LOG[-50:]
329
+ avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
330
+ constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
331
+
332
+ perf_block = (
333
+ "\n\n---\nPerformance (task impact)\n"
334
+ f"- Total latency (ms): {perf['latency_total_ms']}\n"
335
+ f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
336
+ f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
337
+ f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
338
+ f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
339
+ )
340
+
341
+ if error_msg or not valid:
342
+ display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
343
+ explanation = f"{ui_warnings}❌ Error Details:\n\n"
344
+ if error_msg: explanation += f"{error_msg}\n\n"
345
+
346
+ error_type = classify_error(final_sql, str(error_msg or ""))
347
+ explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
348
+ explanation += perf_block
349
+ ops = sql_ops(final_sql)
350
+ for op in ops:
351
+ if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
352
+ return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
353
+
354
+ safe_cols = cols if cols else ["Result"]
355
+ explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
356
+
357
+ ops = sql_ops(final_sql)
358
+ for op in ops:
359
+ if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
360
+ _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
361
+
362
+ limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
363
+ if limit_match and len(rows) < int(limit_match.group(1)):
364
+ explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
365
+
366
+ return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
367
+
368
+ def _run_cmd(cmd: list[str], env: dict | None = None) -> str:
369
+ run_env = (env or os.environ.copy()).copy()
370
+ project_root = str(PROJECT_ROOT)
371
+ run_env["PYTHONPATH"] = project_root + (os.pathsep + run_env["PYTHONPATH"] if run_env.get("PYTHONPATH") else "")
372
+ res = subprocess.run(cmd, capture_output=True, text=True, env=run_env, cwd=project_root)
373
+ out = (res.stdout or "") + ("\n" + res.stderr if res.stderr else "")
374
+ return out.strip()
375
+
376
+ def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
377
+ project_root = str(PROJECT_ROOT)
378
+ env = os.environ.copy()
379
+ env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
380
+ env.setdefault("MPLBACKEND", "Agg")
381
+ env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
382
+ try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
383
+ except Exception: pass
384
+
385
+ cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
386
+ proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
387
+ last_yield = time.perf_counter()
388
+ lines: list[str] = []
389
+ yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
390
+
391
+ assert proc.stdout is not None
392
+ for line in proc.stdout:
393
+ lines.append(line)
394
+ now = time.perf_counter()
395
+ if now - last_yield >= 0.5:
396
+ last_yield = now
397
+ yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
398
+
399
+ proc.wait()
400
+ out = "".join(lines).strip()
401
+
402
+ plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
403
+ if os.path.exists(plot_path):
404
+ try:
405
+ b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
406
+ yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
407
+ return
408
+ except Exception:
409
+ yield out, f"<pre>{plot_path}</pre>"
410
+ return
411
+
412
+ yield out, "<i>No plot generated</i>"
413
+
414
+ def task2_dashboard_structured():
415
+ if not _QUERY_LOG:
416
+ empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
417
+ empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
418
+ return empty_counts, empty_recent, gr.update(choices=[], value=None)
419
+
420
+ counts = {}
421
+ for r in _QUERY_LOG[-1000:]:
422
+ k = r.get("error_type") or "other"
423
+ counts[k] = counts.get(k, 0) + 1
424
+ rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
425
+ counts_df = pd.DataFrame(rows)
426
+
427
+ recent = []
428
+ for r in _QUERY_LOG[-100:]:
429
+ ts = r.get("t")
430
+ try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
431
+ except Exception: ts_s = ""
432
+ recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
433
+ recent_df = pd.DataFrame(recent)
434
+
435
+ choices = [str(x["error_type"]) for x in rows]
436
+ default = choices[0] if choices else None
437
+ return counts_df, recent_df, gr.update(choices=choices, value=default)
438
+
439
+ def task2_error_examples(error_type: str) -> str:
440
+ if not error_type: return ""
441
+ hint = get_hint(error_type)
442
+ matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
443
+ if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
444
+ out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
445
+ for i, r in enumerate(matches, 1):
446
+ out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
447
+ return "\n".join(out).strip()
448
+
449
+ def _plot_op_stats_html() -> str:
450
+ try:
451
+ import matplotlib.pyplot as plt
452
+ labels = list(_OP_STATS.keys())
453
+ oks = [int(_OP_STATS[k]["ok"]) for k in labels]
454
+ fails = [int(_OP_STATS[k]["fail"]) for k in labels]
455
+
456
+ fig, ax = plt.subplots(figsize=(9, 3.5))
457
+ x = list(range(len(labels)))
458
+ ax.bar(x, oks, label="ok", color="#16a34a")
459
+ ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
460
+ ax.set_xticks(x)
461
+ ax.set_xticklabels(labels, rotation=30, ha="right")
462
+ ax.set_title("Success/Failure by SQL operation")
463
+ ax.legend()
464
+ fig.tight_layout()
465
+
466
+ buf = io.BytesIO()
467
+ fig.savefig(buf, format="png", dpi=160)
468
+ plt.close(fig)
469
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
470
+ return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
471
+ except Exception as e: return f"<pre>Plot error: {e}</pre>"
472
+
473
+ def task2_ops_table():
474
+ rows = []
475
+ for op, d in _OP_STATS.items():
476
+ ok = int(d.get("ok", 0))
477
+ fail = int(d.get("fail", 0))
478
+ total = ok + fail
479
+ rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
480
+ return pd.DataFrame(rows), _plot_op_stats_html()
481
+
482
+ def run_adversarial_suite():
483
+ quant_artifact_dir = DEFAULT_QUANT_ARTIFACT
484
+ use_constrained = False
485
+ gen_beams = 4
486
+ gen_max_new = 120
487
+ exec_timeout_s = 2.0
488
+ exec_workers = 8
489
+ exec_cache_on = True
490
+
491
+ engine = quant_engine
492
+ if quant_artifact_dir and str(quant_artifact_dir).strip():
493
+ engine = get_quant_engine(str(quant_artifact_dir).strip(), bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(exec_cache_on))
494
+ if engine is None: return pd.DataFrame(columns=["name", "db_id", "expected", "got", "status"]), "Engine unavailable."
495
+
496
+ ADVERSARIAL_CASES = [
497
+ {"name": "Missing JOIN", "db_id": "chinook_1", "question": "List track names and their artist names.", "expected": "missing_join"},
498
+ {"name": "Wrong WHERE", "db_id": "flight_1", "question": "Show flights that cost more than 500.", "expected": "wrong_where"},
499
+ {"name": "NULL handling", "db_id": "student_1", "question": "List students with no advisor.", "expected": "null_handling"},
500
+ {"name": "Type mismatch", "db_id": "store_1", "question": "List orders where order_id equals 'abc'.", "expected": "type_mismatch"},
501
+ ]
502
+
503
+ out_rows = []
504
+ for c in ADVERSARIAL_CASES:
505
+ dbid = c["db_id"]
506
+ q = c["question"]
507
+ expected = c.get("expected", "")
508
+ try:
509
+ try: res = engine.ask(q, dbid, num_beams=int(gen_beams), max_new_tokens=int(gen_max_new), timeout_s=float(exec_timeout_s))
510
+ except TypeError: res = engine.ask(q, dbid)
511
+ sql = str(res.get("sql", "") or "")
512
+ err = res.get("error", None)
513
+ got = classify_error(sql, str(err or "")) if err else ""
514
+ status = "ok" if not err else "fail"
515
+ except Exception as e:
516
+ got, status = "backend_crash", "crash"
517
+ out_rows.append({"name": c["name"], "db_id": dbid, "expected": expected, "got": got, "status": status})
518
+
519
+ df = pd.DataFrame(out_rows)
520
+ return df, json.dumps({"summary": df["status"].value_counts().to_dict()}, indent=2)
521
+
522
+ def toggle_input_method(method, current_sample):
523
+ if method == "💡 Pick a Sample":
524
+ db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
525
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
526
+ return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
527
+
528
+ def load_sample(selected_question):
529
+ if not selected_question: return gr.update()
530
+ return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
531
+
532
+ def clear_inputs():
533
+ return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
534
+
535
+ def update_schema(db_id):
536
+ if not db_id: return ""
537
+ try:
538
+ raw_schema = schema_encoder.structured_schema(db_id)
539
+ html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
540
+ for line in raw_schema.strip().split('\n'):
541
+ line = line.strip()
542
+ if not line: continue
543
+ match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
544
+ if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
545
+ else: html_output += f"<div style='color: #475569;'>{line}</div>"
546
+ html_output += "</div>"
547
+ return html_output
548
+ except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
549
+
550
+ # =========================
551
+ # UI LAYOUT
552
+ # =========================
553
+ with gr.Blocks(title="Text-to-SQL RLHF") as demo:
554
+ gr.HTML("""
555
+ <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
556
+ <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
557
+ <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
558
+ </div>
559
+ """)
560
+
561
+ DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
562
+
563
+ with gr.Tabs():
564
+ with gr.Tab("Inference"):
565
+ with gr.Row():
566
+ with gr.Column(scale=1):
567
+ gr.Markdown("### 1. Configuration & Input")
568
+ input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
569
+ sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
570
+ type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
571
+ gr.Markdown("---")
572
+ db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
573
+ custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
574
+
575
+ gr.Markdown("#### 📋 Database Structure")
576
+ gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
577
+ schema_display = gr.HTML(value=update_schema("chinook_1"))
578
+
579
+ with gr.Row():
580
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
581
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
582
+
583
+ with gr.Column(scale=2):
584
+ gr.Markdown("### 2. Execution Results")
585
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
586
+ result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
587
+ explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
588
+
589
+ with gr.Tab("Diagnostics"):
590
+ gr.Markdown("## Diagnostics & Telemetry")
591
+
592
+ with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
593
+ gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
594
+ t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
595
+ t1_workers = gr.Number(value=10, precision=0, label="Max workers")
596
+ t1_run = gr.Button("Run Task 1 benchmark")
597
+ t1_out = gr.Textbox(label="Output", lines=12)
598
+ t1_plot = gr.HTML(label="Plot (if generated)")
599
+ t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
600
+
601
+ with gr.Accordion("Task 2: Error Dashboard", open=True):
602
+ gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
603
+ t2_refresh = gr.Button("Refresh dashboard")
604
+ t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
605
+ t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
606
+ t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
607
+ t2_examples = gr.Textbox(label="Examples + hint", lines=10)
608
+
609
+ t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
610
+ t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
611
+
612
+ with gr.Accordion("Task 2: Clause Telemetry", open=False):
613
+ gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
614
+ t2_ops_refresh = gr.Button("Refresh SQL-op stats")
615
+ t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
616
+ t2_ops_plot = gr.HTML(label="Op plot")
617
+ t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
618
+
619
+ with gr.Accordion("Task 2: Adversarial Suite", open=False):
620
+ gr.Markdown("*(Runs a predefined set of tricky, out-of-distribution natural language questions designed to confuse the model, testing its robustness and constraint mapping.)*")
621
+ adv_run = gr.Button("Run adversarial suite")
622
+ adv_out = gr.Dataframe(label="Adversarial results", interactive=False, wrap=True)
623
+ adv_summary = gr.Textbox(label="Summary", lines=8)
624
+ adv_run.click(fn=run_adversarial_suite, inputs=[], outputs=[adv_out, adv_summary])
625
+
626
+
627
+ # EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
628
+ input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
629
+ sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
630
+ db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
631
+
632
+ run_btn.click(
633
+ fn=run_query,
634
+ inputs=[input_method, sample_dropdown, custom_question, db_id],
635
+ outputs=[final_sql, result_table, explanation]
636
+ ).then(
637
+ fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
638
+ ).then(
639
+ fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
640
+ )
641
+
642
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
643
+
644
+ if __name__ == "__main__":
645
+ share = os.environ.get("GRADIO_SHARE", "0").strip() in {"1", "true", "True", "yes", "Y"}
646
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
647
+ demo.launch(
648
+ share=share,
649
+ server_name=server_name,
650
+ # theme=gr.themes.Soft()
651
+ )
final_databases/academic.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ca59ebaa830731011a222885480e4b9f9d49c3e36849dee25b769fb74f296c2
3
+ size 122880
final_databases/activity_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e32fff45bbcf0d0f4304bf646d52c651063f3c128d974d72cb751d5cf105c83
3
+ size 24576
final_databases/aircraft.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b92973e99d8348a324e318fd4d62645ebe428e7346b51277e78385d1ad3b1ef
3
+ size 45056
final_databases/allergy_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3ce73d649463526f6f6b04f457f40a778c8c192b8451cc2837437dc7a18a207
3
+ size 20480
final_databases/apartment_rentals.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42121c39f5132a9df8eea63ff01d4cdb446144e020079c073a58a949f9d40e34
3
+ size 53248
final_databases/architecture.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea405f3511d01a3f7d688fbec2b214b7fb82ed9f5795f424a0e1a655f704d406
3
+ size 28672
final_databases/assets_maintenance.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:847af85ebcc94f6be06e06f5bd8a22500bd4e208a29a3e732eed13c72eda43c1
3
+ size 69632
final_databases/battle_death.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12569f4493a9655639c4ea86ba4bd8a4ea6411e1b8ae0e1fdf3d6a995344265d
3
+ size 28672
final_databases/behavior_monitoring.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:689af3fbf11164f50d02f4a75ff6c9a7fb419938de861a1c26477ae577c1443a
3
+ size 65536
final_databases/bike_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5ae0e24e3a9d860a38ec6256828e3b9e37691c80931ef554adc202f8eb2950c
3
+ size 1785856
final_databases/body_builder.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76221ea59be7c46e0712e308dd9f0527c63b71669186b2e97c23df78ab974108
3
+ size 20480
final_databases/book_2.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b0bffbf8805809461b29573dc55782d810cc688af526e28088b299326d72407
3
+ size 20480
final_databases/browser_web.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ced1ee75f531e02adf66f387ced5997f416653adbcf519c037813aa69599f17a
3
+ size 28672
final_databases/candidate_poll.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7df58984f9d9796538716294ba0dda9f48a4a54907f03a3859285aa4b5cf8dc5
3
+ size 20480
final_databases/car_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d851e396e02997a1de073ae982fe1e4b1769fdffce2fac6e325857a3a938709
3
+ size 65536
final_databases/chinook_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70b3eb8c3ffb5351eb7943e3bf05b693a5397c22d200ff022ef99e2ca18ab7b9
3
+ size 901120
final_databases/cinema.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fbcc850e33d813a792417291ddfb1cebc4a31716621284ed28ef53df360a14e
3
+ size 28672
final_databases/city_record.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:112ff582c4552d8b35a0a70205fb7d8562f18c6a231b86fe8ac19beebebc4d4b
3
+ size 36864
final_databases/climbing.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2d3893fc5e79943ffad120b59522358b90ef077a989fb40c7c5fbbe16bfccd9
3
+ size 20480
final_databases/club_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c2fc993225d6ed604f7cb95146e025e5d71777f9a997db6ab0d9ff77eb26748
3
+ size 16384
final_databases/coffee_shop.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a182a8a7f9b5df6da8b2976b42aca0a7db6c2b13300b13b398e568a5bf4a6b2e
3
+ size 36864
final_databases/college_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bf3ee598b9f10b4d9bef9a2d390f1391bccbcb64986c88241fad24452695d18
3
+ size 53248
final_databases/college_2.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb729920ad0b7f06d38a12f6f678307964acc7d3417af83d98c519c65c90d386
3
+ size 2117632
final_databases/college_3.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f47b8f927f27a63ba4aa35fcf6cc8360fd68bfd267e6cf05c04430f19cd773e
3
+ size 45056
final_databases/company_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a58f492277a8b1bfa575746ac8b06dcbaed059c2a24ba6eb6090debc8b6c892f
3
+ size 12288
final_databases/company_employee.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86574c9edbecd80ded437d9439d6f81def0973816b1c8af8d007335e62565378
3
+ size 28672
final_databases/company_office.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4514dd10d9257ccaa08acff4d272595b6d92819575c736393f89e72f956c9915
3
+ size 28672
final_databases/concert_singer.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa1ba5ab4577e895271088b1dc44aa94be88e25a54293317a67584112ef059d
3
+ size 36864
final_databases/county_public_safety.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d608ce42a7b28ecb7500c4ded8096ade65a57dae6c5eefd56709ba96c3015d4a
3
+ size 20480
final_databases/course_teach.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da45fcdde64ac2b9330146506b7653ad4417489d8e496e509045e1b02245d793
3
+ size 28672
final_databases/cre_Doc_Control_Systems.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8451c4a097dcbf8fabf7993e70e91df3d825c3c941ba96eb41faa8d236897d30
3
+ size 81920
final_databases/cre_Doc_Template_Mgt.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3fdd03d8795ecde60aae782ce50c4fb1d6c03de1401fcdd6dc177342a53df5
3
+ size 24576
final_databases/cre_Doc_Tracking_DB.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c9df67d3c16e585d159328b627c334383bea333d250c515052f65613a48fc86
3
+ size 57344
final_databases/cre_Docs_and_Epenses.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87b83375d433672b428afaffa8f586fd73456b5e49f2609c1feae42d90cd4c79
3
+ size 40960
final_databases/cre_Drama_Workshop_Groups.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9a9e215fae2314be429307b47ee7fc3ffdc749fba933faba980b97a5f6b9a0a
3
+ size 147456
final_databases/cre_Theme_park.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea37aa49d5cd7905f04eb76c2638f706fab1ef3cd8ce41234507b457a660d261
3
+ size 94208
final_databases/csu_1.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66e85b742a4941118aab88183c4be7f3e238f9d2558a5e68f8818c504e2710d
3
+ size 102400
final_databases/culture_company.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d73d31db2c91dadeb67b8ed0c9366fbbe7658774dfa110161c168ee13e781820
3
+ size 28672
final_databases/customer_complaints.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:495dbacbcd3ea5afb6c659a99cc10ef524213b66f692a8803942924d93bdb4a9
3
+ size 20480
final_databases/customer_deliveries.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4d269a357190432d3f0b154070dccf1cc5f7f7715571c1597e2fdd9ee47d485
3
+ size 61440
final_databases/customers_and_addresses.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d8ceb2520e7b2db55eca8c2a1e9a52a5f1d9e27d7f0dd11bcd367274e6fee96
3
+ size 32768
final_databases/customers_and_invoices.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19621a734777eefff77544c93f5da387f086e53b9dc0aa6266bc2da55dbd7103
3
+ size 45056
final_databases/customers_and_products_contacts.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d7e27cbc07e177126eda4914c9fee9eba6afdcfef03ffdc3eeee094c68aa399
3
+ size 32768
final_databases/customers_campaigns_ecommerce.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f7ae5bbe53359854f5c18656d5878713e88b1a364dfba05670149bcc00f3544
3
+ size 36864
final_databases/customers_card_transactions.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2cde3ed8c5a33dd1f708a657ef9fab552c1a74ce19010bcc0dbff37ce708d62
3
+ size 20480
final_databases/debate.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e9d72cbaf8e3e695a2e8865e6a35ee8c73042ddd651c30b1816d72fadfccaa7
3
+ size 28672