jtdearmon commited on
Commit
f642f6e
·
verified ·
1 Parent(s): 5952084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -219
app.py CHANGED
@@ -5,9 +5,9 @@
5
  # - Generates 8–12 randomized SQL questions with varied phrasings.
6
  # - Validates answers by executing canonical SQL and comparing result sets.
7
  # - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
8
- # - Shows data results at the bottom pane for every run (SELECT or preview for VIEW/CTAS).
9
  #
10
- # Hugging Face Spaces: set OPENAI_API_KEY as a secret to enable LLM randomization.
11
 
12
  import os
13
  import re
@@ -16,7 +16,7 @@ import time
16
  import random
17
  import sqlite3
18
  import threading
19
- from dataclasses import dataclass, asdict
20
  from datetime import datetime, timezone
21
  from typing import List, Dict, Any, Tuple, Optional
22
 
@@ -24,17 +24,10 @@ import gradio as gr
24
  import pandas as pd
25
  import numpy as np
26
 
27
- # Matplotlib for ERD drawing (headless)
28
- import matplotlib
29
- matplotlib.use("Agg")
30
- import matplotlib.pyplot as plt
31
- from io import BytesIO
32
- from PIL import Image
33
-
34
  # -------------------- OpenAI (optional) --------------------
35
  USE_RESPONSES_API = True
36
  OPENAI_AVAILABLE = True
37
- MODEL_ID = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
38
  try:
39
  from openai import OpenAI
40
  _client = OpenAI() # requires OPENAI_API_KEY
@@ -42,6 +35,17 @@ except Exception:
42
  OPENAI_AVAILABLE = False
43
  _client = None
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  # -------------------- Global settings --------------------
46
  DB_DIR = "/data" if os.path.exists("/data") else "."
47
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
@@ -51,76 +55,14 @@ RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
51
  random.seed(RANDOM_SEED)
52
  SYS_RAND = random.SystemRandom()
53
 
54
- PLOT_FIGSIZE = (6.8, 3.4)
55
- PLOT_DPI = 110
56
- PLOT_HEIGHT = 300
57
-
58
- # -------------------- ERD helpers --------------------
59
- def _to_pil(fig) -> Image.Image:
60
- buf = BytesIO()
61
- fig.tight_layout()
62
- fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
63
- plt.close(fig)
64
- buf.seek(0)
65
- return Image.open(buf)
66
-
67
- def draw_dynamic_erd(schema: Dict[str, Any]) -> Image.Image:
68
- """
69
- Draw a simple ERD for the current randomized schema.
70
- schema = {
71
- "domain": "bookstore",
72
- "tables": [
73
- {"name":"authors","columns":[{"name":"author_id","type":"INTEGER"}, ...],
74
- "pk":["author_id"], "fks":[{"columns":["author_id"],"ref_table":"...","ref_columns":["..."]}],
75
- "rows":[{...}, {...}]}
76
- ]
77
- }
78
- """
79
- fig, ax = plt.subplots(figsize=PLOT_FIGSIZE)
80
- ax.axis("off")
81
- tables = schema.get("tables", [])
82
- n = max(1, len(tables))
83
- # Lay out boxes horizontally
84
- margin = 0.03
85
- width = (1 - margin*(n+1)) / n
86
- height = 0.65
87
- y = 0.25
88
- boxes = {}
89
- for i, t in enumerate(tables):
90
- x = margin + i*(width + margin)
91
- boxes[t["name"]] = (x, y, width, height)
92
- ax.add_patch(plt.Rectangle((x, y), width, height, fill=False))
93
- ax.text(x + 0.01, y + height - 0.05, f"**{t['name']}**", fontsize=10, ha="left", va="top")
94
- yy = y + height - 0.10
95
- pk = set(t.get("pk", []))
96
- cols = t.get("columns", [])
97
- for col in cols:
98
- nm = col["name"]
99
- mark = " (PK)" if nm in pk else ""
100
- ax.text(x + 0.02, yy, f"{nm}{mark}", fontsize=9, ha="left", va="top")
101
- yy -= 0.06
102
-
103
- # Draw FK arrows
104
- for t in tables:
105
- for fk in t.get("fks", []):
106
- src_tbl = t["name"]
107
- dst_tbl = fk.get("ref_table")
108
- if src_tbl in boxes and dst_tbl in boxes:
109
- (x1, y1, w1, h1) = boxes[src_tbl]
110
- (x2, y2, w2, h2) = boxes[dst_tbl]
111
- ax.annotate("", xy=(x2 + w2/2, y2 + h2), xytext=(x1 + w1/2, y1),
112
- arrowprops=dict(arrowstyle="->", lw=1.1))
113
- ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
114
- return _to_pil(fig)
115
-
116
  # -------------------- SQLite connection + locking --------------------
