Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
# - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
|
| 3 |
# - 3–4 related tables with seed rows installed in SQLite.
|
| 4 |
# - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
|
| 5 |
-
# - Validator
|
| 6 |
-
# - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by
|
| 7 |
|
| 8 |
import os
|
| 9 |
import re
|
|
@@ -12,7 +12,6 @@ import time
|
|
| 12 |
import random
|
| 13 |
import sqlite3
|
| 14 |
import threading
|
| 15 |
-
from dataclasses import dataclass
|
| 16 |
from datetime import datetime, timezone
|
| 17 |
from typing import List, Dict, Any, Tuple, Optional, Set
|
| 18 |
|
|
@@ -69,25 +68,21 @@ def draw_dynamic_erd(
|
|
| 69 |
highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
|
| 70 |
"""
|
| 71 |
highlight_tables = set(highlight_tables or [])
|
| 72 |
-
# normalize edges so (A,B) & (B,A) match regardless of direction
|
| 73 |
def _norm_edge(a, b): return tuple(sorted([a, b]))
|
| 74 |
H = set(_norm_edge(*e) for e in (highlight_edges or set()))
|
| 75 |
|
| 76 |
tables = schema.get("tables", [])
|
|
|
|
| 77 |
if not tables:
|
| 78 |
-
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
|
| 79 |
ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
|
| 80 |
return _fig_to_pil(fig)
|
| 81 |
|
| 82 |
-
# Layout tables horizontally
|
| 83 |
n = len(tables)
|
| 84 |
-
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
|
| 85 |
margin = 0.03
|
| 86 |
width = (1 - margin * (n + 1)) / max(n, 1)
|
| 87 |
height = 0.70
|
| 88 |
y = 0.20
|
| 89 |
|
| 90 |
-
# Build quick FK lookup: [(src_table, dst_table)]
|
| 91 |
fk_edges = []
|
| 92 |
for t in tables:
|
| 93 |
for fk in t.get("fks", []) or []:
|
|
@@ -95,20 +90,16 @@ def draw_dynamic_erd(
|
|
| 95 |
if dst:
|
| 96 |
fk_edges.append((t["name"], dst))
|
| 97 |
|
| 98 |
-
# Draw table boxes + columns
|
| 99 |
boxes: Dict[str, Tuple[float,float,float,float]] = {}
|
| 100 |
for i, t in enumerate(tables):
|
| 101 |
tx = margin + i * (width + margin)
|
| 102 |
boxes[t["name"]] = (tx, y, width, height)
|
| 103 |
-
|
| 104 |
-
# Highlight table border if used in current SQL
|
| 105 |
lw = 2.0 if t["name"] in highlight_tables else 1.2
|
| 106 |
ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
|
| 107 |
ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
|
| 108 |
|
| 109 |
yy = y + height - 0.09
|
| 110 |
pkset = set(t.get("pk", []) or [])
|
| 111 |
-
# For FK annotation by column
|
| 112 |
fk_map: Dict[str, List[Tuple[str, str]]] = {}
|
| 113 |
for fk in t.get("fks", []) or []:
|
| 114 |
ref_tbl = fk.get("ref_table", "")
|
|
@@ -126,7 +117,6 @@ def draw_dynamic_erd(
|
|
| 126 |
ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
|
| 127 |
yy -= 0.055
|
| 128 |
|
| 129 |
-
# Draw FK edges: light gray
|
| 130 |
for (src, dst) in fk_edges:
|
| 131 |
if src not in boxes or dst not in boxes:
|
| 132 |
continue
|
|
@@ -137,7 +127,6 @@ def draw_dynamic_erd(
|
|
| 137 |
xytext=(x1 + w1/2.0, y1),
|
| 138 |
arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
|
| 139 |
|
| 140 |
-
# Overlay highlighted edges: bold, darker
|
| 141 |
for (src, dst) in fk_edges:
|
| 142 |
if _norm_edge(src, dst) in H:
|
| 143 |
(x1, y1, w1, h1) = boxes[src]
|
|
@@ -150,22 +139,15 @@ def draw_dynamic_erd(
|
|
| 150 |
ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
|
| 151 |
return _fig_to_pil(fig)
|
| 152 |
|
| 153 |
-
# Parse JOINs from SQL and turn them into tables/edges to highlight on ERD
|
| 154 |
JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
|
| 155 |
EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
|
| 156 |
USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
|
| 157 |
|
| 158 |
def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
|
| 159 |
-
"""
|
| 160 |
-
Returns (highlight_tables, highlight_edges) based on the student's SQL.
|
| 161 |
-
- Tables: any table appearing after FROM or JOIN (by name or alias).
|
| 162 |
-
- Edges: pairs inferred from `a.col = b.col` or JOIN ... USING (...).
|
| 163 |
-
"""
|
| 164 |
if not sql:
|
| 165 |
return set(), set()
|
| 166 |
|
| 167 |
low = " ".join(sql.strip().split())
|
| 168 |
-
# Alias map alias->table and list of tables in join order
|
| 169 |
alias_to_table: Dict[str, str] = {}
|
| 170 |
join_order: List[str] = []
|
| 171 |
|
|
@@ -175,7 +157,6 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
|
|
| 175 |
alias_to_table[alias] = table
|
| 176 |
join_order.append(alias)
|
| 177 |
|
| 178 |
-
# Edges from explicit equality ON clauses
|
| 179 |
edges: Set[Tuple[str, str]] = set()
|
| 180 |
for a1, a2 in EQ_ON_RE.findall(low):
|
| 181 |
t1 = alias_to_table.get(a1, a1)
|
|
@@ -183,7 +164,6 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
|
|
| 183 |
if t1 != t2:
|
| 184 |
edges.add((t1, t2))
|
| 185 |
|
| 186 |
-
# Heuristic for USING(): connect the previous alias with the current JOIN alias
|
| 187 |
if USING_RE.search(low) and len(join_order) >= 2:
|
| 188 |
for i in range(1, len(join_order)):
|
| 189 |
t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
|
|
@@ -191,14 +171,10 @@ def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tupl
|
|
| 191 |
if t_left != t_right:
|
| 192 |
edges.add((t_left, t_right))
|
| 193 |
|
| 194 |
-
# Highlight tables used (map aliases back to table names)
|
| 195 |
used_tables = {alias_to_table.get(a, a) for a in join_order}
|
| 196 |
-
|
| 197 |
-
# Normalize edges to actual table names present in schema
|
| 198 |
schema_tables = {t["name"] for t in schema.get("tables", [])}
|
| 199 |
edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
|
| 200 |
used_tables = { t for t in used_tables if t in schema_tables }
|
| 201 |
-
|
| 202 |
return used_tables, edges
|
| 203 |
|
| 204 |
# -------------------- SQLite + locking --------------------
|
|
@@ -405,7 +381,6 @@ def _loose_json_parse(s: str) -> Optional[dict]:
|
|
| 405 |
return None
|
| 406 |
return None
|
| 407 |
|
| 408 |
-
# Canonicalization of LLM output (questions & tables)
|
| 409 |
_SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 410 |
_CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
|
| 411 |
|
|
@@ -500,7 +475,7 @@ def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
| 500 |
})
|
| 501 |
return out
|
| 502 |
|
| 503 |
-
def llm_generate_domain_and_questions(prev_domain: Optional[str])
|
| 504 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 505 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 506 |
errors = []
|
|
@@ -605,8 +580,6 @@ def install_schema_and_prepare_questions(prev_domain: Optional[str]):
|
|
| 605 |
install_schema(CONN, schema)
|
| 606 |
if not questions:
|
| 607 |
questions = FALLBACK_QUESTIONS
|
| 608 |
-
info = {"source":"openai+fallback-questions","model":info.get("model"),
|
| 609 |
-
"error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
|
| 610 |
return schema, questions, info
|
| 611 |
|
| 612 |
# -------------------- Session globals --------------------
|
|
@@ -647,7 +620,7 @@ def pick_next_question(user_id: str) -> Dict[str,Any]:
|
|
| 647 |
pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
|
| 648 |
df = fetch_attempts(CONN, user_id)
|
| 649 |
stats = topic_stats(df)
|
| 650 |
-
stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
|
| 651 |
weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
|
| 652 |
cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
|
| 653 |
return dict(random.choice(cands))
|
|
@@ -699,16 +672,14 @@ def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame)
|
|
| 699 |
return "Possible cartesian product: no join condition detected."
|
| 700 |
return None
|
| 701 |
|
| 702 |
-
# Column enforcement policy — only when the prompt asks for it
|
| 703 |
def should_enforce_columns(q: Dict[str, Any]) -> bool:
|
| 704 |
cat = (q.get("category") or "").strip()
|
| 705 |
if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
|
| 706 |
return True
|
| 707 |
prompt = (q.get("prompt_md") or "").lower()
|
| 708 |
-
|
| 709 |
-
if re.search(r"`[^`]+`", q.get("prompt_md") or ""): # backticked names
|
| 710 |
return True
|
| 711 |
-
if re.search(r"\((?:show|return|display)[^)]+\)", prompt):
|
| 712 |
return True
|
| 713 |
if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
|
| 714 |
return True
|
|
@@ -734,7 +705,6 @@ def results_equal_or_superset(df_student: pd.DataFrame, df_expected: pd.DataFram
|
|
| 734 |
return False, None
|
| 735 |
|
| 736 |
def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
|
| 737 |
-
# When projection isn't specified, match on row count only.
|
| 738 |
return df_student.shape[0] == df_expected.shape[0]
|
| 739 |
|
| 740 |
def aliases_present(sql: str, required_aliases: List[str]) -> bool:
|
|
@@ -744,7 +714,7 @@ def aliases_present(sql: str, required_aliases: List[str]) -> bool:
|
|
| 744 |
return False
|
| 745 |
return True
|
| 746 |
|
| 747 |
-
def exec_student_sql(sql_text: str)
|
| 748 |
if not sql_text or not sql_text.strip():
|
| 749 |
return None, "Enter a SQL statement.", None, None
|
| 750 |
sql_raw = sql_text.strip().rstrip(";")
|
|
@@ -940,28 +910,15 @@ def show_hint(session: dict):
|
|
| 940 |
}.get(cat, "Identify keys from the schema and join on them.")
|
| 941 |
return gr.update(value=f"**Hint:** {hint}", visible=True)
|
| 942 |
|
| 943 |
-
def export_progress(user_name: str):
|
| 944 |
-
slug = "-".join((user_name or "").lower().split())
|
| 945 |
-
if not slug: return None
|
| 946 |
-
user_id = slug[:64]
|
| 947 |
-
with DB_LOCK:
|
| 948 |
-
df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
|
| 949 |
-
os.makedirs(EXPORT_DIR, exist_ok=True)
|
| 950 |
-
path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
|
| 951 |
-
(pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
|
| 952 |
-
return path
|
| 953 |
-
|
| 954 |
def _domain_status_md():
|
| 955 |
-
if CURRENT_INFO.get("source","
|
| 956 |
-
note = " (LLM domain ok; used fallback questions)" if CURRENT_INFO.get("source") == "openai+fallback-questions" else ""
|
| 957 |
accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
|
| 958 |
-
return (f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}
|
| 959 |
f"Accepted questions: {accepted}, dropped: {dropped}. \n"
|
| 960 |
f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
|
| 961 |
err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
|
| 962 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
| 963 |
|
| 964 |
-
# ----- UPDATED: regenerate also refreshes tbl list and, if session active, seeds a new question + shows input
|
| 965 |
def list_tables_for_preview():
|
| 966 |
df = run_df(CONN, """
|
| 967 |
SELECT name FROM sqlite_master
|
|
@@ -972,6 +929,7 @@ def list_tables_for_preview():
|
|
| 972 |
""")
|
| 973 |
return df["name"].tolist() if not df.empty else ["(no tables)"]
|
| 974 |
|
|
|
|
| 975 |
def regenerate_domain(session: dict):
|
| 976 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 977 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
|
@@ -979,20 +937,33 @@ def regenerate_domain(session: dict):
|
|
| 979 |
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 980 |
status = _domain_status_md()
|
| 981 |
|
| 982 |
-
#
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
|
| 997 |
def preview_table(tbl: str):
|
| 998 |
try:
|
|
@@ -1059,7 +1030,6 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 1059 |
gr.Markdown("### Result Preview")
|
| 1060 |
result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
|
| 1061 |
|
| 1062 |
-
# Wire events
|
| 1063 |
start_btn.click(
|
| 1064 |
start_session,
|
| 1065 |
inputs=[name_box, session_state],
|
|
@@ -1086,18 +1056,15 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 1086 |
outputs=[feedback_md],
|
| 1087 |
)
|
| 1088 |
export_btn.click(
|
| 1089 |
-
|
| 1090 |
inputs=[export_name],
|
| 1091 |
outputs=[export_file],
|
| 1092 |
)
|
| 1093 |
-
|
| 1094 |
-
# UPDATED: one callback handles regeneration, dropdown refresh, and (if session) reseeding the next question
|
| 1095 |
-
regen_btn.click(
|
| 1096 |
regenerate_domain,
|
| 1097 |
inputs=[session_state],
|
| 1098 |
-
outputs=[regen_fb, er_image, prompt_md, sql_input, tbl_dd, session_state],
|
| 1099 |
)
|
| 1100 |
-
|
| 1101 |
tbl_btn.click(
|
| 1102 |
preview_table,
|
| 1103 |
inputs=[tbl_dd],
|
|
|
|
| 2 |
# - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
|
| 3 |
# - 3–4 related tables with seed rows installed in SQLite.
|
| 4 |
# - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
|
| 5 |
+
# - Validator enforces columns only when the prompt asks; otherwise focuses on rows.
|
| 6 |
+
# - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by JOINs.
|
| 7 |
|
| 8 |
import os
|
| 9 |
import re
|
|
|
|
| 12 |
import random
|
| 13 |
import sqlite3
|
| 14 |
import threading
|
|
|
|
| 15 |
from datetime import datetime, timezone
|
| 16 |
from typing import List, Dict, Any, Tuple, Optional, Set
|
| 17 |
|
|
|
|
| 68 |
highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
|
| 69 |
"""
|
| 70 |
highlight_tables = set(highlight_tables or [])
|
|
|
|
| 71 |
def _norm_edge(a, b): return tuple(sorted([a, b]))
|
| 72 |
H = set(_norm_edge(*e) for e in (highlight_edges or set()))
|
| 73 |
|
| 74 |
tables = schema.get("tables", [])
|
| 75 |
+
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
|
| 76 |
if not tables:
|
|
|
|
| 77 |
ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
|
| 78 |
return _fig_to_pil(fig)
|
| 79 |
|
|
|
|
| 80 |
n = len(tables)
|
|
|
|
| 81 |
margin = 0.03
|
| 82 |
width = (1 - margin * (n + 1)) / max(n, 1)
|
| 83 |
height = 0.70
|
| 84 |
y = 0.20
|
| 85 |
|
|
|
|
| 86 |
fk_edges = []
|
| 87 |
for t in tables:
|
| 88 |
for fk in t.get("fks", []) or []:
|
|
|
|
| 90 |
if dst:
|
| 91 |
fk_edges.append((t["name"], dst))
|
| 92 |
|
|
|
|
| 93 |
boxes: Dict[str, Tuple[float,float,float,float]] = {}
|
| 94 |
for i, t in enumerate(tables):
|
| 95 |
tx = margin + i * (width + margin)
|
| 96 |
boxes[t["name"]] = (tx, y, width, height)
|
|
|
|
|
|
|
| 97 |
lw = 2.0 if t["name"] in highlight_tables else 1.2
|
| 98 |
ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
|
| 99 |
ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
|
| 100 |
|
| 101 |
yy = y + height - 0.09
|
| 102 |
pkset = set(t.get("pk", []) or [])
|
|
|
|
| 103 |
fk_map: Dict[str, List[Tuple[str, str]]] = {}
|
| 104 |
for fk in t.get("fks", []) or []:
|
| 105 |
ref_tbl = fk.get("ref_table", "")
|
|
|
|
| 117 |
ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
|
| 118 |
yy -= 0.055
|
| 119 |
|
|
|
|
| 120 |
for (src, dst) in fk_edges:
|
| 121 |
if src not in boxes or dst not in boxes:
|
| 122 |
continue
|
|
|
|
| 127 |
xytext=(x1 + w1/2.0, y1),
|
| 128 |
arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
|
| 129 |
|
|
|
|
| 130 |
for (src, dst) in fk_edges:
|
| 131 |
if _norm_edge(src, dst) in H:
|
| 132 |
(x1, y1, w1, h1) = boxes[src]
|
|
|
|
| 139 |
ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
|
| 140 |
return _fig_to_pil(fig)
|
| 141 |
|
|
|
|
| 142 |
JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
|
| 143 |
EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
|
| 144 |
USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
|
| 145 |
|
| 146 |
def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
if not sql:
|
| 148 |
return set(), set()
|
| 149 |
|
| 150 |
low = " ".join(sql.strip().split())
|
|
|
|
| 151 |
alias_to_table: Dict[str, str] = {}
|
| 152 |
join_order: List[str] = []
|
| 153 |
|
|
|
|
| 157 |
alias_to_table[alias] = table
|
| 158 |
join_order.append(alias)
|
| 159 |
|
|
|
|
| 160 |
edges: Set[Tuple[str, str]] = set()
|
| 161 |
for a1, a2 in EQ_ON_RE.findall(low):
|
| 162 |
t1 = alias_to_table.get(a1, a1)
|
|
|
|
| 164 |
if t1 != t2:
|
| 165 |
edges.add((t1, t2))
|
| 166 |
|
|
|
|
| 167 |
if USING_RE.search(low) and len(join_order) >= 2:
|
| 168 |
for i in range(1, len(join_order)):
|
| 169 |
t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
|
|
|
|
| 171 |
if t_left != t_right:
|
| 172 |
edges.add((t_left, t_right))
|
| 173 |
|
|
|
|
| 174 |
used_tables = {alias_to_table.get(a, a) for a in join_order}
|
|
|
|
|
|
|
| 175 |
schema_tables = {t["name"] for t in schema.get("tables", [])}
|
| 176 |
edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
|
| 177 |
used_tables = { t for t in used_tables if t in schema_tables }
|
|
|
|
| 178 |
return used_tables, edges
|
| 179 |
|
| 180 |
# -------------------- SQLite + locking --------------------
|
|
|
|
| 381 |
return None
|
| 382 |
return None
|
| 383 |
|
|
|
|
| 384 |
_SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 385 |
_CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
|
| 386 |
|
|
|
|
| 475 |
})
|
| 476 |
return out
|
| 477 |
|
| 478 |
+
def llm_generate_domain_and_questions(prev_domain: Optional[str]):
|
| 479 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 480 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 481 |
errors = []
|
|
|
|
| 580 |
install_schema(CONN, schema)
|
| 581 |
if not questions:
|
| 582 |
questions = FALLBACK_QUESTIONS
|
|
|
|
|
|
|
| 583 |
return schema, questions, info
|
| 584 |
|
| 585 |
# -------------------- Session globals --------------------
|
|
|
|
| 620 |
pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
|
| 621 |
df = fetch_attempts(CONN, user_id)
|
| 622 |
stats = topic_stats(df)
|
| 623 |
+
stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True]) if not stats.empty else stats
|
| 624 |
weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
|
| 625 |
cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
|
| 626 |
return dict(random.choice(cands))
|
|
|
|
| 672 |
return "Possible cartesian product: no join condition detected."
|
| 673 |
return None
|
| 674 |
|
|
|
|
| 675 |
def should_enforce_columns(q: Dict[str, Any]) -> bool:
|
| 676 |
cat = (q.get("category") or "").strip()
|
| 677 |
if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
|
| 678 |
return True
|
| 679 |
prompt = (q.get("prompt_md") or "").lower()
|
| 680 |
+
if re.search(r"`[^`]+`", q.get("prompt_md") or ""):
|
|
|
|
| 681 |
return True
|
| 682 |
+
if re.search(r"\((?:show|return|display)[^)]+\)", prompt):
|
| 683 |
return True
|
| 684 |
if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
|
| 685 |
return True
|
|
|
|
| 705 |
return False, None
|
| 706 |
|
| 707 |
def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
|
|
|
|
| 708 |
return df_student.shape[0] == df_expected.shape[0]
|
| 709 |
|
| 710 |
def aliases_present(sql: str, required_aliases: List[str]) -> bool:
|
|
|
|
| 714 |
return False
|
| 715 |
return True
|
| 716 |
|
| 717 |
+
def exec_student_sql(sql_text: str):
|
| 718 |
if not sql_text or not sql_text.strip():
|
| 719 |
return None, "Enter a SQL statement.", None, None
|
| 720 |
sql_raw = sql_text.strip().rstrip(";")
|
|
|
|
| 910 |
}.get(cat, "Identify keys from the schema and join on them.")
|
| 911 |
return gr.update(value=f"**Hint:** {hint}", visible=True)
|
| 912 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
def _domain_status_md():
|
| 914 |
+
if CURRENT_INFO.get("source","openai"):
|
|
|
|
| 915 |
accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
|
| 916 |
+
return (f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**. "
|
| 917 |
f"Accepted questions: {accepted}, dropped: {dropped}. \n"
|
| 918 |
f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
|
| 919 |
err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
|
| 920 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
| 921 |
|
|
|
|
| 922 |
def list_tables_for_preview():
|
| 923 |
df = run_df(CONN, """
|
| 924 |
SELECT name FROM sqlite_master
|
|
|
|
| 929 |
""")
|
| 930 |
return df["name"].tolist() if not df.empty else ["(no tables)"]
|
| 931 |
|
| 932 |
+
# >>> FIX: Always reseed a question on randomize (creates a guest session if needed)
|
| 933 |
def regenerate_domain(session: dict):
|
| 934 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 935 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
|
|
|
| 937 |
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 938 |
status = _domain_status_md()
|
| 939 |
|
| 940 |
+
# Ensure a session (guest if needed)
|
| 941 |
+
if not session or not session.get("user_id"):
|
| 942 |
+
user_id = f"guest-{int(time.time())}"
|
| 943 |
+
upsert_user(CONN, user_id, "Guest")
|
| 944 |
+
session = {"user_id": user_id, "name": "Guest", "qid": None, "start_ts": time.time(), "q": None}
|
| 945 |
+
|
| 946 |
+
# Seed next question for this session
|
| 947 |
+
q = pick_next_question(session["user_id"])
|
| 948 |
+
session.update({"qid": q["id"], "q": q, "start_ts": time.time()})
|
| 949 |
+
|
| 950 |
+
# Fresh mastery and cleared result preview
|
| 951 |
+
stats = topic_stats(fetch_attempts(CONN, session["user_id"]))
|
| 952 |
+
empty_df = pd.DataFrame()
|
| 953 |
+
|
| 954 |
+
# Refresh dropdown
|
| 955 |
+
dd_update = gr.update(choices=list_tables_for_preview(), value=None)
|
| 956 |
+
|
| 957 |
+
return (
|
| 958 |
+
gr.update(value=status, visible=True), # regen_fb
|
| 959 |
+
erd, # er_image
|
| 960 |
+
gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), # prompt_md
|
| 961 |
+
gr.update(value="", visible=True), # sql_input
|
| 962 |
+
dd_update, # tbl_dd
|
| 963 |
+
stats, # mastery_df
|
| 964 |
+
empty_df, # result_df
|
| 965 |
+
session # session_state
|
| 966 |
+
)
|
| 967 |
|
| 968 |
def preview_table(tbl: str):
|
| 969 |
try:
|
|
|
|
| 1030 |
gr.Markdown("### Result Preview")
|
| 1031 |
result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
|
| 1032 |
|
|
|
|
| 1033 |
start_btn.click(
|
| 1034 |
start_session,
|
| 1035 |
inputs=[name_box, session_state],
|
|
|
|
| 1056 |
outputs=[feedback_md],
|
| 1057 |
)
|
| 1058 |
export_btn.click(
|
| 1059 |
+
lambda user: os.path.abspath(os.path.join(EXPORT_DIR, f"{'-'.join((user or '').lower().split())[:64]}_progress.csv")),
|
| 1060 |
inputs=[export_name],
|
| 1061 |
outputs=[export_file],
|
| 1062 |
)
|
| 1063 |
+
regen_btn.click( # one callback: reseed question, refresh dropdown, clear previews
|
|
|
|
|
|
|
| 1064 |
regenerate_domain,
|
| 1065 |
inputs=[session_state],
|
| 1066 |
+
outputs=[regen_fb, er_image, prompt_md, sql_input, tbl_dd, mastery_df, result_df, session_state],
|
| 1067 |
)
|
|
|
|
| 1068 |
tbl_btn.click(
|
| 1069 |
preview_table,
|
| 1070 |
inputs=[tbl_dd],
|