jtdearmon commited on
Commit
31dc1d4
·
verified ·
1 Parent(s): af6b0e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +361 -327
app.py CHANGED
@@ -1,13 +1,9 @@
1
  # Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
2
- # - Randomizes a relational domain via OpenAI (bookstore, retail sales, wholesaler,
3
- # sales tax, oil & gas wells, marketing) OR falls back to a built-in dataset.
4
- # - Builds 3–4 related tables (schema + seed rows) in SQLite.
5
- # - Generates 8–12 randomized SQL questions with varied phrasings.
6
- # - Validates answers by executing canonical SQL and comparing result sets.
7
- # - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
8
- # - Always shows data results at the bottom pane.
9
- #
10
- # Hugging Face Spaces: set OPENAI_API_KEY in secrets to enable randomization.
11
 
12
  import os
13
  import re
@@ -18,7 +14,7 @@ import sqlite3
18
  import threading
19
  from dataclasses import dataclass
20
  from datetime import datetime, timezone
21
- from typing import List, Dict, Any, Tuple, Optional
22
 
23
  import gradio as gr
24
  import pandas as pd
@@ -43,7 +39,169 @@ def _candidate_models():
43
  seen = set()
44
  return [m for m in base if m and (m not in seen and not seen.add(m))]
45
 
46
- # -------------------- Global settings --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  DB_DIR = "/data" if os.path.exists("/data") else "."
48
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
49
  EXPORT_DIR = "."
@@ -51,15 +209,9 @@ RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
51
  random.seed(RANDOM_SEED)
52
  SYS_RAND = random.SystemRandom()
53
 
54
- # -------------------- SQLite connection + locking --------------------
55
  DB_LOCK = threading.RLock()
56
 
57
  def connect_db():
58
- """
59
- Single shared connection usable across threads.
60
- All operations (reads + writes) serialized via DB_LOCK.
61
- WAL mode enables concurrent reads.
62
- """
63
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
64
  con.execute("PRAGMA journal_mode=WAL;")
65
  con.execute("PRAGMA synchronous=NORMAL;")
@@ -104,7 +256,7 @@ def init_progress_tables(con: sqlite3.Connection):
104
 
105
  init_progress_tables(CONN)
106
 