117
  DB_LOCK = threading.RLock()
118
 
119
  def connect_db():
120
  """
121
- Single shared connection that is allowed to be used across threads.
122
  All operations (reads + writes) are serialized via DB_LOCK.
123
- WAL mode improves read concurrency.
124
  """
125
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
126
  con.execute("PRAGMA journal_mode=WAL;")
@@ -242,66 +184,44 @@ FALLBACK_SCHEMA = {
242
  }
243
 
244
  FALLBACK_QUESTIONS = [
245
- {
246
- "id":"Q01","category":"SELECT *","difficulty":1,
247
- "prompt_md":"Select all rows and columns from `authors`.",
248
- "answer_sql":["SELECT * FROM authors;"],
249
- "requires_aliases":False,"required_aliases":[]
250
- },
251
- {
252
- "id":"Q02","category":"SELECT columns","difficulty":1,
253
- "prompt_md":"Show `title` and `price` from `books`.",
254
- "answer_sql":["SELECT title, price FROM books;"],
255
- "requires_aliases":False,"required_aliases":[]
256
- },
257
- {
258
- "id":"Q03","category":"WHERE","difficulty":1,
259
- "prompt_md":"List Sci‑Fi books under $15 (show title, price).",
260
- "answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"],
261
- "requires_aliases":False,"required_aliases":[]
262
- },
263
- {
264
- "id":"Q04","category":"Aliases","difficulty":1,
265
- "prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.",
266
- "answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"],
267
- "requires_aliases":True,"required_aliases":["a","b"]
268
- },
269
- {
270
- "id":"Q05","category":"JOIN (INNER)","difficulty":2,
271
- "prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.",
272
- "answer_sql":[
273
- "SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;"
274
- ],
275
- "requires_aliases":False,"required_aliases":[]
276
- },
277
- {
278
- "id":"Q06","category":"JOIN (LEFT)","difficulty":2,
279
- "prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.",
280
- "answer_sql":[
281
- "SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;"
282
- ],
283
- "requires_aliases":False,"required_aliases":[]
284
- },
285
- {
286
- "id":"Q07","category":"VIEW","difficulty":2,
287
- "prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.",
288
- "answer_sql":[
289
- "CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;"
290
- ],
291
- "requires_aliases":False,"required_aliases":[]
292
- },
293
- {
294
- "id":"Q08","category":"CTAS / SELECT INTO","difficulty":2,
295
- "prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.",
296
- "answer_sql":[
297
- "CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;",
298
- "SELECT * INTO cheap_books FROM books WHERE price < 12;"
299
- ],
300
- "requires_aliases":False,"required_aliases":[]
301
- },
302
  ]
303
 
