jtdearmon commited on
Commit
20c4c8c
·
verified ·
1 Parent(s): b159bde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -74
app.py CHANGED
@@ -2,8 +2,8 @@
2
  # - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
3
  # - 3–4 related tables with seed rows installed in SQLite.
4
  # - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
5
- # - Validator now enforces columns only when the prompt asks for them; otherwise it focuses on rows.
6
- # - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by the student’s JOINs.
7
 
8
  import os
9
  import re
@@ -12,7 +12,6 @@ import time
12
  import random
13
  import sqlite3
14
  import threading
15
- from dataclasses import dataclass
16
  from datetime import datetime, timezone
17
  from typing import List, Dict, Any, Tuple, Optional, Set
18
 
@@ -69,25 +68,21 @@ def draw_dynamic_erd(
69
  highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
70
  """
71
  highlight_tables = set(highlight_tables or [])
72
- # normalize edges so (A,B) & (B,A) match regardless of direction
73
  def _norm_edge(a, b): return tuple(sorted([a, b]))
74
  H = set(_norm_edge(*e) for e in (highlight_edges or set()))
75
 
76
  tables = schema.get("tables", [])
 
77
  if not tables:
78
- fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
79
  ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
80
  return _fig_to_pil(fig)
81
 
82
- # Layout tables horizontally
83
  n = len(tables)
84
- fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
85
  margin = 0.03
86
  width = (1 - margin * (n + 1)) / max(n, 1)
87
  height = 0.70
88
  y = 0.20
89
 
90
- # Build quick FK lookup: [(src_table, dst_table)]
91
  fk_edges = []
92
  for t in tables:
93
  for fk in t.get("fks", []) or []:
@@ -95,20 +90,16 @@ def draw_dynamic_erd(
95
  if dst:
96
  fk_edges.append((t["name"], dst))
97
 
98
- # Draw table boxes + columns
99
  boxes: Dict[str, Tuple[float,float,float,float]] = {}
100
  for i, t in enumerate(tables):
101
  tx = margin + i * (width + margin)
102
  boxes[t["name"]] = (tx, y, width, height)
103
-
104
- # Highlight table border if used in current SQL
105
  lw = 2.0 if t["name"] in highlight_tables else 1.2
106
  ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
107
  ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
108
 
109
  yy = y + height - 0.09
110
  pkset = set(t.get("pk", []) or [])
111
- # For FK annotation by column
112
  fk_map: Dict[str, List[Tuple[str, str]]] = {}
113
  for fk in t.get("fks", []) or []:
114
  ref_tbl = fk.get("ref_table", "")
@@ -126,7 +117,6 @@ def draw_dynamic_erd(
126
  ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
127
  yy -= 0.055
128
 
129
- # Draw FK edges: light gray
130
  for (src, dst) in fk_edges:
131
  if src not in boxes or dst not in boxes:
132
  continue
@@ -137,7 +127,6 @@ def draw_dynamic_erd(
137
  xytext=(x1 + w1/2.0, y1),
138
  arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
139
 
140
- # Overlay highlighted edges: bold, darker
141
  for (src, dst) in fk_edges:
142
  if _norm_edge(src, dst) in H:
143
  (x1, y1, w1, h1) = boxes[src]
@@ -150,22 +139,15 @@ def draw_dynamic_erd(
150
  ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
151
  return _fig_to_pil(fig)
152
 
153
- # Parse JOINs from SQL and turn them into tables/edges to highlight on ERD
154
  JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
155
  EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
156
  USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
157
 
158
  def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
159
- """
160
- Returns (highlight_tables, highlight_edges) based on the student's SQL.
161
- - Tables: any table appearing after FROM or JOIN (by name or alias).
162
- - Edges: pairs inferred from `a.col = b.col` or JOIN ... USING (...).
163
- """
164
  if not sql:
165
  return set(), set()
166
 
167
  low = " ".join(sql.strip().split())
168
- # Alias map alias->table and list of tables in join order
169
  alias_to_table: Dict[str, str] = {}
170
  join_order: List[str] = []
171
 
@@ -175,7 +157,6 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
175
  alias_to_table[alias] = table
176
  join_order.append(alias)
177
 
178
- # Edges from explicit equality ON clauses
179
  edges: Set[Tuple[str, str]] = set()
180
  for a1, a2 in EQ_ON_RE.findall(low):
181
  t1 = alias_to_table.get(a1, a1)
@@ -183,7 +164,6 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
183
  if t1 != t2:
184
  edges.add((t1, t2))
185
 
186
- # Heuristic for USING(): connect the previous alias with the current JOIN alias
187
  if USING_RE.search(low) and len(join_order) >= 2:
188
  for i in range(1, len(join_order)):
189
  t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
@@ -191,14 +171,10 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
191
  if t_left != t_right:
192
  edges.add((t_left, t_right))
193
 
194
- # Highlight tables used (map aliases back to table names)
195
  used_tables = {alias_to_table.get(a, a) for a in join_order}
196
-
197
- # Normalize edges to actual table names present in schema
198
  schema_tables = {t["name"] for t in schema.get("tables", [])}
199
  edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
200
  used_tables = { t for t in used_tables if t in schema_tables }
201
-
202
  return used_tables, edges
203
 
204
  # -------------------- SQLite + locking --------------------
@@ -405,7 +381,6 @@ def _loose_json_parse(s: str) -> Optional[dict]:
405
  return None
406
  return None
407
 
408
- # Canonicalization of LLM output (questions & tables)
409
  _SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
410
  _CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
411
 
@@ -500,7 +475,7 @@ def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
500
  })
501
  return out
502
 
503
- def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
504
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
505
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
506
  errors = []
@@ -605,8 +580,6 @@ def install_schema_and_prepare_questions(prev_domain: Optional[str]):
605
  install_schema(CONN, schema)
606
  if not questions:
607
  questions = FALLBACK_QUESTIONS
608
- info = {"source":"openai+fallback-questions","model":info.get("model"),
609
- "error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
610
  return schema, questions, info
611
 
612
  # -------------------- Session globals --------------------
@@ -647,7 +620,7 @@ def pick_next_question(user_id: str) -> Dict[str,Any]:
647
  pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
648
  df = fetch_attempts(CONN, user_id)
649
  stats = topic_stats(df)
650
- stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
651
  weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
652
  cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
653
  return dict(random.choice(cands))
@@ -699,16 +672,14 @@ def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame)
699
  return "Possible cartesian product: no join condition detected."
700
  return None
701
 
702
- # Column enforcement policy — only when the prompt asks for it
703
  def should_enforce_columns(q: Dict[str, Any]) -> bool:
704
  cat = (q.get("category") or "").strip()
705
  if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
706
  return True
707
  prompt = (q.get("prompt_md") or "").lower()
708
- # Signals that the projection is specified in the prompt
709
- if re.search(r"`[^`]+`", q.get("prompt_md") or ""): # backticked names
710
  return True
711
- if re.search(r"\((?:show|return|display)[^)]+\)", prompt): # e.g., "(show title, price)"
712
  return True
713
  if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
714
  return True
@@ -734,7 +705,6 @@ def results_equal_or_superset(df_student: pd.DataFrame, df_expected: pd.DataFram
734
  return False, None
735
 
736
  def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
737
- # When projection isn't specified, match on row count only.
738
  return df_student.shape[0] == df_expected.shape[0]
739
 
740
  def aliases_present(sql: str, required_aliases: List[str]) -> bool:
@@ -744,7 +714,7 @@ def aliases_present(sql: str, required_aliases: List[str]) -> bool:
744
  return False
745
  return True
746
 
747
- def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
748
  if not sql_text or not sql_text.strip():
749
  return None, "Enter a SQL statement.", None, None
750
  sql_raw = sql_text.strip().rstrip(";")
@@ -940,28 +910,15 @@ def show_hint(session: dict):
940
  }.get(cat, "Identify keys from the schema and join on them.")
941
  return gr.update(value=f"**Hint:** {hint}", visible=True)
942
 
943
- def export_progress(user_name: str):
944
- slug = "-".join((user_name or "").lower().split())
945
- if not slug: return None
946
- user_id = slug[:64]
947
- with DB_LOCK:
948
- df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
949
- os.makedirs(EXPORT_DIR, exist_ok=True)
950
- path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
951
- (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
952
- return path
953
-
954
  def _domain_status_md():
955
- if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
956
- note = " (LLM domain ok; used fallback questions)" if CURRENT_INFO.get("source") == "openai+fallback-questions" else ""
957
  accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
958
- return (f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
959
  f"Accepted questions: {accepted}, dropped: {dropped}. \n"
960
  f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
961
  err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
962
  return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
963
 
964
- # ----- UPDATED: regenerate also refreshes tbl list and, if session active, seeds a new question + shows input
965
  def list_tables_for_preview():
966
  df = run_df(CONN, """
967
  SELECT name FROM sqlite_master
@@ -972,6 +929,7 @@ def list_tables_for_preview():
972
  """)
973
  return df["name"].tolist() if not df.empty else ["(no tables)"]
974
 
 
975
  def regenerate_domain(session: dict):
976
  global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
977
  prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
@@ -979,20 +937,33 @@ def regenerate_domain(session: dict):
979
  erd = draw_dynamic_erd(CURRENT_SCHEMA)
980
  status = _domain_status_md()
981
 
982
- # Refresh the preview dropdown
983
- choices = list_tables_for_preview()
984
- dd_update = gr.update(choices=choices, value=(choices[0] if choices and choices[0]!="(no tables)" else None))
985
-
986
- # If a session is active, show the first question immediately for the new domain
987
- prompt_update = gr.update()
988
- input_update = gr.update()
989
- if session and session.get("user_id"):
990
- q = pick_next_question(session["user_id"])
991
- session["qid"] = q["id"]; session["q"] = q; session["start_ts"] = time.time()
992
- prompt_update = gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True)
993
- input_update = gr.update(value="", visible=True)
994
-
995
- return status, erd, prompt_update, input_update, dd_update, session
 
 
 
 
 
 
 
 
 
 
 
 
 
996
 
997
  def preview_table(tbl: str):
998
  try:
@@ -1059,7 +1030,6 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
1059
  gr.Markdown("### Result Preview")
1060
  result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
1061
 
1062
- # Wire events
1063
  start_btn.click(
1064
  start_session,
1065
  inputs=[name_box, session_state],
@@ -1086,18 +1056,15 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
1086
  outputs=[feedback_md],
1087
  )
1088
  export_btn.click(
1089
- export_progress,
1090
  inputs=[export_name],
1091
  outputs=[export_file],
1092
  )
1093
-
1094
- # UPDATED: one callback handles regeneration, dropdown refresh, and (if session) reseeding the next question
1095
- regen_btn.click(
1096
  regenerate_domain,
1097
  inputs=[session_state],
1098
- outputs=[regen_fb, er_image, prompt_md, sql_input, tbl_dd, session_state],
1099
  )
1100
-
1101
  tbl_btn.click(
1102
  preview_table,
1103
  inputs=[tbl_dd],
 
2
  # - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
3
  # - 3–4 related tables with seed rows installed in SQLite.
4
  # - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
5
+ # - Validator enforces columns only when the prompt asks; otherwise focuses on rows.
6
+ # - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by JOINs.
7
 
8
  import os
9
  import re
 
12
  import random
13
  import sqlite3
14
  import threading
 
15
  from datetime import datetime, timezone
16
  from typing import List, Dict, Any, Tuple, Optional, Set
17
 
 
68
  highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
69
  """
70
  highlight_tables = set(highlight_tables or [])
 
71
  def _norm_edge(a, b): return tuple(sorted([a, b]))
72
  H = set(_norm_edge(*e) for e in (highlight_edges or set()))
73
 
74
  tables = schema.get("tables", [])
75
+ fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
76
  if not tables:
 
77
  ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
78
  return _fig_to_pil(fig)
79
 
 
80
  n = len(tables)
 
81
  margin = 0.03
82
  width = (1 - margin * (n + 1)) / max(n, 1)
83
  height = 0.70
84
  y = 0.20
85
 
 
86
  fk_edges = []
87
  for t in tables:
88
  for fk in t.get("fks", []) or []:
 
90
  if dst:
91
  fk_edges.append((t["name"], dst))
92
 
 
93
  boxes: Dict[str, Tuple[float,float,float,float]] = {}
94
  for i, t in enumerate(tables):
95
  tx = margin + i * (width + margin)
96
  boxes[t["name"]] = (tx, y, width, height)
 
 
97
  lw = 2.0 if t["name"] in highlight_tables else 1.2
98
  ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
99
  ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
100
 
101
  yy = y + height - 0.09
102
  pkset = set(t.get("pk", []) or [])
 
103
  fk_map: Dict[str, List[Tuple[str, str]]] = {}
104
  for fk in t.get("fks", []) or []:
105
  ref_tbl = fk.get("ref_table", "")
 
117
  ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
118
  yy -= 0.055
119
 
 
120
  for (src, dst) in fk_edges:
121
  if src not in boxes or dst not in boxes:
122
  continue
 
127
  xytext=(x1 + w1/2.0, y1),
128
  arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
129
 
 
130
  for (src, dst) in fk_edges:
131
  if _norm_edge(src, dst) in H:
132
  (x1, y1, w1, h1) = boxes[src]
 
139
  ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
140
  return _fig_to_pil(fig)
141
 
 
142
  JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
143
  EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
144
  USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
145
 
146
  def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
 
 
 
 
 
147
  if not sql:
148
  return set(), set()
149
 
150
  low = " ".join(sql.strip().split())
 
151
  alias_to_table: Dict[str, str] = {}
152
  join_order: List[str] = []
153
 
 
157
  alias_to_table[alias] = table
158
  join_order.append(alias)
159
 
 
160
  edges: Set[Tuple[str, str]] = set()
161
  for a1, a2 in EQ_ON_RE.findall(low):
162
  t1 = alias_to_table.get(a1, a1)
 
164
  if t1 != t2:
165
  edges.add((t1, t2))
166
 
 
167
  if USING_RE.search(low) and len(join_order) >= 2:
168
  for i in range(1, len(join_order)):
169
  t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
 
171
  if t_left != t_right:
172
  edges.add((t_left, t_right))
173
 
 
174
  used_tables = {alias_to_table.get(a, a) for a in join_order}
 
 
175
  schema_tables = {t["name"] for t in schema.get("tables", [])}
176
  edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
177
  used_tables = { t for t in used_tables if t in schema_tables }
 
178
  return used_tables, edges
179
 
180
  # -------------------- SQLite + locking --------------------
 
381
  return None
382
  return None
383
 
 
384
  _SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
385
  _CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
386
 
 
475
  })