107
- # -------------------- Fallback dataset (if no OpenAI) --------------------
108
  FALLBACK_SCHEMA = {
109
  "domain": "bookstore",
110
  "tables": [
@@ -217,10 +369,8 @@ FALLBACK_QUESTIONS = [
217
  "requires_aliases":False,"required_aliases":[]},
218
  ]
219
 
220
- # -------------------- OpenAI JSON request helpers --------------------
221
- DOMAIN_AND_QUESTIONS_SCHEMA = {
222
- "required": ["domain", "tables", "questions"]
223
- }
224
 
225
  def _domain_prompt(prev_domain: Optional[str]) -> str:
226
  extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
@@ -232,28 +382,22 @@ Rules:
232
  - One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
233
  - Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
234
  columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
235
- - Questions: diverse natural language. Categories: "SELECT *", "SELECT columns", "WHERE", "Aliases",
236
  "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
237
  Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
238
  Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
239
  For 1–2 questions, set requires_aliases=true and list required_aliases.
240
 
241
- Example top-level keys (do not include comments in output):
242
- {{
243
- "domain": "retail sales",
244
- "tables": [...],
245
- "questions": [...]
246
- }}
247
  """
248
 
249
  def _loose_json_parse(s: str) -> Optional[dict]:
250
- """Extract the first JSON object from a possibly-wrapped string."""
251
  try:
252
  return json.loads(s)
253
  except Exception:
254
  pass
255
- start = s.find("{")
256
- end = s.rfind("}")
257
  if start != -1 and end != -1 and end > start:
258
  try:
259
  return json.loads(s[start:end+1])
@@ -261,94 +405,64 @@ def _loose_json_parse(s: str) -> Optional[dict]:
261
  return None
262
  return None
263
 
264
- # -------------------- Canonicalization & validation --------------------
265
  _SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
266
  _CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
267
 
268
  def _strip_code_fences(txt: str) -> str:
269
- if txt is None:
270
- return ""
271
  m = _SQL_FENCE.findall(txt)
272
- if m:
273
- return "\n".join([x.strip() for x in m if x.strip()])
274
  m2 = _CODE_FENCE.findall(txt)
275
- if m2:
276
- return "\n".join([x.strip() for x in m2 if x.strip()])
277
  return txt.strip()
278
 
279
  def _as_list_of_sql(val) -> List[str]:
280
- if val is None:
281
- return []
282
  if isinstance(val, str):
283
  s = _strip_code_fences(val)
284
  parts = [p.strip() for p in s.split("\n") if p.strip()]
285
- # if it’s a single long line, keep as is
286
  return parts or ([s] if s else [])
287
  if isinstance(val, list):
288
  out = []
289
  for v in val:
290
  if isinstance(v, str):
291
  s = _strip_code_fences(v)
292
- if s:
293
- out.append(s)
294
  return out
295
  return []
296
 
297
  def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
298
- """Normalize one question; return None if missing critical fields."""
299
- if not isinstance(q, dict):
300
- return None
301
- # field mapping/synonyms
302
  cat = q.get("category") or q.get("type") or q.get("topic")
303
  prompt = q.get("prompt_md") or q.get("prompt") or q.get("question") or q.get("text")
304
  answer_sql = q.get("answer_sql") or q.get("answers") or q.get("solutions") or q.get("sql")
305
  diff = q.get("difficulty") or 1
306
  req_alias = bool(q.get("requires_aliases", False))
307
  req_aliases = q.get("required_aliases") or []
308
-
309
  cat = str(cat).strip() if cat is not None else ""
310
  prompt = str(prompt).strip() if prompt is not None else ""
311
  answers = _as_list_of_sql(answer_sql)
312
-
313
- if not cat or not prompt or not answers:
314
- return None
315
-
316
- # keep only known categories if provided; otherwise accept free text
317
  known = {
318
  "SELECT *","SELECT columns","WHERE","Aliases",
319
  "JOIN (INNER)","JOIN (LEFT)","Aggregation","VIEW","CTAS / SELECT INTO"
320
  }
321
  if cat not in known:
322
- # Try to map rough names to our set
323
  low = cat.lower()
324
- if "select *" in low:
325
- cat = "SELECT *"
326
- elif "select col" in low or "columns" in low:
327
- cat = "SELECT columns"
328
- elif "where" in low or "filter" in low:
329
- cat = "WHERE"
330
- elif "alias" in low:
331
- cat = "Aliases"
332
- elif "left" in low:
333
- cat = "JOIN (LEFT)"
334
- elif "inner" in low or "join" in low:
335
- cat = "JOIN (INNER)"
336
- elif "agg" in low or "group" in low:
337
- cat = "Aggregation"
338
- elif "view" in low:
339
- cat = "VIEW"
340
- elif "into" in low or "ctas" in low or "create table" in low:
341
- cat = "CTAS / SELECT INTO"
342
- else:
343
- # leave as-is; still usable for practice buckets
344
- pass
345
-
346
- # normalize aliases list
347
  if isinstance(req_aliases, str):
348
  req_aliases = [a.strip() for a in re.split(r"[,\s]+", req_aliases) if a.strip()]
349
  elif not isinstance(req_aliases, list):
350
  req_aliases = []
351
-
352
  return {
353
  "id": str(q.get("id") or f"LLM_{int(time.time()*1000)}_{random.randint(100,999)}"),
354
  "category": cat,
@@ -362,22 +476,17 @@ def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
362
  def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
363
  out = []
364
  for t in (tables or []):
365
- if not isinstance(t, dict):
366
- continue
367
  name = str(t.get("name","")).strip()
368
- if not name:
369
- continue
370
  cols = t.get("columns") or []
371
  good_cols = []
372
  for c in cols:
373
- if not isinstance(c, dict):
374
- continue
375
  cname = str(c.get("name","")).strip()
376
  ctype = str(c.get("type","TEXT")).strip() or "TEXT"
377
- if cname:
378
- good_cols.append({"name": cname, "type": ctype})
379
- if not good_cols:
380
- continue
381
  pk = t.get("pk") or []
382
  if isinstance(pk, str): pk = [pk]
383
  fks = t.get("fks") or []
@@ -391,31 +500,22 @@ def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
391
  })
392
  return out
393
 
394
- # -------------------- LLM call --------------------
395
  def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
396
- """
397
- Returns (obj, error_message, model_used, stats_dict).
398
- stats_dict contains {"accepted_questions": n, "dropped_questions": m}
399
- """
400
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
401
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
402
-
403
  errors = []
404
  prompt = _domain_prompt(prev_domain)
405
-
406
  for model in _candidate_models():
407
  try:
408
- # Try JSON mode first
409
  try:
410
  chat = _client.chat.completions.create(
411
  model=model,
412
  messages=[{"role":"user","content": prompt}],
413
  temperature=0.6,
414
- response_format={"type":"json_object"} # newer SDKs
415
  )
416
  data_text = chat.choices[0].message.content
417
  except TypeError:
418
- # Older SDKs: no response_format ⇒ plain chat + strict instructions
419
  chat = _client.chat.completions.create(
420
  model=model,
421
  messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
@@ -423,52 +523,35 @@ def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optio
423
  temperature=0.6
424
  )
425
  data_text = chat.choices[0].message.content
426
-
427
  obj_raw = _loose_json_parse(data_text or "")
428
  if not obj_raw:
429
  raise RuntimeError("Could not parse JSON from model output.")
430
-
431
- # Minimal top-level validation
432
  for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
433
  if k not in obj_raw:
434
  raise RuntimeError(f"Missing key '{k}'")
435
-
436
- # Canonicalize tables
437
  tables = _canon_tables(obj_raw.get("tables", []))
438
- if not tables:
439
- raise RuntimeError("No usable tables in LLM output.")
440
  obj_raw["tables"] = tables
441
-
442
- # Canonicalize questions
443
  dropped = 0
444
  clean_qs = []
445
  for q in obj_raw.get("questions", []):
446
  cq = _canon_question(q)
447
- if not cq:
448
- dropped += 1
449
- continue
450
- # Strip RIGHT/FULL joins from answers
451
  answers = [a for a in cq["answer_sql"] if " right join " not in a.lower() and " full " not in a.lower()]
452
- if not answers:
453
- dropped += 1
454
- continue
455
  cq["answer_sql"] = answers
456
  clean_qs.append(cq)
457
-
458
  if not clean_qs:
459
  raise RuntimeError("No usable questions after canonicalization.")
460
  stats = {"accepted_questions": len(clean_qs), "dropped_questions": dropped}
461
-
462
  obj_raw["questions"] = clean_qs
463
  return obj_raw, None, model, stats
464
-
465
  except Exception as e:
466
  errors.append(f"{model}: {e}")
467
  continue
468
-
469
  return None, "; ".join(errors) if errors else "Unknown LLM error.", None, {"accepted_questions":0,"dropped_questions":0}
470
 
471
- # -------------------- Schema install & question handling --------------------
472
  def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
473
  with DB_LOCK:
474
  cur = con.cursor()
@@ -487,131 +570,30 @@ def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
487
  drop_existing_domain_tables(con, keep_internal=True)
488
  with DB_LOCK:
489
  cur = con.cursor()
490
- # Create tables
491
  for t in schema.get("tables", []):
492
  cols_sql = []
493
  pk = t.get("pk", [])
494
  for c in t.get("columns", []):
495
- cname = c["name"]
496
- ctype = c.get("type","TEXT")
497
- cols_sql.append(f"{cname} {ctype}")
498
- if pk:
499
- cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
500
  create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
501
  cur.execute(create_sql)
502
- # Insert rows
503
  for t in schema.get("tables", []):
504
- if not t.get("rows"):
505
- continue
506
  cols = [c["name"] for c in t.get("columns", [])]
507
  qmarks = ",".join(["?"]*len(cols))
508
  insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
509
  for r in t["rows"]:
510
- if isinstance(r, dict):
511
- vals = [r.get(col, None) for col in cols]
512
  elif isinstance(r, (list, tuple)):
513
- vals = list(r) + [None]*(len(cols)-len(r))
514
- vals = vals[:len(cols)]
515
- else:
516
- continue
517
  cur.execute(insert_sql, vals)
518
  con.commit()
519
- # Persist schema JSON
520
  cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
521
  (schema.get("domain","unknown"), json.dumps(schema)))
522
  con.commit()
523
 
524
- def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
525
- with DB_LOCK:
526
- return pd.read_sql_query(sql, con)
527
-
528
- def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
529
- s = sql.strip().strip(";")
530
- if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
531
- m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
532
- if m:
533
- cols, tbl, rest = m.groups()
534
- return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
535
- return sql, None
536
-
537
- def detect_unsupported_joins(sql: str) -> Optional[str]:
538
- low = sql.lower()
539
- if " right join " in low:
540
- return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
541
- if " full join " in low or " full outer join " in low:
542
- return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
543
- if " ilike " in low:
544
- return "SQLite has no ILIKE. Use LOWER(col) LIKE LOWER('%pattern%')."
545
- return None
546
-
547
- def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
548
- low = sql.lower()
549
- if " cross join " in low:
550
- return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
551
- comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
552
- missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
553
- if comma_from or missing_on:
554
- try:
555
- with DB_LOCK:
556
- cur = con.cursor()
557
- if comma_from:
558
- t1, t2 = comma_from.groups()
559
- else:
560
- m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low)
561
- j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
562
- if not m or not j:
563
- return "Possible cartesian product: no join condition detected."
564
- t1, t2 = m.group(1), j.group(1)
565
- cur.execute(f"SELECT COUNT(*) FROM {t1}")
566
- n1 = cur.fetchone()[0]
567
- cur.execute(f"SELECT COUNT(*) FROM {t2}")
568
- n2 = cur.fetchone()[0]
569
- prod = n1 * n2
570
- if len(df_result) == prod and prod > 0:
571
- return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
572
- except Exception:
573
- return "Possible cartesian product: no join condition detected."
574
- return None
575
-
576
- def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool:
577
- if df_a.shape != df_b.shape:
578
- return False
579
- a = df_a.copy()
580
- b = df_b.copy()
581
- a.columns = [c.lower() for c in a.columns]
582
- b.columns = [c.lower() for c in b.columns]
583
- a = a.sort_values(list(a.columns)).reset_index(drop=True)
584
- b = b.sort_values(list(b.columns)).reset_index(drop=True)
585
- return a.equals(b)
586
-
587
- def aliases_present(sql: str, required_aliases: List[str]) -> bool:
588
- low = re.sub(r"\s+", " ", sql.lower())
589
- for al in required_aliases:
590
- if f" {al}." not in low and f" as {al} " not in low:
591
- return False
592
- return True
593
-
594
- # -------------------- Question model helpers --------------------
595
- @dataclass
596
- class SQLQuestion:
597
- id: str
598
- category: str
599
- difficulty: int
600
- prompt_md: str
601
- answer_sql: List[str]
602
- requires_aliases: bool = False
603
- required_aliases: List[str] = None
604
-
605
- def to_question_dict(q) -> Dict[str,Any]:
606
- d = dict(q)
607
- d.setdefault("requires_aliases", False)
608
- d.setdefault("required_aliases", [])
609
- return d
610
-
611
- def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
612
- return [to_question_dict(o) for o in obj_list]
613
-
614
- # -------------------- Domain bootstrap --------------------
615
  def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
616
  obj, err, model_used, stats = llm_generate_domain_and_questions(prev_domain)
617
  if obj is None:
@@ -621,13 +603,13 @@ def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
621
  def install_schema_and_prepare_questions(prev_domain: Optional[str]):
622
  schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
623
  install_schema(CONN, schema)
624
- # Safety: if questions empty, fall back
625
  if not questions:
626
  questions = FALLBACK_QUESTIONS
627
- info = {"source":"openai+fallback-questions","model":info.get("model"),"error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
 
628
  return schema, questions, info
629
 
630
- # -------------------- Session state --------------------
631
  CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=None)
632
 
633
  # -------------------- Progress + mastery --------------------
@@ -662,7 +644,6 @@ def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
662
  return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
663
 
664
  def pick_next_question(user_id: str) -> Dict[str,Any]:
665
- # Defensive: ensure we always have a pool
666
  pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
667
  df = fetch_attempts(CONN, user_id)
668
  stats = topic_stats(df)
@@ -671,21 +652,99 @@ def pick_next_question(user_id: str) -> Dict[str,Any]:
671
  cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
672
  return dict(random.choice(cands))
673
 
674
- # -------------------- Execution & feedback --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
676
  if not sql_text or not sql_text.strip():
677
  return None, "Enter a SQL statement.", None, None
678
-
679
  sql_raw = sql_text.strip().rstrip(";")
680
  sql_rew, created_tbl = rewrite_select_into(sql_raw)
681
- note = None
682
- if sql_rew != sql_raw:
683
- note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite."
684
-
685
  unsup = detect_unsupported_joins(sql_rew)
686
- if unsup:
687
- return None, unsup, None, note
688
-
689
  try:
690
  low = sql_rew.lower()
691
  if low.startswith("select"):
@@ -695,72 +754,56 @@ def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[st
695
  else:
696
  with DB_LOCK:
697
  cur = CONN.cursor()
698
- cur.execute(sql_rew)
699
- CONN.commit()
700
- # Preview newly created objects
701
  if low.startswith("create view"):
702
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
703
  name = m.group(2) if m else None
704
  if name:
705
- try:
706
- df = pd.read_sql_query(f"SELECT * FROM {name}", CONN)
707
- return df, None, None, note
708
- except Exception:
709
- return None, "View created but could not be queried.", None, note
710
  if low.startswith("create table"):
711
  tbl = created_tbl
712
  if not tbl:
713
  m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
714
  tbl = m.group(2) if m else None
715
  if tbl:
716
- try:
717
- df = pd.read_sql_query(f"SELECT * FROM {tbl}", CONN)
718
- return df, None, None, note
719
- except Exception:
720
- return None, "Table created but could not be queried.", None, note
721
  return pd.DataFrame(), None, None, note
722
  except Exception as e:
723
  msg = str(e)
724
- if "no such table" in msg.lower():
725
- return None, f"{msg}. Check table names for this randomized domain.", None, note
726
- if "no such column" in msg.lower():
727
- return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
728
- if "ambiguous column name" in msg.lower():
729
- return None, f"{msg}. Qualify the column with a table alias.", None, note
730
  if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
731
  return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
732
  if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
733
  return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
734
  if "syntax error" in msg.lower():
735
- return None, f"Syntax error. Check commas, keywords, and parentheses. Raw error: {msg}", None, note
736
  return None, f"SQL error: {msg}", None, note
737
 
738
  def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
739
  for sql in answer_sql:
740
  try:
741
  low = sql.strip().lower()
742
- if low.startswith("select"):
743
- return run_df(CONN, sql)
744
  if low.startswith("create view"):
745
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
746
  view_name = m.group(2) if m else "vw_tmp"
747
  with DB_LOCK:
748
  cur = CONN.cursor()
749
  cur.execute(f"DROP VIEW IF EXISTS {view_name}")
750
- cur.execute(sql)
751
- CONN.commit()
752
  return run_df(CONN, f"SELECT * FROM {view_name}")
753
  if low.startswith("create table"):
754
  m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
755
  tbl = m.group(2) if m else None
756
  with DB_LOCK:
757
  cur = CONN.cursor()
758
- if tbl:
759
- cur.execute(f"DROP TABLE IF EXISTS {tbl}")
760
- cur.execute(sql)
761
- CONN.commit()
762
- if tbl:
763
- return run_df(CONN, f"SELECT * FROM {tbl}")
764
  except Exception:
765
  continue
766
  return None
@@ -771,7 +814,18 @@ def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.
771
  return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
772
  if df_student is None:
773
  return False, f"**Explanation:** Expected data result differs."
774
- return results_equal(df_student, df_expected), f"**Explanation:** Compare your result to a canonical solution."
 
 
 
 
 
 
 
 
 
 
 
775
 
776
  def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
777
  time_taken: float, difficulty: int, source: str, notes: str):
@@ -792,6 +846,7 @@ def start_session(name: str, session: dict):
792
  gr.update(value="Please enter your name to begin.", visible=True),
793
  gr.update(visible=False),
794
  gr.update(visible=False),
 
795
  gr.update(visible=False),
796
  pd.DataFrame(),
797
  pd.DataFrame())
@@ -804,21 +859,25 @@ def start_session(name: str, session: dict):
804
 
805
  prompt = q["prompt_md"]
806
  stats = topic_stats(fetch_attempts(CONN, user_id))
 
807
  return (session,
808
  gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
809
- gr.update(visible=True), # show SQL input
810
- gr.update(value="", visible=True), # preview block
811
- gr.update(visible=False), # next btn hidden until submit
 
812
  stats,
813
  pd.DataFrame())
814
 
815
  def render_preview(sql_text: str, session: dict):
816
  if not session or "q" not in session:
817
- return gr.update(value="", visible=False)
818
  s = (sql_text or "").strip()
819
  if not s:
820
- return gr.update(value="", visible=False)
821
- return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True)
 
 
822
 
823
  def submit_answer(sql_text: str, session: dict):
824
  if not session or "user_id" not in session or "q" not in session:
@@ -826,7 +885,6 @@ def submit_answer(sql_text: str, session: dict):
826
  user_id = session["user_id"]
827
  q = session["q"]
828
  elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
829
-
830
  df, err, warn, note = exec_student_sql(sql_text)
831
  details = []
832
  if note: details.append(f"ℹ️ {note}")
@@ -836,35 +894,27 @@ def submit_answer(sql_text: str, session: dict):
836
  log_attempt(user_id, q.get("id","?"), q.get("category","?"), False, sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join([err] + details))
837
  stats = topic_stats(fetch_attempts(CONN, user_id))
838
  return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
839
-
840
  alias_msg = None
841
- if q.get("requires_aliases"):
842
- if not aliases_present(sql_text, q.get("required_aliases", [])):
843
- alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
844
-
845
  is_correct, explanation = validate_answer(q, sql_text, df)
846
  if warn: details.append(f"⚠️ {warn}")
847
  if alias_msg: details.append(alias_msg)
848
-
849
  prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
850
  feedback = prefix
851
- if details:
852
- feedback += "\n\n" + "\n".join(details)
853
  feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
854
-
855
  log_attempt(user_id, q["id"], q.get("category","?"), bool(is_correct), sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join(details))
856
  stats = topic_stats(fetch_attempts(CONN, user_id))
857
  return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
858
 
859
  def next_question(session: dict):
860
  if not session or "user_id" not in session:
861
- return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), gr.update(visible=False)
862
  user_id = session["user_id"]
863
  q = pick_next_question(user_id)
864
- session["qid"] = q["id"]
865
- session["q"] = q
866
- session["start_ts"] = time.time()
867
- return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), gr.update(visible=False)
868
 
869
  def show_hint(session: dict):
870
  if not session or "q" not in session:
@@ -885,8 +935,7 @@ def show_hint(session: dict):
885
 
886
  def export_progress(user_name: str):
887
  slug = "-".join((user_name or "").lower().split())
888
- if not slug:
889
- return None
890
  user_id = slug[:64]
891
  with DB_LOCK:
892
  df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
@@ -897,25 +946,20 @@ def export_progress(user_name: str):
897
 
898
  def _domain_status_md():
899
  if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
900
- note = ""
901
- if CURRENT_INFO.get("source") == "openai+fallback-questions":
902
- note = " (LLM domain ok; used fallback questions)"
903
- accepted = CURRENT_INFO.get("accepted",0)
904
- dropped = CURRENT_INFO.get("dropped",0)
905
- return (
906
- f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
907
- f"Accepted questions: {accepted}, dropped: {dropped}. \n"
908
- f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}."
909
- )
910
- err = CURRENT_INFO.get("error","")
911
- err_short = (err[:160] + "…") if len(err) > 160 else err
912
  return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
913
 
914
  def regenerate_domain():
915
  global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
916
  prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
917
  CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=prev)
918
- return gr.update(value=_domain_status_md(), visible=True)
 
919
 
920
  def preview_table(tbl: str):
921
  try:
@@ -925,8 +969,7 @@ def preview_table(tbl: str):
925
 
926
  def list_tables_for_preview():
927
  df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
928
- if df.empty:
929
- return ["(no tables)"]
930
  return df["name"].tolist()
931
 
932
  # -------------------- UI --------------------
@@ -937,14 +980,11 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
937
  - Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
938
  sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
939
  - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
940
- - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.
941
-
942
- > Set your `OPENAI_API_KEY` in Space secrets to enable randomization.
943
  """
944
  )