304
- # -------------------- OpenAI prompts --------------------
305
  DOMAIN_AND_QUESTIONS_SCHEMA = {
306
  "name": "DomainSQLPack",
307
  "schema": {
@@ -322,10 +242,7 @@ DOMAIN_AND_QUESTIONS_SCHEMA = {
322
  "items": {
323
  "type":"object",
324
  "additionalProperties": False,
325
- "properties": {
326
- "name":{"type":"string"},
327
- "type":{"type":"string"}
328
- },
329
  "required":["name","type"]
330
  }
331
  },
@@ -372,15 +289,17 @@ DOMAIN_AND_QUESTIONS_SCHEMA = {
372
  "strict": True
373
  }
374
 
375
- DOMAIN_AND_QUESTIONS_PROMPT = """
376
- You are designing a small relational dataset and training questions for SQL basics.
 
 
377
 
378
  1) Choose ONE domain at random from:
379
  - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
380
 
381
  2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
382
  - Use snake_case, avoid reserved words.
383
- - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (but no advanced features).
384
  - Primary keys (pk) and foreign keys (fks) must align.
385
  - Provide 8–15 small, realistic seed rows per table (not huge).
386
 
@@ -397,29 +316,59 @@ You are designing a small relational dataset and training questions for SQL basi
397
  Return JSON only.
398
  """
399
 
400
- def llm_generate_domain_and_questions() -> Optional[Dict[str,Any]]:
401
- if not OPENAI_AVAILABLE:
402
- return None
403
- try:
404
- if USE_RESPONSES_API:
405
- resp = _client.responses.create(
406
- model=MODEL_ID,
407
- response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
408
- input=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
409
- temperature=0.6,
410
- )
411
- data_text = getattr(resp, "output_text", None)
412
- else:
413
- chat = _client.chat.completions.create(
414
- model=MODEL_ID,
415
- messages=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
416
- temperature=0.6
417
- )
418
- data_text = chat.choices[0].message.content
419
- obj = json.loads(data_text) if data_text else None
420
- return obj
421
- except Exception:
422
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  # -------------------- Schema install & question handling --------------------
425
  def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
@@ -440,7 +389,7 @@ def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
440
  drop_existing_domain_tables(con, keep_internal=True)
441
  with DB_LOCK:
442
  cur = con.cursor()
443
- # Create tables first
444
  for t in schema.get("tables", []):
445
  cols_sql = []
446
  pk = t.get("pk", [])
@@ -494,7 +443,7 @@ def detect_unsupported_joins(sql: str) -> Optional[str]:
494
  if " full join " in low or " full outer join " in low:
495
  return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
496
  if " ilike " in low:
497
- return "SQLite has no ILIKE. Use `LOWER(col) LIKE LOWER('%pattern%')`."
498
  return None
499
 
500
  def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
@@ -544,7 +493,7 @@ def aliases_present(sql: str, required_aliases: List[str]) -> bool:
544
  return False
545
  return True
546
 
547
- # -------------------- Question model --------------------
548
  @dataclass
549
  class SQLQuestion:
550
  id: str
@@ -562,36 +511,22 @@ def to_question_dict(q) -> Dict[str,Any]:
562
  return d
563
 
564
  def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
565
- out = []
566
- for o in obj_list:
567
- out.append(to_question_dict(o))
568
- return out
569
 
570
  # -------------------- Domain bootstrap --------------------
571
- def bootstrap_domain_with_llm_or_fallback() -> Tuple[Dict[str,Any], List[Dict[str,Any]]]:
572
- obj = llm_generate_domain_and_questions()
573
  if obj is None:
574
- return FALLBACK_SCHEMA, FALLBACK_QUESTIONS
575
- # Guardrails: strip RIGHT/FULL joins from answers
576
- clean_qs = []
577
- for q in obj.get("questions", []):
578
- answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()]
579
- if not answers:
580
- continue
581
- q["answer_sql"] = answers
582
- q.setdefault("requires_aliases", False)
583
- q.setdefault("required_aliases", [])
584
- clean_qs.append(q)
585
- obj["questions"] = clean_qs
586
- return obj, clean_qs
587
-
588
- def install_new_domain():
589
- schema, questions = bootstrap_domain_with_llm_or_fallback()
590
  install_schema(CONN, schema)
591
- return schema, questions
592
 
593
  # -------------------- Session state --------------------
594
- CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
595
 
596
  # -------------------- Progress + mastery --------------------
597
  def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
@@ -681,7 +616,6 @@ def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[st
681
  return None, "Table created but could not be queried.", None, note
682
  return pd.DataFrame(), None, None, note
683
  except Exception as e:
684
- # Tailored messages
685
  msg = str(e)
686
  if "no such table" in msg.lower():
687
  return None, f"{msg}. Check table names for this randomized domain.", None, note
@@ -704,7 +638,6 @@ def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
704
  if low.startswith("select"):
705
  return run_df(CONN, sql)
706
  if low.startswith("create view"):
707
- # temp preview
708
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
709
  view_name = m.group(2) if m else "vw_tmp"
710
  with DB_LOCK:
@@ -730,7 +663,6 @@ def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
730
 
731
  def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]:
732
  df_expected = answer_df(q["answer_sql"])
733
- # If we can't build a canonical DF (e.g., DDL side effect), accept any successful execution as correct
734
  if df_expected is None:
735
  return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
736
  if df_student is None:
@@ -756,7 +688,6 @@ def start_session(name: str, session: dict):
756
  gr.update(value="Please enter your name to begin.", visible=True),
757
  gr.update(visible=False),
758
  gr.update(visible=False),
759
- None,
760
  gr.update(visible=False),
761
  pd.DataFrame(),
762
  pd.DataFrame())
