jtdearmon commited on
Commit
51df8be
·
verified ·
1 Parent(s): 8d704bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +965 -0
app.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
2
+ # - Randomizes a relational domain via OpenAI (bookstore, retail sales, wholesaler,
3
+ # sales tax, oil & gas wells, marketing) OR falls back to a built-in dataset.
4
+ # - Builds 3–4 related tables (schema + seed rows) in SQLite.
5
+ # - Generates 8–12 randomized SQL questions with varied phrasings.
6
+ # - Validates answers by executing canonical SQL and comparing result sets.
7
+ # - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
8
+ # - Shows data results at the bottom pane for every run (SELECT or preview for VIEW/CTAS).
9
+ #
10
+ # Hugging Face Spaces: set OPENAI_API_KEY as a secret to enable LLM randomization.
11
+
12
+ import os
13
+ import re
14
+ import json
15
+ import time
16
+ import random
17
+ import sqlite3
18
+ from dataclasses import dataclass, asdict
19
+ from datetime import datetime, timezone
20
+ from typing import List, Dict, Any, Tuple, Optional
21
+
22
+ import gradio as gr
23
+ import pandas as pd
24
+ import numpy as np
25
+
26
+ # Matplotlib for ERD drawing (headless)
27
+ import matplotlib
28
+ matplotlib.use("Agg")
29
+ import matplotlib.pyplot as plt
30
+ from io import BytesIO
31
+ from PIL import Image
32
+
33
+ # -------------------- OpenAI (optional) --------------------
34
+ USE_RESPONSES_API = True
35
+ OPENAI_AVAILABLE = True
36
+ MODEL_ID = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
37
+ try:
38
+ from openai import OpenAI
39
+ _client = OpenAI() # requires OPENAI_API_KEY
40
+ except Exception:
41
+ OPENAI_AVAILABLE = False
42
+ _client = None
43
+
44
+ # -------------------- Global settings --------------------
45
+ DB_DIR = "/data" if os.path.exists("/data") else "."
46
+ DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
47
+ EXPORT_DIR = "."
48
+ ADMIN_KEY = os.getenv("ADMIN_KEY", "demo")
49
+ RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
50
+ random.seed(RANDOM_SEED)
51
+ SYS_RAND = random.SystemRandom()
52
+
53
+ PLOT_FIGSIZE = (6.8, 3.4)
54
+ PLOT_DPI = 110
55
+ PLOT_HEIGHT = 300
56
+
57
+ # -------------------- ERD helpers --------------------
58
+ def _to_pil(fig) -> Image.Image:
59
+ buf = BytesIO()
60
+ fig.tight_layout()
61
+ fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
62
+ plt.close(fig)
63
+ buf.seek(0)
64
+ return Image.open(buf)
65
+
66
+ def draw_dynamic_erd(schema: Dict[str, Any]) -> Image.Image:
67
+ """
68
+ schema = {
69
+ "domain": "bookstore",
70
+ "tables": [
71
+ {"name":"authors","columns":[{"name":"author_id","type":"INTEGER",...}, ...],
72
+ "pk":["author_id"], "fks":[{"columns":["author_id"],"ref_table":"...","ref_columns":["..."]}],
73
+ "rows":[{...}, {...}]}
74
+ ]
75
+ }
76
+ """
77
+ fig, ax = plt.subplots(figsize=PLOT_FIGSIZE)
78
+ ax.axis("off")
79
+ tables = schema.get("tables", [])
80
+ n = max(1, len(tables))
81
+ # Lay out boxes horizontally
82
+ margin = 0.03
83
+ width = (1 - margin*(n+1)) / n
84
+ height = 0.65
85
+ y = 0.25
86
+ boxes = {}
87
+ for i, t in enumerate(tables):
88
+ x = margin + i*(width + margin)
89
+ boxes[t["name"]] = (x, y, width, height)
90
+ ax.add_patch(plt.Rectangle((x, y), width, height, fill=False))
91
+ ax.text(x + 0.01, y + height - 0.05, f"**{t['name']}**", fontsize=10, ha="left", va="top")
92
+ yy = y + height - 0.10
93
+ pk = set(t.get("pk", []))
94
+ cols = t.get("columns", [])
95
+ for col in cols:
96
+ nm = col["name"]
97
+ mark = " (PK)" if nm in pk else ""
98
+ ax.text(x + 0.02, yy, f"{nm}{mark}", fontsize=9, ha="left", va="top")
99
+ yy -= 0.06
100
+
101
+ # Draw FK arrows
102
+ for t in tables:
103
+ for fk in t.get("fks", []):
104
+ src_tbl = t["name"]
105
+ dst_tbl = fk.get("ref_table")
106
+ if src_tbl in boxes and dst_tbl in boxes:
107
+ (x1, y1, w1, h1) = boxes[src_tbl]
108
+ (x2, y2, w2, h2) = boxes[dst_tbl]
109
+ ax.annotate("", xy=(x2 + w2/2, y2 + h2), xytext=(x1 + w1/2, y1),
110
+ arrowprops=dict(arrowstyle="->", lw=1.1))
111
+ ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
112
+ return _to_pil(fig)
113
+
114
+ # -------------------- SQLite helpers --------------------
115
+ def connect_db():
116
+ con = sqlite3.connect(DB_PATH)
117
+ con.execute("PRAGMA foreign_keys = ON;")
118
+ return con
119
+
120
+ CONN = connect_db()
121
+
122
+ def init_progress_tables(con: sqlite3.Connection):
123
+ cur = con.cursor()
124
+ cur.execute("""
125
+ CREATE TABLE IF NOT EXISTS users (
126
+ user_id TEXT PRIMARY KEY,
127
+ name TEXT,
128
+ created_at TEXT
129
+ )
130
+ """)
131
+ cur.execute("""
132
+ CREATE TABLE IF NOT EXISTS attempts (
133
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
134
+ user_id TEXT,
135
+ question_id TEXT,
136
+ category TEXT,
137
+ correct INTEGER,
138
+ sql_text TEXT,
139
+ timestamp TEXT,
140
+ time_taken REAL,
141
+ difficulty INTEGER,
142
+ source TEXT,
143
+ notes TEXT
144
+ )
145
+ """)
146
+ cur.execute("""
147
+ CREATE TABLE IF NOT EXISTS session_meta (
148
+ id INTEGER PRIMARY KEY CHECK (id=1),
149
+ domain TEXT,
150
+ schema_json TEXT
151
+ )
152
+ """)
153
+ con.commit()
154
+
155
+ init_progress_tables(CONN)
156
+
157
+ # -------------------- Fallback dataset (if no OpenAI) --------------------
158
+ FALLBACK_SCHEMA = {
159
+ "domain": "bookstore",
160
+ "tables": [
161
+ {
162
+ "name": "authors",
163
+ "pk": ["author_id"],
164
+ "columns": [
165
+ {"name":"author_id","type":"INTEGER"},
166
+ {"name":"name","type":"TEXT"},
167
+ {"name":"country","type":"TEXT"},
168
+ {"name":"birth_year","type":"INTEGER"},
169
+ ],
170
+ "fks": [],
171
+ "rows": [
172
+ {"author_id":1,"name":"Isaac Asimov","country":"USA","birth_year":1920},
173
+ {"author_id":2,"name":"Ursula K. Le Guin","country":"USA","birth_year":1929},
174
+ {"author_id":3,"name":"Haruki Murakami","country":"Japan","birth_year":1949},
175
+ {"author_id":4,"name":"Chinua Achebe","country":"Nigeria","birth_year":1930},
176
+ {"author_id":5,"name":"Jane Austen","country":"UK","birth_year":1775},
177
+ {"author_id":6,"name":"J.K. Rowling","country":"UK","birth_year":1965},
178
+ {"author_id":7,"name":"Yuval Noah Harari","country":"Israel","birth_year":1976},
179
+ {"author_id":8,"name":"New Author","country":"Nowhere","birth_year":1990},
180
+ ],
181
+ },
182
+ {
183
+ "name": "bookstores",
184
+ "pk": ["store_id"],
185
+ "columns": [
186
+ {"name":"store_id","type":"INTEGER"},
187
+ {"name":"name","type":"TEXT"},
188
+ {"name":"city","type":"TEXT"},
189
+ {"name":"state","type":"TEXT"},
190
+ ],
191
+ "fks": [],
192
+ "rows": [
193
+ {"store_id":1,"name":"Downtown Books","city":"Oklahoma City","state":"OK"},
194
+ {"store_id":2,"name":"Harbor Books","city":"Seattle","state":"WA"},
195
+ {"store_id":3,"name":"Desert Pages","city":"Phoenix","state":"AZ"},
196
+ ],
197
+ },
198
+ {
199
+ "name": "books",
200
+ "pk": ["book_id"],
201
+ "columns": [
202
+ {"name":"book_id","type":"INTEGER"},
203
+ {"name":"title","type":"TEXT"},
204
+ {"name":"author_id","type":"INTEGER"},
205
+ {"name":"store_id","type":"INTEGER"},
206
+ {"name":"category","type":"TEXT"},
207
+ {"name":"price","type":"REAL"},
208
+ {"name":"published_year","type":"INTEGER"},
209
+ ],
210
+ "fks": [
211
+ {"columns":["author_id"],"ref_table":"authors","ref_columns":["author_id"]},
212
+ {"columns":["store_id"],"ref_table":"bookstores","ref_columns":["store_id"]},
213
+ ],
214
+ "rows": [
215
+ {"book_id":101,"title":"Foundation","author_id":1,"store_id":1,"category":"Sci-Fi","price":14.99,"published_year":1951},
216
+ {"book_id":102,"title":"I, Robot","author_id":1,"store_id":1,"category":"Sci-Fi","price":12.50,"published_year":1950},
217
+ {"book_id":103,"title":"The Left Hand of Darkness","author_id":2,"store_id":2,"category":"Sci-Fi","price":16.00,"published_year":1969},
218
+ {"book_id":104,"title":"A Wizard of Earthsea","author_id":2,"store_id":2,"category":"Fantasy","price":11.50,"published_year":1968},
219
+ {"book_id":105,"title":"Norwegian Wood","author_id":3,"store_id":3,"category":"Fiction","price":18.00,"published_year":1987},
220
+ {"book_id":106,"title":"Kafka on the Shore","author_id":3,"store_id":1,"category":"Fiction","price":21.00,"published_year":2002},
221
+ {"book_id":107,"title":"Things Fall Apart","author_id":4,"store_id":1,"category":"Fiction","price":10.00,"published_year":1958},
222
+ {"book_id":108,"title":"Pride and Prejudice","author_id":5,"store_id":2,"category":"Fiction","price":9.00,"published_year":1813},
223
+ {"book_id":109,"title":"Harry Potter and the Sorcerer's Stone","author_id":6,"store_id":3,"category":"Children","price":22.00,"published_year":1997},
224
+ {"book_id":110,"title":"Harry Potter and the Chamber of Secrets","author_id":6,"store_id":3,"category":"Children","price":23.00,"published_year":1998},
225
+ {"book_id":111,"title":"Sapiens","author_id":7,"store_id":1,"category":"History","price":26.00,"published_year":2011},
226
+ {"book_id":112,"title":"Homo Deus","author_id":7,"store_id":2,"category":"History","price":28.00,"published_year":2015},
227
+ ],
228
+ },
229
+ ]
230
+ }
231
+
232
+ FALLBACK_QUESTIONS = [
233
+ {
234
+ "id":"Q01","category":"SELECT *","difficulty":1,
235
+ "prompt_md":"Select all rows and columns from `authors`.",
236
+ "answer_sql":["SELECT * FROM authors;"],
237
+ "requires_aliases":False,"required_aliases":[]
238
+ },
239
+ {
240
+ "id":"Q02","category":"SELECT columns","difficulty":1,
241
+ "prompt_md":"Show `title` and `price` from `books`.",
242
+ "answer_sql":["SELECT title, price FROM books;"],
243
+ "requires_aliases":False,"required_aliases":[]
244
+ },
245
+ {
246
+ "id":"Q03","category":"WHERE","difficulty":1,
247
+ "prompt_md":"List Sci‑Fi books under $15 (show title, price).",
248
+ "answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"],
249
+ "requires_aliases":False,"required_aliases":[]
250
+ },
251
+ {
252
+ "id":"Q04","category":"Aliases","difficulty":1,
253
+ "prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.",
254
+ "answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"],
255
+ "requires_aliases":True,"required_aliases":["a","b"]
256
+ },
257
+ {
258
+ "id":"Q05","category":"JOIN (INNER)","difficulty":2,
259
+ "prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.",
260
+ "answer_sql":[
261
+ "SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;"
262
+ ],
263
+ "requires_aliases":False,"required_aliases":[]
264
+ },
265
+ {
266
+ "id":"Q06","category":"JOIN (LEFT)","difficulty":2,
267
+ "prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.",
268
+ "answer_sql":[
269
+ "SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;"
270
+ ],
271
+ "requires_aliases":False,"required_aliases":[]
272
+ },
273
+ {
274
+ "id":"Q07","category":"VIEW","difficulty":2,
275
+ "prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.",
276
+ "answer_sql":[
277
+ "CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;"
278
+ ],
279
+ "requires_aliases":False,"required_aliases":[]
280
+ },
281
+ {
282
+ "id":"Q08","category":"CTAS / SELECT INTO","difficulty":2,
283
+ "prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.",
284
+ "answer_sql":[
285
+ "CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;",
286
+ "SELECT * INTO cheap_books FROM books WHERE price < 12;"
287
+ ],
288
+ "requires_aliases":False,"required_aliases":[]
289
+ },
290
+ ]
291
+
292
+ # -------------------- OpenAI prompts --------------------
293
+ DOMAIN_AND_QUESTIONS_SCHEMA = {
294
+ "name": "DomainSQLPack",
295
+ "schema": {
296
+ "type": "object",
297
+ "additionalProperties": False,
298
+ "properties": {
299
+ "domain": {"type":"string"},
300
+ "tables": {
301
+ "type":"array",
302
+ "items": {
303
+ "type":"object",
304
+ "additionalProperties": False,
305
+ "properties": {
306
+ "name": {"type":"string"},
307
+ "pk": {"type":"array","items":{"type":"string"}},
308
+ "columns": {
309
+ "type":"array",
310
+ "items": {
311
+ "type":"object",
312
+ "additionalProperties": False,
313
+ "properties": {
314
+ "name":{"type":"string"},
315
+ "type":{"type":"string"}
316
+ },
317
+ "required":["name","type"]
318
+ }
319
+ },
320
+ "fks": {
321
+ "type":"array",
322
+ "items": {
323
+ "type":"object",
324
+ "additionalProperties": False,
325
+ "properties": {
326
+ "columns":{"type":"array","items":{"type":"string"}},
327
+ "ref_table":{"type":"string"},
328
+ "ref_columns":{"type":"array","items":{"type":"string"}}
329
+ },
330
+ "required":["columns","ref_table","ref_columns"]
331
+ }
332
+ },
333
+ "rows": {"type":"array","items":{"type":["object","array"]}}
334
+ },
335
+ "required":["name","pk","columns","fks","rows"]
336
+ },
337
+ "minItems":3,"maxItems":4
338
+ },
339
+ "questions": {
340
+ "type":"array",
341
+ "items": {
342
+ "type":"object",
343
+ "additionalProperties": False,
344
+ "properties": {
345
+ "id":{"type":"string"},
346
+ "category":{"type":"string"},
347
+ "difficulty":{"type":"integer"},
348
+ "prompt_md":{"type":"string"},
349
+ "answer_sql":{"type":"array","items":{"type":"string"}},
350
+ "requires_aliases":{"type":"boolean"},
351
+ "required_aliases":{"type":"array","items":{"type":"string"}}
352
+ },
353
+ "required":["id","category","difficulty","prompt_md","answer_sql"]
354
+ },
355
+ "minItems":8,"maxItems":12
356
+ }
357
+ },
358
+ "required":["domain","tables","questions"]
359
+ },
360
+ "strict": True
361
+ }
362
+
363
+ DOMAIN_AND_QUESTIONS_PROMPT = """
364
+ You are designing a small relational dataset and training questions for SQL basics.
365
+
366
+ 1) Choose ONE domain at random from:
367
+ - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
368
+
369
+ 2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
370
+ - Use snake_case, avoid reserved words.
371
+ - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (but no advanced features).
372
+ - Primary keys (pk) and foreign keys (fks) must align.
373
+ - Provide 8–15 small, realistic seed rows per table (not huge).
374
+
375
+ 3) Generate 8–12 SQL questions covering basics with varied, natural language:
376
+ - Categories from: "SELECT *", "SELECT columns", "WHERE", "Aliases",
377
+ "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
378
+ - Include a few joins and at least one LEFT JOIN.
379
+ - Include one view creation.
380
+ - Include one table creation from SELECT (either CTAS or SELECT INTO).
381
+ - Prefer SQLite-compatible SQL. DO NOT use RIGHT/FULL OUTER JOIN.
382
+ - Offer 1–3 acceptable answer_sql variants per question.
383
+ - For 1–2 questions, require table aliases (set requires_aliases=true and list required_aliases).
384
+
385
+ Return JSON only.
386
+ """
387
+
388
+ def llm_generate_domain_and_questions() -> Optional[Dict[str,Any]]:
389
+ if not OPENAI_AVAILABLE:
390
+ return None
391
+ try:
392
+ if USE_RESPONSES_API:
393
+ resp = _client.responses.create(
394
+ model=MODEL_ID,
395
+ response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
396
+ input=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
397
+ temperature=0.6,
398
+ )
399
+ data_text = getattr(resp, "output_text", None)
400
+ else:
401
+ chat = _client.chat.completions.create(
402
+ model=MODEL_ID,
403
+ messages=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
404
+ temperature=0.6
405
+ )
406
+ data_text = chat.choices[0].message.content
407
+ obj = json.loads(data_text) if data_text else None
408
+ return obj
409
+ except Exception:
410
+ return None
411
+
412
+ # -------------------- Schema install & question handling --------------------
413
+ def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
414
+ cur = con.cursor()
415
+ cur.execute("SELECT name, type FROM sqlite_master WHERE type IN ('table','view')")
416
+ items = cur.fetchall()
417
+ for name, typ in items:
418
+ if keep_internal and name in ("users","attempts","session_meta"):
419
+ continue
420
+ try:
421
+ cur.execute(f"DROP {typ.upper()} IF EXISTS {name}")
422
+ except Exception:
423
+ pass
424
+ con.commit()
425
+
426
+ def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
427
+ drop_existing_domain_tables(con, keep_internal=True)
428
+ cur = con.cursor()
429
+ # Create tables first
430
+ for t in schema.get("tables", []):
431
+ cols_sql = []
432
+ pk = t.get("pk", [])
433
+ for c in t.get("columns", []):
434
+ cname = c["name"]
435
+ ctype = c.get("type","TEXT")
436
+ cols_sql.append(f"{cname} {ctype}")
437
+ if pk:
438
+ cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
439
+ create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
440
+ cur.execute(create_sql)
441
+ # Add FKs (SQLite requires them inline on create; to be safe, we validate only)
442
+ # Insert rows
443
+ for t in schema.get("tables", []):
444
+ if not t.get("rows"):
445
+ continue
446
+ cols = [c["name"] for c in t.get("columns", [])]
447
+ qmarks = ",".join(["?"]*len(cols))
448
+ insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
449
+ # rows can be objects or arrays
450
+ for r in t["rows"]:
451
+ if isinstance(r, dict):
452
+ vals = [r.get(col, None) for col in cols]
453
+ elif isinstance(r, list) or isinstance(r, tuple):
454
+ vals = list(r) + [None]*(len(cols)-len(r))
455
+ vals = vals[:len(cols)]
456
+ else:
457
+ continue
458
+ cur.execute(insert_sql, vals)
459
+ con.commit()
460
+ # Persist schema JSON
461
+ cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
462
+ (schema.get("domain","unknown"), json.dumps(schema)))
463
+ con.commit()
464
+
465
+ def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
466
+ return pd.read_sql_query(sql, con)
467
+
468
+ def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
469
+ s = sql.strip().strip(";")
470
+ if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
471
+ m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
472
+ if m:
473
+ cols, tbl, rest = m.groups()
474
+ return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
475
+ return sql, None
476
+
477
+ def detect_unsupported_joins(sql: str) -> Optional[str]:
478
+ low = sql.lower()
479
+ if " right join " in low:
480
+ return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
481
+ if " full join " in low or " full outer join " in low:
482
+ return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
483
+ if " ilike " in low:
484
+ return "SQLite has no ILIKE. Use `LOWER(col) LIKE LOWER('%pattern%')`."
485
+ return None
486
+
487
+ def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
488
+ low = sql.lower()
489
+ if " cross join " in low:
490
+ return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
491
+ comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
492
+ missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
493
+ if comma_from or missing_on:
494
+ try:
495
+ cur = con.cursor()
496
+ if comma_from:
497
+ t1, t2 = comma_from.groups()
498
+ else:
499
+ m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low)
500
+ j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
501
+ if not m or not j:
502
+ return "Possible cartesian product: no join condition detected."
503
+ t1, t2 = m.group(1), j.group(1)
504
+ cur.execute(f"SELECT COUNT(*) FROM {t1}")
505
+ n1 = cur.fetchone()[0]
506
+ cur.execute(f"SELECT COUNT(*) FROM {t2}")
507
+ n2 = cur.fetchone()[0]
508
+ prod = n1 * n2
509
+ if len(df_result) == prod and prod > 0:
510
+ return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
511
+ except Exception:
512
+ return "Possible cartesian product: no join condition detected."
513
+ return None
514
+
515
+ def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool:
516
+ if df_a.shape != df_b.shape:
517
+ return False
518
+ a = df_a.copy()
519
+ b = df_b.copy()
520
+ a.columns = [c.lower() for c in a.columns]
521
+ b.columns = [c.lower() for c in b.columns]
522
+ a = a.sort_values(list(a.columns)).reset_index(drop=True)
523
+ b = b.sort_values(list(b.columns)).reset_index(drop=True)
524
+ return a.equals(b)
525
+
526
+ def aliases_present(sql: str, required_aliases: List[str]) -> bool:
527
+ low = re.sub(r"\s+", " ", sql.lower())
528
+ for al in required_aliases:
529
+ if f" {al}." not in low and f" as {al} " not in low:
530
+ return False
531
+ return True
532
+
533
+ # -------------------- Question model --------------------
534
+ @dataclass
535
+ class SQLQuestion:
536
+ id: str
537
+ category: str
538
+ difficulty: int
539
+ prompt_md: str
540
+ answer_sql: List[str]
541
+ requires_aliases: bool = False
542
+ required_aliases: List[str] = None
543
+
544
+ def to_question_dict(q) -> Dict[str,Any]:
545
+ d = dict(q)
546
+ d.setdefault("requires_aliases", False)
547
+ d.setdefault("required_aliases", [])
548
+ return d
549
+
550
+ def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
551
+ out = []
552
+ for o in obj_list:
553
+ out.append(to_question_dict(o))
554
+ return out
555
+
556
+ # -------------------- Domain bootstrap --------------------
557
+ def bootstrap_domain_with_llm_or_fallback() -> Tuple[Dict[str,Any], List[Dict[str,Any]]]:
558
+ obj = llm_generate_domain_and_questions()
559
+ if obj is None:
560
+ return FALLBACK_SCHEMA, FALLBACK_QUESTIONS
561
+ # Guardrails: strip RIGHT/FULL joins from answers
562
+ clean_qs = []
563
+ for q in obj.get("questions", []):
564
+ answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()]
565
+ if not answers:
566
+ continue
567
+ q["answer_sql"] = answers
568
+ q.setdefault("requires_aliases", False)
569
+ q.setdefault("required_aliases", [])
570
+ clean_qs.append(q)
571
+ obj["questions"] = clean_qs
572
+ return obj, clean_qs
573
+
574
+ def install_new_domain():
575
+ schema, questions = bootstrap_domain_with_llm_or_fallback()
576
+ install_schema(CONN, schema)
577
+ return schema, questions
578
+
579
+ # -------------------- Session state --------------------
580
+ CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
581
+
582
+ # -------------------- Progress + mastery --------------------
583
+ def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
584
+ cur = con.cursor()
585
+ cur.execute("SELECT user_id FROM users WHERE user_id = ?", (user_id,))
586
+ if cur.fetchone() is None:
587
+ cur.execute("INSERT INTO users (user_id, name, created_at) VALUES (?, ?, ?)",
588
+ (user_id, name, datetime.now(timezone.utc).isoformat()))
589
+ else:
590
+ cur.execute("UPDATE users SET name=? WHERE user_id=?", (name, user_id))
591
+ con.commit()
592
+
593
+ CATEGORIES_ORDER = [
594
+ "SELECT *", "SELECT columns", "WHERE", "Aliases",
595
+ "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO"
596
+ ]
597
+
598
+ def topic_stats(df_attempts: pd.DataFrame) -> pd.DataFrame:
599
+ rows = []
600
+ for cat in CATEGORIES_ORDER:
601
+ sub = df_attempts[df_attempts["category"] == cat] if not df_attempts.empty else pd.DataFrame()
602
+ att = int(sub.shape[0]) if not sub.empty else 0
603
+ cor = int(sub["correct"].sum()) if not sub.empty else 0
604
+ acc = float(cor / max(att, 1))
605
+ rows.append({"category":cat,"attempts":att,"correct":cor,"accuracy":acc})
606
+ return pd.DataFrame(rows)
607
+
608
+ def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
609
+ return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
610
+
611
+ def pick_next_question(user_id: str) -> Dict[str,Any]:
612
+ df = fetch_attempts(CONN, user_id)
613
+ stats = topic_stats(df)
614
+ stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
615
+ weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
616
+ cands = [q for q in CURRENT_QS if q["category"] == weakest] or CURRENT_QS
617
+ return dict(random.choice(cands))
618
+
619
+ # -------------------- Execution & feedback --------------------
620
+ def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
621
+ if not sql_text or not sql_text.strip():
622
+ return None, "Enter a SQL statement.", None, None
623
+
624
+ sql_raw = sql_text.strip().rstrip(";")
625
+ sql_rew, created_tbl = rewrite_select_into(sql_raw)
626
+ note = None
627
+ if sql_rew != sql_raw:
628
+ note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite."
629
+
630
+ unsup = detect_unsupported_joins(sql_rew)
631
+ if unsup:
632
+ return None, unsup, None, note
633
+
634
+ try:
635
+ low = sql_rew.lower()
636
+ if low.startswith("select"):
637
+ df = run_df(CONN, sql_rew)
638
+ warn = detect_cartesian(CONN, sql_rew, df)
639
+ return df, None, warn, note
640
+ else:
641
+ cur = CONN.cursor()
642
+ cur.execute(sql_rew)
643
+ CONN.commit()
644
+ # Preview newly created objects
645
+ if low.startswith("create view"):
646
+ m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
647
+ name = m.group(2) if m else None
648
+ if name:
649
+ try:
650
+ df = run_df(CONN, f"SELECT * FROM {name}")
651
+ return df, None, None, note
652
+ except Exception:
653
+ return None, "View created but could not be queried.", None, note
654
+ if low.startswith("create table"):
655
+ tbl = created_tbl
656
+ if not tbl:
657
+ m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
658
+ tbl = m.group(2) if m else None
659
+ if tbl:
660
+ try:
661
+ df = run_df(CONN, f"SELECT * FROM {tbl}")
662
+ return df, None, None, note
663
+ except Exception:
664
+ return None, "Table created but could not be queried.", None, note
665
+ return pd.DataFrame(), None, None, note
666
+ except Exception as e:
667
+ # Tailored messages
668
+ msg = str(e)
669
+ if "no such table" in msg.lower():
670
+ return None, f"{msg}. Check table names for this randomized domain.", None, note
671
+ if "no such column" in msg.lower():
672
+ return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
673
+ if "ambiguous column name" in msg.lower():
674
+ return None, f"{msg}. Qualify the column with a table alias.", None, note
675
+ if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
676
+ return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
677
+ if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
678
+ return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
679
+ if "syntax error" in msg.lower():
680
+ return None, f"Syntax error. Check commas, keywords, and parentheses. Raw error: {msg}", None, note
681
+ return None, f"SQL error: {msg}", None, note
682
+
683
+ def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
684
+ for sql in answer_sql:
685
+ try:
686
+ low = sql.strip().lower()
687
+ if low.startswith("select"):
688
+ return run_df(CONN, sql)
689
+ if low.startswith("create view"):
690
+ # temp preview
691
+ m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
692
+ view_name = m.group(2) if m else "vw_tmp"
693
+ cur = CONN.cursor()
694
+ cur.execute(f"DROP VIEW IF EXISTS {view_name}")
695
+ cur.execute(sql)
696
+ CONN.commit()
697
+ return run_df(CONN, f"SELECT * FROM {view_name}")
698
+ if low.startswith("create table"):
699
+ m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
700
+ tbl = m.group(2) if m else None
701
+ cur = CONN.cursor()
702
+ if tbl:
703
+ cur.execute(f"DROP TABLE IF EXISTS {tbl}")
704
+ cur.execute(sql)
705
+ CONN.commit()
706
+ if tbl:
707
+ return run_df(CONN, f"SELECT * FROM {tbl}")
708
+ except Exception:
709
+ continue
710
+ return None
711
+
712
+ def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]:
713
+ df_expected = answer_df(q["answer_sql"])
714
+ # If we can't build a canonical DF (e.g., DDL side effect), we accept any successful execution as correct
715
+ if df_expected is None:
716
+ return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
717
+ if df_student is None:
718
+ return False, f"**Explanation:** Expected data result differs."
719
+ return results_equal(df_student, df_expected), f"**Explanation:** Compare your result to a canonical solution."
720
+
721
+ def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
722
+ time_taken: float, difficulty: int, source: str, notes: str):
723
+ cur = CONN.cursor()
724
+ cur.execute("""
725
+ INSERT INTO attempts (user_id, question_id, category, correct, sql_text, timestamp, time_taken, difficulty, source, notes)
726
+ VALUES (?,?,?,?,?,?,?,?,?,?)
727
+ """, (user_id, qid, category, int(correct), sql_text, datetime.now(timezone.utc).isoformat(),
728
+ time_taken, difficulty, source, notes))
729
+ CONN.commit()
730
+
731
+ # -------------------- UI callbacks --------------------
732
+ def start_session(name: str, session: dict):
733
+ name = (name or "").strip()
734
+ if not name:
735
+ return session, gr.update(value="Please enter your name to begin.", visible=True), gr.update(visible=False), gr.update(visible=False), None, gr.update(visible=False), pd.DataFrame(), pd.DataFrame()
736
+
737
+ slug = "-".join(name.lower().split())
738
+ user_id = slug[:64] if slug else f"user-{int(time.time())}"
739
+ upsert_user(CONN, user_id, name)
740
+ q = pick_next_question(user_id)
741
+ session = {"user_id": user_id, "name": name, "qid": q["id"], "start_ts": time.time(), "q": q}
742
+
743
+ prompt = q["prompt_md"]
744
+ stats = topic_stats(fetch_attempts(CONN, user_id))
745
+ erd = draw_dynamic_erd(CURRENT_SCHEMA)
746
+ return (session,
747
+ gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
748
+ gr.update(visible=True), # show SQL input
749
+ gr.update(value="", visible=True), # preview block
750
+ erd,
751
+ gr.update(visible=False), # next btn hidden until submit
752
+ stats,
753
+ pd.DataFrame())
754
+
755
+ def render_preview_and_erd(sql_text: str, session: dict):
756
+ if not session or "q" not in session:
757
+ return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
758
+ s = (sql_text or "").strip()
759
+ if not s:
760
+ return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
761
+ return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), draw_dynamic_erd(CURRENT_SCHEMA)
762
+
763
+ def submit_answer(sql_text: str, session: dict):
764
+ if not session or "user_id" not in session or "q" not in session:
765
+ return gr.update(value="Start a session first.", visible=True), pd.DataFrame(), gr.update(visible=False), pd.DataFrame()
766
+ user_id = session["user_id"]
767
+ q = session["q"]
768
+ elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
769
+
770
+ df, err, warn, note = exec_student_sql(sql_text)
771
+ details = []
772
+ if note: details.append(f"ℹ️ {note}")
773
+ if err:
774
+ fb = f"❌ **Did not run**\n\n{err}"
775
+ if details: fb += "\n\n" + "\n".join(details)
776
+ log_attempt(user_id, q["id"], q["category"], False, sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join([err] + details))
777
+ stats = topic_stats(fetch_attempts(CONN, user_id))
778
+ return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
779
+
780
+ # Validate correctness
781
+ alias_msg = None
782
+ if q.get("requires_aliases"):
783
+ if not aliases_present(sql_text, q.get("required_aliases", [])):
784
+ alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
785
+
786
+ is_correct, explanation = validate_answer(q, sql_text, df)
787
+ if warn: details.append(f"⚠️ {warn}")
788
+ if alias_msg: details.append(alias_msg)
789
+
790
+ prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
791
+ feedback = prefix
792
+ if details:
793
+ feedback += "\n\n" + "\n".join(details)
794
+ feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
795
+
796
+ log_attempt(user_id, q["id"], q["category"], bool(is_correct), sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join(details))
797
+ stats = topic_stats(fetch_attempts(CONN, user_id))
798
+ return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
799
+
800
+ def next_question(session: dict):
801
+ if not session or "user_id" not in session:
802
+ return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
803
+ user_id = session["user_id"]
804
+ q = pick_next_question(user_id)
805
+ session["qid"] = q["id"]
806
+ session["q"] = q
807
+ session["start_ts"] = time.time()
808
+ return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
809
+
810
+ def show_hint(session: dict):
811
+ if not session or "q" not in session:
812
+ return gr.update(value="Start a session first.", visible=True)
813
+ # Lightweight hint policy: category-specific guidance
814
+ cat = session["q"]["category"]
815
+ hint = {
816
+ "SELECT *": "Use `SELECT * FROM table_name`.",
817
+ "SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
818
+ "WHERE": "Filter with `WHERE` and combine conditions using AND/OR.",
819
+ "Aliases": "Use `table_name t` and qualify: `t.col`.",
820
+ "JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.",
821
+ "JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.",
822
+ "Aggregation": "Use aggregate functions and `GROUP BY` non-aggregated columns.",
823
+ "VIEW": "`CREATE VIEW view_name AS SELECT ...`.",
824
+ "CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`."
825
+ }.get(cat, "Read the ER diagram and identify keys to join on.")
826
+ return gr.update(value=f"**Hint:** {hint}", visible=True)
827
+
828
+ def export_progress(user_name: str):
829
+ slug = "-".join((user_name or "").lower().split())
830
+ if not slug:
831
+ return None
832
+ user_id = slug[:64]
833
+ df = fetch_attempts(CONN, user_id)
834
+ os.makedirs(EXPORT_DIR, exist_ok=True)
835
+ path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
836
+ (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
837
+ return path
838
+
839
+ def regenerate_domain():
840
+ global CURRENT_SCHEMA, CURRENT_QS
841
+ CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
842
+ erd = draw_dynamic_erd(CURRENT_SCHEMA)
843
+ return gr.update(value="✅ Domain regenerated.", visible=True), erd
844
+
845
+ def preview_table(tbl: str):
846
+ try:
847
+ return run_df(CONN, f"SELECT * FROM {tbl} LIMIT 20")
848
+ except Exception as e:
849
+ return pd.DataFrame([{"error": str(e)}])
850
+
851
+ def list_tables_for_preview():
852
+ df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
853
+ if df.empty:
854
+ return ["(no tables)"]
855
+ return df["name"].tolist()
856
+
857
+ # -------------------- UI --------------------
858
+ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
859
+ gr.Markdown(
860
+ """
861
+ # 🧪 Adaptive SQL Trainer — Randomized Domains (SQLite)
862
+ - Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
863
+ sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
864
+ - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
865
+ - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.
866
+
867
+ > Set your `OPENAI_API_KEY` in the Space secrets to enable randomization.
868
+ """
869
+ )
870
+
871
+ with gr.Row():
872
+ with gr.Column(scale=1):
873
+ name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
874
+ start_btn = gr.Button("Start / Resume Session", variant="primary")
875
+ session_state = gr.State({"user_id": None, "name": None, "qid": None, "start_ts": None, "q": None})
876
+
877
+ gr.Markdown("---")
878
+ gr.Markdown("### Dataset Controls")
879
+ regen_btn = gr.Button("🔀 Randomize Dataset (OpenAI)")
880
+ regen_fb = gr.Markdown(visible=False)
881
+
882
+ gr.Markdown("---")
883
+ gr.Markdown("### Instructor Tools")
884
+ export_name = gr.Textbox(label="Export a student's progress (enter name)")
885
+ export_btn = gr.Button("Export CSV")
886
+ export_file = gr.File(label="Download progress")
887
+
888
+ gr.Markdown("---")
889
+ gr.Markdown("### Quick Table/View Preview (top 20 rows)")
890
+ tbl_dd = gr.Dropdown(choices=list_tables_for_preview(), label="Pick table/view", interactive=True)
891
+ tbl_btn = gr.Button("Preview")
892
+ preview_df = gr.Dataframe(headers=[], interactive=False)
893
+
894
+ with gr.Column(scale=2):
895
+ prompt_md = gr.Markdown(visible=False)
896
+ sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
897
+
898
+ preview_md = gr.Markdown(visible=False)
899
+ er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)
900
+
901
+ with gr.Row():
902
+ submit_btn = gr.Button("Run & Submit", variant="primary")
903
+ hint_btn = gr.Button("Hint")
904
+ next_btn = gr.Button("Next Question ▶", visible=False)
905
+
906
+ feedback_md = gr.Markdown("")
907
+
908
+ gr.Markdown("---")
909
+ gr.Markdown("### Your Progress by Category")
910
+ mastery_df = gr.Dataframe(headers=["category","attempts","correct","accuracy"], row_count=(0, "dynamic"), interactive=False)
911
+
912
+ gr.Markdown("---")
913
+ gr.Markdown("### Result Preview")
914
+ result_df = gr.Dataframe(headers=[], interactive=False)
915
+
916
+ # Wire events
917
+ start_btn.click(
918
+ start_session,
919
+ inputs=[name_box, session_state],
920
+ outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
921
+ )
922
+ sql_input.change(
923
+ render_preview_and_erd,
924
+ inputs=[sql_input, session_state],
925
+ outputs=[preview_md, er_image],
926
+ )
927
+ submit_btn.click(
928
+ submit_answer,
929
+ inputs=[sql_input, session_state],
930
+ outputs=[feedback_md, result_df, next_btn, mastery_df],
931
+ )
932
+ next_btn.click(
933
+ next_question,
934
+ inputs=[session_state],
935
+ outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
936
+ )
937
+ hint_btn.click(
938
+ show_hint,
939
+ inputs=[session_state],
940
+ outputs=[feedback_md],
941
+ )
942
+ export_btn.click(
943
+ export_progress,
944
+ inputs=[export_name],
945
+ outputs=[export_file],
946
+ )
947
+ regen_btn.click(
948
+ regenerate_domain,
949
+ inputs=[],
950
+ outputs=[regen_fb, er_image],
951
+ )
952
+ tbl_btn.click(
953
+ lambda name: preview_table(name),
954
+ inputs=[tbl_dd],
955
+ outputs=[preview_df]
956
+ )
957
+ # Keep dropdown fresh after regeneration
958
+ regen_btn.click(
959
+ lambda: gr.update(choices=list_tables_for_preview()),
960
+ inputs=[],
961
+ outputs=[tbl_dd]
962
+ )
963
+
964
+ if __name__ == "__main__":
965
+ demo.launch()