945
 
946
  with gr.Row():
947
- # -------- Left column: controls + quick preview ----------
948
  with gr.Column(scale=1):
949
  name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
950
  start_btn = gr.Button("Start / Resume Session", variant="primary")
@@ -967,12 +1007,11 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
967
  tbl_btn = gr.Button("Preview")
968
  preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
969
 
970
- # -------- Right column: task + feedback + mastery + results ----------
971
  with gr.Column(scale=2):
972
  prompt_md = gr.Markdown(visible=False)
973
  sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
974
-
975
  preview_md = gr.Markdown(visible=False)
 
976
 
977
  with gr.Row():
978
  submit_btn = gr.Button("Run & Submit", variant="primary")
@@ -983,12 +1022,8 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
983
 
984
  gr.Markdown("---")
985
  gr.Markdown("### Your Progress by Category")
986
- mastery_df = gr.Dataframe(
987
- headers=["category","attempts","correct","accuracy"],
988
- col_count=(4, "dynamic"),
989
- row_count=(0, "dynamic"),
990
- interactive=False
991
- )
992
 
993
  gr.Markdown("---")
994
  gr.Markdown("### Result Preview")
@@ -998,12 +1033,12 @@ with gr.Blocks(title="Adaptive SQL Trainer �� Randomized Domains") as demo:
998
  start_btn.click(
999
  start_session,
1000
  inputs=[name_box, session_state],
1001
- outputs=[session_state, prompt_md, sql_input, preview_md, next_btn, mastery_df, result_df],
1002
  )