@@ -769,23 +700,21 @@ def start_session(name: str, session: dict):
769
 
770
  prompt = q["prompt_md"]
771
  stats = topic_stats(fetch_attempts(CONN, user_id))
772
- erd = draw_dynamic_erd(CURRENT_SCHEMA)
773
  return (session,
774
  gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
775
  gr.update(visible=True), # show SQL input
776
  gr.update(value="", visible=True), # preview block
777
- erd,
778
  gr.update(visible=False), # next btn hidden until submit
779
  stats,
780
  pd.DataFrame())
781
 
782
- def render_preview_and_erd(sql_text: str, session: dict):
783
  if not session or "q" not in session:
784
- return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
785
  s = (sql_text or "").strip()
786
  if not s:
787
- return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
788
- return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), draw_dynamic_erd(CURRENT_SCHEMA)
789
 
790
  def submit_answer(sql_text: str, session: dict):
791
  if not session or "user_id" not in session or "q" not in session:
@@ -804,7 +733,6 @@ def submit_answer(sql_text: str, session: dict):
804
  stats = topic_stats(fetch_attempts(CONN, user_id))
805
  return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
806
 
807
- # Validate correctness
808
  alias_msg = None
809
  if q.get("requires_aliases"):
810
  if not aliases_present(sql_text, q.get("required_aliases", [])):
@@ -826,30 +754,29 @@ def submit_answer(sql_text: str, session: dict):
826
 
827
  def next_question(session: dict):
828
  if not session or "user_id" not in session:
829
- return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
830
  user_id = session["user_id"]
831
  q = pick_next_question(user_id)
832
  session["qid"] = q["id"]
833
  session["q"] = q
834
  session["start_ts"] = time.time()
835
- return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
836
 
837
  def show_hint(session: dict):
838
  if not session or "q" not in session:
839
  return gr.update(value="Start a session first.", visible=True)
840
- # Lightweight hint policy: category-specific guidance
841
  cat = session["q"]["category"]
842
  hint = {
843
  "SELECT *": "Use `SELECT * FROM table_name`.",
844
  "SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
845
  "WHERE": "Filter with `WHERE` and combine conditions using AND/OR.",
846
- "Aliases": "Use `table_name t` and qualify: `t.col`.",
847
  "JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.",
848
  "JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.",
849
- "Aggregation": "Use aggregate functions and `GROUP BY` non-aggregated columns.",
850
  "VIEW": "`CREATE VIEW view_name AS SELECT ...`.",
851
  "CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`."
852
- }.get(cat, "Read the ER diagram and identify keys to join on.")
853
  return gr.update(value=f"**Hint:** {hint}", visible=True)
854
 
855
  def export_progress(user_name: str):