476
  return out
477
 
478
+ def llm_generate_domain_and_questions(prev_domain: Optional[str]):
479
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
480
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
481
  errors = []
 
580
  install_schema(CONN, schema)
581
  if not questions:
582
  questions = FALLBACK_QUESTIONS
 
 
583
  return schema, questions, info
584
 
585
  # -------------------- Session globals --------------------
 
620
  pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
621
  df = fetch_attempts(CONN, user_id)
622
  stats = topic_stats(df)
623
+ stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True]) if not stats.empty else stats
624
  weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
625
  cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
626
  return dict(random.choice(cands))
 
672
  return "Possible cartesian product: no join condition detected."
673
  return None
674
 
 
675
  def should_enforce_columns(q: Dict[str, Any]) -> bool:
676
  cat = (q.get("category") or "").strip()
677
  if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
678
  return True
679
  prompt = (q.get("prompt_md") or "").lower()
680
+ if re.search(r"`[^`]+`", q.get("prompt_md") or ""):
 
681
  return True
682
+ if re.search(r"\((?:show|return|display)[^)]+\)", prompt):
683
  return True
684
  if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
685
  return True
 
705
  return False, None
706
 
707
  def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
 
708
  return df_student.shape[0] == df_expected.shape[0]
709
 
710
  def aliases_present(sql: str, required_aliases: List[str]) -> bool:
 