1003
  sql_input.change(
1004
  render_preview,
1005
  inputs=[sql_input, session_state],
1006
- outputs=[preview_md],
1007
  )
1008
  submit_btn.click(
1009
  submit_answer,
@@ -1013,7 +1048,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
1013
  next_btn.click(
1014
  next_question,
1015
  inputs=[session_state],
1016
- outputs=[session_state, prompt_md, sql_input, next_btn],
1017
  )
1018
  hint_btn.click(
1019
  show_hint,
@@ -1028,15 +1063,14 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
1028
  regen_btn.click(
1029
  regenerate_domain,
1030
  inputs=[],
1031
- outputs=[regen_fb],
1032
  )
1033
  tbl_btn.click(
1034
  lambda name: preview_table(name),
1035
  inputs=[tbl_dd],
1036
  outputs=[preview_df]
1037
  )
1038
- # Keep dropdown fresh after regeneration
1039
- regen_btn.click(
1040
  lambda: gr.update(choices=list_tables_for_preview()),
1041
  inputs=[],
1042
  outputs=[tbl_dd]
 
1
  # Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
2
+ # - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
3
+ # - 3–4 related tables with seed rows installed in SQLite.
4
+ # - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
5
+ # - Validator now enforces columns only when the prompt asks for them; otherwise it focuses on rows.
6
+ # - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by the student’s JOINs.
 
 
 
 
7
 
8
  import os
9
  import re
 
14
  import threading
15
  from dataclasses import dataclass
16
  from datetime import datetime, timezone
17
+ from typing import List, Dict, Any, Tuple, Optional, Set
18
 
19
  import gradio as gr
20
  import pandas as pd
 
39
  seen = set()
40
  return [m for m in base if m and (m not in seen and not seen.add(m))]
41
 
42
+ # -------------------- ERD drawing (headless) --------------------
43
+ import matplotlib
44
+ matplotlib.use("Agg")
45
+ import matplotlib.pyplot as plt
46
+ from matplotlib.patches import Rectangle
47
+ from io import BytesIO
48
+ from PIL import Image
49
+
50
+ PLOT_FIGSIZE = (7.6, 3.8)
51
+ PLOT_DPI = 120
52
+ PLOT_HEIGHT = 300
53
+
54
+ def _fig_to_pil(fig) -> Image.Image:
55
+ buf = BytesIO()
56
+ fig.tight_layout()
57
+ fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
58
+ plt.close(fig)
59
+ buf.seek(0)
60
+ return Image.open(buf)
61
+
62
+ def draw_dynamic_erd(
63
+ schema: Dict[str, Any],
64
+ highlight_tables: Optional[Set[str]] = None,
65
+ highlight_edges: Optional[Set[Tuple[str, str]]] = None,
66
+ ) -> Image.Image:
67
+ """
68
+ Draw tables + FK edges. If highlight_* provided, overlay those tables/edges in bold.
69
+ highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
70
+ """
71
+ highlight_tables = set(highlight_tables or [])
72
+ # normalize edges so (A,B) & (B,A) match regardless of direction
73
+ def _norm_edge(a, b): return tuple(sorted([a, b]))
74
+ H = set(_norm_edge(*e) for e in (highlight_edges or set()))
75
+
76
+ tables = schema.get("tables", [])
77
+ if not tables:
78
+ fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
79
+ ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
80
+ return _fig_to_pil(fig)
81
+
82
+ # Layout tables horizontally
83
+ n = len(tables)
84
+ fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
85
+ margin = 0.03
86
+ width = (1 - margin * (n + 1)) / max(n, 1)
87
+ height = 0.70
88
+ y = 0.20
89
+
90
+ # Build quick FK lookup: [(src_table, dst_table)]
91
+ fk_edges = []
92
+ for t in tables:
93
+ for fk in t.get("fks", []) or []:
94
+ dst = fk.get("ref_table")
95
+ if dst:
96
+ fk_edges.append((t["name"], dst))
97
+
98
+ # Draw table boxes + columns
99
+ boxes: Dict[str, Tuple[float,float,float,float]] = {}
100
+ for i, t in enumerate(tables):
101
+ tx = margin + i * (width + margin)
102
+ boxes[t["name"]] = (tx, y, width, height)
103
+
104
+ # Highlight table border if used in current SQL
105
+ lw = 2.0 if t["name"] in highlight_tables else 1.2
106
+ ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
107
+ ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
108
+
109
+ yy = y + height - 0.09
110
+ pkset = set(t.get("pk", []) or [])
111
+ # For FK annotation by column
112
+ fk_map: Dict[str, List[Tuple[str, str]]] = {}
113
+ for fk in t.get("fks", []) or []:
114
+ ref_tbl = fk.get("ref_table", "")
115
+ for c, rc in zip(fk.get("columns", []) or [], fk.get("ref_columns", []) or []):
116
+ fk_map.setdefault(c, []).append((ref_tbl, rc))
117
+
118
+ for col in t.get("columns", []):
119
+ nm = col.get("name", "")
120
+ tag = ""
121
+ if nm in pkset:
122
+ tag = " (PK)"
123
+ if nm in fk_map:
124
+ ref = fk_map[nm][0]
125
+ tag = f" (FK→{ref[0]}.{ref[1]})" if not tag else tag.replace(")", f", FK→{ref[0]}.{ref[1]})")
126
+ ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
127
+ yy -= 0.055
128
+
129
+ # Draw FK edges: light gray
130
+ for (src, dst) in fk_edges:
131
+ if src not in boxes or dst not in boxes:
132
+ continue
133
+ (x1, y1, w1, h1) = boxes[src]
134
+ (x2, y2, w2, h2) = boxes[dst]
135
+ ax.annotate("",
136
+ xy=(x2 + w2/2.0, y2 + h2),
137
+ xytext=(x1 + w1/2.0, y1),
138
+ arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
139
+
140
+ # Overlay highlighted edges: bold, darker
141
+ for (src, dst) in fk_edges:
142
+ if _norm_edge(src, dst) in H:
143
+ (x1, y1, w1, h1) = boxes[src]
144
+ (x2, y2, w2, h2) = boxes[dst]
145
+ ax.annotate("",
146
+ xy=(x2 + w2/2.0, y2 + h2),
147
+ xytext=(x1 + w1/2.0, y1),
148
+ arrowprops=dict(arrowstyle="->", lw=2.6, color="#2b6cb0"))
149
+
150
+ ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
151
+ return _fig_to_pil(fig)
152
+
153
+ # Parse JOINs from SQL and turn them into tables/edges to highlight on ERD
154
+ JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
155
+ EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
156
+ USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
157
+
158
+ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
159
+ """
160
+ Returns (highlight_tables, highlight_edges) based on the student's SQL.
161
+ - Tables: any table appearing after FROM or JOIN (by name or alias).
162
+ - Edges: pairs inferred from `a.col = b.col` or JOIN ... USING (...).
163
+ """
164
+ if not sql:
165
+ return set(), set()
166
+
167
+ low = " ".join(sql.strip().split())
168
+ # Alias map alias->table and list of tables in join order
169
+ alias_to_table: Dict[str, str] = {}
170
+ join_order: List[str] = []
171
+
172
+ for m in JOIN_TBL_RE.finditer(low):
173
+ table = m.group(1)
174
+ alias = m.group(2) or table
175
+ alias_to_table[alias] = table
176
+ join_order.append(alias)
177
+
178
+ # Edges from explicit equality ON clauses
179
+ edges: Set[Tuple[str, str]] = set()
180
+ for a1, a2 in EQ_ON_RE.findall(low):
181
+ t1 = alias_to_table.get(a1, a1)
182
+ t2 = alias_to_table.get(a2, a2)
183
+ if t1 != t2:
184
+ edges.add((t1, t2))
185
+
186
+ # Heuristic for USING(): connect the previous alias with the current JOIN alias
187
+ if USING_RE.search(low) and len(join_order) >= 2:
188
+ for i in range(1, len(join_order)):
189
+ t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
190
+ t_right = alias_to_table.get(join_order[i], join_order[i])
191
+ if t_left != t_right:
192
+ edges.add((t_left, t_right))
193
+
194
+ # Highlight tables used (map aliases back to table names)
195
+ used_tables = {alias_to_table.get(a, a) for a in join_order}
196
+
197
+ # Normalize edges to actual table names present in schema
198
+ schema_tables = {t["name"] for t in schema.get("tables", [])}
199
+ edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
200
+ used_tables = { t for t in used_tables if t in schema_tables }
201
+
202
+ return used_tables, edges
203
+
204
+ # -------------------- SQLite + locking --------------------
205
  DB_DIR = "/data" if os.path.exists("/data") else "."
206
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
207
  EXPORT_DIR = "."
 
209
  random.seed(RANDOM_SEED)
210
  SYS_RAND = random.SystemRandom()
211
 
 
212
  DB_LOCK = threading.RLock()
213
 
214
  def connect_db():
 
 
 
 
 
215
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
216
  con.execute("PRAGMA journal_mode=WAL;")
217
  con.execute("PRAGMA synchronous=NORMAL;")
 
256
 
257
  init_progress_tables(CONN)
258
 
259
+ # -------------------- Fallback dataset & questions --------------------
260
  FALLBACK_SCHEMA = {
261
  "domain": "bookstore",
262
  "tables": [
 
369
  "requires_aliases":False,"required_aliases":[]},
370
  ]
371
 
372
+ # --------------- OpenAI prompts + parsing helpers ---------------
373
+ DOMAIN_AND_QUESTIONS_SCHEMA = {"required": ["domain", "tables", "questions"]}
 
 
374
 
375
  def _domain_prompt(prev_domain: Optional[str]) -> str:
376
  extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
 
382
  - One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
383
  - Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
384
  columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
385
+ - Questions: categories among "SELECT *", "SELECT columns", "WHERE", "Aliases",
386
  "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
387
  Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
388
  Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
389
  For 1–2 questions, set requires_aliases=true and list required_aliases.
390
 
391
+ Example top-level keys:
392
+ {{"domain":"retail sales","tables":[...],"questions":[...]}}
 
 
 
 
393
  """
394
 
395
  def _loose_json_parse(s: str) -> Optional[dict]:
 
396
  try:
397
  return json.loads(s)
398
  except Exception:
399
  pass
400
+ start = s.find("{"); end = s.rfind("}")
 
401
  if start != -1 and end != -1 and end > start:
402
  try:
403
  return json.loads(s[start:end+1])
 
405
  return None
406
  return None
407
 
408
+ # Canonicalization of LLM output (questions & tables)
409
  _SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
410
  _CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
411
 
412
  def _strip_code_fences(txt: str) -> str:
413
+ if txt is None: return ""
 
414
  m = _SQL_FENCE.findall(txt)
415
+ if m: return "\n".join([x.strip() for x in m if x.strip()])
 
416
  m2 = _CODE_FENCE.findall(txt)
417
+ if m2: return "\n".join([x.strip() for x in m2 if x.strip()])
 
418
  return txt.strip()
419
 
420
  def _as_list_of_sql(val) -> List[str]:
421
+ if val is None: return []
 
422
  if isinstance(val, str):
423
  s = _strip_code_fences(val)
424
  parts = [p.strip() for p in s.split("\n") if p.strip()]
 
425
  return parts or ([s] if s else [])
426
  if isinstance(val, list):
427
  out = []
428
  for v in val:
429
  if isinstance(v, str):
430
  s = _strip_code_fences(v)
431
+ if s: out.append(s)
 
432
  return out
433
  return []
434
 
435
  def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
436
+ if not isinstance(q, dict): return None
 
 
 
437
  cat = q.get("category") or q.get("type") or q.get("topic")
438
  prompt = q.get("prompt_md") or q.get("prompt") or q.get("question") or q.get("text")
439
  answer_sql = q.get("answer_sql") or q.get("answers") or q.get("solutions") or q.get("sql")
440
  diff = q.get("difficulty") or 1
441
  req_alias = bool(q.get("requires_aliases", False))
442
  req_aliases = q.get("required_aliases") or []
 
443
  cat = str(cat).strip() if cat is not None else ""
444
  prompt = str(prompt).strip() if prompt is not None else ""
445
  answers = _as_list_of_sql(answer_sql)
446
+ if not cat or not prompt or not answers: return None
 
 
 
 
447
  known = {
448
  "SELECT *","SELECT columns","WHERE","Aliases",
449
  "JOIN (INNER)","JOIN (LEFT)","Aggregation","VIEW","CTAS / SELECT INTO"
450
  }
451
  if cat not in known:
 
452
  low = cat.lower()
453
+ if "select *" in low: cat = "SELECT *"
454
+ elif "columns" in low: cat = "SELECT columns"
455
+ elif "where" in low or "filter" in low: cat = "WHERE"
456
+ elif "alias" in low: cat = "Aliases"
457
+ elif "left" in low: cat = "JOIN (LEFT)"
458
+ elif "inner" in low or "join" in low: cat = "JOIN (INNER)"
459
+ elif "agg" in low or "group" in low: cat = "Aggregation"
460
+ elif "view" in low: cat = "VIEW"
461
+ elif "into" in low or "ctas" in low: cat = "CTAS / SELECT INTO"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  if isinstance(req_aliases, str):
463
  req_aliases = [a.strip() for a in re.split(r"[,\s]+", req_aliases) if a.strip()]
464
  elif not isinstance(req_aliases, list):
465
  req_aliases = []
 
466
  return {
467
  "id": str(q.get("id") or f"LLM_{int(time.time()*1000)}_{random.randint(100,999)}"),
468
  "category": cat,
 
476
  def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
477
  out = []
478
  for t in (tables or []):
479
+ if not isinstance(t, dict): continue
 
480
  name = str(t.get("name","")).strip()
481
+ if not name: continue
 
482
  cols = t.get("columns") or []
483
  good_cols = []
484
  for c in cols:
485
+ if not isinstance(c, dict): continue
 
486
  cname = str(c.get("name","")).strip()
487
  ctype = str(c.get("type","TEXT")).strip() or "TEXT"
488
+ if cname: good_cols.append({"name": cname, "type": ctype})
489
+ if not good_cols: continue
 
 
490
  pk = t.get("pk") or []
491
  if isinstance(pk, str): pk = [pk]
492
  fks = t.get("fks") or []
 
500
  })
501
  return out
502
 
 
503
  def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
 
 
 
 
504
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
505
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
 
506
  errors = []
507
  prompt = _domain_prompt(prev_domain)
 
508
  for model in _candidate_models():
509
  try:
 
510
  try:
511
  chat = _client.chat.completions.create(
512
  model=model,
513
  messages=[{"role":"user","content": prompt}],
514
  temperature=0.6,
515
+ response_format={"type":"json_object"}
516
  )
517
  data_text = chat.choices[0].message.content
518
  except TypeError:
 
519
  chat = _client.chat.completions.create(
520
  model=model,
521
  messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
 
523
  temperature=0.6
524
  )
525
  data_text = chat.choices[0].message.content
 
526
  obj_raw = _loose_json_parse(data_text or "")
527
  if not obj_raw:
528
  raise RuntimeError("Could not parse JSON from model output.")
 
 
529
  for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
530
  if k not in obj_raw:
531
  raise RuntimeError(f"Missing key '{k}'")
 
 
532
  tables = _canon_tables(obj_raw.get("tables", []))
533
+ if not tables: raise RuntimeError("No usable tables in LLM output.")
 
534
  obj_raw["tables"] = tables
 
 
535
  dropped = 0
536
  clean_qs = []
537
  for q in obj_raw.get("questions", []):
538
  cq = _canon_question(q)
539
+ if not cq: dropped += 1; continue
 
 
 
540
  answers = [a for a in cq["answer_sql"] if " right join " not in a.lower() and " full " not in a.lower()]
541
+ if not answers: dropped += 1; continue
 
 
542
  cq["answer_sql"] = answers
543
  clean_qs.append(cq)
 
544
  if not clean_qs:
545
  raise RuntimeError("No usable questions after canonicalization.")
546
  stats = {"accepted_questions": len(clean_qs), "dropped_questions": dropped}
 
547
  obj_raw["questions"] = clean_qs
548
  return obj_raw, None, model, stats
 
549
  except Exception as e:
550
  errors.append(f"{model}: {e}")
551
  continue
 
552
  return None, "; ".join(errors) if errors else "Unknown LLM error.", None, {"accepted_questions":0,"dropped_questions":0}
553
 
554
+ # -------------------- Install schema & prepare questions --------------------
555
  def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
556
  with DB_LOCK:
557
  cur = con.cursor()
 
570
  drop_existing_domain_tables(con, keep_internal=True)
571
  with DB_LOCK:
572
  cur = con.cursor()
 
573
  for t in schema.get("tables", []):
574
  cols_sql = []
575
  pk = t.get("pk", [])
576
  for c in t.get("columns", []):
577
+ cols_sql.append(f"{c['name']} {c.get('type','TEXT')}")
578
+ if pk: cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
 
 
 
579
  create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
580
  cur.execute(create_sql)
 
581
  for t in schema.get("tables", []):
582
+ if not t.get("rows"): continue
 
583
  cols = [c["name"] for c in t.get("columns", [])]
584
  qmarks = ",".join(["?"]*len(cols))
585
  insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
586
  for r in t["rows"]:
587
+ if isinstance(r, dict): vals = [r.get(col, None) for col in cols]
 
588
  elif isinstance(r, (list, tuple)):
589
+ vals = list(r) + [None]*(len(cols)-len(r)); vals = vals[:len(cols)]
590
+ else: continue
 
 
591
  cur.execute(insert_sql, vals)
592
  con.commit()
 
593
  cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
594
  (schema.get("domain","unknown"), json.dumps(schema)))
595
  con.commit()
596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
598
  obj, err, model_used, stats = llm_generate_domain_and_questions(prev_domain)
599
  if obj is None:
 
603
  def install_schema_and_prepare_questions(prev_domain: Optional[str]):
604
  schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
605
  install_schema(CONN, schema)
 
606
  if not questions:
607
  questions = FALLBACK_QUESTIONS
608
+ info = {"source":"openai+fallback-questions","model":info.get("model"),
609
+ "error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
610
  return schema, questions, info
611
 
612
+ # -------------------- Session globals --------------------
613
  CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=None)
614
 
615
  # -------------------- Progress + mastery --------------------
 
644
  return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
645
 
646
  def pick_next_question(user_id: str) -> Dict[str,Any]:
 
647
  pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
648
  df = fetch_attempts(CONN, user_id)
649
  stats = topic_stats(df)
 
652
  cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
653
  return dict(random.choice(cands))
654
 
655
+ # -------------------- SQL execution & grading --------------------
656
+ def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
657
+ with DB_LOCK:
658
+ return pd.read_sql_query(sql, con)
659
+
660
+ def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
661
+ s = sql.strip().strip(";")
662
+ if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
663
+ m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
664
+ if m:
665
+ cols, tbl, rest = m.groups()
666
+ return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
667
+ return sql, None
668
+
669
+ def detect_unsupported_joins(sql: str) -> Optional[str]:
670
+ low = sql.lower()
671
+ if " right join " in low:
672
+ return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
673
+ if " full join " in low or " full outer join " in low:
674
+ return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION."
675
+ if " ilike " in low:
676
+ return "SQLite has no ILIKE. Use LOWER(col) LIKE LOWER('%pattern%')."
677
+ return None
678
+
679
+ def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
680
+ low = sql.lower()
681
+ if " cross join " in low: return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
682
+ comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
683
+ missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
684
+ if comma_from or missing_on:
685
+ try:
686
+ with DB_LOCK:
687
+ cur = con.cursor()
688
+ if comma_from: t1, t2 = comma_from.groups()
689
+ else:
690
+ m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low); j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
691
+ if not m or not j: return "Possible cartesian product: no join condition detected."
692
+ t1, t2 = m.group(1), j.group(1)
693
+ cur.execute(f"SELECT COUNT(*) FROM {t1}"); n1 = cur.fetchone()[0]
694
+ cur.execute(f"SELECT COUNT(*) FROM {t2}"); n2 = cur.fetchone()[0]
695
+ prod = n1 * n2
696
+ if len(df_result) == prod and prod > 0:
697
+ return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
698
+ except Exception:
699
+ return "Possible cartesian product: no join condition detected."
700
+ return None
701
+
702
+ # Column enforcement policy — only when the prompt asks for it
703
+ def should_enforce_columns(q: Dict[str, Any]) -> bool:
704
+ cat = (q.get("category") or "").strip()
705
+ if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
706
+ return True
707
+ prompt = (q.get("prompt_md") or "").lower()
708
+ # Signals that the projection is specified in the prompt
709
+ if re.search(r"`[^`]+`", q.get("prompt_md") or ""): # backticked names
710
+ return True
711
+ if re.search(r"\((?:show|return|display)[^)]+\)", prompt): # e.g., "(show title, price)"
712
+ return True
713
+ if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
714
+ return True
715
+ return False
716
+
717
+ def _normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
718
+ out = df.copy()
719
+ out.columns = [str(c).strip().lower() for c in out.columns]
720
+ return out
721
+
722
+ def results_equal_or_superset(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> Tuple[bool, Optional[str]]:
723
+ a = _normalize_columns(df_student); b = _normalize_columns(df_expected)
724
+ if set(a.columns) == set(b.columns):
725
+ a2 = a[sorted(a.columns)].sort_values(sorted(a.columns)).reset_index(drop=True)
726
+ b2 = b[sorted(b.columns)].sort_values(sorted(b.columns)).reset_index(drop=True)
727
+ return (a2.equals(b2), None)
728
+ if set(b.columns).issubset(set(a.columns)):
729
+ a_proj = a[b.columns]
730
+ a2 = a_proj.sort_values(list(b.columns)).reset_index(drop=True)
731
+ b2 = b.sort_values(list(b.columns)).reset_index(drop=True)
732
+ if a2.equals(b2):
733
+ return True, "extra_columns"
734
+ return False, None
735
+
736
+ def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
737
+ # When projection isn't specified, match on row count only.
738
+ return df_student.shape[0] == df_expected.shape[0]
739
+
740
  def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
741
  if not sql_text or not sql_text.strip():
742
  return None, "Enter a SQL statement.", None, None
 
743
  sql_raw = sql_text.strip().rstrip(";")
744
  sql_rew, created_tbl = rewrite_select_into(sql_raw)
745
+ note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite." if sql_rew != sql_raw else None
 
 
 
746
  unsup = detect_unsupported_joins(sql_rew)
747
+ if unsup: return None, unsup, None, note
 
 
748
  try:
749
  low = sql_rew.lower()
750
  if low.startswith("select"):
 
754
  else:
755
  with DB_LOCK:
756
  cur = CONN.cursor()
757
+ cur.execute(sql_rew); CONN.commit()
 
 
758
  if low.startswith("create view"):
759
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
760
  name = m.group(2) if m else None
761
  if name:
762
+ try: return pd.read_sql_query(f"SELECT * FROM {name}", CONN), None, None, note
763
+ except Exception: return None, "View created but could not be queried.", None, note
 
 
 
764
  if low.startswith("create table"):
765
  tbl = created_tbl
766
  if not tbl:
767
  m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
768
  tbl = m.group(2) if m else None
769
  if tbl:
770
+ try: return pd.read_sql_query(f"SELECT * FROM {tbl}", CONN), None, None, note
771
+ except Exception: return None, "Table created but could not be queried.", None, note
 
 
 
772
  return pd.DataFrame(), None, None, note
773
  except Exception as e:
774
  msg = str(e)
775
+ if "no such table" in msg.lower(): return None, f"{msg}. Check table names for this randomized domain.", None, note
776
+ if "no such column" in msg.lower(): return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
777
+ if "ambiguous column name" in msg.lower(): return None, f"{msg}. Qualify the column with a table alias.", None, note
 
 
 
778
  if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
779
  return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
780
  if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
781
  return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
782
  if "syntax error" in msg.lower():
783
+ return None, f"Syntax error. Check commas, keywords, parentheses. Raw error: {msg}", None, note
784
  return None, f"SQL error: {msg}", None, note
785
 
786
  def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
787
  for sql in answer_sql:
788
  try:
789
  low = sql.strip().lower()
790
+ if low.startswith("select"): return run_df(CONN, sql)
 
791
  if low.startswith("create view"):
792
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
793
  view_name = m.group(2) if m else "vw_tmp"
794
  with DB_LOCK:
795
  cur = CONN.cursor()
796
  cur.execute(f"DROP VIEW IF EXISTS {view_name}")
797
+ cur.execute(sql); CONN.commit()
 
798
  return run_df(CONN, f"SELECT * FROM {view_name}")
799
  if low.startswith("create table"):
800
  m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
801
  tbl = m.group(2) if m else None
802
  with DB_LOCK:
803
  cur = CONN.cursor()
804
+ if tbl: cur.execute(f"DROP TABLE IF EXISTS {tbl}")
805
+ cur.execute(sql); CONN.commit()
806
+ if tbl: return run_df(CONN, f"SELECT * FROM {tbl}")
 
 
 
807
  except Exception:
808
  continue
809
  return None
 
814
  return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
815
  if df_student is None:
816
  return False, f"**Explanation:** Expected data result differs."
817
+ if should_enforce_columns(q):
818
+ ok, note = results_equal_or_superset(df_student, df_expected)
819
+ if ok and note == "extra_columns":
820
+ return True, "**Note:** You returned extra columns. The rows match; try selecting only the requested columns next time."
821
+ if ok:
822
+ return True, "**Explanation:** Your result matches a canonical solution."
823
+ return False, f"**Explanation:** Compare your result to a canonical solution."
824
+ else:
825
+ ok = results_equal_rowcount_only(df_student, df_expected)
826
+ if ok:
827
+ return True, "**Explanation:** Columns weren’t specified for this task; row count matches the canonical answer."
828
+ return False, "**Explanation:** For this task we compared row counts (projection not enforced) and they didn’t match."
829
 
830
  def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
831
  time_taken: float, difficulty: int, source: str, notes: str):
 
846
  gr.update(value="Please enter your name to begin.", visible=True),
847
  gr.update(visible=False),
848
  gr.update(visible=False),
849
+ draw_dynamic_erd(CURRENT_SCHEMA),
850
  gr.update(visible=False),
851
  pd.DataFrame(),
852
  pd.DataFrame())
 
859
 
860
  prompt = q["prompt_md"]
861
  stats = topic_stats(fetch_attempts(CONN, user_id))
862
+ erd = draw_dynamic_erd(CURRENT_SCHEMA)
863
  return (session,
864
  gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
865
+ gr.update(visible=True),
866
+ gr.update(value="", visible=True),
867
+ erd,
868
+ gr.update(visible=False),
869
  stats,
870
  pd.DataFrame())
871
 
872
  def render_preview(sql_text: str, session: dict):
873
  if not session or "q" not in session:
874
+ return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
875
  s = (sql_text or "").strip()
876
  if not s:
877
+ return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
878
+ hi_tables, hi_edges = sql_highlights(s, CURRENT_SCHEMA)
879
+ erd = draw_dynamic_erd(CURRENT_SCHEMA, highlight_tables=hi_tables, highlight_edges=hi_edges)
880
+ return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), erd
881
 
882
  def submit_answer(sql_text: str, session: dict):
883
  if not session or "user_id" not in session or "q" not in session:
 
885
  user_id = session["user_id"]
886
  q = session["q"]
887
  elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
 
888
  df, err, warn, note = exec_student_sql(sql_text)
889
  details = []
890
  if note: details.append(f"ℹ️ {note}")
 
894
  log_attempt(user_id, q.get("id","?"), q.get("category","?"), False, sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join([err] + details))
895
  stats = topic_stats(fetch_attempts(CONN, user_id))
896
  return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
 
897
  alias_msg = None
898
+ if q.get("requires_aliases") and not aliases_present(sql_text, q.get("required_aliases", [])):
899
+ alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
 
 
900
  is_correct, explanation = validate_answer(q, sql_text, df)
901
  if warn: details.append(f"⚠️ {warn}")
902
  if alias_msg: details.append(alias_msg)
 
903
  prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
904
  feedback = prefix
905
+ if details: feedback += "\n\n" + "\n".join(details)
 
906
  feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
 
907
  log_attempt(user_id, q["id"], q.get("category","?"), bool(is_correct), sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join(details))
908
  stats = topic_stats(fetch_attempts(CONN, user_id))
909
  return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
910
 
911
  def next_question(session: dict):
912
  if not session or "user_id" not in session:
913
+ return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
914
  user_id = session["user_id"]
915
  q = pick_next_question(user_id)
916
+ session["qid"] = q["id"]; session["q"] = q; session["start_ts"] = time.time()
917
+ return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
 
 
918
 
919
  def show_hint(session: dict):
920
  if not session or "q" not in session:
 
935
 
936
  def export_progress(user_name: str):
937
  slug = "-".join((user_name or "").lower().split())
938
+ if not slug: return None
 
939
  user_id = slug[:64]
940
  with DB_LOCK:
941
  df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
 
946
 
947
  def _domain_status_md():
948
  if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
949
+ note = " (LLM domain ok; used fallback questions)" if CURRENT_INFO.get("source") == "openai+fallback-questions" else ""
950
+ accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
951
+ return (f" **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
952
+ f"Accepted questions: {accepted}, dropped: {dropped}. \n"
953
+ f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
954
+ err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
 
 
 
 
 
 
955
  return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
956
 
957
  def regenerate_domain():
958
  global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
959
  prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
960
  CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=prev)
961
+ erd = draw_dynamic_erd(CURRENT_SCHEMA)
962
+ return gr.update(value=_domain_status_md(), visible=True), erd
963
 
964
  def preview_table(tbl: str):
965
  try:
 
969
 
970
  def list_tables_for_preview():
971
  df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
972
+ if df.empty: return ["(no tables)"]
 
973
  return df["name"].tolist()
974
 
975
  # -------------------- UI --------------------
 
980
  - Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
981
  sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
982
  - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
983
+ - **ERD highlights your JOINs** as you type; all FK edges remain visible in light gray.
 
 
984
  """
985
  )
986
 
987
  with gr.Row():
 
988
  with gr.Column(scale=1):
989
  name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
990
  start_btn = gr.Button("Start / Resume Session", variant="primary")
 
1007
  tbl_btn = gr.Button("Preview")
1008
  preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
1009
 
 
1010
  with gr.Column(scale=2):
1011
  prompt_md = gr.Markdown(visible=False)
1012
  sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
 
1013
  preview_md = gr.Markdown(visible=False)
1014
+ er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)
1015
 
1016
  with gr.Row():
1017
  submit_btn = gr.Button("Run & Submit", variant="primary")
 
1022
 
1023
  gr.Markdown("---")
1024
  gr.Markdown("### Your Progress by Category")
1025
+ mastery_df = gr.Dataframe(headers=["category","attempts","correct","accuracy"],
1026
+ col_count=(4,"dynamic"), row_count=(0,"dynamic"), interactive=False)
 
 
 
 
1027
 
1028
  gr.Markdown("---")
1029
  gr.Markdown("### Result Preview")
 
1033
  start_btn.click(
1034
  start_session,
1035
  inputs=[name_box, session_state],
1036
+ outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
1037
  )
1038
  sql_input.change(
1039
  render_preview,
1040
  inputs=[sql_input, session_state],
1041
+ outputs=[preview_md, er_image],
1042
  )
1043
  submit_btn.click(
1044
  submit_answer,
 
1048
  next_btn.click(
1049
  next_question,
1050
  inputs=[session_state],
1051
+ outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
1052
  )
1053
  hint_btn.click(
1054
  show_hint,
 
1063
  regen_btn.click(
1064
  regenerate_domain,
1065
  inputs=[],
1066
+ outputs=[regen_fb, er_image],
1067
  )
1068
  tbl_btn.click(
1069
  lambda name: preview_table(name),
1070
  inputs=[tbl_dd],
1071
  outputs=[preview_df]
1072
  )
1073
+ regen_btn.click( # refresh list after regeneration
 
1074
  lambda: gr.update(choices=list_tables_for_preview()),
1075
  inputs=[],
1076
  outputs=[tbl_dd]