@@ -863,11 +790,19 @@ def export_progress(user_name: str):
863
  (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
864
  return path
865
 
 
 
 
 
 
 
 
 
866
  def regenerate_domain():
867
- global CURRENT_SCHEMA, CURRENT_QS
868
- CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
869
- erd = draw_dynamic_erd(CURRENT_SCHEMA)
870
- return gr.update(value="✅ Domain regenerated.", visible=True), erd
871
 
872
  def preview_table(tbl: str):
873
  try:
@@ -891,7 +826,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
891
  - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
892
  - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.
893
 
894
- > Set your `OPENAI_API_KEY` in the Space secrets to enable randomization.
895
  """
896
  )
897
 
@@ -905,7 +840,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
905
  gr.Markdown("---")
906
  gr.Markdown("### Dataset Controls")
907
  regen_btn = gr.Button("🔀 Randomize Dataset (OpenAI)")
908
- regen_fb = gr.Markdown(visible=False)
909
 
910
  gr.Markdown("---")
911
  gr.Markdown("### Instructor Tools")
@@ -925,7 +860,6 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
925
  sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
926
 
927
  preview_md = gr.Markdown(visible=False)
928
- er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)
929
 
930
  with gr.Row():
931
  submit_btn = gr.Button("Run & Submit", variant="primary")
@@ -951,12 +885,12 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
951
  start_btn.click(
952
  start_session,
953
  inputs=[name_box, session_state],
954
- outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
955
  )
956
  sql_input.change(
957
- render_preview_and_erd,
958
  inputs=[sql_input, session_state],
959
- outputs=[preview_md, er_image],
960
  )
961
  submit_btn.click(
962
  submit_answer,
@@ -966,7 +900,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
966
  next_btn.click(
967
  next_question,
968
  inputs=[session_state],
969
- outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
970
  )
971
  hint_btn.click(
972
  show_hint,
@@ -981,7 +915,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
981
  regen_btn.click(
982
  regenerate_domain,
983
  inputs=[],
984
- outputs=[regen_fb, er_image],
985
  )
986
  tbl_btn.click(
987
  lambda name: preview_table(name),
 
5
  # - Generates 8–12 randomized SQL questions with varied phrasings.
6
  # - Validates answers by executing canonical SQL and comparing result sets.
7
  # - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
8
+ # - Always shows data results at the bottom pane.
9
  #
10
+ # Hugging Face Spaces: set OPENAI_API_KEY in secrets to enable randomization.
11
 
12
  import os
13
  import re
 
16
  import random
17
  import sqlite3
18
  import threading
19
+ from dataclasses import dataclass
20
  from datetime import datetime, timezone
21
  from typing import List, Dict, Any, Tuple, Optional
22
 
 
24
  import pandas as pd
25
  import numpy as np
26
 
 
 
 
 
 
 
 
27
  # -------------------- OpenAI (optional) --------------------
28
  USE_RESPONSES_API = True
29
  OPENAI_AVAILABLE = True
30
+ DEFAULT_MODEL = os.getenv("OPENAI_MODEL") # optional override
31
  try:
32
  from openai import OpenAI
33
  _client = OpenAI() # requires OPENAI_API_KEY
 
35
  OPENAI_AVAILABLE = False
36
  _client = None
37
 
38
+ def _candidate_models():
39
+ base = [
40
+ DEFAULT_MODEL,
41
+ "gpt-4o-mini",
42
+ "gpt-4o",
43
+ "gpt-4.1-mini",
44
+ "o3-mini",
45
+ ]
46
+ seen = set()
47
+ return [m for m in base if m and (m not in seen and not seen.add(m))]
48
+
49
  # -------------------- Global settings --------------------
50
  DB_DIR = "/data" if os.path.exists("/data") else "."
51
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
 
55
  random.seed(RANDOM_SEED)
56
  SYS_RAND = random.SystemRandom()
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # -------------------- SQLite connection + locking --------------------
59
  DB_LOCK = threading.RLock()
60
 
61
  def connect_db():
62
  """
63
+ Single shared connection that can be used across threads.
64
  All operations (reads + writes) are serialized via DB_LOCK.
65
+ WAL mode enables concurrent reads.
66
  """
67
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
68
  con.execute("PRAGMA journal_mode=WAL;")
 
184
  }
185
 
186
  FALLBACK_QUESTIONS = [
187
+ {"id":"Q01","category":"SELECT *","difficulty":1,
188
+ "prompt_md":"Select all rows and columns from `authors`.",
189
+ "answer_sql":["SELECT * FROM authors;"],
190
+ "requires_aliases":False,"required_aliases":[]},
191
+ {"id":"Q02","category":"SELECT columns","difficulty":1,
192
+ "prompt_md":"Show `title` and `price` from `books`.",
193
+ "answer_sql":["SELECT title, price FROM books;"],
194
+ "requires_aliases":False,"required_aliases":[]},
195
+ {"id":"Q03","category":"WHERE","difficulty":1,
196
+ "prompt_md":"List Sci‑Fi books under $15 (show title, price).",
197
+ "answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"],
198
+ "requires_aliases":False,"required_aliases":[]},
199
+ {"id":"Q04","category":"Aliases","difficulty":1,
200
+ "prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.",
201
+ "answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"],
202
+ "requires_aliases":True,"required_aliases":["a","b"]},
203
+ {"id":"Q05","category":"JOIN (INNER)","difficulty":2,
204
+ "prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.",
205
+ "answer_sql":["SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;"],
206
+ "requires_aliases":False,"required_aliases":[]},
207
+ {"id":"Q06","category":"JOIN (LEFT)","difficulty":2,
208
+ "prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.",
209
+ "answer_sql":["SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;"],
210
+ "requires_aliases":False,"required_aliases":[]},
211
+ {"id":"Q07","category":"VIEW","difficulty":2,
212
+ "prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.",
213
+ "answer_sql":["CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;"],
214
+ "requires_aliases":False,"required_aliases":[]},
215
+ {"id":"Q08","category":"CTAS / SELECT INTO","difficulty":2,
216
+ "prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.",
217
+ "answer_sql":[
218
+ "CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;",
219
+ "SELECT * INTO cheap_books FROM books WHERE price < 12;"
220
+ ],
221
+ "requires_aliases":False,"required_aliases":[]},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  ]
223
 
224
+ # -------------------- OpenAI JSON schema --------------------
225
  DOMAIN_AND_QUESTIONS_SCHEMA = {
226
  "name": "DomainSQLPack",
227
  "schema": {
 
242
  "items": {
243
  "type":"object",
244
  "additionalProperties": False,
245
+ "properties": {"name":{"type":"string"}, "type":{"type":"string"}},
 
 
 
246
  "required":["name","type"]
247
  }
248
  },
 
289
  "strict": True
290
  }
291
 
292
+ def _domain_prompt(prev_domain: Optional[str]) -> str:
293
+ extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
294
+ return f"""
295
+ You are designing a small relational dataset and training questions for SQL basics.{extra}
296
 
297
  1) Choose ONE domain at random from:
298
  - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
299
 
300
  2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
301
  - Use snake_case, avoid reserved words.
302
+ - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (no advanced features).
303
  - Primary keys (pk) and foreign keys (fks) must align.
304
  - Provide 8–15 small, realistic seed rows per table (not huge).
305
 
 
316
  Return JSON only.
317
  """
318
 
319
+ def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str]]:
320
+ """
321
+ Returns (obj, error_message, model_used).
322
+ """
323
+ if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
324
+ return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
325
+
326
+ errors = []
327
+ for model in _candidate_models():
328
+ try:
329
+ prompt = _domain_prompt(prev_domain)
330
+ if USE_RESPONSES_API:
331
+ resp = _client.responses.create(
332
+ model=model,
333
+ response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
334
+ input=[{"role":"user","content": prompt}],
335
+ temperature=0.6,
336
+ )
337
+ data_text = getattr(resp, "output_text", None)
338
+ if not data_text:
339
+ try:
340
+ data_text = resp.output[0].content[0].text # older SDK layout
341
+ except Exception:
342
+ data_text = None
343
+ else:
344
+ chat = _client.chat.completions.create(
345
+ model=model,
346
+ messages=[{"role":"user","content": prompt}],
347
+ temperature=0.6
348
+ )
349
+ data_text = chat.choices[0].message.content
350
+
351
+ if not data_text:
352
+ raise RuntimeError("Empty response from model.")
353
+
354
+ obj = json.loads(data_text)
355
+ # Guardrails: strip RIGHT/FULL joins from answers
356
+ clean_qs = []
357
+ for q in obj.get("questions", []):
358
+ answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()]
359
+ if not answers:
360
+ continue
361
+ q["answer_sql"] = answers
362
+ q.setdefault("requires_aliases", False)
363
+ q.setdefault("required_aliases", [])
364
+ clean_qs.append(q)
365
+ obj["questions"] = clean_qs
366
+ return obj, None, model
367
+ except Exception as e:
368
+ errors.append(f"{model}: {e}")
369
+ continue
370
+
371
+ return None, "; ".join(errors) if errors else "Unknown LLM error.", None
372
 
373
  # -------------------- Schema install & question handling --------------------
374
  def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
 
389
  drop_existing_domain_tables(con, keep_internal=True)
390
  with DB_LOCK:
391
  cur = con.cursor()
392
+ # Create tables
393
  for t in schema.get("tables", []):
394
  cols_sql = []
395
  pk = t.get("pk", [])
 
443
  if " full join " in low or " full outer join " in low:
444
  return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
445
  if " ilike " in low:
446
+ return "SQLite has no ILIKE. Use LOWER(col) LIKE LOWER('%pattern%')."
447
  return None
448
 
449
  def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
 
493
  return False
494
  return True
495
 
496
+ # -------------------- Question model helpers --------------------
497
  @dataclass
498
  class SQLQuestion:
499
  id: str
 
511
  return d
512
 
513
  def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
514
+ return [to_question_dict(o) for o in obj_list]
 
 
 
515
 
516
  # -------------------- Domain bootstrap --------------------
517
+ def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
518
+ obj, err, model_used = llm_generate_domain_and_questions(prev_domain)
519
  if obj is None:
520
+ return FALLBACK_SCHEMA, FALLBACK_QUESTIONS, {"source":"fallback","model":None,"error":err}
521
+ return obj, obj["questions"], {"source":"openai","model":model_used,"error":None}
522
+
523
+ def install_new_domain(prev_domain: Optional[str]):
524
+ schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
 
 
 
 
 
 
 
 
 
 
 
525
  install_schema(CONN, schema)
526
+ return schema, questions, info
527
 
528
  # -------------------- Session state --------------------
529
+ CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_new_domain(prev_domain=None)
530
 
531
  # -------------------- Progress + mastery --------------------
532
  def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
 
616
  return None, "Table created but could not be queried.", None, note
617
  return pd.DataFrame(), None, None, note
618
  except Exception as e:
 
619
  msg = str(e)
620
  if "no such table" in msg.lower():
621
  return None, f"{msg}. Check table names for this randomized domain.", None, note
 
638
  if low.startswith("select"):
639
  return run_df(CONN, sql)
640
  if low.startswith("create view"):
 
641
  m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
642
  view_name = m.group(2) if m else "vw_tmp"
643
  with DB_LOCK:
 
663
 
664
  def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]:
665
  df_expected = answer_df(q["answer_sql"])
 
666
  if df_expected is None:
667
  return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
668
  if df_student is None:
 
688
  gr.update(value="Please enter your name to begin.", visible=True),
689
  gr.update(visible=False),
690
  gr.update(visible=False),
 
691
  gr.update(visible=False),
692
  pd.DataFrame(),
693
  pd.DataFrame())
 
700
 
701
  prompt = q["prompt_md"]
702
  stats = topic_stats(fetch_attempts(CONN, user_id))
 
703
  return (session,
704
  gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
705
  gr.update(visible=True), # show SQL input
706
  gr.update(value="", visible=True), # preview block
 
707
  gr.update(visible=False), # next btn hidden until submit
708
  stats,
709
  pd.DataFrame())
710
 
711
+ def render_preview(sql_text: str, session: dict):
712
  if not session or "q" not in session:
713
+ return gr.update(value="", visible=False)
714
  s = (sql_text or "").strip()
715
  if not s:
716
+ return gr.update(value="", visible=False)
717
+ return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True)
718
 
719
  def submit_answer(sql_text: str, session: dict):
720
  if not session or "user_id" not in session or "q" not in session:
 
733
  stats = topic_stats(fetch_attempts(CONN, user_id))
734
  return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
735
 
 
736
  alias_msg = None
737
  if q.get("requires_aliases"):
738
  if not aliases_present(sql_text, q.get("required_aliases", [])):
 
754
 
755
  def next_question(session: dict):
756
  if not session or "user_id" not in session:
757
+ return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), gr.update(visible=False)
758
  user_id = session["user_id"]
759
  q = pick_next_question(user_id)
760
  session["qid"] = q["id"]
761
  session["q"] = q
762
  session["start_ts"] = time.time()
763
+ return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), gr.update(visible=False)
764
 
765
  def show_hint(session: dict):
766
  if not session or "q" not in session:
767
  return gr.update(value="Start a session first.", visible=True)
 
768
  cat = session["q"]["category"]
769
  hint = {
770
  "SELECT *": "Use `SELECT * FROM table_name`.",
771
  "SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
772
  "WHERE": "Filter with `WHERE` and combine conditions using AND/OR.",
773
+ "Aliases": "Use `table_name t` and qualify as `t.col`.",
774
  "JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.",
775
  "JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.",
776
+ "Aggregation": "Use aggregates and `GROUP BY` non-aggregated columns.",
777
  "VIEW": "`CREATE VIEW view_name AS SELECT ...`.",
778
  "CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`."
779
+ }.get(cat, "Identify keys from the schema and join on them.")
780
  return gr.update(value=f"**Hint:** {hint}", visible=True)
781
 
782
  def export_progress(user_name: str):
 
790
  (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
791
  return path
792
 
793
+ def _domain_status_md():
794
+ if CURRENT_INFO.get("source") == "openai":
795
+ return f"✅ **Domain regenerated via OpenAI** (`{CURRENT_INFO.get('model','?')}`) → **{CURRENT_SCHEMA.get('domain','?')}**. " \
796
+ f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}."
797
+ err = CURRENT_INFO.get("error","")
798
+ err_short = (err[:160] + "…") if len(err) > 160 else err
799
+ return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
800
+
801
  def regenerate_domain():
802
+ global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
803
+ prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
804
+ CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_new_domain(prev_domain=prev)
805
+ return gr.update(value=_domain_status_md(), visible=True)
806
 
807
  def preview_table(tbl: str):
808
  try:
 
826
  - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
827
  - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.
828
 
829
+ > Set your `OPENAI_API_KEY` in Space secrets to enable randomization.
830
  """
