Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
|
| 2 |
+
# - Randomizes a relational domain via OpenAI (bookstore, retail sales, wholesaler,
|
| 3 |
+
# sales tax, oil & gas wells, marketing) OR falls back to a built-in dataset.
|
| 4 |
+
# - Builds 3–4 related tables (schema + seed rows) in SQLite.
|
| 5 |
+
# - Generates 8–12 randomized SQL questions with varied phrasings.
|
| 6 |
+
# - Validates answers by executing canonical SQL and comparing result sets.
|
| 7 |
+
# - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
|
| 8 |
+
# - Shows data results at the bottom pane for every run (SELECT or preview for VIEW/CTAS).
|
| 9 |
+
#
|
| 10 |
+
# Hugging Face Spaces: set OPENAI_API_KEY as a secret to enable LLM randomization.
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
import random
|
| 17 |
+
import sqlite3
|
| 18 |
+
from dataclasses import dataclass, asdict
|
| 19 |
+
from datetime import datetime, timezone
|
| 20 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 21 |
+
|
| 22 |
+
import gradio as gr
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
# Matplotlib for ERD drawing (headless)
|
| 27 |
+
import matplotlib
|
| 28 |
+
matplotlib.use("Agg")
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
from io import BytesIO
|
| 31 |
+
from PIL import Image
|
| 32 |
+
|
| 33 |
+
# -------------------- OpenAI (optional) --------------------
|
| 34 |
+
USE_RESPONSES_API = True
|
| 35 |
+
OPENAI_AVAILABLE = True
|
| 36 |
+
MODEL_ID = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
|
| 37 |
+
try:
|
| 38 |
+
from openai import OpenAI
|
| 39 |
+
_client = OpenAI() # requires OPENAI_API_KEY
|
| 40 |
+
except Exception:
|
| 41 |
+
OPENAI_AVAILABLE = False
|
| 42 |
+
_client = None
|
| 43 |
+
|
| 44 |
+
# -------------------- Global settings --------------------
|
| 45 |
+
DB_DIR = "/data" if os.path.exists("/data") else "."
|
| 46 |
+
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
|
| 47 |
+
EXPORT_DIR = "."
|
| 48 |
+
ADMIN_KEY = os.getenv("ADMIN_KEY", "demo")
|
| 49 |
+
RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
|
| 50 |
+
random.seed(RANDOM_SEED)
|
| 51 |
+
SYS_RAND = random.SystemRandom()
|
| 52 |
+
|
| 53 |
+
PLOT_FIGSIZE = (6.8, 3.4)
|
| 54 |
+
PLOT_DPI = 110
|
| 55 |
+
PLOT_HEIGHT = 300
|
| 56 |
+
|
| 57 |
+
# -------------------- ERD helpers --------------------
|
| 58 |
+
def _to_pil(fig) -> Image.Image:
|
| 59 |
+
buf = BytesIO()
|
| 60 |
+
fig.tight_layout()
|
| 61 |
+
fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
|
| 62 |
+
plt.close(fig)
|
| 63 |
+
buf.seek(0)
|
| 64 |
+
return Image.open(buf)
|
| 65 |
+
|
| 66 |
+
def draw_dynamic_erd(schema: Dict[str, Any]) -> Image.Image:
|
| 67 |
+
"""
|
| 68 |
+
schema = {
|
| 69 |
+
"domain": "bookstore",
|
| 70 |
+
"tables": [
|
| 71 |
+
{"name":"authors","columns":[{"name":"author_id","type":"INTEGER",...}, ...],
|
| 72 |
+
"pk":["author_id"], "fks":[{"columns":["author_id"],"ref_table":"...","ref_columns":["..."]}],
|
| 73 |
+
"rows":[{...}, {...}]}
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
"""
|
| 77 |
+
fig, ax = plt.subplots(figsize=PLOT_FIGSIZE)
|
| 78 |
+
ax.axis("off")
|
| 79 |
+
tables = schema.get("tables", [])
|
| 80 |
+
n = max(1, len(tables))
|
| 81 |
+
# Lay out boxes horizontally
|
| 82 |
+
margin = 0.03
|
| 83 |
+
width = (1 - margin*(n+1)) / n
|
| 84 |
+
height = 0.65
|
| 85 |
+
y = 0.25
|
| 86 |
+
boxes = {}
|
| 87 |
+
for i, t in enumerate(tables):
|
| 88 |
+
x = margin + i*(width + margin)
|
| 89 |
+
boxes[t["name"]] = (x, y, width, height)
|
| 90 |
+
ax.add_patch(plt.Rectangle((x, y), width, height, fill=False))
|
| 91 |
+
ax.text(x + 0.01, y + height - 0.05, f"**{t['name']}**", fontsize=10, ha="left", va="top")
|
| 92 |
+
yy = y + height - 0.10
|
| 93 |
+
pk = set(t.get("pk", []))
|
| 94 |
+
cols = t.get("columns", [])
|
| 95 |
+
for col in cols:
|
| 96 |
+
nm = col["name"]
|
| 97 |
+
mark = " (PK)" if nm in pk else ""
|
| 98 |
+
ax.text(x + 0.02, yy, f"{nm}{mark}", fontsize=9, ha="left", va="top")
|
| 99 |
+
yy -= 0.06
|
| 100 |
+
|
| 101 |
+
# Draw FK arrows
|
| 102 |
+
for t in tables:
|
| 103 |
+
for fk in t.get("fks", []):
|
| 104 |
+
src_tbl = t["name"]
|
| 105 |
+
dst_tbl = fk.get("ref_table")
|
| 106 |
+
if src_tbl in boxes and dst_tbl in boxes:
|
| 107 |
+
(x1, y1, w1, h1) = boxes[src_tbl]
|
| 108 |
+
(x2, y2, w2, h2) = boxes[dst_tbl]
|
| 109 |
+
ax.annotate("", xy=(x2 + w2/2, y2 + h2), xytext=(x1 + w1/2, y1),
|
| 110 |
+
arrowprops=dict(arrowstyle="->", lw=1.1))
|
| 111 |
+
ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
|
| 112 |
+
return _to_pil(fig)
|
| 113 |
+
|
| 114 |
+
# -------------------- SQLite helpers --------------------
|
| 115 |
+
def connect_db():
|
| 116 |
+
con = sqlite3.connect(DB_PATH)
|
| 117 |
+
con.execute("PRAGMA foreign_keys = ON;")
|
| 118 |
+
return con
|
| 119 |
+
|
| 120 |
+
CONN = connect_db()
|
| 121 |
+
|
| 122 |
+
def init_progress_tables(con: sqlite3.Connection):
|
| 123 |
+
cur = con.cursor()
|
| 124 |
+
cur.execute("""
|
| 125 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 126 |
+
user_id TEXT PRIMARY KEY,
|
| 127 |
+
name TEXT,
|
| 128 |
+
created_at TEXT
|
| 129 |
+
)
|
| 130 |
+
""")
|
| 131 |
+
cur.execute("""
|
| 132 |
+
CREATE TABLE IF NOT EXISTS attempts (
|
| 133 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 134 |
+
user_id TEXT,
|
| 135 |
+
question_id TEXT,
|
| 136 |
+
category TEXT,
|
| 137 |
+
correct INTEGER,
|
| 138 |
+
sql_text TEXT,
|
| 139 |
+
timestamp TEXT,
|
| 140 |
+
time_taken REAL,
|
| 141 |
+
difficulty INTEGER,
|
| 142 |
+
source TEXT,
|
| 143 |
+
notes TEXT
|
| 144 |
+
)
|
| 145 |
+
""")
|
| 146 |
+
cur.execute("""
|
| 147 |
+
CREATE TABLE IF NOT EXISTS session_meta (
|
| 148 |
+
id INTEGER PRIMARY KEY CHECK (id=1),
|
| 149 |
+
domain TEXT,
|
| 150 |
+
schema_json TEXT
|
| 151 |
+
)
|
| 152 |
+
""")
|
| 153 |
+
con.commit()
|
| 154 |
+
|
| 155 |
+
init_progress_tables(CONN)
|
| 156 |
+
|
| 157 |
+
# -------------------- Fallback dataset (if no OpenAI) --------------------
|
| 158 |
+
FALLBACK_SCHEMA = {
|
| 159 |
+
"domain": "bookstore",
|
| 160 |
+
"tables": [
|
| 161 |
+
{
|
| 162 |
+
"name": "authors",
|
| 163 |
+
"pk": ["author_id"],
|
| 164 |
+
"columns": [
|
| 165 |
+
{"name":"author_id","type":"INTEGER"},
|
| 166 |
+
{"name":"name","type":"TEXT"},
|
| 167 |
+
{"name":"country","type":"TEXT"},
|
| 168 |
+
{"name":"birth_year","type":"INTEGER"},
|
| 169 |
+
],
|
| 170 |
+
"fks": [],
|
| 171 |
+
"rows": [
|
| 172 |
+
{"author_id":1,"name":"Isaac Asimov","country":"USA","birth_year":1920},
|
| 173 |
+
{"author_id":2,"name":"Ursula K. Le Guin","country":"USA","birth_year":1929},
|
| 174 |
+
{"author_id":3,"name":"Haruki Murakami","country":"Japan","birth_year":1949},
|
| 175 |
+
{"author_id":4,"name":"Chinua Achebe","country":"Nigeria","birth_year":1930},
|
| 176 |
+
{"author_id":5,"name":"Jane Austen","country":"UK","birth_year":1775},
|
| 177 |
+
{"author_id":6,"name":"J.K. Rowling","country":"UK","birth_year":1965},
|
| 178 |
+
{"author_id":7,"name":"Yuval Noah Harari","country":"Israel","birth_year":1976},
|
| 179 |
+
{"author_id":8,"name":"New Author","country":"Nowhere","birth_year":1990},
|
| 180 |
+
],
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"name": "bookstores",
|
| 184 |
+
"pk": ["store_id"],
|
| 185 |
+
"columns": [
|
| 186 |
+
{"name":"store_id","type":"INTEGER"},
|
| 187 |
+
{"name":"name","type":"TEXT"},
|
| 188 |
+
{"name":"city","type":"TEXT"},
|
| 189 |
+
{"name":"state","type":"TEXT"},
|
| 190 |
+
],
|
| 191 |
+
"fks": [],
|
| 192 |
+
"rows": [
|
| 193 |
+
{"store_id":1,"name":"Downtown Books","city":"Oklahoma City","state":"OK"},
|
| 194 |
+
{"store_id":2,"name":"Harbor Books","city":"Seattle","state":"WA"},
|
| 195 |
+
{"store_id":3,"name":"Desert Pages","city":"Phoenix","state":"AZ"},
|
| 196 |
+
],
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"name": "books",
|
| 200 |
+
"pk": ["book_id"],
|
| 201 |
+
"columns": [
|
| 202 |
+
{"name":"book_id","type":"INTEGER"},
|
| 203 |
+
{"name":"title","type":"TEXT"},
|
| 204 |
+
{"name":"author_id","type":"INTEGER"},
|
| 205 |
+
{"name":"store_id","type":"INTEGER"},
|
| 206 |
+
{"name":"category","type":"TEXT"},
|
| 207 |
+
{"name":"price","type":"REAL"},
|
| 208 |
+
{"name":"published_year","type":"INTEGER"},
|
| 209 |
+
],
|
| 210 |
+
"fks": [
|
| 211 |
+
{"columns":["author_id"],"ref_table":"authors","ref_columns":["author_id"]},
|
| 212 |
+
{"columns":["store_id"],"ref_table":"bookstores","ref_columns":["store_id"]},
|
| 213 |
+
],
|
| 214 |
+
"rows": [
|
| 215 |
+
{"book_id":101,"title":"Foundation","author_id":1,"store_id":1,"category":"Sci-Fi","price":14.99,"published_year":1951},
|
| 216 |
+
{"book_id":102,"title":"I, Robot","author_id":1,"store_id":1,"category":"Sci-Fi","price":12.50,"published_year":1950},
|
| 217 |
+
{"book_id":103,"title":"The Left Hand of Darkness","author_id":2,"store_id":2,"category":"Sci-Fi","price":16.00,"published_year":1969},
|
| 218 |
+
{"book_id":104,"title":"A Wizard of Earthsea","author_id":2,"store_id":2,"category":"Fantasy","price":11.50,"published_year":1968},
|
| 219 |
+
{"book_id":105,"title":"Norwegian Wood","author_id":3,"store_id":3,"category":"Fiction","price":18.00,"published_year":1987},
|
| 220 |
+
{"book_id":106,"title":"Kafka on the Shore","author_id":3,"store_id":1,"category":"Fiction","price":21.00,"published_year":2002},
|
| 221 |
+
{"book_id":107,"title":"Things Fall Apart","author_id":4,"store_id":1,"category":"Fiction","price":10.00,"published_year":1958},
|
| 222 |
+
{"book_id":108,"title":"Pride and Prejudice","author_id":5,"store_id":2,"category":"Fiction","price":9.00,"published_year":1813},
|
| 223 |
+
{"book_id":109,"title":"Harry Potter and the Sorcerer's Stone","author_id":6,"store_id":3,"category":"Children","price":22.00,"published_year":1997},
|
| 224 |
+
{"book_id":110,"title":"Harry Potter and the Chamber of Secrets","author_id":6,"store_id":3,"category":"Children","price":23.00,"published_year":1998},
|
| 225 |
+
{"book_id":111,"title":"Sapiens","author_id":7,"store_id":1,"category":"History","price":26.00,"published_year":2011},
|
| 226 |
+
{"book_id":112,"title":"Homo Deus","author_id":7,"store_id":2,"category":"History","price":28.00,"published_year":2015},
|
| 227 |
+
],
|
| 228 |
+
},
|
| 229 |
+
]
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
FALLBACK_QUESTIONS = [
|
| 233 |
+
{
|
| 234 |
+
"id":"Q01","category":"SELECT *","difficulty":1,
|
| 235 |
+
"prompt_md":"Select all rows and columns from `authors`.",
|
| 236 |
+
"answer_sql":["SELECT * FROM authors;"],
|
| 237 |
+
"requires_aliases":False,"required_aliases":[]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"id":"Q02","category":"SELECT columns","difficulty":1,
|
| 241 |
+
"prompt_md":"Show `title` and `price` from `books`.",
|
| 242 |
+
"answer_sql":["SELECT title, price FROM books;"],
|
| 243 |
+
"requires_aliases":False,"required_aliases":[]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"id":"Q03","category":"WHERE","difficulty":1,
|
| 247 |
+
"prompt_md":"List Sci‑Fi books under $15 (show title, price).",
|
| 248 |
+
"answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"],
|
| 249 |
+
"requires_aliases":False,"required_aliases":[]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"id":"Q04","category":"Aliases","difficulty":1,
|
| 253 |
+
"prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.",
|
| 254 |
+
"answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"],
|
| 255 |
+
"requires_aliases":True,"required_aliases":["a","b"]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"id":"Q05","category":"JOIN (INNER)","difficulty":2,
|
| 259 |
+
"prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.",
|
| 260 |
+
"answer_sql":[
|
| 261 |
+
"SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;"
|
| 262 |
+
],
|
| 263 |
+
"requires_aliases":False,"required_aliases":[]
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"id":"Q06","category":"JOIN (LEFT)","difficulty":2,
|
| 267 |
+
"prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.",
|
| 268 |
+
"answer_sql":[
|
| 269 |
+
"SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;"
|
| 270 |
+
],
|
| 271 |
+
"requires_aliases":False,"required_aliases":[]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"id":"Q07","category":"VIEW","difficulty":2,
|
| 275 |
+
"prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.",
|
| 276 |
+
"answer_sql":[
|
| 277 |
+
"CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;"
|
| 278 |
+
],
|
| 279 |
+
"requires_aliases":False,"required_aliases":[]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"id":"Q08","category":"CTAS / SELECT INTO","difficulty":2,
|
| 283 |
+
"prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.",
|
| 284 |
+
"answer_sql":[
|
| 285 |
+
"CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;",
|
| 286 |
+
"SELECT * INTO cheap_books FROM books WHERE price < 12;"
|
| 287 |
+
],
|
| 288 |
+
"requires_aliases":False,"required_aliases":[]
|
| 289 |
+
},
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
# -------------------- OpenAI prompts --------------------
|
| 293 |
+
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 294 |
+
"name": "DomainSQLPack",
|
| 295 |
+
"schema": {
|
| 296 |
+
"type": "object",
|
| 297 |
+
"additionalProperties": False,
|
| 298 |
+
"properties": {
|
| 299 |
+
"domain": {"type":"string"},
|
| 300 |
+
"tables": {
|
| 301 |
+
"type":"array",
|
| 302 |
+
"items": {
|
| 303 |
+
"type":"object",
|
| 304 |
+
"additionalProperties": False,
|
| 305 |
+
"properties": {
|
| 306 |
+
"name": {"type":"string"},
|
| 307 |
+
"pk": {"type":"array","items":{"type":"string"}},
|
| 308 |
+
"columns": {
|
| 309 |
+
"type":"array",
|
| 310 |
+
"items": {
|
| 311 |
+
"type":"object",
|
| 312 |
+
"additionalProperties": False,
|
| 313 |
+
"properties": {
|
| 314 |
+
"name":{"type":"string"},
|
| 315 |
+
"type":{"type":"string"}
|
| 316 |
+
},
|
| 317 |
+
"required":["name","type"]
|
| 318 |
+
}
|
| 319 |
+
},
|
| 320 |
+
"fks": {
|
| 321 |
+
"type":"array",
|
| 322 |
+
"items": {
|
| 323 |
+
"type":"object",
|
| 324 |
+
"additionalProperties": False,
|
| 325 |
+
"properties": {
|
| 326 |
+
"columns":{"type":"array","items":{"type":"string"}},
|
| 327 |
+
"ref_table":{"type":"string"},
|
| 328 |
+
"ref_columns":{"type":"array","items":{"type":"string"}}
|
| 329 |
+
},
|
| 330 |
+
"required":["columns","ref_table","ref_columns"]
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
"rows": {"type":"array","items":{"type":["object","array"]}}
|
| 334 |
+
},
|
| 335 |
+
"required":["name","pk","columns","fks","rows"]
|
| 336 |
+
},
|
| 337 |
+
"minItems":3,"maxItems":4
|
| 338 |
+
},
|
| 339 |
+
"questions": {
|
| 340 |
+
"type":"array",
|
| 341 |
+
"items": {
|
| 342 |
+
"type":"object",
|
| 343 |
+
"additionalProperties": False,
|
| 344 |
+
"properties": {
|
| 345 |
+
"id":{"type":"string"},
|
| 346 |
+
"category":{"type":"string"},
|
| 347 |
+
"difficulty":{"type":"integer"},
|
| 348 |
+
"prompt_md":{"type":"string"},
|
| 349 |
+
"answer_sql":{"type":"array","items":{"type":"string"}},
|
| 350 |
+
"requires_aliases":{"type":"boolean"},
|
| 351 |
+
"required_aliases":{"type":"array","items":{"type":"string"}}
|
| 352 |
+
},
|
| 353 |
+
"required":["id","category","difficulty","prompt_md","answer_sql"]
|
| 354 |
+
},
|
| 355 |
+
"minItems":8,"maxItems":12
|
| 356 |
+
}
|
| 357 |
+
},
|
| 358 |
+
"required":["domain","tables","questions"]
|
| 359 |
+
},
|
| 360 |
+
"strict": True
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
DOMAIN_AND_QUESTIONS_PROMPT = """
|
| 364 |
+
You are designing a small relational dataset and training questions for SQL basics.
|
| 365 |
+
|
| 366 |
+
1) Choose ONE domain at random from:
|
| 367 |
+
- bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
|
| 368 |
+
|
| 369 |
+
2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
|
| 370 |
+
- Use snake_case, avoid reserved words.
|
| 371 |
+
- Types: INTEGER, REAL, TEXT, NUMERIC, DATE (but no advanced features).
|
| 372 |
+
- Primary keys (pk) and foreign keys (fks) must align.
|
| 373 |
+
- Provide 8–15 small, realistic seed rows per table (not huge).
|
| 374 |
+
|
| 375 |
+
3) Generate 8–12 SQL questions covering basics with varied, natural language:
|
| 376 |
+
- Categories from: "SELECT *", "SELECT columns", "WHERE", "Aliases",
|
| 377 |
+
"JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
|
| 378 |
+
- Include a few joins and at least one LEFT JOIN.
|
| 379 |
+
- Include one view creation.
|
| 380 |
+
- Include one table creation from SELECT (either CTAS or SELECT INTO).
|
| 381 |
+
- Prefer SQLite-compatible SQL. DO NOT use RIGHT/FULL OUTER JOIN.
|
| 382 |
+
- Offer 1–3 acceptable answer_sql variants per question.
|
| 383 |
+
- For 1–2 questions, require table aliases (set requires_aliases=true and list required_aliases).
|
| 384 |
+
|
| 385 |
+
Return JSON only.
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
def llm_generate_domain_and_questions() -> Optional[Dict[str,Any]]:
|
| 389 |
+
if not OPENAI_AVAILABLE:
|
| 390 |
+
return None
|
| 391 |
+
try:
|
| 392 |
+
if USE_RESPONSES_API:
|
| 393 |
+
resp = _client.responses.create(
|
| 394 |
+
model=MODEL_ID,
|
| 395 |
+
response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
|
| 396 |
+
input=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
|
| 397 |
+
temperature=0.6,
|
| 398 |
+
)
|
| 399 |
+
data_text = getattr(resp, "output_text", None)
|
| 400 |
+
else:
|
| 401 |
+
chat = _client.chat.completions.create(
|
| 402 |
+
model=MODEL_ID,
|
| 403 |
+
messages=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
|
| 404 |
+
temperature=0.6
|
| 405 |
+
)
|
| 406 |
+
data_text = chat.choices[0].message.content
|
| 407 |
+
obj = json.loads(data_text) if data_text else None
|
| 408 |
+
return obj
|
| 409 |
+
except Exception:
|
| 410 |
+
return None
|
| 411 |
+
|
| 412 |
+
# -------------------- Schema install & question handling --------------------
|
| 413 |
+
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
|
| 414 |
+
cur = con.cursor()
|
| 415 |
+
cur.execute("SELECT name, type FROM sqlite_master WHERE type IN ('table','view')")
|
| 416 |
+
items = cur.fetchall()
|
| 417 |
+
for name, typ in items:
|
| 418 |
+
if keep_internal and name in ("users","attempts","session_meta"):
|
| 419 |
+
continue
|
| 420 |
+
try:
|
| 421 |
+
cur.execute(f"DROP {typ.upper()} IF EXISTS {name}")
|
| 422 |
+
except Exception:
|
| 423 |
+
pass
|
| 424 |
+
con.commit()
|
| 425 |
+
|
| 426 |
+
def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
|
| 427 |
+
drop_existing_domain_tables(con, keep_internal=True)
|
| 428 |
+
cur = con.cursor()
|
| 429 |
+
# Create tables first
|
| 430 |
+
for t in schema.get("tables", []):
|
| 431 |
+
cols_sql = []
|
| 432 |
+
pk = t.get("pk", [])
|
| 433 |
+
for c in t.get("columns", []):
|
| 434 |
+
cname = c["name"]
|
| 435 |
+
ctype = c.get("type","TEXT")
|
| 436 |
+
cols_sql.append(f"{cname} {ctype}")
|
| 437 |
+
if pk:
|
| 438 |
+
cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
|
| 439 |
+
create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
|
| 440 |
+
cur.execute(create_sql)
|
| 441 |
+
# Add FKs (SQLite requires them inline on create; to be safe, we validate only)
|
| 442 |
+
# Insert rows
|
| 443 |
+
for t in schema.get("tables", []):
|
| 444 |
+
if not t.get("rows"):
|
| 445 |
+
continue
|
| 446 |
+
cols = [c["name"] for c in t.get("columns", [])]
|
| 447 |
+
qmarks = ",".join(["?"]*len(cols))
|
| 448 |
+
insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
|
| 449 |
+
# rows can be objects or arrays
|
| 450 |
+
for r in t["rows"]:
|
| 451 |
+
if isinstance(r, dict):
|
| 452 |
+
vals = [r.get(col, None) for col in cols]
|
| 453 |
+
elif isinstance(r, list) or isinstance(r, tuple):
|
| 454 |
+
vals = list(r) + [None]*(len(cols)-len(r))
|
| 455 |
+
vals = vals[:len(cols)]
|
| 456 |
+
else:
|
| 457 |
+
continue
|
| 458 |
+
cur.execute(insert_sql, vals)
|
| 459 |
+
con.commit()
|
| 460 |
+
# Persist schema JSON
|
| 461 |
+
cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
|
| 462 |
+
(schema.get("domain","unknown"), json.dumps(schema)))
|
| 463 |
+
con.commit()
|
| 464 |
+
|
| 465 |
+
def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
|
| 466 |
+
return pd.read_sql_query(sql, con)
|
| 467 |
+
|
| 468 |
+
def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
|
| 469 |
+
s = sql.strip().strip(";")
|
| 470 |
+
if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
|
| 471 |
+
m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
|
| 472 |
+
if m:
|
| 473 |
+
cols, tbl, rest = m.groups()
|
| 474 |
+
return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
|
| 475 |
+
return sql, None
|
| 476 |
+
|
| 477 |
+
def detect_unsupported_joins(sql: str) -> Optional[str]:
|
| 478 |
+
low = sql.lower()
|
| 479 |
+
if " right join " in low:
|
| 480 |
+
return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
|
| 481 |
+
if " full join " in low or " full outer join " in low:
|
| 482 |
+
return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
|
| 483 |
+
if " ilike " in low:
|
| 484 |
+
return "SQLite has no ILIKE. Use `LOWER(col) LIKE LOWER('%pattern%')`."
|
| 485 |
+
return None
|
| 486 |
+
|
| 487 |
+
def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
|
| 488 |
+
low = sql.lower()
|
| 489 |
+
if " cross join " in low:
|
| 490 |
+
return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
|
| 491 |
+
comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
|
| 492 |
+
missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
|
| 493 |
+
if comma_from or missing_on:
|
| 494 |
+
try:
|
| 495 |
+
cur = con.cursor()
|
| 496 |
+
if comma_from:
|
| 497 |
+
t1, t2 = comma_from.groups()
|
| 498 |
+
else:
|
| 499 |
+
m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low)
|
| 500 |
+
j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
|
| 501 |
+
if not m or not j:
|
| 502 |
+
return "Possible cartesian product: no join condition detected."
|
| 503 |
+
t1, t2 = m.group(1), j.group(1)
|
| 504 |
+
cur.execute(f"SELECT COUNT(*) FROM {t1}")
|
| 505 |
+
n1 = cur.fetchone()[0]
|
| 506 |
+
cur.execute(f"SELECT COUNT(*) FROM {t2}")
|
| 507 |
+
n2 = cur.fetchone()[0]
|
| 508 |
+
prod = n1 * n2
|
| 509 |
+
if len(df_result) == prod and prod > 0:
|
| 510 |
+
return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
|
| 511 |
+
except Exception:
|
| 512 |
+
return "Possible cartesian product: no join condition detected."
|
| 513 |
+
return None
|
| 514 |
+
|
| 515 |
+
def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool:
|
| 516 |
+
if df_a.shape != df_b.shape:
|
| 517 |
+
return False
|
| 518 |
+
a = df_a.copy()
|
| 519 |
+
b = df_b.copy()
|
| 520 |
+
a.columns = [c.lower() for c in a.columns]
|
| 521 |
+
b.columns = [c.lower() for c in b.columns]
|
| 522 |
+
a = a.sort_values(list(a.columns)).reset_index(drop=True)
|
| 523 |
+
b = b.sort_values(list(b.columns)).reset_index(drop=True)
|
| 524 |
+
return a.equals(b)
|
| 525 |
+
|
| 526 |
+
def aliases_present(sql: str, required_aliases: List[str]) -> bool:
|
| 527 |
+
low = re.sub(r"\s+", " ", sql.lower())
|
| 528 |
+
for al in required_aliases:
|
| 529 |
+
if f" {al}." not in low and f" as {al} " not in low:
|
| 530 |
+
return False
|
| 531 |
+
return True
|
| 532 |
+
|
| 533 |
+
# -------------------- Question model --------------------
|
| 534 |
+
@dataclass
|
| 535 |
+
class SQLQuestion:
|
| 536 |
+
id: str
|
| 537 |
+
category: str
|
| 538 |
+
difficulty: int
|
| 539 |
+
prompt_md: str
|
| 540 |
+
answer_sql: List[str]
|
| 541 |
+
requires_aliases: bool = False
|
| 542 |
+
required_aliases: List[str] = None
|
| 543 |
+
|
| 544 |
+
def to_question_dict(q) -> Dict[str,Any]:
|
| 545 |
+
d = dict(q)
|
| 546 |
+
d.setdefault("requires_aliases", False)
|
| 547 |
+
d.setdefault("required_aliases", [])
|
| 548 |
+
return d
|
| 549 |
+
|
| 550 |
+
def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
|
| 551 |
+
out = []
|
| 552 |
+
for o in obj_list:
|
| 553 |
+
out.append(to_question_dict(o))
|
| 554 |
+
return out
|
| 555 |
+
|
| 556 |
+
# -------------------- Domain bootstrap --------------------
|
| 557 |
+
def bootstrap_domain_with_llm_or_fallback() -> Tuple[Dict[str,Any], List[Dict[str,Any]]]:
|
| 558 |
+
obj = llm_generate_domain_and_questions()
|
| 559 |
+
if obj is None:
|
| 560 |
+
return FALLBACK_SCHEMA, FALLBACK_QUESTIONS
|
| 561 |
+
# Guardrails: strip RIGHT/FULL joins from answers
|
| 562 |
+
clean_qs = []
|
| 563 |
+
for q in obj.get("questions", []):
|
| 564 |
+
answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()]
|
| 565 |
+
if not answers:
|
| 566 |
+
continue
|
| 567 |
+
q["answer_sql"] = answers
|
| 568 |
+
q.setdefault("requires_aliases", False)
|
| 569 |
+
q.setdefault("required_aliases", [])
|
| 570 |
+
clean_qs.append(q)
|
| 571 |
+
obj["questions"] = clean_qs
|
| 572 |
+
return obj, clean_qs
|
| 573 |
+
|
| 574 |
+
def install_new_domain():
|
| 575 |
+
schema, questions = bootstrap_domain_with_llm_or_fallback()
|
| 576 |
+
install_schema(CONN, schema)
|
| 577 |
+
return schema, questions
|
| 578 |
+
|
| 579 |
+
# -------------------- Session state --------------------
|
| 580 |
+
CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
|
| 581 |
+
|
| 582 |
+
# -------------------- Progress + mastery --------------------
|
| 583 |
+
def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
|
| 584 |
+
cur = con.cursor()
|
| 585 |
+
cur.execute("SELECT user_id FROM users WHERE user_id = ?", (user_id,))
|
| 586 |
+
if cur.fetchone() is None:
|
| 587 |
+
cur.execute("INSERT INTO users (user_id, name, created_at) VALUES (?, ?, ?)",
|
| 588 |
+
(user_id, name, datetime.now(timezone.utc).isoformat()))
|
| 589 |
+
else:
|
| 590 |
+
cur.execute("UPDATE users SET name=? WHERE user_id=?", (name, user_id))
|
| 591 |
+
con.commit()
|
| 592 |
+
|
| 593 |
+
CATEGORIES_ORDER = [
|
| 594 |
+
"SELECT *", "SELECT columns", "WHERE", "Aliases",
|
| 595 |
+
"JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO"
|
| 596 |
+
]
|
| 597 |
+
|
| 598 |
+
def topic_stats(df_attempts: pd.DataFrame) -> pd.DataFrame:
|
| 599 |
+
rows = []
|
| 600 |
+
for cat in CATEGORIES_ORDER:
|
| 601 |
+
sub = df_attempts[df_attempts["category"] == cat] if not df_attempts.empty else pd.DataFrame()
|
| 602 |
+
att = int(sub.shape[0]) if not sub.empty else 0
|
| 603 |
+
cor = int(sub["correct"].sum()) if not sub.empty else 0
|
| 604 |
+
acc = float(cor / max(att, 1))
|
| 605 |
+
rows.append({"category":cat,"attempts":att,"correct":cor,"accuracy":acc})
|
| 606 |
+
return pd.DataFrame(rows)
|
| 607 |
+
|
| 608 |
+
def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
|
| 609 |
+
return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))
|
| 610 |
+
|
| 611 |
+
def pick_next_question(user_id: str) -> Dict[str,Any]:
|
| 612 |
+
df = fetch_attempts(CONN, user_id)
|
| 613 |
+
stats = topic_stats(df)
|
| 614 |
+
stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
|
| 615 |
+
weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
|
| 616 |
+
cands = [q for q in CURRENT_QS if q["category"] == weakest] or CURRENT_QS
|
| 617 |
+
return dict(random.choice(cands))
|
| 618 |
+
|
| 619 |
+
# -------------------- Execution & feedback --------------------
|
| 620 |
+
def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
|
| 621 |
+
if not sql_text or not sql_text.strip():
|
| 622 |
+
return None, "Enter a SQL statement.", None, None
|
| 623 |
+
|
| 624 |
+
sql_raw = sql_text.strip().rstrip(";")
|
| 625 |
+
sql_rew, created_tbl = rewrite_select_into(sql_raw)
|
| 626 |
+
note = None
|
| 627 |
+
if sql_rew != sql_raw:
|
| 628 |
+
note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite."
|
| 629 |
+
|
| 630 |
+
unsup = detect_unsupported_joins(sql_rew)
|
| 631 |
+
if unsup:
|
| 632 |
+
return None, unsup, None, note
|
| 633 |
+
|
| 634 |
+
try:
|
| 635 |
+
low = sql_rew.lower()
|
| 636 |
+
if low.startswith("select"):
|
| 637 |
+
df = run_df(CONN, sql_rew)
|
| 638 |
+
warn = detect_cartesian(CONN, sql_rew, df)
|
| 639 |
+
return df, None, warn, note
|
| 640 |
+
else:
|
| 641 |
+
cur = CONN.cursor()
|
| 642 |
+
cur.execute(sql_rew)
|
| 643 |
+
CONN.commit()
|
| 644 |
+
# Preview newly created objects
|
| 645 |
+
if low.startswith("create view"):
|
| 646 |
+
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
|
| 647 |
+
name = m.group(2) if m else None
|
| 648 |
+
if name:
|
| 649 |
+
try:
|
| 650 |
+
df = run_df(CONN, f"SELECT * FROM {name}")
|
| 651 |
+
return df, None, None, note
|
| 652 |
+
except Exception:
|
| 653 |
+
return None, "View created but could not be queried.", None, note
|
| 654 |
+
if low.startswith("create table"):
|
| 655 |
+
tbl = created_tbl
|
| 656 |
+
if not tbl:
|
| 657 |
+
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 658 |
+
tbl = m.group(2) if m else None
|
| 659 |
+
if tbl:
|
| 660 |
+
try:
|
| 661 |
+
df = run_df(CONN, f"SELECT * FROM {tbl}")
|
| 662 |
+
return df, None, None, note
|
| 663 |
+
except Exception:
|
| 664 |
+
return None, "Table created but could not be queried.", None, note
|
| 665 |
+
return pd.DataFrame(), None, None, note
|
| 666 |
+
except Exception as e:
|
| 667 |
+
# Tailored messages
|
| 668 |
+
msg = str(e)
|
| 669 |
+
if "no such table" in msg.lower():
|
| 670 |
+
return None, f"{msg}. Check table names for this randomized domain.", None, note
|
| 671 |
+
if "no such column" in msg.lower():
|
| 672 |
+
return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
|
| 673 |
+
if "ambiguous column name" in msg.lower():
|
| 674 |
+
return None, f"{msg}. Qualify the column with a table alias.", None, note
|
| 675 |
+
if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
|
| 676 |
+
return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
|
| 677 |
+
if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
|
| 678 |
+
return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
|
| 679 |
+
if "syntax error" in msg.lower():
|
| 680 |
+
return None, f"Syntax error. Check commas, keywords, and parentheses. Raw error: {msg}", None, note
|
| 681 |
+
return None, f"SQL error: {msg}", None, note
|
| 682 |
+
|
| 683 |
+
def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
|
| 684 |
+
for sql in answer_sql:
|
| 685 |
+
try:
|
| 686 |
+
low = sql.strip().lower()
|
| 687 |
+
if low.startswith("select"):
|
| 688 |
+
return run_df(CONN, sql)
|
| 689 |
+
if low.startswith("create view"):
|
| 690 |
+
# temp preview
|
| 691 |
+
m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 692 |
+
view_name = m.group(2) if m else "vw_tmp"
|
| 693 |
+
cur = CONN.cursor()
|
| 694 |
+
cur.execute(f"DROP VIEW IF EXISTS {view_name}")
|
| 695 |
+
cur.execute(sql)
|
| 696 |
+
CONN.commit()
|
| 697 |
+
return run_df(CONN, f"SELECT * FROM {view_name}")
|
| 698 |
+
if low.startswith("create table"):
|
| 699 |
+
m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
|
| 700 |
+
tbl = m.group(2) if m else None
|
| 701 |
+
cur = CONN.cursor()
|
| 702 |
+
if tbl:
|
| 703 |
+
cur.execute(f"DROP TABLE IF EXISTS {tbl}")
|
| 704 |
+
cur.execute(sql)
|
| 705 |
+
CONN.commit()
|
| 706 |
+
if tbl:
|
| 707 |
+
return run_df(CONN, f"SELECT * FROM {tbl}")
|
| 708 |
+
except Exception:
|
| 709 |
+
continue
|
| 710 |
+
return None
|
| 711 |
+
|
| 712 |
+
def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]:
|
| 713 |
+
df_expected = answer_df(q["answer_sql"])
|
| 714 |
+
# If we can't build a canonical DF (e.g., DDL side effect), we accept any successful execution as correct
|
| 715 |
+
if df_expected is None:
|
| 716 |
+
return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
|
| 717 |
+
if df_student is None:
|
| 718 |
+
return False, f"**Explanation:** Expected data result differs."
|
| 719 |
+
return results_equal(df_student, df_expected), f"**Explanation:** Compare your result to a canonical solution."
|
| 720 |
+
|
| 721 |
+
def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
|
| 722 |
+
time_taken: float, difficulty: int, source: str, notes: str):
|
| 723 |
+
cur = CONN.cursor()
|
| 724 |
+
cur.execute("""
|
| 725 |
+
INSERT INTO attempts (user_id, question_id, category, correct, sql_text, timestamp, time_taken, difficulty, source, notes)
|
| 726 |
+
VALUES (?,?,?,?,?,?,?,?,?,?)
|
| 727 |
+
""", (user_id, qid, category, int(correct), sql_text, datetime.now(timezone.utc).isoformat(),
|
| 728 |
+
time_taken, difficulty, source, notes))
|
| 729 |
+
CONN.commit()
|
| 730 |
+
|
| 731 |
+
# -------------------- UI callbacks --------------------
|
| 732 |
+
def start_session(name: str, session: dict):
|
| 733 |
+
name = (name or "").strip()
|
| 734 |
+
if not name:
|
| 735 |
+
return session, gr.update(value="Please enter your name to begin.", visible=True), gr.update(visible=False), gr.update(visible=False), None, gr.update(visible=False), pd.DataFrame(), pd.DataFrame()
|
| 736 |
+
|
| 737 |
+
slug = "-".join(name.lower().split())
|
| 738 |
+
user_id = slug[:64] if slug else f"user-{int(time.time())}"
|
| 739 |
+
upsert_user(CONN, user_id, name)
|
| 740 |
+
q = pick_next_question(user_id)
|
| 741 |
+
session = {"user_id": user_id, "name": name, "qid": q["id"], "start_ts": time.time(), "q": q}
|
| 742 |
+
|
| 743 |
+
prompt = q["prompt_md"]
|
| 744 |
+
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 745 |
+
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 746 |
+
return (session,
|
| 747 |
+
gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
|
| 748 |
+
gr.update(visible=True), # show SQL input
|
| 749 |
+
gr.update(value="", visible=True), # preview block
|
| 750 |
+
erd,
|
| 751 |
+
gr.update(visible=False), # next btn hidden until submit
|
| 752 |
+
stats,
|
| 753 |
+
pd.DataFrame())
|
| 754 |
+
|
| 755 |
+
def render_preview_and_erd(sql_text: str, session: dict):
|
| 756 |
+
if not session or "q" not in session:
|
| 757 |
+
return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
|
| 758 |
+
s = (sql_text or "").strip()
|
| 759 |
+
if not s:
|
| 760 |
+
return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
|
| 761 |
+
return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), draw_dynamic_erd(CURRENT_SCHEMA)
|
| 762 |
+
|
| 763 |
+
def submit_answer(sql_text: str, session: dict):
|
| 764 |
+
if not session or "user_id" not in session or "q" not in session:
|
| 765 |
+
return gr.update(value="Start a session first.", visible=True), pd.DataFrame(), gr.update(visible=False), pd.DataFrame()
|
| 766 |
+
user_id = session["user_id"]
|
| 767 |
+
q = session["q"]
|
| 768 |
+
elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))
|
| 769 |
+
|
| 770 |
+
df, err, warn, note = exec_student_sql(sql_text)
|
| 771 |
+
details = []
|
| 772 |
+
if note: details.append(f"ℹ️ {note}")
|
| 773 |
+
if err:
|
| 774 |
+
fb = f"❌ **Did not run**\n\n{err}"
|
| 775 |
+
if details: fb += "\n\n" + "\n".join(details)
|
| 776 |
+
log_attempt(user_id, q["id"], q["category"], False, sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join([err] + details))
|
| 777 |
+
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 778 |
+
return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats
|
| 779 |
+
|
| 780 |
+
# Validate correctness
|
| 781 |
+
alias_msg = None
|
| 782 |
+
if q.get("requires_aliases"):
|
| 783 |
+
if not aliases_present(sql_text, q.get("required_aliases", [])):
|
| 784 |
+
alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."
|
| 785 |
+
|
| 786 |
+
is_correct, explanation = validate_answer(q, sql_text, df)
|
| 787 |
+
if warn: details.append(f"⚠️ {warn}")
|
| 788 |
+
if alias_msg: details.append(alias_msg)
|
| 789 |
+
|
| 790 |
+
prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
|
| 791 |
+
feedback = prefix
|
| 792 |
+
if details:
|
| 793 |
+
feedback += "\n\n" + "\n".join(details)
|
| 794 |
+
feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"
|
| 795 |
+
|
| 796 |
+
log_attempt(user_id, q["id"], q["category"], bool(is_correct), sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join(details))
|
| 797 |
+
stats = topic_stats(fetch_attempts(CONN, user_id))
|
| 798 |
+
return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats
|
| 799 |
+
|
| 800 |
+
def next_question(session: dict):
|
| 801 |
+
if not session or "user_id" not in session:
|
| 802 |
+
return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
|
| 803 |
+
user_id = session["user_id"]
|
| 804 |
+
q = pick_next_question(user_id)
|
| 805 |
+
session["qid"] = q["id"]
|
| 806 |
+
session["q"] = q
|
| 807 |
+
session["start_ts"] = time.time()
|
| 808 |
+
return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
|
| 809 |
+
|
| 810 |
+
def show_hint(session: dict):
|
| 811 |
+
if not session or "q" not in session:
|
| 812 |
+
return gr.update(value="Start a session first.", visible=True)
|
| 813 |
+
# Lightweight hint policy: category-specific guidance
|
| 814 |
+
cat = session["q"]["category"]
|
| 815 |
+
hint = {
|
| 816 |
+
"SELECT *": "Use `SELECT * FROM table_name`.",
|
| 817 |
+
"SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
|
| 818 |
+
"WHERE": "Filter with `WHERE` and combine conditions using AND/OR.",
|
| 819 |
+
"Aliases": "Use `table_name t` and qualify: `t.col`.",
|
| 820 |
+
"JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.",
|
| 821 |
+
"JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.",
|
| 822 |
+
"Aggregation": "Use aggregate functions and `GROUP BY` non-aggregated columns.",
|
| 823 |
+
"VIEW": "`CREATE VIEW view_name AS SELECT ...`.",
|
| 824 |
+
"CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`."
|
| 825 |
+
}.get(cat, "Read the ER diagram and identify keys to join on.")
|
| 826 |
+
return gr.update(value=f"**Hint:** {hint}", visible=True)
|
| 827 |
+
|
| 828 |
+
def export_progress(user_name: str):
|
| 829 |
+
slug = "-".join((user_name or "").lower().split())
|
| 830 |
+
if not slug:
|
| 831 |
+
return None
|
| 832 |
+
user_id = slug[:64]
|
| 833 |
+
df = fetch_attempts(CONN, user_id)
|
| 834 |
+
os.makedirs(EXPORT_DIR, exist_ok=True)
|
| 835 |
+
path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
|
| 836 |
+
(pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
|
| 837 |
+
return path
|
| 838 |
+
|
| 839 |
+
def regenerate_domain():
|
| 840 |
+
global CURRENT_SCHEMA, CURRENT_QS
|
| 841 |
+
CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
|
| 842 |
+
erd = draw_dynamic_erd(CURRENT_SCHEMA)
|
| 843 |
+
return gr.update(value="✅ Domain regenerated.", visible=True), erd
|
| 844 |
+
|
| 845 |
+
def preview_table(tbl: str):
|
| 846 |
+
try:
|
| 847 |
+
return run_df(CONN, f"SELECT * FROM {tbl} LIMIT 20")
|
| 848 |
+
except Exception as e:
|
| 849 |
+
return pd.DataFrame([{"error": str(e)}])
|
| 850 |
+
|
| 851 |
+
def list_tables_for_preview():
|
| 852 |
+
df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
|
| 853 |
+
if df.empty:
|
| 854 |
+
return ["(no tables)"]
|
| 855 |
+
return df["name"].tolist()
|
| 856 |
+
|
| 857 |
+
# -------------------- UI --------------------
|
| 858 |
+
with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
|
| 859 |
+
gr.Markdown(
|
| 860 |
+
"""
|
| 861 |
+
# 🧪 Adaptive SQL Trainer — Randomized Domains (SQLite)
|
| 862 |
+
- Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
|
| 863 |
+
sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
|
| 864 |
+
- Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
|
| 865 |
+
- The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.
|
| 866 |
+
|
| 867 |
+
> Set your `OPENAI_API_KEY` in the Space secrets to enable randomization.
|
| 868 |
+
"""
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
with gr.Row():
|
| 872 |
+
with gr.Column(scale=1):
|
| 873 |
+
name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
|
| 874 |
+
start_btn = gr.Button("Start / Resume Session", variant="primary")
|
| 875 |
+
session_state = gr.State({"user_id": None, "name": None, "qid": None, "start_ts": None, "q": None})
|
| 876 |
+
|
| 877 |
+
gr.Markdown("---")
|
| 878 |
+
gr.Markdown("### Dataset Controls")
|
| 879 |
+
regen_btn = gr.Button("🔀 Randomize Dataset (OpenAI)")
|
| 880 |
+
regen_fb = gr.Markdown(visible=False)
|
| 881 |
+
|
| 882 |
+
gr.Markdown("---")
|
| 883 |
+
gr.Markdown("### Instructor Tools")
|
| 884 |
+
export_name = gr.Textbox(label="Export a student's progress (enter name)")
|
| 885 |
+
export_btn = gr.Button("Export CSV")
|
| 886 |
+
export_file = gr.File(label="Download progress")
|
| 887 |
+
|
| 888 |
+
gr.Markdown("---")
|
| 889 |
+
gr.Markdown("### Quick Table/View Preview (top 20 rows)")
|
| 890 |
+
tbl_dd = gr.Dropdown(choices=list_tables_for_preview(), label="Pick table/view", interactive=True)
|
| 891 |
+
tbl_btn = gr.Button("Preview")
|
| 892 |
+
preview_df = gr.Dataframe(headers=[], interactive=False)
|
| 893 |
+
|
| 894 |
+
with gr.Column(scale=2):
|
| 895 |
+
prompt_md = gr.Markdown(visible=False)
|
| 896 |
+
sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)
|
| 897 |
+
|
| 898 |
+
preview_md = gr.Markdown(visible=False)
|
| 899 |
+
er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)
|
| 900 |
+
|
| 901 |
+
with gr.Row():
|
| 902 |
+
submit_btn = gr.Button("Run & Submit", variant="primary")
|
| 903 |
+
hint_btn = gr.Button("Hint")
|
| 904 |
+
next_btn = gr.Button("Next Question ▶", visible=False)
|
| 905 |
+
|
| 906 |
+
feedback_md = gr.Markdown("")
|
| 907 |
+
|
| 908 |
+
gr.Markdown("---")
|
| 909 |
+
gr.Markdown("### Your Progress by Category")
|
| 910 |
+
mastery_df = gr.Dataframe(headers=["category","attempts","correct","accuracy"], row_count=(0, "dynamic"), interactive=False)
|
| 911 |
+
|
| 912 |
+
gr.Markdown("---")
|
| 913 |
+
gr.Markdown("### Result Preview")
|
| 914 |
+
result_df = gr.Dataframe(headers=[], interactive=False)
|
| 915 |
+
|
| 916 |
+
# Wire events
|
| 917 |
+
start_btn.click(
|
| 918 |
+
start_session,
|
| 919 |
+
inputs=[name_box, session_state],
|
| 920 |
+
outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
|
| 921 |
+
)
|
| 922 |
+
sql_input.change(
|
| 923 |
+
render_preview_and_erd,
|
| 924 |
+
inputs=[sql_input, session_state],
|
| 925 |
+
outputs=[preview_md, er_image],
|
| 926 |
+
)
|
| 927 |
+
submit_btn.click(
|
| 928 |
+
submit_answer,
|
| 929 |
+
inputs=[sql_input, session_state],
|
| 930 |
+
outputs=[feedback_md, result_df, next_btn, mastery_df],
|
| 931 |
+
)
|
| 932 |
+
next_btn.click(
|
| 933 |
+
next_question,
|
| 934 |
+
inputs=[session_state],
|
| 935 |
+
outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
|
| 936 |
+
)
|
| 937 |
+
hint_btn.click(
|
| 938 |
+
show_hint,
|
| 939 |
+
inputs=[session_state],
|
| 940 |
+
outputs=[feedback_md],
|
| 941 |
+
)
|
| 942 |
+
export_btn.click(
|
| 943 |
+
export_progress,
|
| 944 |
+
inputs=[export_name],
|
| 945 |
+
outputs=[export_file],
|
| 946 |
+
)
|
| 947 |
+
regen_btn.click(
|
| 948 |
+
regenerate_domain,
|
| 949 |
+
inputs=[],
|
| 950 |
+
outputs=[regen_fb, er_image],
|
| 951 |
+
)
|
| 952 |
+
tbl_btn.click(
|
| 953 |
+
lambda name: preview_table(name),
|
| 954 |
+
inputs=[tbl_dd],
|
| 955 |
+
outputs=[preview_df]
|
| 956 |
+
)
|
| 957 |
+
# Keep dropdown fresh after regeneration
|
| 958 |
+
regen_btn.click(
|
| 959 |
+
lambda: gr.update(choices=list_tables_for_preview()),
|
| 960 |
+
inputs=[],
|
| 961 |
+
outputs=[tbl_dd]
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
if __name__ == "__main__":
|
| 965 |
+
demo.launch()
|