Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -217,7 +217,7 @@ FALLBACK_QUESTIONS = [
|
|
| 217 |
"requires_aliases":False,"required_aliases":[]},
|
| 218 |
]
|
| 219 |
|
| 220 |
-
# -------------------- OpenAI JSON
|
| 221 |
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 222 |
"required": ["domain", "tables", "questions"]
|
| 223 |
}
|
|
@@ -252,7 +252,6 @@ def _loose_json_parse(s: str) -> Optional[dict]:
|
|
| 252 |
return json.loads(s)
|
| 253 |
except Exception:
|
| 254 |
pass
|
| 255 |
-
# Try to find the first {...} block
|
| 256 |
start = s.find("{")
|
| 257 |
end = s.rfind("}")
|
| 258 |
if start != -1 and end != -1 and end > start:
|
|
@@ -262,20 +261,151 @@ def _loose_json_parse(s: str) -> Optional[dict]:
|
|
| 262 |
return None
|
| 263 |
return None
|
| 264 |
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
"""
|
| 267 |
-
Returns (obj, error_message, model_used).
|
| 268 |
-
|
| 269 |
"""
|
| 270 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 271 |
-
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
|
| 272 |
|
| 273 |
errors = []
|
| 274 |
prompt = _domain_prompt(prev_domain)
|
| 275 |
|
| 276 |
for model in _candidate_models():
|
| 277 |
-
# Try JSON mode first (if supported)
|
| 278 |
try:
|
|
|
|
| 279 |
try:
|
| 280 |
chat = _client.chat.completions.create(
|
| 281 |
model=model,
|
|
@@ -285,7 +415,7 @@ def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optio
|
|
| 285 |
)
|
| 286 |
data_text = chat.choices[0].message.content
|
| 287 |
except TypeError:
|
| 288 |
-
# Older SDKs: no response_format
|
| 289 |
chat = _client.chat.completions.create(
|
| 290 |
model=model,
|
| 291 |
messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
|
|
@@ -294,31 +424,49 @@ def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optio
|
|
| 294 |
)
|
| 295 |
data_text = chat.choices[0].message.content
|
| 296 |
|
| 297 |
-
|
| 298 |
-
if not
|
| 299 |
raise RuntimeError("Could not parse JSON from model output.")
|
| 300 |
|
| 301 |
-
# Minimal validation
|
| 302 |
for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
|
| 303 |
-
if k not in
|
| 304 |
raise RuntimeError(f"Missing key '{k}'")
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
clean_qs = []
|
| 307 |
-
for q in
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
if not answers:
|
|
|
|
| 310 |
continue
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
except Exception as e:
|
| 318 |
errors.append(f"{model}: {e}")
|
| 319 |
continue
|
| 320 |
|
| 321 |
-
return None, "; ".join(errors) if errors else "Unknown LLM error.", None
|
| 322 |
|
| 323 |
# -------------------- Schema install & question handling --------------------
|
| 324 |
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
|
|
@@ -465,18 +613,22 @@ def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
|
|
| 465 |
|
| 466 |
# -------------------- Domain bootstrap --------------------
|
| 467 |
def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
|
| 468 |
-
obj, err, model_used = llm_generate_domain_and_questions(prev_domain)
|
| 469 |
if obj is None:
|
| 470 |
-
return FALLBACK_SCHEMA, FALLBACK_QUESTIONS, {"source":"fallback","model":None,"error":err}
|
| 471 |
-
return obj, obj["questions"], {"source":"openai","model":model_used,"error":None}
|
| 472 |
|
| 473 |
-
def
|
| 474 |
schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
|
| 475 |
install_schema(CONN, schema)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
return schema, questions, info
|
| 477 |
|
| 478 |
# -------------------- Session state --------------------
|
| 479 |
-
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO =
|
| 480 |
|
| 481 |
# -------------------- Progress + mastery --------------------
|
| 482 |
def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
|
|
@@ -510,11 +662,13 @@ def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
|
|
| 510 |
return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
|
| 511 |
|
| 512 |
def pick_next_question(user_id: str) -> Dict[str,Any]:
|
|
|
|
|
|
|
| 513 |
df = fetch_attempts(CONN, user_id)
|
| 514 |
stats = topic_stats(df)
|
| 515 |
stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
|
| 516 |
weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
|
| 517 |
-
cands = [q for q in
|
| 518 |
return dict(random.choice(cands))
|
| 519 |
|
| 520 |
# -------------------- Execution & feedback --------------------
|
|
@@ -679,7 +833,7 @@ def submit_answer(sql_text: str, session: dict):
|
|
| 679 |
if err:
|
| 680 |
fb = f"❌ **Did not run**\n\n{err}"
|
| 681 |
if details: fb += "\n\n" + "\n".join(details)
|
| 682 |
-
log_attempt(user_id, q
|
| 683 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 684 |
return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
|
| 685 |
|
|
@@ -698,7 +852,7 @@ def submit_answer(sql_text: str, session: dict):
|
|
| 698 |
feedback += "\n\n" + "\n".join(details)
|
| 699 |
feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
|
| 700 |
|
| 701 |
-
log_attempt(user_id, q["id"], q
|
| 702 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 703 |
return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
|
| 704 |
|
|
@@ -715,7 +869,7 @@ def next_question(session: dict):
|
|
| 715 |
def show_hint(session: dict):
|
| 716 |
if not session or "q" not in session:
|
| 717 |
return gr.update(value="Start a session first.", visible=True)
|
| 718 |
-
cat = session["q"]
|
| 719 |
hint = {
|
| 720 |
"SELECT *": "Use `SELECT * FROM table_name`.",
|
| 721 |
"SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
|
|
@@ -734,16 +888,25 @@ def export_progress(user_name: str):
|
|
| 734 |
if not slug:
|
| 735 |
return None
|
| 736 |
user_id = slug[:64]
|
| 737 |
-
|
|
|
|
| 738 |
os.makedirs(EXPORT_DIR, exist_ok=True)
|
| 739 |
path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
|
| 740 |
(pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
|
| 741 |
return path
|
| 742 |
|
| 743 |
def _domain_status_md():
|
| 744 |
-
if CURRENT_INFO.get("source")
|
| 745 |
-
|
| 746 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
err = CURRENT_INFO.get("error","")
|
| 748 |
err_short = (err[:160] + "…") if len(err) > 160 else err
|
| 749 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
|
@@ -751,7 +914,7 @@ def _domain_status_md():
|
|
| 751 |
def regenerate_domain():
|
| 752 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 753 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
| 754 |
-
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO =
|
| 755 |
return gr.update(value=_domain_status_md(), visible=True)
|
| 756 |
|
| 757 |
def preview_table(tbl: str):
|
|
|
|
| 217 |
"requires_aliases":False,"required_aliases":[]},
|
| 218 |
]
|
| 219 |
|
| 220 |
+
# -------------------- OpenAI JSON request helpers --------------------
|
| 221 |
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 222 |
"required": ["domain", "tables", "questions"]
|
| 223 |
}
|
|
|
|
| 252 |
return json.loads(s)
|
| 253 |
except Exception:
|
| 254 |
pass
|
|
|
|
| 255 |
start = s.find("{")
|
| 256 |
end = s.rfind("}")
|
| 257 |
if start != -1 and end != -1 and end > start:
|
|
|
|
| 261 |
return None
|
| 262 |
return None
|
| 263 |
|
| 264 |
+
# -------------------- Canonicalization & validation --------------------
|
| 265 |
+
_SQL_FENCE = re.compile(r"```sql(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 266 |
+
_CODE_FENCE = re.compile(r"```(.*?)```", re.DOTALL)
|
| 267 |
+
|
| 268 |
+
def _strip_code_fences(txt: str) -> str:
|
| 269 |
+
if txt is None:
|
| 270 |
+
return ""
|
| 271 |
+
m = _SQL_FENCE.findall(txt)
|
| 272 |
+
if m:
|
| 273 |
+
return "\n".join([x.strip() for x in m if x.strip()])
|
| 274 |
+
m2 = _CODE_FENCE.findall(txt)
|
| 275 |
+
if m2:
|
| 276 |
+
return "\n".join([x.strip() for x in m2 if x.strip()])
|
| 277 |
+
return txt.strip()
|
| 278 |
+
|
| 279 |
+
def _as_list_of_sql(val) -> List[str]:
|
| 280 |
+
if val is None:
|
| 281 |
+
return []
|
| 282 |
+
if isinstance(val, str):
|
| 283 |
+
s = _strip_code_fences(val)
|
| 284 |
+
parts = [p.strip() for p in s.split("\n") if p.strip()]
|
| 285 |
+
# if it’s a single long line, keep as is
|
| 286 |
+
return parts or ([s] if s else [])
|
| 287 |
+
if isinstance(val, list):
|
| 288 |
+
out = []
|
| 289 |
+
for v in val:
|
| 290 |
+
if isinstance(v, str):
|
| 291 |
+
s = _strip_code_fences(v)
|
| 292 |
+
if s:
|
| 293 |
+
out.append(s)
|
| 294 |
+
return out
|
| 295 |
+
return []
|
| 296 |
+
|
| 297 |
+
def _canon_question(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 298 |
+
"""Normalize one question; return None if missing critical fields."""
|
| 299 |
+
if not isinstance(q, dict):
|
| 300 |
+
return None
|
| 301 |
+
# field mapping/synonyms
|
| 302 |
+
cat = q.get("category") or q.get("type") or q.get("topic")
|
| 303 |
+
prompt = q.get("prompt_md") or q.get("prompt") or q.get("question") or q.get("text")
|
| 304 |
+
answer_sql = q.get("answer_sql") or q.get("answers") or q.get("solutions") or q.get("sql")
|
| 305 |
+
diff = q.get("difficulty") or 1
|
| 306 |
+
req_alias = bool(q.get("requires_aliases", False))
|
| 307 |
+
req_aliases = q.get("required_aliases") or []
|
| 308 |
+
|
| 309 |
+
cat = str(cat).strip() if cat is not None else ""
|
| 310 |
+
prompt = str(prompt).strip() if prompt is not None else ""
|
| 311 |
+
answers = _as_list_of_sql(answer_sql)
|
| 312 |
+
|
| 313 |
+
if not cat or not prompt or not answers:
|
| 314 |
+
return None
|
| 315 |
+
|
| 316 |
+
# keep only known categories if provided; otherwise accept free text
|
| 317 |
+
known = {
|
| 318 |
+
"SELECT *","SELECT columns","WHERE","Aliases",
|
| 319 |
+
"JOIN (INNER)","JOIN (LEFT)","Aggregation","VIEW","CTAS / SELECT INTO"
|
| 320 |
+
}
|
| 321 |
+
if cat not in known:
|
| 322 |
+
# Try to map rough names to our set
|
| 323 |
+
low = cat.lower()
|
| 324 |
+
if "select *" in low:
|
| 325 |
+
cat = "SELECT *"
|
| 326 |
+
elif "select col" in low or "columns" in low:
|
| 327 |
+
cat = "SELECT columns"
|
| 328 |
+
elif "where" in low or "filter" in low:
|
| 329 |
+
cat = "WHERE"
|
| 330 |
+
elif "alias" in low:
|
| 331 |
+
cat = "Aliases"
|
| 332 |
+
elif "left" in low:
|
| 333 |
+
cat = "JOIN (LEFT)"
|
| 334 |
+
elif "inner" in low or "join" in low:
|
| 335 |
+
cat = "JOIN (INNER)"
|
| 336 |
+
elif "agg" in low or "group" in low:
|
| 337 |
+
cat = "Aggregation"
|
| 338 |
+
elif "view" in low:
|
| 339 |
+
cat = "VIEW"
|
| 340 |
+
elif "into" in low or "ctas" in low or "create table" in low:
|
| 341 |
+
cat = "CTAS / SELECT INTO"
|
| 342 |
+
else:
|
| 343 |
+
# leave as-is; still usable for practice buckets
|
| 344 |
+
pass
|
| 345 |
+
|
| 346 |
+
# normalize aliases list
|
| 347 |
+
if isinstance(req_aliases, str):
|
| 348 |
+
req_aliases = [a.strip() for a in re.split(r"[,\s]+", req_aliases) if a.strip()]
|
| 349 |
+
elif not isinstance(req_aliases, list):
|
| 350 |
+
req_aliases = []
|
| 351 |
+
|
| 352 |
+
return {
|
| 353 |
+
"id": str(q.get("id") or f"LLM_{int(time.time()*1000)}_{random.randint(100,999)}"),
|
| 354 |
+
"category": cat,
|
| 355 |
+
"difficulty": int(diff) if str(diff).isdigit() else 1,
|
| 356 |
+
"prompt_md": prompt,
|
| 357 |
+
"answer_sql": answers,
|
| 358 |
+
"requires_aliases": bool(req_alias),
|
| 359 |
+
"required_aliases": req_aliases,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
def _canon_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 363 |
+
out = []
|
| 364 |
+
for t in (tables or []):
|
| 365 |
+
if not isinstance(t, dict):
|
| 366 |
+
continue
|
| 367 |
+
name = str(t.get("name","")).strip()
|
| 368 |
+
if not name:
|
| 369 |
+
continue
|
| 370 |
+
cols = t.get("columns") or []
|
| 371 |
+
good_cols = []
|
| 372 |
+
for c in cols:
|
| 373 |
+
if not isinstance(c, dict):
|
| 374 |
+
continue
|
| 375 |
+
cname = str(c.get("name","")).strip()
|
| 376 |
+
ctype = str(c.get("type","TEXT")).strip() or "TEXT"
|
| 377 |
+
if cname:
|
| 378 |
+
good_cols.append({"name": cname, "type": ctype})
|
| 379 |
+
if not good_cols:
|
| 380 |
+
continue
|
| 381 |
+
pk = t.get("pk") or []
|
| 382 |
+
if isinstance(pk, str): pk = [pk]
|
| 383 |
+
fks = t.get("fks") or []
|
| 384 |
+
rows = t.get("rows") or []
|
| 385 |
+
out.append({
|
| 386 |
+
"name": name,
|
| 387 |
+
"pk": [str(x) for x in pk],
|
| 388 |
+
"columns": good_cols,
|
| 389 |
+
"fks": fks if isinstance(fks, list) else [],
|
| 390 |
+
"rows": rows if isinstance(rows, list) else [],
|
| 391 |
+
})
|
| 392 |
+
return out
|
| 393 |
+
|
| 394 |
+
# -------------------- LLM call --------------------
|
| 395 |
+
def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str], Dict[str,int]]:
|
| 396 |
"""
|
| 397 |
+
Returns (obj, error_message, model_used, stats_dict).
|
| 398 |
+
stats_dict contains {"accepted_questions": n, "dropped_questions": m}
|
| 399 |
"""
|
| 400 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 401 |
+
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 402 |
|
| 403 |
errors = []
|
| 404 |
prompt = _domain_prompt(prev_domain)
|
| 405 |
|
| 406 |
for model in _candidate_models():
|
|
|
|
| 407 |
try:
|
| 408 |
+
# Try JSON mode first
|
| 409 |
try:
|
| 410 |
chat = _client.chat.completions.create(
|
| 411 |
model=model,
|
|
|
|
| 415 |
)
|
| 416 |
data_text = chat.choices[0].message.content
|
| 417 |
except TypeError:
|
| 418 |
+
# Older SDKs: no response_format ⇒ plain chat + strict instructions
|
| 419 |
chat = _client.chat.completions.create(
|
| 420 |
model=model,
|
| 421 |
messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
|
|
|
|
| 424 |
)
|
| 425 |
data_text = chat.choices[0].message.content
|
| 426 |
|
| 427 |
+
obj_raw = _loose_json_parse(data_text or "")
|
| 428 |
+
if not obj_raw:
|
| 429 |
raise RuntimeError("Could not parse JSON from model output.")
|
| 430 |
|
| 431 |
+
# Minimal top-level validation
|
| 432 |
for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
|
| 433 |
+
if k not in obj_raw:
|
| 434 |
raise RuntimeError(f"Missing key '{k}'")
|
| 435 |
+
|
| 436 |
+
# Canonicalize tables
|
| 437 |
+
tables = _canon_tables(obj_raw.get("tables", []))
|
| 438 |
+
if not tables:
|
| 439 |
+
raise RuntimeError("No usable tables in LLM output.")
|
| 440 |
+
obj_raw["tables"] = tables
|
| 441 |
+
|
| 442 |
+
# Canonicalize questions
|
| 443 |
+
dropped = 0
|
| 444 |
clean_qs = []
|
| 445 |
+
for q in obj_raw.get("questions", []):
|
| 446 |
+
cq = _canon_question(q)
|
| 447 |
+
if not cq:
|
| 448 |
+
dropped += 1
|
| 449 |
+
continue
|
| 450 |
+
# Strip RIGHT/FULL joins from answers
|
| 451 |
+
answers = [a for a in cq["answer_sql"] if " right join " not in a.lower() and " full " not in a.lower()]
|
| 452 |
if not answers:
|
| 453 |
+
dropped += 1
|
| 454 |
continue
|
| 455 |
+
cq["answer_sql"] = answers
|
| 456 |
+
clean_qs.append(cq)
|
| 457 |
+
|
| 458 |
+
if not clean_qs:
|
| 459 |
+
raise RuntimeError("No usable questions after canonicalization.")
|
| 460 |
+
stats = {"accepted_questions": len(clean_qs), "dropped_questions": dropped}
|
| 461 |
+
|
| 462 |
+
obj_raw["questions"] = clean_qs
|
| 463 |
+
return obj_raw, None, model, stats
|
| 464 |
+
|
| 465 |
except Exception as e:
|
| 466 |
errors.append(f"{model}: {e}")
|
| 467 |
continue
|
| 468 |
|
| 469 |
+
return None, "; ".join(errors) if errors else "Unknown LLM error.", None, {"accepted_questions":0,"dropped_questions":0}
|
| 470 |
|
| 471 |
# -------------------- Schema install & question handling --------------------
|
| 472 |
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
|
|
|
|
| 613 |
|
| 614 |
# -------------------- Domain bootstrap --------------------
|
| 615 |
def bootstrap_domain_with_llm_or_fallback(prev_domain: Optional[str]):
|
| 616 |
+
obj, err, model_used, stats = llm_generate_domain_and_questions(prev_domain)
|
| 617 |
if obj is None:
|
| 618 |
+
return FALLBACK_SCHEMA, FALLBACK_QUESTIONS, {"source":"fallback","model":None,"error":err,"accepted":0,"dropped":0}
|
| 619 |
+
return obj, obj["questions"], {"source":"openai","model":model_used,"error":None,"accepted":stats["accepted_questions"],"dropped":stats["dropped_questions"]}
|
| 620 |
|
| 621 |
+
def install_schema_and_prepare_questions(prev_domain: Optional[str]):
|
| 622 |
schema, questions, info = bootstrap_domain_with_llm_or_fallback(prev_domain)
|
| 623 |
install_schema(CONN, schema)
|
| 624 |
+
# Safety: if questions empty, fall back
|
| 625 |
+
if not questions:
|
| 626 |
+
questions = FALLBACK_QUESTIONS
|
| 627 |
+
info = {"source":"openai+fallback-questions","model":info.get("model"),"error":"LLM returned 0 usable questions; used fallback bank.","accepted":0,"dropped":0}
|
| 628 |
return schema, questions, info
|
| 629 |
|
| 630 |
# -------------------- Session state --------------------
|
| 631 |
+
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=None)
|
| 632 |
|
| 633 |
# -------------------- Progress + mastery --------------------
|
| 634 |
def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
|
|
|
|
| 662 |
return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
|
| 663 |
|
| 664 |
def pick_next_question(user_id: str) -> Dict[str,Any]:
|
| 665 |
+
# Defensive: ensure we always have a pool
|
| 666 |
+
pool = CURRENT_QS if CURRENT_QS else FALLBACK_QUESTIONS
|
| 667 |
df = fetch_attempts(CONN, user_id)
|
| 668 |
stats = topic_stats(df)
|
| 669 |
stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
|
| 670 |
weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
|
| 671 |
+
cands = [q for q in pool if str(q.get("category","")).strip() == weakest] or pool
|
| 672 |
return dict(random.choice(cands))
|
| 673 |
|
| 674 |
# -------------------- Execution & feedback --------------------
|
|
|
|
| 833 |
if err:
|
| 834 |
fb = f"❌ **Did not run**\n\n{err}"
|
| 835 |
if details: fb += "\n\n" + "\n".join(details)
|
| 836 |
+
log_attempt(user_id, q.get("id","?"), q.get("category","?"), False, sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join([err] + details))
|
| 837 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 838 |
return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
|
| 839 |
|
|
|
|
| 852 |
feedback += "\n\n" + "\n".join(details)
|
| 853 |
feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
|
| 854 |
|
| 855 |
+
log_attempt(user_id, q["id"], q.get("category","?"), bool(is_correct), sql_text, elapsed, int(q.get("difficulty",1)), "bank", " | ".join(details))
|
| 856 |
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 857 |
return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
|
| 858 |
|
|
|
|
| 869 |
def show_hint(session: dict):
|
| 870 |
if not session or "q" not in session:
|
| 871 |
return gr.update(value="Start a session first.", visible=True)
|
| 872 |
+
cat = session["q"].get("category","?")
|
| 873 |
hint = {
|
| 874 |
"SELECT *": "Use `SELECT * FROM table_name`.",
|
| 875 |
"SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
|
|
|
|
| 888 |
if not slug:
|
| 889 |
return None
|
| 890 |
user_id = slug[:64]
|
| 891 |
+
with DB_LOCK:
|
| 892 |
+
df = pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", CONN, params=(user_id,))
|
| 893 |
os.makedirs(EXPORT_DIR, exist_ok=True)
|
| 894 |
path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
|
| 895 |
(pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
|
| 896 |
return path
|
| 897 |
|
| 898 |
def _domain_status_md():
|
| 899 |
+
if CURRENT_INFO.get("source","") in ("openai","openai+fallback-questions"):
|
| 900 |
+
note = ""
|
| 901 |
+
if CURRENT_INFO.get("source") == "openai+fallback-questions":
|
| 902 |
+
note = " (LLM domain ok; used fallback questions)"
|
| 903 |
+
accepted = CURRENT_INFO.get("accepted",0)
|
| 904 |
+
dropped = CURRENT_INFO.get("dropped",0)
|
| 905 |
+
return (
|
| 906 |
+
f"✅ **Domain via OpenAI** `{CURRENT_INFO.get('model','?')}` → **{CURRENT_SCHEMA.get('domain','?')}**{note}. "
|
| 907 |
+
f"Accepted questions: {accepted}, dropped: {dropped}. \n"
|
| 908 |
+
f"Tables: {', '.join(t['name'] for t in CURRENT_SCHEMA.get('tables', []))}."
|
| 909 |
+
)
|
| 910 |
err = CURRENT_INFO.get("error","")
|
| 911 |
err_short = (err[:160] + "…") if len(err) > 160 else err
|
| 912 |
return f"⚠️ **OpenAI randomization unavailable** → using fallback **{CURRENT_SCHEMA.get('domain','?')}**.\n\n> Reason: {err_short}"
|
|
|
|
| 914 |
def regenerate_domain():
|
| 915 |
global CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO
|
| 916 |
prev = CURRENT_SCHEMA.get("domain") if CURRENT_SCHEMA else None
|
| 917 |
+
CURRENT_SCHEMA, CURRENT_QS, CURRENT_INFO = install_schema_and_prepare_questions(prev_domain=prev)
|
| 918 |
return gr.update(value=_domain_status_md(), visible=True)
|
| 919 |
|
| 920 |
def preview_table(tbl: str):
|