831
  )
832
 
 
840
  gr.Markdown("---")
841
  gr.Markdown("### Dataset Controls")
842
  regen_btn = gr.Button("🔀 Randomize Dataset (OpenAI)")
843
+ regen_fb = gr.Markdown(_domain_status_md(), visible=True)
844
 
845
  gr.Markdown("---")
846
  gr.Markdown("### Instructor Tools")
 
860
  sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
861
 
862
  preview_md = gr.Markdown(visible=False)
 
863
 
864
  with gr.Row():
865
  submit_btn = gr.Button("Run & Submit", variant="primary")
 
885
  start_btn.click(
886
  start_session,
887
  inputs=[name_box, session_state],
888
+ outputs=[session_state, prompt_md, sql_input, preview_md, next_btn, mastery_df, result_df],
889
  )
890
  sql_input.change(
891
+ render_preview,
892
  inputs=[sql_input, session_state],
893
+ outputs=[preview_md],
894
  )
895
  submit_btn.click(
896
  submit_answer,
 
900
  next_btn.click(
901
  next_question,
902
  inputs=[session_state],
903
+ outputs=[session_state, prompt_md, sql_input, next_btn],
904
  )
905
  hint_btn.click(
906
  show_hint,
 
915
  regen_btn.click(
916
  regenerate_domain,
917
  inputs=[],
918
+ outputs=[regen_fb],
919
  )
920
  tbl_btn.click(
921
  lambda name: preview_table(name),