714
  return False
715
  return True
716
 
717
+ def exec_student_sql(sql_text: str):
718
  if not sql_text or not sql_text.strip():
719
  return None, "Enter a SQL statement.", None, None
720
  sql_raw = sql_text.strip().rstrip(";")
 
910
  }.get(cat, "Identify keys from the schema and join on them.")
911
  return gr.update(value=f"**Hint:** {hint}", visible=True)
912
 
 
 
 
 
 
 
 
 
 
 
 
913
  def _domain_status_md():
914
+ if CURRENT_INFO.get("source","openai"):
 
915
  accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
916
+ return (f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**. "
917
  f"Accepted questions: {accepted}, dropped: {dropped}. \n"
918
  f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
919
  err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
920
  return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
921
 
 
922
  def list_tables_for_preview():
923
  df = run_df(CONN, """
924
  SELECT name FROM sqlite_master
 
929
  """)
930
  return df["name"].tolist() if not df.empty else ["(no tables)"]
931
 
932
+ # >>> FIX: Always reseed a question on randomize (creates a guest session if needed)
933
  def regenerate_domain(session: dict):
934
  global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
935
  prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
 
937
  erd = draw_dynamic_erd(CURRENT_SCHEMA)
938
  status = _domain_status_md()
939
 
940
+ # Ensure a session (guest if needed)
941
+ if not session or not session.get("user_id"):
942
+ user_id = f"guest-{int(time.time())}"
943
+ upsert_user(CONN, user_id, "Guest")
944
+ session = {"user_id": user_id, "name": "Guest", "qid": None, "start_ts": time.time(), "q": None}
945
+
946
+ # Seed next question for this session
947
+ q = pick_next_question(session["user_id"])
948
+ session.update({"qid": q["id"], "q": q, "start_ts": time.time()})
949
+
950
+ # Fresh mastery and cleared result preview
951
+ stats = topic_stats(fetch_attempts(CONN, session["user_id"]))
952
+ empty_df = pd.DataFrame()
953
+
954
+ # Refresh dropdown
955
+ dd_update = gr.update(choices=list_tables_for_preview(), value=None)
956
+
957
+ return (
958
+ gr.update(value=status, visible=True), # regen_fb
959
+ erd, # er_image
960
+ gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), # prompt_md
961
+ gr.update(value="", visible=True), # sql_input
962
+ dd_update, # tbl_dd
963
+ stats, # mastery_df
964
+ empty_df, # result_df
965
+ session # session_state
966
+ )
967
 
968
  def preview_table(tbl: str):
969
  try:
 
1030
  gr.Markdown("### Result Preview")
1031
  result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
1032
 
 
1033
  start_btn.click(
1034
  start_session,
1035
  inputs=[name_box, session_state],
 
1056
  outputs=[feedback_md],
1057
  )
1058
  export_btn.click(
1059
+ lambda user: os.path.abspath(os.path.join(EXPORT_DIR, f"{'-'.join((user or '').lower().split())[:64]}_progress.csv")),
1060
  inputs=[export_name],
1061
  outputs=[export_file],
1062
  )
1063
+ regen_btn.click( # one callback: reseed question, refresh dropdown, clear previews
 
 
1064
  regenerate_domain,
1065
  inputs=[session_state],
1066
+ outputs=[regen_fb, er_image, prompt_md, sql_input, tbl_dd, mastery_df, result_df, session_state],
1067
  )
 
1068
  tbl_btn.click(
1069
  preview_table,
1070
  inputs=[tbl_dd],