Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
# Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
|
| 2 |
-
# -
|
| 3 |
-
#
|
| 4 |
-
# -
|
| 5 |
-
# -
|
| 6 |
-
# -
|
| 7 |
-
# - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
|
| 8 |
-
# - Always shows data results at the bottom pane.
|
| 9 |
-
#
|
| 10 |
-
# Hugging Face Spaces: set OPENAI_API_KEY in secrets to enable randomization.
|
| 11 |
|
| 12 |
import os
|
| 13 |
import re
|
|
@@ -18,7 +14,7 @@ import sqlite3
|
|
| 18 |
import threading
|
| 19 |
from dataclasses import dataclass
|
| 20 |
from datetime import datetime, timezone
|
| 21 |
-
from typing import List, Dict, Any, Tuple, Optional
|
| 22 |
|
| 23 |
import gradio as gr
|
| 24 |
import pandas as pd
|
|
@@ -43,7 +39,169 @@ def _candidate_models():
|
|
| 43 |
seen = set()
|
| 44 |
return [m for m in base if m and (m not in seen and not seen.add(m))]
|
| 45 |
|
| 46 |
-
# --------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
DB_DIR = "/data" if os.path.exists("/data") else "."
|
| 48 |
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
|
| 49 |
EXPORT_DIR = "."
|
|
@@ -51,15 +209,9 @@ RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
|
|
| 51 |
random.seed(RANDOM_SEED)
|
| 52 |
SYS_RAND = random.SystemRandom()
|
| 53 |
|
| 54 |
-
# -------------------- SQLite connection + locking --------------------
|
| 55 |
DB_LOCK = threading.RLock()
|
| 56 |
|
| 57 |
def connect_db():
|
| 58 |
-
"""
|
| 59 |
-
Single shared connection usable across threads.
|
| 60 |
-
All operations (reads + writes) serialized via DB_LOCK.
|
| 61 |
-
WAL mode enables concurrent reads.
|
| 62 |
-
"""
|
| 63 |
con = sqlite3.connect(DB_PATH, check_same_thread=False)
|
| 64 |
con.execute("PRAGMA journal_mode=WAL;")
|
| 65 |
con.execute("PRAGMA synchronous=NORMAL;")
|
|
@@ -104,7 +256,7 @@ def init_progress_tables(con: sqlite3.Connection):
|
|
| 104 |
|
| 105 |
init_progress_tables(CONN)
|
| 106 |
|
| 107 |
-
# -------------------- Fallback dataset
|
| 108 |
FALLBACK_SCHEMA = {
|
| 109 |
"domain": "bookstore",
|
| 110 |
"tables": [
|
|
@@ -217,10 +369,8 @@ FALLBACK_QUESTIONS = [
|
|
| 217 |
"requires_aliases":False,"required_aliases":[]},
|
| 218 |
]
|
| 219 |
|
| 220 |
-
#
|
| 221 |
-
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 222 |
-
"required": ["domain", "tables", "questions"]
|
| 223 |
-
}
|
| 224 |
|
| 225 |
def _domain_prompt(prev_domain: Optional[str]) -> str:
|
| 226 |
extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
|
|
@@ -232,28 +382,22 @@ Rules:
|
|
| 232 |
- One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
|
| 233 |
- Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
|
| 234 |
columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
|
| 235 |
-
- Questions:
|
| 236 |
"JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
|
| 237 |
Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
|
| 238 |
Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
|
| 239 |
For 1–2 questions, set requires_aliases=true and list required_aliases.
|
| 240 |
|
| 241 |
-
Example top-level keys
|
| 242 |
-
{{
|
| 243 |
-
"domain": "retail sales",
|
| 244 |
-
"tables": [...],
|
| 245 |
-
"questions": [...]
|
| 246 |
-
}}
|
| 247 |
"""
|
| 248 |
|
| 249 |
def _loose_json_parse(s: str) -> Optional[dict]:
|
| 250 |
-
"""Extract the first JSON object from a possibly-wrapped string."""
|
| 251 |
try:
|
| 252 |
return json.loads(s)
|
| 253 |
except Exception:
|
| 254 |
pass
|
| 255 |
-
start = s.find("{")
|
| 256 |
-
end = s.rfind("}")
|
| 257 |
if start != -1 and end != -1 and end > start:
|
| 258 |
try:
|
| 259 |
return json.loads(s[start:end+1])
|
|
@@ -261,94 +405,64 @@ def _loose_json_parse(s: str) -> Optional[dict]:
|
|
| 261 |
return None
|
| 262 |
return None
|
| 263 |
|
| 264 |
-
#
|
| 265 |
_SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 266 |
_CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
|
| 267 |
|
| 268 |
def _strip_code_fences(txt: str) -> str:
|
| 269 |
-
if txt is None:
|
| 270 |
-
return ""
|
| 271 |
m = _SQL_FENCE.findall(txt)
|
| 272 |
-
if m:
|
| 273 |
-
return "\n".join([x.strip() for x in m if x.strip()])
|
| 274 |
m2 = _CODE_FENCE.findall(txt)
|
| 275 |
-
if m2:
|
| 276 |
-
return "\n".join([x.strip() for x in m2 if x.strip()])
|
| 277 |
return txt.strip()
|
| 278 |
|
| 279 |
def _as_list_of_sql(val) -> List[str]:
|
| 280 |
-
if val is None:
|
| 281 |
-
return []
|
| 282 |
if isinstance(val, str):
|
| 283 |
s = _strip_code_fences(val)
|
| 284 |
parts = [p.strip() for p in s.split("\n") if p.strip()]
|
| 285 |
-
# if it’s a single long line, keep as is
|
| 286 |
return parts or ([s] if s else [])
|
| 287 |
if isinstance(val, list):
|
| 288 |
out = []
|
| 289 |
for v in val:
|
| 290 |
if isinstance(v, str):
|
| 291 |
s = _strip_code_fences(v)
|
| 292 |
-
if s:
|
| 293 |
-
out.append(s)
|
| 294 |
return out
|
| 295 |
return []
|
| 296 |
|
| 297 |
def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 298 |
-
|
| 299 |
-
if not isinstance(q, dict):
|
| 300 |
-
return None
|
| 301 |
-
# field mapping/synonyms
|
| 302 |
cat = q.get("category") or q.get("type") or q.get("topic")
|
| 303 |
prompt = q.get("prompt_md") or q.get("prompt") or q.get("question") or q.get("text")
|
| 304 |
answer_sql = q.get("answer_sql") or q.get("answers") or q.get("solutions") or q.get("sql")
|
| 305 |
diff = q.get("difficulty") or 1
|
| 306 |
req_alias = bool(q.get("requires_aliases", False))
|
| 307 |
req_aliases = q.get("required_aliases") or []
|
| 308 |
-
|
| 309 |
cat = str(cat).strip() if cat is not None else ""
|
| 310 |
prompt = str(prompt).strip() if prompt is not None else ""
|
| 311 |
answers = _as_list_of_sql(answer_sql)
|
| 312 |
-
|
| 313 |
-
if not cat or not prompt or not answers:
|
| 314 |
-
return None
|
| 315 |
-
|
| 316 |
-
# keep only known categories if provided; otherwise accept free text
|
| 317 |
known = {
|
| 318 |
"SELECT *","SELECT columns","WHERE","Aliases",
|
| 319 |
"JOIN (INNER)","JOIN (LEFT)","Aggregation","VIEW","CTAS / SELECT INTO"
|
| 320 |
}
|
| 321 |
if cat not in known:
|
| 322 |
-
# Try to map rough names to our set
|
| 323 |
low = cat.lower()
|
| 324 |
-
if "select *" in low:
|
| 325 |
-
|
| 326 |
-
elif "
|
| 327 |
-
|
| 328 |
-
elif "
|
| 329 |
-
|
| 330 |
-
elif "
|
| 331 |
-
|
| 332 |
-
elif "
|
| 333 |
-
cat = "JOIN (LEFT)"
|
| 334 |
-
elif "inner" in low or "join" in low:
|
| 335 |
-
cat = "JOIN (INNER)"
|
| 336 |
-
elif "agg" in low or "group" in low:
|
| 337 |
-
cat = "Aggregation"
|
| 338 |
-
elif "view" in low:
|
| 339 |
-
cat = "VIEW"
|
| 340 |
-
elif "into" in low or "ctas" in low or "create table" in low:
|
| 341 |
-
cat = "CTAS / SELECT INTO"
|
| 342 |
-
else:
|
| 343 |
-
# leave as-is; still usable for practice buckets
|
| 344 |
-
pass
|
| 345 |
-
|
| 346 |
-
# normalize aliases list
|
| 347 |
if isinstance(req_aliases, str):
|
| 348 |
req_aliases = [a.strip() for a in re.split(r"[,\s]+", req_aliases) if a.strip()]
|
| 349 |
elif not isinstance(req_aliases, list):
|
| 350 |
req_aliases = []
|
| 351 |
-
|
| 352 |
return {
|
| 353 |
"id": str(q.get("id") or f"LLM_{int(time.time()*1000)}_{random.randint(100,999)}"),
|
| 354 |
"category": cat,
|
|
@@ -362,22 +476,17 @@ def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
| 362 |
def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 363 |
out = []
|
| 364 |
for t in (tables or []):
|
| 365 |
-
if not isinstance(t, dict):
|
| 366 |
-
continue
|
| 367 |
name = str(t.get("name","")).strip()
|
| 368 |
-
if not name:
|
| 369 |
-
continue
|
| 370 |
cols = t.get("columns") or []
|
| 371 |
good_cols = []
|
| 372 |
for c in cols:
|
| 373 |
-
if not isinstance(c, dict):
|
| 374 |
-
continue
|
| 375 |
cname = str(c.get("name","")).strip()
|
| 376 |
ctype = str(c.get("type","TEXT")).strip() or "TEXT"
|
| 377 |
-
if cname:
|
| 378 |
-
|
| 379 |
-
if not good_cols:
|
| 380 |
-
continue
|
| 381 |
pk = t.get("pk") or []
|
| 382 |
if isinstance(pk, str): pk = [pk]
|
| 383 |
fks = t.get("fks") or []
|
|
@@ -391,31 +500,22 @@ def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
| 391 |
})
|
| 392 |
return out
|
| 393 |
|
| 394 |
-
# -------------------- LLM call --------------------
|
| 395 |
def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
|
| 396 |
-
"""
|
| 397 |
-
Returns (obj, error_message, model_used, stats_dict).
|
| 398 |
-
stats_dict contains {"accepted_questions": n, "dropped_questions": m}
|
| 399 |
-
"""
|
| 400 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 401 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 402 |
-
|
| 403 |
errors = []
|
| 404 |
prompt = _domain_prompt(prev_domain)
|
| 405 |
-
|
| 406 |
for model in _candidate_models():
|
| 407 |
try:
|
| 408 |
-
# Try JSON mode first
|
| 409 |
try:
|
| 410 |
chat = _client.chat.completions.create(
|
| 411 |
model=model,
|
| 412 |
messages=[{"role":"user","content": prompt}],
|
| 413 |
temperature=0.6,
|
| 414 |
-
response_format={"type":"json_object"}
|
| 415 |
)
|
| 416 |
data_text = chat.choices[0].message.content
|
| 417 |
except TypeError:
|
| 418 |
-
# Older SDKs: no response_format ⇒ plain chat + strict instructions
|
| 419 |
chat = _client.chat.completions.create(
|
| 420 |
model=model,
|
| 421 |
messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
|
|
@@ -423,52 +523,35 @@ def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optio
|
|
| 423 |
temperature=0.6
|
| 424 |
)
|
| 425 |
data_text = chat.choices[0].message.content
|
| 426 |
-
|
| 427 |
obj_raw = _loose_json_parse(data_text or "")
|
| 428 |
if not obj_raw:
|
| 429 |
raise RuntimeError("Could not parse JSON from model output.")
|
| 430 |
-
|
| 431 |
-
# Minimal top-level validation
|
| 432 |
for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
|
| 433 |
if k not in obj_raw:
|
| 434 |
raise RuntimeError(f"Missing key '{k}'")
|
| 435 |
-
|
| 436 |
-
# Canonicalize tables
|
| 437 |
tables = _canon_tables(obj_raw.get("tables", []))
|
| 438 |
-
if not tables:
|
| 439 |
-
raise RuntimeError("No usable tables in LLM output.")
|
| 440 |
obj_raw["tables"] = tables
|
| 441 |
-
|
| 442 |
-
# Canonicalize questions
|
| 443 |
dropped = 0
|
| 444 |
clean_qs = []
|
| 445 |
for q in obj_raw.get("questions", []):
|
| 446 |
cq = _canon_question(q)
|
| 447 |
-
if not cq:
|
| 448 |
-
dropped += 1
|
| 449 |
-
continue
|
| 450 |
-
# Strip RIGHT/FULL joins from answers
|
| 451 |
answers = [a for a in cq["answer_sql"] if " right join " not in a.lower() and " full " not in a.lower()]
|
| 452 |
-
if not answers:
|
| 453 |
-
dropped += 1
|
| 454 |
-
continue
|
| 455 |
cq["answer_sql"] = answers
|
| 456 |
clean_qs.append(cq)
|
| 457 |
-
|
| 458 |
if not clean_qs:
|
| 459 |
raise RuntimeError("No usable questions after canonicalization.")
|
| 460 |
stats = {"accepted_questions": len(clean_qs), "dropped_questions": dropped}
|
| 461 |
-
|
| 462 |
obj_raw["questions"] = clean_qs
|
| 463 |
return obj_raw, None, model, stats
|
| 464 |
-
|
| 465 |
except Exception as e:
|
| 466 |
errors.append(f"{model}: {e}")
|
| 467 |
continue
|
| 468 |
-
|
| 469 |
return None, "; ".join(errors) if errors else "Unknown LLM error.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 470 |
|
| 471 |
-
# --------------------
|
| 472 |
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
|
| 473 |
with DB_LOCK:
|
| 474 |
cur = con.cursor()
|
|
@@ -487,131 +570,30 @@ def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
|
|
| 487 |
drop_existing_domain_tables(con, keep_internal=True)
|
| 488 |
with DB_LOCK:
|
| 489 |
cur = con.cursor()
|
| 490 |
-
# Create tables
|
| 491 |
for t in schema.get("tables", []):
|
| 492 |
cols_sql = []
|
| 493 |
pk = t.get("pk", [])
|
| 494 |
for c in t.get("columns", []):
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
cols_sql.append(f"{cname} {ctype}")
|
| 498 |
-
if pk:
|
| 499 |
-
cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
|
| 500 |
create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
|
| 501 |
cur.execute(create_sql)
|
| 502 |
-
# Insert rows
|
| 503 |
for t in schema.get("tables", []):
|
| 504 |
-
if not t.get("rows"):
|
| 505 |
-
continue
|
| 506 |
cols = [c["name"] for c in t.get("columns", [])]
|
| 507 |
qmarks = ",".join(["?"]*len(cols))
|
| 508 |
insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
|
| 509 |
for r in t["rows"]:
|
| 510 |
-
if isinstance(r, dict):
|
| 511 |
-
vals = [r.get(col, None) for col in cols]
|
| 512 |
elif isinstance(r, (list, tuple)):
|
| 513 |
-
vals = list(r) + [None]*(len(cols)-len(r))
|
| 514 |
-
|
| 515 |
-
else:
|
| 516 |
-
continue
|
| 517 |
cur.execute(insert_sql, vals)
|
| 518 |
con.commit()
|
| 519 |
-
# Persist schema JSON
|
| 520 |
cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
|
| 521 |
(schema.get("domain","unknown"), json.dumps(schema)))
|
| 522 |
con.commit()
|
| 523 |
|
| 524 |
-
def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
|
| 525 |
-
with DB_LOCK:
|
| 526 |
-
return pd.read_sql_query(sql, con)
|
| 527 |
-
|
| 528 |
-
def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
|
| 529 |
-
s = sql.strip().strip(";")
|
| 530 |
-
if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
|
| 531 |
-
m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
|
| 532 |
-
if m:
|
| 533 |
-
cols, tbl, rest = m.groups()
|
| 534 |
-
return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
|
| 535 |
-
return sql, None
|
| 536 |
-
|
| 537 |
-
def detect_unsupported_joins(sql: str) -> Optional[str]:
|
| 538 |
-
low = sql.lower()
|
| 539 |
-
if " right join " in low:
|
| 540 |
-
return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
|
| 541 |
-
if " full join " in low or " full outer join " in low:
|
| 542 |
-
return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
|
| 543 |
-
if " ilike " in low:
|
| 544 |
-
return "SQLite has no ILIKE. Use LOWER(col) LIKE LOWER('%pattern%')."
|
| 545 |
-
return None
|
| 546 |
-
|
| 547 |
-
def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
|
| 548 |
-
low = sql.lower()
|
| 549 |
-
if " cross join " in low:
|
| 550 |
-
return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
|
| 551 |
-
comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
|
| 552 |
-
missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
|
| 553 |
-
if comma_from or missing_on:
|
| 554 |
-
try:
|
| 555 |
-
with DB_LOCK:
|
| 556 |
-
cur = con.cursor()
|
| 557 |
-
if comma_from:
|
| 558 |
-
t1, t2 = comma_from.groups()
|
| 559 |
-
else:
|
| 560 |
-
m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low)
|
| 561 |
-
j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
|
| 562 |
-
if not m or not j:
|
| 563 |
-
return "Possible cartesian product: no join condition detected."
|
| 564 |
-
t1, t2 = m.group(1), j.group(1)
|
| 565 |
-
cur.execute(f"SELECT COUNT(*) FROM {t1}")
|
| 566 |
-
n1 = cur.fetchone()[0]
|
| 567 |
-
cur.execute(f"SELECT COUNT(*) FROM {t2}")
|
| 568 |
-
n2 = cur.fetchone()[0]
|
| 569 |
-
prod = n1 * n2
|
| 570 |
-
if len(df_result) == prod and prod > 0:
|
| 571 |
-
return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
|
| 572 |
-
except Exception:
|
| 573 |
-
return "Possible cartesian product: no join condition detected."
|
| 574 |
-
return None
|
| 575 |
-
|
| 576 |
-
def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool:
|
| 577 |
-
if df_a.shape != df_b.shape:
|
| 578 |
-
return False
|
| 579 |
-
a = df_a.copy()
|
| 580 |
-
b = df_b.copy()
|
| 581 |
-
a.columns = [c.lower() for c in a.columns]
|
| 582 |
-
b.columns = [c.lower() for c in b.columns]
|
| 583 |
-
a = a.sort_values(list(a.columns)).reset_index(drop=True)
|
| 584 |
-
b = b.sort_values(list(b.columns)).reset_index(drop=True)
|
| 585 |
-
return a.equals(b)
|
| 586 |
-
|
| 587 |
-
def aliases_present(sql: str, required_aliases: List[str]) -> bool:
|
| 588 |
-
low = re.sub(r"\s+", " ", sql.lower())
|
| 589 |
-
for al in required_aliases:
|
| 590 |
-
if f" {al}." not in low and f" as {al} " not in low:
|
| 591 |
-
return False
|
| 592 |
-
return True
|
| 593 |
-
|
| 594 |
-
# -------------------- Question model helpers --------------------
|
| 595 |
-
@dataclass
|
| 596 |
-
class SQLQuestion:
|
| 597 |
-
id: str
|
| 598 |
-
category: str
|
| 599 |
-
difficulty: int
|
| 600 |
-
prompt_md: str
|
| 601 |
-
answer_sql: List[str]
|
| 602 |
-
requires_aliases: bool = False
|
| 603 |
-
required_aliases: List[str] = None
|
| 604 |
-
|
| 605 |
-
def to_question_dict(q) -> Dict[str,Any]:
|
| 606 |
-
d = dict(q)
|
| 607 |
-
d.setdefault("requires_aliases", False)
|
| 608 |
-
d.setdefault("required_aliases", [])
|
| 609 |
-
return d
|
| 610 |
-
|
| 611 |
-
def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
|
| 612 |
-
return [to_question_dict(o) for o in obj_list]
|
| 613 |
-
|
| 614 |
-
# -------------------- Domain bootstrap --------------------
|
| 615 |
def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
|
| 616 |
obj, err, model_used, stats = llm_generate_domain_and_questions(prev_domain)
|
| 617 |
if obj is None:
|
|
@@ -621,13 +603,13 @@ def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
|
|
| 621 |
def install_schema_and_prepare_questions(prev_domain: Optional[str]):
|
| 622 |
schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
|
| 623 |
install_schema(CONN, schema)
|
| 624 |
-
# Safety: if questions empty, fall back
|
| 625 |
if not questions:
|
| 626 |
questions = FALLBACK_QUESTIONS
|
| 627 |
-
info = {"source":"openai+fallback-questions","model":info.get("model"),
|
|
|
|
| 628 |
return schema, questions, info
|
| 629 |
|
| 630 |
-
# -------------------- Session
|
| 631 |
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=None)
|
| 632 |
|
| 633 |
# -------------------- Progress + mastery --------------------
|
|
@@ -662,7 +644,6 @@ def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
|
|
| 662 |
return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
|
| 663 |
|
| 664 |
def pick_next_question(user_id: str) -> Dict[str,Any]:
|
| 665 |
-
# Defensive: ensure we always have a pool
|
| 666 |
pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
|
| 667 |
df = fetch_attempts(CONN, user_id)
|
| 668 |
stats = topic_stats(df)
|
|
@@ -671,21 +652,99 @@ def pick_next_question(user_id: str) -> Dict[str,Any]:
|
|
| 671 |
cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
|
| 672 |
return dict(random.choice(cands))
|
| 673 |
|
| 674 |
-
# --------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
|
| 676 |
if not sql_text or not sql_text.strip():
|
| 677 |
return None, "Enter a SQL statement.", None, None
|
| 678 |
-
|
| 679 |
sql_raw = sql_text.strip().rstrip(";")
|
| 680 |
sql_rew, created_tbl = rewrite_select_into(sql_raw)
|
| 681 |
-
note = None
|
| 682 |
-
if sql_rew != sql_raw:
|
| 683 |
-
note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite."
|
| 684 |
-
|
| 685 |
unsup = detect_unsupported_joins(sql_rew)
|
| 686 |
-
if unsup:
|
| 687 |
-
return None, unsup, None, note
|
| 688 |
-
|
| 689 |
try:
|
| 690 |
low = sql_rew.lower()
|
| 691 |
if low.startswith("select"):
|
|
@@ -695,72 +754,56 @@ def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[st
|
|
| 695 |
else:
|
| 696 |
with DB_LOCK:
|
| 697 |
cur = CONN.cursor()
|
| 698 |
-
cur.execute(sql_rew)
|
| 699 |
-
CONN.commit()
|
| 700 |
-
# Preview newly created objects
|
| 701 |
if low.startswith("create view"):
|
| 702 |
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
|
| 703 |
name = m.group(2) if m else None
|
| 704 |
if name:
|
| 705 |
-
try:
|
| 706 |
-
|
| 707 |
-
return df, None, None, note
|
| 708 |
-
except Exception:
|
| 709 |
-
return None, "View created but could not be queried.", None, note
|
| 710 |
if low.startswith("create table"):
|
| 711 |
tbl = created_tbl
|
| 712 |
if not tbl:
|
| 713 |
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 714 |
tbl = m.group(2) if m else None
|
| 715 |
if tbl:
|
| 716 |
-
try:
|
| 717 |
-
|
| 718 |
-
return df, None, None, note
|
| 719 |
-
except Exception:
|
| 720 |
-
return None, "Table created but could not be queried.", None, note
|
| 721 |
return pd.DataFrame(), None, None, note
|
| 722 |
except Exception as e:
|
| 723 |
msg = str(e)
|
| 724 |
-
if "no such table" in msg.lower():
|
| 725 |
-
|
| 726 |
-
if "
|
| 727 |
-
return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
|
| 728 |
-
if "ambiguous column name" in msg.lower():
|
| 729 |
-
return None, f"{msg}. Qualify the column with a table alias.", None, note
|
| 730 |
if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
|
| 731 |
return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
|
| 732 |
if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
|
| 733 |
return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
|
| 734 |
if "syntax error" in msg.lower():
|
| 735 |
-
return None, f"Syntax error. Check commas, keywords,
|
| 736 |
return None, f"SQL error: {msg}", None, note
|
| 737 |
|
| 738 |
def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
|
| 739 |
for sql in answer_sql:
|
| 740 |
try:
|
| 741 |
low = sql.strip().lower()
|
| 742 |
-
if low.startswith("select"):
|
| 743 |
-
return run_df(CONN, sql)
|
| 744 |
if low.startswith("create view"):
|
| 745 |
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 746 |
view_name = m.group(2) if m else "vw_tmp"
|
| 747 |
with DB_LOCK:
|
| 748 |
cur = CONN.cursor()
|
| 749 |
cur.execute(f"DROP VIEW IF EXISTS {view_name}")
|
| 750 |
-
cur.execute(sql)
|
| 751 |
-
CONN.commit()
|
| 752 |
return run_df(CONN, f"SELECT * FROM {view_name}")
|
| 753 |
if low.startswith("create table"):
|
| 754 |
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 755 |
tbl = m.group(2) if m else None
|
| 756 |
with DB_LOCK:
|
| 757 |
cur = CONN.cursor()
|
| 758 |
-
if tbl:
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
CONN.commit()
|
| 762 |
-
if tbl:
|
| 763 |
-
return run_df(CONN, f"SELECT * FROM {tbl}")
|
| 764 |
except Exception:
|
| 765 |
continue
|
| 766 |
return None
|
|
@@ -771,7 +814,18 @@ def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.
|
|
| 771 |
return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
|
| 772 |
if df_student is None:
|
| 773 |
return False, f"**Explanation:** Expected data result differs."
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
|
| 777 |
time_taken: float, difficulty: int, source: str, notes: str):
|
|
@@ -792,6 +846,7 @@ def start_session(name: str, session: dict):
|
|
| 792 |
gr.update(value="Please enter your name to begin.", visible=True),
|
| 793 |
gr.update(visible=False),
|
| 794 |
gr.update(visible=False),
|
|
|
|
| 795 |
gr.update(visible=False),
|
| 796 |
pd.DataFrame(),
|
| 797 |
pd.DataFrame())
|
|
@@ -804,21 +859,25 @@ def start_session(name: str, session: dict):
|
|
| 804 |
|
| 805 |
prompt = q["prompt_md"]
|
| 806 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
|
|
|
| 807 |
return (session,
|
| 808 |
gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
|
| 809 |
-
gr.update(visible=True),
|
| 810 |
-
gr.update(value="", visible=True),
|
| 811 |
-
|
|
|
|
| 812 |
stats,
|
| 813 |
pd.DataFrame())
|
| 814 |
|
| 815 |
def render_preview(sql_text: str, session: dict):
|
| 816 |
if not session or "q" not in session:
|
| 817 |
-
return gr.update(value="", visible=False)
|
| 818 |
s = (sql_text or "").strip()
|
| 819 |
if not s:
|
| 820 |
-
return gr.update(value="", visible=False)
|
| 821 |
-
|
|
|
|
|
|
|
| 822 |
|
| 823 |
def submit_answer(sql_text: str, session: dict):
|
| 824 |
if not session or "user_id" not in session or "q" not in session:
|
|
@@ -826,7 +885,6 @@ def submit_answer(sql_text: str, session: dict):
|
|
| 826 |
user_id = session["user_id"]
|
| 827 |
q = session["q"]
|
| 828 |
elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
|
| 829 |
-
|
| 830 |
df, err, warn, note = exec_student_sql(sql_text)
|
| 831 |
details = []
|
| 832 |
if note: details.append(f"ℹ️ {note}")
|
|
@@ -836,35 +894,27 @@ def submit_answer(sql_text: str, session: dict):
|
|
| 836 |
log_attempt(user_id, q.get("id","?"), q.get("category","?"), False, sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join([err] + details))
|
| 837 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 838 |
return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
|
| 839 |
-
|
| 840 |
alias_msg = None
|
| 841 |
-
if q.get("requires_aliases"):
|
| 842 |
-
|
| 843 |
-
alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
|
| 844 |
-
|
| 845 |
is_correct, explanation = validate_answer(q, sql_text, df)
|
| 846 |
if warn: details.append(f"⚠️ {warn}")
|
| 847 |
if alias_msg: details.append(alias_msg)
|
| 848 |
-
|
| 849 |
prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
|
| 850 |
feedback = prefix
|
| 851 |
-
if details:
|
| 852 |
-
feedback += "\n\n" + "\n".join(details)
|
| 853 |
feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
|
| 854 |
-
|
| 855 |
log_attempt(user_id, q["id"], q.get("category","?"), bool(is_correct), sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join(details))
|
| 856 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 857 |
return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
|
| 858 |
|
| 859 |
def next_question(session: dict):
|
| 860 |
if not session or "user_id" not in session:
|
| 861 |
-
return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), gr.update(visible=False)
|
| 862 |
user_id = session["user_id"]
|
| 863 |
q = pick_next_question(user_id)
|
| 864 |
-
session["qid"] = q["id"]
|
| 865 |
-
session
|
| 866 |
-
session["start_ts"] = time.time()
|
| 867 |
-
return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), gr.update(visible=False)
|
| 868 |
|
| 869 |
def show_hint(session: dict):
|
| 870 |
if not session or "q" not in session:
|
|
@@ -885,8 +935,7 @@ def show_hint(session: dict):
|
|
| 885 |
|
| 886 |
def export_progress(user_name: str):
|
| 887 |
slug = "-".join((user_name or "").lower().split())
|
| 888 |
-
if not slug:
|
| 889 |
-
return None
|
| 890 |
user_id = slug[:64]
|
| 891 |
with DB_LOCK:
|
| 892 |
df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
|
|
@@ -897,25 +946,20 @@ def export_progress(user_name: str):
|
|
| 897 |
|
| 898 |
def _domain_status_md():
|
| 899 |
if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
|
| 900 |
-
note = ""
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
|
| 907 |
-
f"Accepted questions: {accepted}, dropped: {dropped}. \n"
|
| 908 |
-
f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}."
|
| 909 |
-
)
|
| 910 |
-
err = CURRENT_INFO.get("error","")
|
| 911 |
-
err_short = (err[:160] + "…") if len(err) > 160 else err
|
| 912 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
| 913 |
|
| 914 |
def regenerate_domain():
|
| 915 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 916 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
| 917 |
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=prev)
|
| 918 |
-
|
|
|
|
| 919 |
|
| 920 |
def preview_table(tbl: str):
|
| 921 |
try:
|
|
@@ -925,8 +969,7 @@ def preview_table(tbl: str):
|
|
| 925 |
|
| 926 |
def list_tables_for_preview():
|
| 927 |
df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
|
| 928 |
-
if df.empty:
|
| 929 |
-
return ["(no tables)"]
|
| 930 |
return df["name"].tolist()
|
| 931 |
|
| 932 |
# -------------------- UI --------------------
|
|
@@ -937,14 +980,11 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 937 |
- Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
|
| 938 |
sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
|
| 939 |
- Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
|
| 940 |
-
-
|
| 941 |
-
|
| 942 |
-
> Set your `OPENAI_API_KEY` in Space secrets to enable randomization.
|
| 943 |
"""
|
| 944 |
)
|
| 945 |
|
| 946 |
with gr.Row():
|
| 947 |
-
# -------- Left column: controls + quick preview ----------
|
| 948 |
with gr.Column(scale=1):
|
| 949 |
name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
|
| 950 |
start_btn = gr.Button("Start / Resume Session", variant="primary")
|
|
@@ -967,12 +1007,11 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 967 |
tbl_btn = gr.Button("Preview")
|
| 968 |
preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
|
| 969 |
|
| 970 |
-
# -------- Right column: task + feedback + mastery + results ----------
|
| 971 |
with gr.Column(scale=2):
|
| 972 |
prompt_md = gr.Markdown(visible=False)
|
| 973 |
sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
|
| 974 |
-
|
| 975 |
preview_md = gr.Markdown(visible=False)
|
|
|
|
| 976 |
|
| 977 |
with gr.Row():
|
| 978 |
submit_btn = gr.Button("Run & Submit", variant="primary")
|
|
@@ -983,12 +1022,8 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 983 |
|
| 984 |
gr.Markdown("---")
|
| 985 |
gr.Markdown("### Your Progress by Category")
|
| 986 |
-
mastery_df = gr.Dataframe(
|
| 987 |
-
|
| 988 |
-
col_count=(4, "dynamic"),
|
| 989 |
-
row_count=(0, "dynamic"),
|
| 990 |
-
interactive=False
|
| 991 |
-
)
|
| 992 |
|
| 993 |
gr.Markdown("---")
|
| 994 |
gr.Markdown("### Result Preview")
|
|
@@ -998,12 +1033,12 @@ with gr.Blocks(title="Adaptive SQL Trainer �� Randomized Domains") as demo:
|
|
| 998 |
start_btn.click(
|
| 999 |
start_session,
|
| 1000 |
inputs=[name_box, session_state],
|
| 1001 |
-
outputs=[session_state, prompt_md, sql_input, preview_md, next_btn, mastery_df, result_df],
|
| 1002 |
)
|
| 1003 |
sql_input.change(
|
| 1004 |
render_preview,
|
| 1005 |
inputs=[sql_input, session_state],
|
| 1006 |
-
outputs=[preview_md],
|
| 1007 |
)
|
| 1008 |
submit_btn.click(
|
| 1009 |
submit_answer,
|
|
@@ -1013,7 +1048,7 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 1013 |
next_btn.click(
|
| 1014 |
next_question,
|
| 1015 |
inputs=[session_state],
|
| 1016 |
-
outputs=[session_state, prompt_md, sql_input, next_btn],
|
| 1017 |
)
|
| 1018 |
hint_btn.click(
|
| 1019 |
show_hint,
|
|
@@ -1028,15 +1063,14 @@ with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
|
| 1028 |
regen_btn.click(
|
| 1029 |
regenerate_domain,
|
| 1030 |
inputs=[],
|
| 1031 |
-
outputs=[regen_fb],
|
| 1032 |
)
|
| 1033 |
tbl_btn.click(
|
| 1034 |
lambda name: preview_table(name),
|
| 1035 |
inputs=[tbl_dd],
|
| 1036 |
outputs=[preview_df]
|
| 1037 |
)
|
| 1038 |
-
#
|
| 1039 |
-
regen_btn.click(
|
| 1040 |
lambda: gr.update(choices=list_tables_for_preview()),
|
| 1041 |
inputs=[],
|
| 1042 |
outputs=[tbl_dd]
|
|
|
|
| 1 |
# Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
|
| 2 |
+
# - OpenAI randomizes a domain and questions (fallback dataset if unavailable).
|
| 3 |
+
# - 3–4 related tables with seed rows installed in SQLite.
|
| 4 |
+
# - Students practice SELECT, WHERE, JOINs (INNER/LEFT), aliases, views, CTAS/SELECT INTO.
|
| 5 |
+
# - Validator now enforces columns only when the prompt asks for them; otherwise it focuses on rows.
|
| 6 |
+
# - ERD shows all FK edges in light gray and dynamically HIGHLIGHTS edges implied by the student’s JOINs.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
import re
|
|
|
|
| 14 |
import threading
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from datetime import datetime, timezone
|
| 17 |
+
from typing import List, Dict, Any, Tuple, Optional, Set
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import pandas as pd
|
|
|
|
| 39 |
seen = set()
|
| 40 |
return [m for m in base if m and (m not in seen and not seen.add(m))]
|
| 41 |
|
| 42 |
+
# -------------------- ERD drawing (headless) --------------------
|
| 43 |
+
import matplotlib
|
| 44 |
+
matplotlib.use("Agg")
|
| 45 |
+
import matplotlib.pyplot as plt
|
| 46 |
+
from matplotlib.patches import Rectangle
|
| 47 |
+
from io import BytesIO
|
| 48 |
+
from PIL import Image
|
| 49 |
+
|
| 50 |
+
PLOT_FIGSIZE = (7.6, 3.8)
|
| 51 |
+
PLOT_DPI = 120
|
| 52 |
+
PLOT_HEIGHT = 300
|
| 53 |
+
|
| 54 |
+
def _fig_to_pil(fig) -> Image.Image:
|
| 55 |
+
buf = BytesIO()
|
| 56 |
+
fig.tight_layout()
|
| 57 |
+
fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
|
| 58 |
+
plt.close(fig)
|
| 59 |
+
buf.seek(0)
|
| 60 |
+
return Image.open(buf)
|
| 61 |
+
|
| 62 |
+
def draw_dynamic_erd(
|
| 63 |
+
schema: Dict[str, Any],
|
| 64 |
+
highlight_tables: Optional[Set[str]] = None,
|
| 65 |
+
highlight_edges: Optional[Set[Tuple[str, str]]] = None,
|
| 66 |
+
) -> Image.Image:
|
| 67 |
+
"""
|
| 68 |
+
Draw tables + FK edges. If highlight_* provided, overlay those tables/edges in bold.
|
| 69 |
+
highlight_edges uses (src_table, dst_table) with dst_table = referenced table.
|
| 70 |
+
"""
|
| 71 |
+
highlight_tables = set(highlight_tables or [])
|
| 72 |
+
# normalize edges so (A,B) & (B,A) match regardless of direction
|
| 73 |
+
def _norm_edge(a, b): return tuple(sorted([a, b]))
|
| 74 |
+
H = set(_norm_edge(*e) for e in (highlight_edges or set()))
|
| 75 |
+
|
| 76 |
+
tables = schema.get("tables", [])
|
| 77 |
+
if not tables:
|
| 78 |
+
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
|
| 79 |
+
ax.text(0.5, 0.5, "No tables to diagram.", ha="center", va="center")
|
| 80 |
+
return _fig_to_pil(fig)
|
| 81 |
+
|
| 82 |
+
# Layout tables horizontally
|
| 83 |
+
n = len(tables)
|
| 84 |
+
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE); ax.axis("off")
|
| 85 |
+
margin = 0.03
|
| 86 |
+
width = (1 - margin * (n + 1)) / max(n, 1)
|
| 87 |
+
height = 0.70
|
| 88 |
+
y = 0.20
|
| 89 |
+
|
| 90 |
+
# Build quick FK lookup: [(src_table, dst_table)]
|
| 91 |
+
fk_edges = []
|
| 92 |
+
for t in tables:
|
| 93 |
+
for fk in t.get("fks", []) or []:
|
| 94 |
+
dst = fk.get("ref_table")
|
| 95 |
+
if dst:
|
| 96 |
+
fk_edges.append((t["name"], dst))
|
| 97 |
+
|
| 98 |
+
# Draw table boxes + columns
|
| 99 |
+
boxes: Dict[str, Tuple[float,float,float,float]] = {}
|
| 100 |
+
for i, t in enumerate(tables):
|
| 101 |
+
tx = margin + i * (width + margin)
|
| 102 |
+
boxes[t["name"]] = (tx, y, width, height)
|
| 103 |
+
|
| 104 |
+
# Highlight table border if used in current SQL
|
| 105 |
+
lw = 2.0 if t["name"] in highlight_tables else 1.2
|
| 106 |
+
ax.add_patch(Rectangle((tx, y), width, height, fill=False, lw=lw))
|
| 107 |
+
ax.text(tx + 0.01, y + height - 0.04, t["name"], fontsize=10, ha="left", va="top", weight="bold")
|
| 108 |
+
|
| 109 |
+
yy = y + height - 0.09
|
| 110 |
+
pkset = set(t.get("pk", []) or [])
|
| 111 |
+
# For FK annotation by column
|
| 112 |
+
fk_map: Dict[str, List[Tuple[str, str]]] = {}
|
| 113 |
+
for fk in t.get("fks", []) or []:
|
| 114 |
+
ref_tbl = fk.get("ref_table", "")
|
| 115 |
+
for c, rc in zip(fk.get("columns", []) or [], fk.get("ref_columns", []) or []):
|
| 116 |
+
fk_map.setdefault(c, []).append((ref_tbl, rc))
|
| 117 |
+
|
| 118 |
+
for col in t.get("columns", []):
|
| 119 |
+
nm = col.get("name", "")
|
| 120 |
+
tag = ""
|
| 121 |
+
if nm in pkset:
|
| 122 |
+
tag = " (PK)"
|
| 123 |
+
if nm in fk_map:
|
| 124 |
+
ref = fk_map[nm][0]
|
| 125 |
+
tag = f" (FK→{ref[0]}.{ref[1]})" if not tag else tag.replace(")", f", FK→{ref[0]}.{ref[1]})")
|
| 126 |
+
ax.text(tx + 0.016, yy, f"{nm}{tag}", fontsize=9, ha="left", va="top")
|
| 127 |
+
yy -= 0.055
|
| 128 |
+
|
| 129 |
+
# Draw FK edges: light gray
|
| 130 |
+
for (src, dst) in fk_edges:
|
| 131 |
+
if src not in boxes or dst not in boxes:
|
| 132 |
+
continue
|
| 133 |
+
(x1, y1, w1, h1) = boxes[src]
|
| 134 |
+
(x2, y2, w2, h2) = boxes[dst]
|
| 135 |
+
ax.annotate("",
|
| 136 |
+
xy=(x2 + w2/2.0, y2 + h2),
|
| 137 |
+
xytext=(x1 + w1/2.0, y1),
|
| 138 |
+
arrowprops=dict(arrowstyle="->", lw=1.0, color="#cccccc"))
|
| 139 |
+
|
| 140 |
+
# Overlay highlighted edges: bold, darker
|
| 141 |
+
for (src, dst) in fk_edges:
|
| 142 |
+
if _norm_edge(src, dst) in H:
|
| 143 |
+
(x1, y1, w1, h1) = boxes[src]
|
| 144 |
+
(x2, y2, w2, h2) = boxes[dst]
|
| 145 |
+
ax.annotate("",
|
| 146 |
+
xy=(x2 + w2/2.0, y2 + h2),
|
| 147 |
+
xytext=(x1 + w1/2.0, y1),
|
| 148 |
+
arrowprops=dict(arrowstyle="->", lw=2.6, color="#2b6cb0"))
|
| 149 |
+
|
| 150 |
+
ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
|
| 151 |
+
return _fig_to_pil(fig)
|
| 152 |
+
|
| 153 |
+
# Parse JOINs from SQL and turn them into tables/edges to highlight on ERD
|
| 154 |
+
JOIN_TBL_RE = re.compile(r"\b(?:from|join)\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?", re.IGNORECASE)
|
| 155 |
+
EQ_ON_RE = re.compile(r"([a-z_]\w*)\.[a-z_]\w*\s*=\s*([a-z_]\w*)\.[a-z_]\w*", re.IGNORECASE)
|
| 156 |
+
USING_RE = re.compile(r"\bjoin\s+([a-z_]\w*)(?:\s+(?:as\s+)?([a-z_]\w*))?\s+using\s*\(", re.IGNORECASE)
|
| 157 |
+
|
| 158 |
+
def sql_highlights(sql: str, schema: Dict[str, Any]) -> Tuple[Set[str], Set[Tuple[str, str]]]:
|
| 159 |
+
"""
|
| 160 |
+
Returns (highlight_tables, highlight_edges) based on the student's SQL.
|
| 161 |
+
- Tables: any table appearing after FROM or JOIN (by name or alias).
|
| 162 |
+
- Edges: pairs inferred from `a.col = b.col` or JOIN ... USING (...).
|
| 163 |
+
"""
|
| 164 |
+
if not sql:
|
| 165 |
+
return set(), set()
|
| 166 |
+
|
| 167 |
+
low = " ".join(sql.strip().split())
|
| 168 |
+
# Alias map alias->table and list of tables in join order
|
| 169 |
+
alias_to_table: Dict[str, str] = {}
|
| 170 |
+
join_order: List[str] = []
|
| 171 |
+
|
| 172 |
+
for m in JOIN_TBL_RE.finditer(low):
|
| 173 |
+
table = m.group(1)
|
| 174 |
+
alias = m.group(2) or table
|
| 175 |
+
alias_to_table[alias] = table
|
| 176 |
+
join_order.append(alias)
|
| 177 |
+
|
| 178 |
+
# Edges from explicit equality ON clauses
|
| 179 |
+
edges: Set[Tuple[str, str]] = set()
|
| 180 |
+
for a1, a2 in EQ_ON_RE.findall(low):
|
| 181 |
+
t1 = alias_to_table.get(a1, a1)
|
| 182 |
+
t2 = alias_to_table.get(a2, a2)
|
| 183 |
+
if t1 != t2:
|
| 184 |
+
edges.add((t1, t2))
|
| 185 |
+
|
| 186 |
+
# Heuristic for USING(): connect the previous alias with the current JOIN alias
|
| 187 |
+
if USING_RE.search(low) and len(join_order) >= 2:
|
| 188 |
+
for i in range(1, len(join_order)):
|
| 189 |
+
t_left = alias_to_table.get(join_order[i-1], join_order[i-1])
|
| 190 |
+
t_right = alias_to_table.get(join_order[i], join_order[i])
|
| 191 |
+
if t_left != t_right:
|
| 192 |
+
edges.add((t_left, t_right))
|
| 193 |
+
|
| 194 |
+
# Highlight tables used (map aliases back to table names)
|
| 195 |
+
used_tables = {alias_to_table.get(a, a) for a in join_order}
|
| 196 |
+
|
| 197 |
+
# Normalize edges to actual table names present in schema
|
| 198 |
+
schema_tables = {t["name"] for t in schema.get("tables", [])}
|
| 199 |
+
edges = { (a, b) for (a, b) in edges if a in schema_tables and b in schema_tables }
|
| 200 |
+
used_tables = { t for t in used_tables if t in schema_tables }
|
| 201 |
+
|
| 202 |
+
return used_tables, edges
|
| 203 |
+
|
| 204 |
+
# -------------------- SQLite + locking --------------------
|
| 205 |
DB_DIR = "/data" if os.path.exists("/data") else "."
|
| 206 |
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
|
| 207 |
EXPORT_DIR = "."
|
|
|
|
| 209 |
random.seed(RANDOM_SEED)
|
| 210 |
SYS_RAND = random.SystemRandom()
|
| 211 |
|
|
|
|
| 212 |
DB_LOCK = threading.RLock()
|
| 213 |
|
| 214 |
def connect_db():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
con = sqlite3.connect(DB_PATH, check_same_thread=False)
|
| 216 |
con.execute("PRAGMA journal_mode=WAL;")
|
| 217 |
con.execute("PRAGMA synchronous=NORMAL;")
|
|
|
|
| 256 |
|
| 257 |
init_progress_tables(CONN)
|
| 258 |
|
| 259 |
+
# -------------------- Fallback dataset & questions --------------------
|
| 260 |
FALLBACK_SCHEMA = {
|
| 261 |
"domain": "bookstore",
|
| 262 |
"tables": [
|
|
|
|
| 369 |
"requires_aliases":False,"required_aliases":[]},
|
| 370 |
]
|
| 371 |
|
| 372 |
+
# --------------- OpenAI prompts + parsing helpers ---------------
|
| 373 |
+
DOMAIN_AND_QUESTIONS_SCHEMA = {"required": ["domain", "tables", "questions"]}
|
|
|
|
|
|
|
| 374 |
|
| 375 |
def _domain_prompt(prev_domain: Optional[str]) -> str:
|
| 376 |
extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
|
|
|
|
| 382 |
- One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
|
| 383 |
- Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
|
| 384 |
columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
|
| 385 |
+
- Questions: categories among "SELECT *", "SELECT columns", "WHERE", "Aliases",
|
| 386 |
"JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
|
| 387 |
Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
|
| 388 |
Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
|
| 389 |
For 1–2 questions, set requires_aliases=true and list required_aliases.
|
| 390 |
|
| 391 |
+
Example top-level keys:
|
| 392 |
+
{{"domain":"retail sales","tables":[...],"questions":[...]}}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
"""
|
| 394 |
|
| 395 |
def _loose_json_parse(s: str) -> Optional[dict]:
|
|
|
|
| 396 |
try:
|
| 397 |
return json.loads(s)
|
| 398 |
except Exception:
|
| 399 |
pass
|
| 400 |
+
start = s.find("{"); end = s.rfind("}")
|
|
|
|
| 401 |
if start != -1 and end != -1 and end > start:
|
| 402 |
try:
|
| 403 |
return json.loads(s[start:end+1])
|
|
|
|
| 405 |
return None
|
| 406 |
return None
|
| 407 |
|
| 408 |
+
# Canonicalization of LLM output (questions & tables)
|
| 409 |
_SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 410 |
_CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
|
| 411 |
|
| 412 |
def _strip_code_fences(txt: str) -> str:
|
| 413 |
+
if txt is None: return ""
|
|
|
|
| 414 |
m = _SQL_FENCE.findall(txt)
|
| 415 |
+
if m: return "\n".join([x.strip() for x in m if x.strip()])
|
|
|
|
| 416 |
m2 = _CODE_FENCE.findall(txt)
|
| 417 |
+
if m2: return "\n".join([x.strip() for x in m2 if x.strip()])
|
|
|
|
| 418 |
return txt.strip()
|
| 419 |
|
| 420 |
def _as_list_of_sql(val) -> List[str]:
|
| 421 |
+
if val is None: return []
|
|
|
|
| 422 |
if isinstance(val, str):
|
| 423 |
s = _strip_code_fences(val)
|
| 424 |
parts = [p.strip() for p in s.split("\n") if p.strip()]
|
|
|
|
| 425 |
return parts or ([s] if s else [])
|
| 426 |
if isinstance(val, list):
|
| 427 |
out = []
|
| 428 |
for v in val:
|
| 429 |
if isinstance(v, str):
|
| 430 |
s = _strip_code_fences(v)
|
| 431 |
+
if s: out.append(s)
|
|
|
|
| 432 |
return out
|
| 433 |
return []
|
| 434 |
|
| 435 |
def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 436 |
+
if not isinstance(q, dict): return None
|
|
|
|
|
|
|
|
|
|
| 437 |
cat = q.get("category") or q.get("type") or q.get("topic")
|
| 438 |
prompt = q.get("prompt_md") or q.get("prompt") or q.get("question") or q.get("text")
|
| 439 |
answer_sql = q.get("answer_sql") or q.get("answers") or q.get("solutions") or q.get("sql")
|
| 440 |
diff = q.get("difficulty") or 1
|
| 441 |
req_alias = bool(q.get("requires_aliases", False))
|
| 442 |
req_aliases = q.get("required_aliases") or []
|
|
|
|
| 443 |
cat = str(cat).strip() if cat is not None else ""
|
| 444 |
prompt = str(prompt).strip() if prompt is not None else ""
|
| 445 |
answers = _as_list_of_sql(answer_sql)
|
| 446 |
+
if not cat or not prompt or not answers: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
known = {
|
| 448 |
"SELECT *","SELECT columns","WHERE","Aliases",
|
| 449 |
"JOIN (INNER)","JOIN (LEFT)","Aggregation","VIEW","CTAS / SELECT INTO"
|
| 450 |
}
|
| 451 |
if cat not in known:
|
|
|
|
| 452 |
low = cat.lower()
|
| 453 |
+
if "select *" in low: cat = "SELECT *"
|
| 454 |
+
elif "columns" in low: cat = "SELECT columns"
|
| 455 |
+
elif "where" in low or "filter" in low: cat = "WHERE"
|
| 456 |
+
elif "alias" in low: cat = "Aliases"
|
| 457 |
+
elif "left" in low: cat = "JOIN (LEFT)"
|
| 458 |
+
elif "inner" in low or "join" in low: cat = "JOIN (INNER)"
|
| 459 |
+
elif "agg" in low or "group" in low: cat = "Aggregation"
|
| 460 |
+
elif "view" in low: cat = "VIEW"
|
| 461 |
+
elif "into" in low or "ctas" in low: cat = "CTAS / SELECT INTO"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
if isinstance(req_aliases, str):
|
| 463 |
req_aliases = [a.strip() for a in re.split(r"[,\s]+", req_aliases) if a.strip()]
|
| 464 |
elif not isinstance(req_aliases, list):
|
| 465 |
req_aliases = []
|
|
|
|
| 466 |
return {
|
| 467 |
"id": str(q.get("id") or f"LLM_{int(time.time()*1000)}_{random.randint(100,999)}"),
|
| 468 |
"category": cat,
|
|
|
|
| 476 |
def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 477 |
out = []
|
| 478 |
for t in (tables or []):
|
| 479 |
+
if not isinstance(t, dict): continue
|
|
|
|
| 480 |
name = str(t.get("name","")).strip()
|
| 481 |
+
if not name: continue
|
|
|
|
| 482 |
cols = t.get("columns") or []
|
| 483 |
good_cols = []
|
| 484 |
for c in cols:
|
| 485 |
+
if not isinstance(c, dict): continue
|
|
|
|
| 486 |
cname = str(c.get("name","")).strip()
|
| 487 |
ctype = str(c.get("type","TEXT")).strip() or "TEXT"
|
| 488 |
+
if cname: good_cols.append({"name": cname, "type": ctype})
|
| 489 |
+
if not good_cols: continue
|
|
|
|
|
|
|
| 490 |
pk = t.get("pk") or []
|
| 491 |
if isinstance(pk, str): pk = [pk]
|
| 492 |
fks = t.get("fks") or []
|
|
|
|
| 500 |
})
|
| 501 |
return out
|
| 502 |
|
|
|
|
| 503 |
def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 505 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
|
|
|
|
| 506 |
errors = []
|
| 507 |
prompt = _domain_prompt(prev_domain)
|
|
|
|
| 508 |
for model in _candidate_models():
|
| 509 |
try:
|
|
|
|
| 510 |
try:
|
| 511 |
chat = _client.chat.completions.create(
|
| 512 |
model=model,
|
| 513 |
messages=[{"role":"user","content": prompt}],
|
| 514 |
temperature=0.6,
|
| 515 |
+
response_format={"type":"json_object"}
|
| 516 |
)
|
| 517 |
data_text = chat.choices[0].message.content
|
| 518 |
except TypeError:
|
|
|
|
| 519 |
chat = _client.chat.completions.create(
|
| 520 |
model=model,
|
| 521 |
messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
|
|
|
|
| 523 |
temperature=0.6
|
| 524 |
)
|
| 525 |
data_text = chat.choices[0].message.content
|
|
|
|
| 526 |
obj_raw = _loose_json_parse(data_text or "")
|
| 527 |
if not obj_raw:
|
| 528 |
raise RuntimeError("Could not parse JSON from model output.")
|
|
|
|
|
|
|
| 529 |
for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
|
| 530 |
if k not in obj_raw:
|
| 531 |
raise RuntimeError(f"Missing key '{k}'")
|
|
|
|
|
|
|
| 532 |
tables = _canon_tables(obj_raw.get("tables", []))
|
| 533 |
+
if not tables: raise RuntimeError("No usable tables in LLM output.")
|
|
|
|
| 534 |
obj_raw["tables"] = tables
|
|
|
|
|
|
|
| 535 |
dropped = 0
|
| 536 |
clean_qs = []
|
| 537 |
for q in obj_raw.get("questions", []):
|
| 538 |
cq = _canon_question(q)
|
| 539 |
+
if not cq: dropped += 1; continue
|
|
|
|
|
|
|
|
|
|
| 540 |
answers = [a for a in cq["answer_sql"] if " right join " not in a.lower() and " full " not in a.lower()]
|
| 541 |
+
if not answers: dropped += 1; continue
|
|
|
|
|
|
|
| 542 |
cq["answer_sql"] = answers
|
| 543 |
clean_qs.append(cq)
|
|
|
|
| 544 |
if not clean_qs:
|
| 545 |
raise RuntimeError("No usable questions after canonicalization.")
|
| 546 |
stats = {"accepted_questions": len(clean_qs), "dropped_questions": dropped}
|
|
|
|
| 547 |
obj_raw["questions"] = clean_qs
|
| 548 |
return obj_raw, None, model, stats
|
|
|
|
| 549 |
except Exception as e:
|
| 550 |
errors.append(f"{model}: {e}")
|
| 551 |
continue
|
|
|
|
| 552 |
return None, "; ".join(errors) if errors else "Unknown LLM error.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 553 |
|
| 554 |
+
# -------------------- Install schema & prepare questions --------------------
|
| 555 |
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
|
| 556 |
with DB_LOCK:
|
| 557 |
cur = con.cursor()
|
|
|
|
| 570 |
drop_existing_domain_tables(con, keep_internal=True)
|
| 571 |
with DB_LOCK:
|
| 572 |
cur = con.cursor()
|
|
|
|
| 573 |
for t in schema.get("tables", []):
|
| 574 |
cols_sql = []
|
| 575 |
pk = t.get("pk", [])
|
| 576 |
for c in t.get("columns", []):
|
| 577 |
+
cols_sql.append(f"{c['name']} {c.get('type','TEXT')}")
|
| 578 |
+
if pk: cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
|
|
|
|
|
|
|
|
|
|
| 579 |
create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
|
| 580 |
cur.execute(create_sql)
|
|
|
|
| 581 |
for t in schema.get("tables", []):
|
| 582 |
+
if not t.get("rows"): continue
|
|
|
|
| 583 |
cols = [c["name"] for c in t.get("columns", [])]
|
| 584 |
qmarks = ",".join(["?"]*len(cols))
|
| 585 |
insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
|
| 586 |
for r in t["rows"]:
|
| 587 |
+
if isinstance(r, dict): vals = [r.get(col, None) for col in cols]
|
|
|
|
| 588 |
elif isinstance(r, (list, tuple)):
|
| 589 |
+
vals = list(r) + [None]*(len(cols)-len(r)); vals = vals[:len(cols)]
|
| 590 |
+
else: continue
|
|
|
|
|
|
|
| 591 |
cur.execute(insert_sql, vals)
|
| 592 |
con.commit()
|
|
|
|
| 593 |
cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
|
| 594 |
(schema.get("domain","unknown"), json.dumps(schema)))
|
| 595 |
con.commit()
|
| 596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
|
| 598 |
obj, err, model_used, stats = llm_generate_domain_and_questions(prev_domain)
|
| 599 |
if obj is None:
|
|
|
|
| 603 |
def install_schema_and_prepare_questions(prev_domain: Optional[str]):
|
| 604 |
schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
|
| 605 |
install_schema(CONN, schema)
|
|
|
|
| 606 |
if not questions:
|
| 607 |
questions = FALLBACK_QUESTIONS
|
| 608 |
+
info = {"source":"openai+fallback-questions","model":info.get("model"),
|
| 609 |
+
"error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
|
| 610 |
return schema, questions, info
|
| 611 |
|
| 612 |
+
# -------------------- Session globals --------------------
|
| 613 |
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=None)
|
| 614 |
|
| 615 |
# -------------------- Progress + mastery --------------------
|
|
|
|
| 644 |
return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
|
| 645 |
|
| 646 |
def pick_next_question(user_id: str) -> Dict[str,Any]:
|
|
|
|
| 647 |
pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
|
| 648 |
df = fetch_attempts(CONN, user_id)
|
| 649 |
stats = topic_stats(df)
|
|
|
|
| 652 |
cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
|
| 653 |
return dict(random.choice(cands))
|
| 654 |
|
| 655 |
+
# -------------------- SQL execution & grading --------------------
|
| 656 |
+
def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
|
| 657 |
+
with DB_LOCK:
|
| 658 |
+
return pd.read_sql_query(sql, con)
|
| 659 |
+
|
| 660 |
+
def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
|
| 661 |
+
s = sql.strip().strip(";")
|
| 662 |
+
if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
|
| 663 |
+
m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
|
| 664 |
+
if m:
|
| 665 |
+
cols, tbl, rest = m.groups()
|
| 666 |
+
return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
|
| 667 |
+
return sql, None
|
| 668 |
+
|
| 669 |
+
def detect_unsupported_joins(sql: str) -> Optional[str]:
|
| 670 |
+
low = sql.lower()
|
| 671 |
+
if " right join " in low:
|
| 672 |
+
return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
|
| 673 |
+
if " full join " in low or " full outer join " in low:
|
| 674 |
+
return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION."
|
| 675 |
+
if " ilike " in low:
|
| 676 |
+
return "SQLite has no ILIKE. Use LOWER(col) LIKE LOWER('%pattern%')."
|
| 677 |
+
return None
|
| 678 |
+
|
| 679 |
+
def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
|
| 680 |
+
low = sql.lower()
|
| 681 |
+
if " cross join " in low: return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
|
| 682 |
+
comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
|
| 683 |
+
missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
|
| 684 |
+
if comma_from or missing_on:
|
| 685 |
+
try:
|
| 686 |
+
with DB_LOCK:
|
| 687 |
+
cur = con.cursor()
|
| 688 |
+
if comma_from: t1, t2 = comma_from.groups()
|
| 689 |
+
else:
|
| 690 |
+
m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low); j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
|
| 691 |
+
if not m or not j: return "Possible cartesian product: no join condition detected."
|
| 692 |
+
t1, t2 = m.group(1), j.group(1)
|
| 693 |
+
cur.execute(f"SELECT COUNT(*) FROM {t1}"); n1 = cur.fetchone()[0]
|
| 694 |
+
cur.execute(f"SELECT COUNT(*) FROM {t2}"); n2 = cur.fetchone()[0]
|
| 695 |
+
prod = n1 * n2
|
| 696 |
+
if len(df_result) == prod and prod > 0:
|
| 697 |
+
return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
|
| 698 |
+
except Exception:
|
| 699 |
+
return "Possible cartesian product: no join condition detected."
|
| 700 |
+
return None
|
| 701 |
+
|
| 702 |
+
# Column enforcement policy — only when the prompt asks for it
|
| 703 |
+
def should_enforce_columns(q: Dict[str, Any]) -> bool:
|
| 704 |
+
cat = (q.get("category") or "").strip()
|
| 705 |
+
if cat in ("SELECT columns", "Aggregation", "VIEW", "CTAS / SELECT INTO"):
|
| 706 |
+
return True
|
| 707 |
+
prompt = (q.get("prompt_md") or "").lower()
|
| 708 |
+
# Signals that the projection is specified in the prompt
|
| 709 |
+
if re.search(r"`[^`]+`", q.get("prompt_md") or ""): # backticked names
|
| 710 |
+
return True
|
| 711 |
+
if re.search(r"\((?:show|return|display)[^)]+\)", prompt): # e.g., "(show title, price)"
|
| 712 |
+
return True
|
| 713 |
+
if re.search(r"\b(show|return|display|select)\b[^.]{0,100}\b(columns?|fields?|name|title|price)\b", prompt):
|
| 714 |
+
return True
|
| 715 |
+
return False
|
| 716 |
+
|
| 717 |
+
def _normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
|
| 718 |
+
out = df.copy()
|
| 719 |
+
out.columns = [str(c).strip().lower() for c in out.columns]
|
| 720 |
+
return out
|
| 721 |
+
|
| 722 |
+
def results_equal_or_superset(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> Tuple[bool, Optional[str]]:
|
| 723 |
+
a = _normalize_columns(df_student); b = _normalize_columns(df_expected)
|
| 724 |
+
if set(a.columns) == set(b.columns):
|
| 725 |
+
a2 = a[sorted(a.columns)].sort_values(sorted(a.columns)).reset_index(drop=True)
|
| 726 |
+
b2 = b[sorted(b.columns)].sort_values(sorted(b.columns)).reset_index(drop=True)
|
| 727 |
+
return (a2.equals(b2), None)
|
| 728 |
+
if set(b.columns).issubset(set(a.columns)):
|
| 729 |
+
a_proj = a[b.columns]
|
| 730 |
+
a2 = a_proj.sort_values(list(b.columns)).reset_index(drop=True)
|
| 731 |
+
b2 = b.sort_values(list(b.columns)).reset_index(drop=True)
|
| 732 |
+
if a2.equals(b2):
|
| 733 |
+
return True, "extra_columns"
|
| 734 |
+
return False, None
|
| 735 |
+
|
| 736 |
+
def results_equal_rowcount_only(df_student: pd.DataFrame, df_expected: pd.DataFrame) -> bool:
|
| 737 |
+
# When projection isn't specified, match on row count only.
|
| 738 |
+
return df_student.shape[0] == df_expected.shape[0]
|
| 739 |
+
|
| 740 |
def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
|
| 741 |
if not sql_text or not sql_text.strip():
|
| 742 |
return None, "Enter a SQL statement.", None, None
|
|
|
|
| 743 |
sql_raw = sql_text.strip().rstrip(";")
|
| 744 |
sql_rew, created_tbl = rewrite_select_into(sql_raw)
|
| 745 |
+
note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite." if sql_rew != sql_raw else None
|
|
|
|
|
|
|
|
|
|
| 746 |
unsup = detect_unsupported_joins(sql_rew)
|
| 747 |
+
if unsup: return None, unsup, None, note
|
|
|
|
|
|
|
| 748 |
try:
|
| 749 |
low = sql_rew.lower()
|
| 750 |
if low.startswith("select"):
|
|
|
|
| 754 |
else:
|
| 755 |
with DB_LOCK:
|
| 756 |
cur = CONN.cursor()
|
| 757 |
+
cur.execute(sql_rew); CONN.commit()
|
|
|
|
|
|
|
| 758 |
if low.startswith("create view"):
|
| 759 |
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
|
| 760 |
name = m.group(2) if m else None
|
| 761 |
if name:
|
| 762 |
+
try: return pd.read_sql_query(f"SELECT * FROM {name}", CONN), None, None, note
|
| 763 |
+
except Exception: return None, "View created but could not be queried.", None, note
|
|
|
|
|
|
|
|
|
|
| 764 |
if low.startswith("create table"):
|
| 765 |
tbl = created_tbl
|
| 766 |
if not tbl:
|
| 767 |
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 768 |
tbl = m.group(2) if m else None
|
| 769 |
if tbl:
|
| 770 |
+
try: return pd.read_sql_query(f"SELECT * FROM {tbl}", CONN), None, None, note
|
| 771 |
+
except Exception: return None, "Table created but could not be queried.", None, note
|
|
|
|
|
|
|
|
|
|
| 772 |
return pd.DataFrame(), None, None, note
|
| 773 |
except Exception as e:
|
| 774 |
msg = str(e)
|
| 775 |
+
if "no such table" in msg.lower(): return None, f"{msg}. Check table names for this randomized domain.", None, note
|
| 776 |
+
if "no such column" in msg.lower(): return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
|
| 777 |
+
if "ambiguous column name" in msg.lower(): return None, f"{msg}. Qualify the column with a table alias.", None, note
|
|
|
|
|
|
|
|
|
|
| 778 |
if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
|
| 779 |
return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
|
| 780 |
if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
|
| 781 |
return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
|
| 782 |
if "syntax error" in msg.lower():
|
| 783 |
+
return None, f"Syntax error. Check commas, keywords, parentheses. Raw error: {msg}", None, note
|
| 784 |
return None, f"SQL error: {msg}", None, note
|
| 785 |
|
| 786 |
def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
|
| 787 |
for sql in answer_sql:
|
| 788 |
try:
|
| 789 |
low = sql.strip().lower()
|
| 790 |
+
if low.startswith("select"): return run_df(CONN, sql)
|
|
|
|
| 791 |
if low.startswith("create view"):
|
| 792 |
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 793 |
view_name = m.group(2) if m else "vw_tmp"
|
| 794 |
with DB_LOCK:
|
| 795 |
cur = CONN.cursor()
|
| 796 |
cur.execute(f"DROP VIEW IF EXISTS {view_name}")
|
| 797 |
+
cur.execute(sql); CONN.commit()
|
|
|
|
| 798 |
return run_df(CONN, f"SELECT * FROM {view_name}")
|
| 799 |
if low.startswith("create table"):
|
| 800 |
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 801 |
tbl = m.group(2) if m else None
|
| 802 |
with DB_LOCK:
|
| 803 |
cur = CONN.cursor()
|
| 804 |
+
if tbl: cur.execute(f"DROP TABLE IF EXISTS {tbl}")
|
| 805 |
+
cur.execute(sql); CONN.commit()
|
| 806 |
+
if tbl: return run_df(CONN, f"SELECT * FROM {tbl}")
|
|
|
|
|
|
|
|
|
|
| 807 |
except Exception:
|
| 808 |
continue
|
| 809 |
return None
|
|
|
|
| 814 |
return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
|
| 815 |
if df_student is None:
|
| 816 |
return False, f"**Explanation:** Expected data result differs."
|
| 817 |
+
if should_enforce_columns(q):
|
| 818 |
+
ok, note = results_equal_or_superset(df_student, df_expected)
|
| 819 |
+
if ok and note == "extra_columns":
|
| 820 |
+
return True, "**Note:** You returned extra columns. The rows match; try selecting only the requested columns next time."
|
| 821 |
+
if ok:
|
| 822 |
+
return True, "**Explanation:** Your result matches a canonical solution."
|
| 823 |
+
return False, f"**Explanation:** Compare your result to a canonical solution."
|
| 824 |
+
else:
|
| 825 |
+
ok = results_equal_rowcount_only(df_student, df_expected)
|
| 826 |
+
if ok:
|
| 827 |
+
return True, "**Explanation:** Columns weren’t specified for this task; row count matches the canonical answer."
|
| 828 |
+
return False, "**Explanation:** For this task we compared row counts (projection not enforced) and they didn’t match."
|
| 829 |
|
| 830 |
def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
|
| 831 |
time_taken: float, difficulty: int, source: str, notes: str):
|
|
|
|
| 846 |
gr.update(value="Please enter your name to begin.", visible=True),
|
| 847 |
gr.update(visible=False),
|
| 848 |
gr.update(visible=False),
|
| 849 |
+
draw_dynamic_erd(CURRENT_SCHEMA),
|
| 850 |
gr.update(visible=False),
|
| 851 |
pd.DataFrame(),
|
| 852 |
pd.DataFrame())
|
|
|
|
| 859 |
|
| 860 |
prompt = q["prompt_md"]
|
| 861 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 862 |
+
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 863 |
return (session,
|
| 864 |
gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
|
| 865 |
+
gr.update(visible=True),
|
| 866 |
+
gr.update(value="", visible=True),
|
| 867 |
+
erd,
|
| 868 |
+
gr.update(visible=False),
|
| 869 |
stats,
|
| 870 |
pd.DataFrame())
|
| 871 |
|
| 872 |
def render_preview(sql_text: str, session: dict):
|
| 873 |
if not session or "q" not in session:
|
| 874 |
+
return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
|
| 875 |
s = (sql_text or "").strip()
|
| 876 |
if not s:
|
| 877 |
+
return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
|
| 878 |
+
hi_tables, hi_edges = sql_highlights(s, CURRENT_SCHEMA)
|
| 879 |
+
erd = draw_dynamic_erd(CURRENT_SCHEMA, highlight_tables=hi_tables, highlight_edges=hi_edges)
|
| 880 |
+
return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), erd
|
| 881 |
|
| 882 |
def submit_answer(sql_text: str, session: dict):
|
| 883 |
if not session or "user_id" not in session or "q" not in session:
|
|
|
|
| 885 |
user_id = session["user_id"]
|
| 886 |
q = session["q"]
|
| 887 |
elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
|
|
|
|
| 888 |
df, err, warn, note = exec_student_sql(sql_text)
|
| 889 |
details = []
|
| 890 |
if note: details.append(f"ℹ️ {note}")
|
|
|
|
| 894 |
log_attempt(user_id, q.get("id","?"), q.get("category","?"), False, sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join([err] + details))
|
| 895 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 896 |
return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
|
|
|
|
| 897 |
alias_msg = None
|
| 898 |
+
if q.get("requires_aliases") and not aliases_present(sql_text, q.get("required_aliases", [])):
|
| 899 |
+
alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
|
|
|
|
|
|
|
| 900 |
is_correct, explanation = validate_answer(q, sql_text, df)
|
| 901 |
if warn: details.append(f"⚠️ {warn}")
|
| 902 |
if alias_msg: details.append(alias_msg)
|
|
|
|
| 903 |
prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
|
| 904 |
feedback = prefix
|
| 905 |
+
if details: feedback += "\n\n" + "\n".join(details)
|
|
|
|
| 906 |
feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
|
|
|
|
| 907 |
log_attempt(user_id, q["id"], q.get("category","?"), bool(is_correct), sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join(details))
|
| 908 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 909 |
return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
|
| 910 |
|
| 911 |
def next_question(session: dict):
|
| 912 |
if not session or "user_id" not in session:
|
| 913 |
+
return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
|
| 914 |
user_id = session["user_id"]
|
| 915 |
q = pick_next_question(user_id)
|
| 916 |
+
session["qid"] = q["id"]; session["q"] = q; session["start_ts"] = time.time()
|
| 917 |
+
return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
|
|
|
|
|
|
|
| 918 |
|
| 919 |
def show_hint(session: dict):
|
| 920 |
if not session or "q" not in session:
|
|
|
|
| 935 |
|
| 936 |
def export_progress(user_name: str):
|
| 937 |
slug = "-".join((user_name or "").lower().split())
|
| 938 |
+
if not slug: return None
|
|
|
|
| 939 |
user_id = slug[:64]
|
| 940 |
with DB_LOCK:
|
| 941 |
df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
|
|
|
|
| 946 |
|
| 947 |
def _domain_status_md():
|
| 948 |
if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
|
| 949 |
+
note = " (LLM domain ok; used fallback questions)" if CURRENT_INFO.get("source") == "openai+fallback-questions" else ""
|
| 950 |
+
accepted = CURRENT_INFO.get("accepted",0); dropped = CURRENT_INFO.get("dropped",0)
|
| 951 |
+
return (f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
|
| 952 |
+
f"Accepted questions: {accepted}, dropped: {dropped}. \n"
|
| 953 |
+
f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}.")
|
| 954 |
+
err = CURRENT_INFO.get("error",""); err_short = (err[:160] + "…") if len(err) > 160 else err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
| 956 |
|
| 957 |
def regenerate_domain():
|
| 958 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 959 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
| 960 |
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=prev)
|
| 961 |
+
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 962 |
+
return gr.update(value=_domain_status_md(), visible=True), erd
|
| 963 |
|
| 964 |
def preview_table(tbl: str):
|
| 965 |
try:
|
|
|
|
| 969 |
|
| 970 |
def list_tables_for_preview():
|
| 971 |
df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
|
| 972 |
+
if df.empty: return ["(no tables)"]
|
|
|
|
| 973 |
return df["name"].tolist()
|
| 974 |
|
| 975 |
# -------------------- UI --------------------
|
|
|
|
| 980 |
- Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
|
| 981 |
sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
|
| 982 |
- Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
|
| 983 |
+
- **ERD highlights your JOINs** as you type; all FK edges remain visible in light gray.
|
|
|
|
|
|
|
| 984 |
"""
|
| 985 |
)
|
| 986 |
|
| 987 |
with gr.Row():
|
|
|
|
| 988 |
with gr.Column(scale=1):
|
| 989 |
name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
|
| 990 |
start_btn = gr.Button("Start / Resume Session", variant="primary")
|
|
|
|
| 1007 |
tbl_btn = gr.Button("Preview")
|
| 1008 |
preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)
|
| 1009 |
|
|
|
|
| 1010 |
with gr.Column(scale=2):
|
| 1011 |
prompt_md = gr.Markdown(visible=False)
|
| 1012 |
sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
|
|
|
|
| 1013 |
preview_md = gr.Markdown(visible=False)
|
| 1014 |
+
er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)
|
| 1015 |
|
| 1016 |
with gr.Row():
|
| 1017 |
submit_btn = gr.Button("Run & Submit", variant="primary")
|
|
|
|
| 1022 |
|
| 1023 |
gr.Markdown("---")
|
| 1024 |
gr.Markdown("### Your Progress by Category")
|
| 1025 |
+
mastery_df = gr.Dataframe(headers=["category","attempts","correct","accuracy"],
|
| 1026 |
+
col_count=(4,"dynamic"), row_count=(0,"dynamic"), interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
|
| 1028 |
gr.Markdown("---")
|
| 1029 |
gr.Markdown("### Result Preview")
|
|
|
|
| 1033 |
start_btn.click(
|
| 1034 |
start_session,
|
| 1035 |
inputs=[name_box, session_state],
|
| 1036 |
+
outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
|
| 1037 |
)
|
| 1038 |
sql_input.change(
|
| 1039 |
render_preview,
|
| 1040 |
inputs=[sql_input, session_state],
|
| 1041 |
+
outputs=[preview_md, er_image],
|
| 1042 |
)
|
| 1043 |
submit_btn.click(
|
| 1044 |
submit_answer,
|
|
|
|
| 1048 |
next_btn.click(
|
| 1049 |
next_question,
|
| 1050 |
inputs=[session_state],
|
| 1051 |
+
outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
|
| 1052 |
)
|
| 1053 |
hint_btn.click(
|
| 1054 |
show_hint,
|
|
|
|
| 1063 |
regen_btn.click(
|
| 1064 |
regenerate_domain,
|
| 1065 |
inputs=[],
|
| 1066 |
+
outputs=[regen_fb, er_image],
|
| 1067 |
)
|
| 1068 |
tbl_btn.click(
|
| 1069 |
lambda name: preview_table(name),
|
| 1070 |
inputs=[tbl_dd],
|
| 1071 |
outputs=[preview_df]
|
| 1072 |
)
|
| 1073 |
+
regen_btn.click( # refresh list after regeneration
|
|
|
|
| 1074 |
lambda: gr.update(choices=list_tables_for_preview()),
|
| 1075 |
inputs=[],
|
| 1076 |
outputs=[tbl_dd]
|