Spaces:
Sleeping
Sleeping
| # Adaptive SQL Trainer β Domain Randomized with OpenAI (Gradio + SQLite) | |
| # - Randomizes a relational domain via OpenAI (bookstore, retail sales, wholesaler, | |
| # sales tax, oil & gas wells, marketing) OR falls back to a built-in dataset. | |
| # - Builds 3β4 related tables (schema + seed rows) in SQLite. | |
| # - Generates 8β12 randomized SQL questions with varied phrasings. | |
| # - Validates answers by executing canonical SQL and comparing result sets. | |
| # - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases). | |
| # - Shows data results at the bottom pane for every run (SELECT or preview for VIEW/CTAS). | |
| # | |
| # Hugging Face Spaces: set OPENAI_API_KEY as a secret to enable LLM randomization. | |
| import os | |
| import re | |
| import json | |
| import time | |
| import random | |
| import sqlite3 | |
| from dataclasses import dataclass, asdict | |
| from datetime import datetime, timezone | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| # Matplotlib for ERD drawing (headless) | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from PIL import Image | |
| # -------------------- OpenAI (optional) -------------------- | |
| USE_RESPONSES_API = True | |
| OPENAI_AVAILABLE = True | |
| MODEL_ID = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") | |
| try: | |
| from openai import OpenAI | |
| _client = OpenAI() # requires OPENAI_API_KEY | |
| except Exception: | |
| OPENAI_AVAILABLE = False | |
| _client = None | |
| # -------------------- Global settings -------------------- | |
| DB_DIR = "/data" if os.path.exists("/data") else "." | |
| DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db") | |
| EXPORT_DIR = "." | |
| ADMIN_KEY = os.getenv("ADMIN_KEY", "demo") | |
| RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7")) | |
| random.seed(RANDOM_SEED) | |
| SYS_RAND = random.SystemRandom() | |
| PLOT_FIGSIZE = (6.8, 3.4) | |
| PLOT_DPI = 110 | |
| PLOT_HEIGHT = 300 | |
| # -------------------- ERD helpers -------------------- | |
| def _to_pil(fig) -> Image.Image: | |
| buf = BytesIO() | |
| fig.tight_layout() | |
| fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def draw_dynamic_erd(schema: Dict[str, Any]) -> Image.Image: | |
| """ | |
| schema = { | |
| "domain": "bookstore", | |
| "tables": [ | |
| {"name":"authors","columns":[{"name":"author_id","type":"INTEGER",...}, ...], | |
| "pk":["author_id"], "fks":[{"columns":["author_id"],"ref_table":"...","ref_columns":["..."]}], | |
| "rows":[{...}, {...}]} | |
| ] | |
| } | |
| """ | |
| fig, ax = plt.subplots(figsize=PLOT_FIGSIZE) | |
| ax.axis("off") | |
| tables = schema.get("tables", []) | |
| n = max(1, len(tables)) | |
| # Lay out boxes horizontally | |
| margin = 0.03 | |
| width = (1 - margin*(n+1)) / n | |
| height = 0.65 | |
| y = 0.25 | |
| boxes = {} | |
| for i, t in enumerate(tables): | |
| x = margin + i*(width + margin) | |
| boxes[t["name"]] = (x, y, width, height) | |
| ax.add_patch(plt.Rectangle((x, y), width, height, fill=False)) | |
| ax.text(x + 0.01, y + height - 0.05, f"**{t['name']}**", fontsize=10, ha="left", va="top") | |
| yy = y + height - 0.10 | |
| pk = set(t.get("pk", [])) | |
| cols = t.get("columns", []) | |
| for col in cols: | |
| nm = col["name"] | |
| mark = " (PK)" if nm in pk else "" | |
| ax.text(x + 0.02, yy, f"{nm}{mark}", fontsize=9, ha="left", va="top") | |
| yy -= 0.06 | |
| # Draw FK arrows | |
| for t in tables: | |
| for fk in t.get("fks", []): | |
| src_tbl = t["name"] | |
| dst_tbl = fk.get("ref_table") | |
| if src_tbl in boxes and dst_tbl in boxes: | |
| (x1, y1, w1, h1) = boxes[src_tbl] | |
| (x2, y2, w2, h2) = boxes[dst_tbl] | |
| ax.annotate("", xy=(x2 + w2/2, y2 + h2), xytext=(x1 + w1/2, y1), | |
| arrowprops=dict(arrowstyle="->", lw=1.1)) | |
| ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center") | |
| return _to_pil(fig) | |
| # -------------------- SQLite helpers -------------------- | |
| def connect_db(): | |
| con = sqlite3.connect(DB_PATH) | |
| con.execute("PRAGMA foreign_keys = ON;") | |
| return con | |
| CONN = connect_db() | |
| def init_progress_tables(con: sqlite3.Connection): | |
| cur = con.cursor() | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| user_id TEXT PRIMARY KEY, | |
| name TEXT, | |
| created_at TEXT | |
| ) | |
| """) | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS attempts ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id TEXT, | |
| question_id TEXT, | |
| category TEXT, | |
| correct INTEGER, | |
| sql_text TEXT, | |
| timestamp TEXT, | |
| time_taken REAL, | |
| difficulty INTEGER, | |
| source TEXT, | |
| notes TEXT | |
| ) | |
| """) | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS session_meta ( | |
| id INTEGER PRIMARY KEY CHECK (id=1), | |
| domain TEXT, | |
| schema_json TEXT | |
| ) | |
| """) | |
| con.commit() | |
| init_progress_tables(CONN) | |
| # -------------------- Fallback dataset (if no OpenAI) -------------------- | |
| FALLBACK_SCHEMA = { | |
| "domain": "bookstore", | |
| "tables": [ | |
| { | |
| "name": "authors", | |
| "pk": ["author_id"], | |
| "columns": [ | |
| {"name":"author_id","type":"INTEGER"}, | |
| {"name":"name","type":"TEXT"}, | |
| {"name":"country","type":"TEXT"}, | |
| {"name":"birth_year","type":"INTEGER"}, | |
| ], | |
| "fks": [], | |
| "rows": [ | |
| {"author_id":1,"name":"Isaac Asimov","country":"USA","birth_year":1920}, | |
| {"author_id":2,"name":"Ursula K. Le Guin","country":"USA","birth_year":1929}, | |
| {"author_id":3,"name":"Haruki Murakami","country":"Japan","birth_year":1949}, | |
| {"author_id":4,"name":"Chinua Achebe","country":"Nigeria","birth_year":1930}, | |
| {"author_id":5,"name":"Jane Austen","country":"UK","birth_year":1775}, | |
| {"author_id":6,"name":"J.K. Rowling","country":"UK","birth_year":1965}, | |
| {"author_id":7,"name":"Yuval Noah Harari","country":"Israel","birth_year":1976}, | |
| {"author_id":8,"name":"New Author","country":"Nowhere","birth_year":1990}, | |
| ], | |
| }, | |
| { | |
| "name": "bookstores", | |
| "pk": ["store_id"], | |
| "columns": [ | |
| {"name":"store_id","type":"INTEGER"}, | |
| {"name":"name","type":"TEXT"}, | |
| {"name":"city","type":"TEXT"}, | |
| {"name":"state","type":"TEXT"}, | |
| ], | |
| "fks": [], | |
| "rows": [ | |
| {"store_id":1,"name":"Downtown Books","city":"Oklahoma City","state":"OK"}, | |
| {"store_id":2,"name":"Harbor Books","city":"Seattle","state":"WA"}, | |
| {"store_id":3,"name":"Desert Pages","city":"Phoenix","state":"AZ"}, | |
| ], | |
| }, | |
| { | |
| "name": "books", | |
| "pk": ["book_id"], | |
| "columns": [ | |
| {"name":"book_id","type":"INTEGER"}, | |
| {"name":"title","type":"TEXT"}, | |
| {"name":"author_id","type":"INTEGER"}, | |
| {"name":"store_id","type":"INTEGER"}, | |
| {"name":"category","type":"TEXT"}, | |
| {"name":"price","type":"REAL"}, | |
| {"name":"published_year","type":"INTEGER"}, | |
| ], | |
| "fks": [ | |
| {"columns":["author_id"],"ref_table":"authors","ref_columns":["author_id"]}, | |
| {"columns":["store_id"],"ref_table":"bookstores","ref_columns":["store_id"]}, | |
| ], | |
| "rows": [ | |
| {"book_id":101,"title":"Foundation","author_id":1,"store_id":1,"category":"Sci-Fi","price":14.99,"published_year":1951}, | |
| {"book_id":102,"title":"I, Robot","author_id":1,"store_id":1,"category":"Sci-Fi","price":12.50,"published_year":1950}, | |
| {"book_id":103,"title":"The Left Hand of Darkness","author_id":2,"store_id":2,"category":"Sci-Fi","price":16.00,"published_year":1969}, | |
| {"book_id":104,"title":"A Wizard of Earthsea","author_id":2,"store_id":2,"category":"Fantasy","price":11.50,"published_year":1968}, | |
| {"book_id":105,"title":"Norwegian Wood","author_id":3,"store_id":3,"category":"Fiction","price":18.00,"published_year":1987}, | |
| {"book_id":106,"title":"Kafka on the Shore","author_id":3,"store_id":1,"category":"Fiction","price":21.00,"published_year":2002}, | |
| {"book_id":107,"title":"Things Fall Apart","author_id":4,"store_id":1,"category":"Fiction","price":10.00,"published_year":1958}, | |
| {"book_id":108,"title":"Pride and Prejudice","author_id":5,"store_id":2,"category":"Fiction","price":9.00,"published_year":1813}, | |
| {"book_id":109,"title":"Harry Potter and the Sorcerer's Stone","author_id":6,"store_id":3,"category":"Children","price":22.00,"published_year":1997}, | |
| {"book_id":110,"title":"Harry Potter and the Chamber of Secrets","author_id":6,"store_id":3,"category":"Children","price":23.00,"published_year":1998}, | |
| {"book_id":111,"title":"Sapiens","author_id":7,"store_id":1,"category":"History","price":26.00,"published_year":2011}, | |
| {"book_id":112,"title":"Homo Deus","author_id":7,"store_id":2,"category":"History","price":28.00,"published_year":2015}, | |
| ], | |
| }, | |
| ] | |
| } | |
| FALLBACK_QUESTIONS = [ | |
| { | |
| "id":"Q01","category":"SELECT *","difficulty":1, | |
| "prompt_md":"Select all rows and columns from `authors`.", | |
| "answer_sql":["SELECT * FROM authors;"], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q02","category":"SELECT columns","difficulty":1, | |
| "prompt_md":"Show `title` and `price` from `books`.", | |
| "answer_sql":["SELECT title, price FROM books;"], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q03","category":"WHERE","difficulty":1, | |
| "prompt_md":"List SciβFi books under $15 (show title, price).", | |
| "answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q04","category":"Aliases","difficulty":1, | |
| "prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.", | |
| "answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"], | |
| "requires_aliases":True,"required_aliases":["a","b"] | |
| }, | |
| { | |
| "id":"Q05","category":"JOIN (INNER)","difficulty":2, | |
| "prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.", | |
| "answer_sql":[ | |
| "SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;" | |
| ], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q06","category":"JOIN (LEFT)","difficulty":2, | |
| "prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.", | |
| "answer_sql":[ | |
| "SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;" | |
| ], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q07","category":"VIEW","difficulty":2, | |
| "prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.", | |
| "answer_sql":[ | |
| "CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;" | |
| ], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| { | |
| "id":"Q08","category":"CTAS / SELECT INTO","difficulty":2, | |
| "prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.", | |
| "answer_sql":[ | |
| "CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;", | |
| "SELECT * INTO cheap_books FROM books WHERE price < 12;" | |
| ], | |
| "requires_aliases":False,"required_aliases":[] | |
| }, | |
| ] | |
| # -------------------- OpenAI prompts -------------------- | |
| DOMAIN_AND_QUESTIONS_SCHEMA = { | |
| "name": "DomainSQLPack", | |
| "schema": { | |
| "type": "object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "domain": {"type":"string"}, | |
| "tables": { | |
| "type":"array", | |
| "items": { | |
| "type":"object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "name": {"type":"string"}, | |
| "pk": {"type":"array","items":{"type":"string"}}, | |
| "columns": { | |
| "type":"array", | |
| "items": { | |
| "type":"object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "name":{"type":"string"}, | |
| "type":{"type":"string"} | |
| }, | |
| "required":["name","type"] | |
| } | |
| }, | |
| "fks": { | |
| "type":"array", | |
| "items": { | |
| "type":"object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "columns":{"type":"array","items":{"type":"string"}}, | |
| "ref_table":{"type":"string"}, | |
| "ref_columns":{"type":"array","items":{"type":"string"}} | |
| }, | |
| "required":["columns","ref_table","ref_columns"] | |
| } | |
| }, | |
| "rows": {"type":"array","items":{"type":["object","array"]}} | |
| }, | |
| "required":["name","pk","columns","fks","rows"] | |
| }, | |
| "minItems":3,"maxItems":4 | |
| }, | |
| "questions": { | |
| "type":"array", | |
| "items": { | |
| "type":"object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "id":{"type":"string"}, | |
| "category":{"type":"string"}, | |
| "difficulty":{"type":"integer"}, | |
| "prompt_md":{"type":"string"}, | |
| "answer_sql":{"type":"array","items":{"type":"string"}}, | |
| "requires_aliases":{"type":"boolean"}, | |
| "required_aliases":{"type":"array","items":{"type":"string"}} | |
| }, | |
| "required":["id","category","difficulty","prompt_md","answer_sql"] | |
| }, | |
| "minItems":8,"maxItems":12 | |
| } | |
| }, | |
| "required":["domain","tables","questions"] | |
| }, | |
| "strict": True | |
| } | |
| DOMAIN_AND_QUESTIONS_PROMPT = """ | |
| You are designing a small relational dataset and training questions for SQL basics. | |
| 1) Choose ONE domain at random from: | |
| - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing. | |
| 2) Produce exactly 3β4 tables that fit together (SQLite-friendly): | |
| - Use snake_case, avoid reserved words. | |
| - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (but no advanced features). | |
| - Primary keys (pk) and foreign keys (fks) must align. | |
| - Provide 8β15 small, realistic seed rows per table (not huge). | |
| 3) Generate 8β12 SQL questions covering basics with varied, natural language: | |
| - Categories from: "SELECT *", "SELECT columns", "WHERE", "Aliases", | |
| "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO". | |
| - Include a few joins and at least one LEFT JOIN. | |
| - Include one view creation. | |
| - Include one table creation from SELECT (either CTAS or SELECT INTO). | |
| - Prefer SQLite-compatible SQL. DO NOT use RIGHT/FULL OUTER JOIN. | |
| - Offer 1β3 acceptable answer_sql variants per question. | |
| - For 1β2 questions, require table aliases (set requires_aliases=true and list required_aliases). | |
| Return JSON only. | |
| """ | |
| def llm_generate_domain_and_questions() -> Optional[Dict[str,Any]]: | |
| if not OPENAI_AVAILABLE: | |
| return None | |
| try: | |
| if USE_RESPONSES_API: | |
| resp = _client.responses.create( | |
| model=MODEL_ID, | |
| response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA}, | |
| input=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}], | |
| temperature=0.6, | |
| ) | |
| data_text = getattr(resp, "output_text", None) | |
| else: | |
| chat = _client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}], | |
| temperature=0.6 | |
| ) | |
| data_text = chat.choices[0].message.content | |
| obj = json.loads(data_text) if data_text else None | |
| return obj | |
| except Exception: | |
| return None | |
| # -------------------- Schema install & question handling -------------------- | |
| def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True): | |
| cur = con.cursor() | |
| cur.execute("SELECT name, type FROM sqlite_master WHERE type IN ('table','view')") | |
| items = cur.fetchall() | |
| for name, typ in items: | |
| if keep_internal and name in ("users","attempts","session_meta"): | |
| continue | |
| try: | |
| cur.execute(f"DROP {typ.upper()} IF EXISTS {name}") | |
| except Exception: | |
| pass | |
| con.commit() | |
| def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]): | |
| drop_existing_domain_tables(con, keep_internal=True) | |
| cur = con.cursor() | |
| # Create tables first | |
| for t in schema.get("tables", []): | |
| cols_sql = [] | |
| pk = t.get("pk", []) | |
| for c in t.get("columns", []): | |
| cname = c["name"] | |
| ctype = c.get("type","TEXT") | |
| cols_sql.append(f"{cname} {ctype}") | |
| if pk: | |
| cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})") | |
| create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})" | |
| cur.execute(create_sql) | |
| # Insert rows | |
| for t in schema.get("tables", []): | |
| if not t.get("rows"): | |
| continue | |
| cols = [c["name"] for c in t.get("columns", [])] | |
| qmarks = ",".join(["?"]*len(cols)) | |
| insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})" | |
| # rows can be objects or arrays | |
| for r in t["rows"]: | |
| if isinstance(r, dict): | |
| vals = [r.get(col, None) for col in cols] | |
| elif isinstance(r, list) or isinstance(r, tuple): | |
| vals = list(r) + [None]*(len(cols)-len(r)) | |
| vals = vals[:len(cols)] | |
| else: | |
| continue | |
| cur.execute(insert_sql, vals) | |
| con.commit() | |
| # Persist schema JSON | |
| cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)", | |
| (schema.get("domain","unknown"), json.dumps(schema))) | |
| con.commit() | |
| def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame: | |
| return pd.read_sql_query(sql, con) | |
| def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]: | |
| s = sql.strip().strip(";") | |
| if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL): | |
| m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s) | |
| if m: | |
| cols, tbl, rest = m.groups() | |
| return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl | |
| return sql, None | |
| def detect_unsupported_joins(sql: str) -> Optional[str]: | |
| low = sql.lower() | |
| if " right join " in low: | |
| return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction." | |
| if " full join " in low or " full outer join " in low: | |
| return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side." | |
| if " ilike " in low: | |
| return "SQLite has no ILIKE. Use `LOWER(col) LIKE LOWER('%pattern%')`." | |
| return None | |
| def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]: | |
| low = sql.lower() | |
| if " cross join " in low: | |
| return "Query uses CROSS JOIN (cartesian product). Ensure this is intended." | |
| comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low) | |
| missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low) | |
| if comma_from or missing_on: | |
| try: | |
| cur = con.cursor() | |
| if comma_from: | |
| t1, t2 = comma_from.groups() | |
| else: | |
| m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low) | |
| j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low) | |
| if not m or not j: | |
| return "Possible cartesian product: no join condition detected." | |
| t1, t2 = m.group(1), j.group(1) | |
| cur.execute(f"SELECT COUNT(*) FROM {t1}") | |
| n1 = cur.fetchone()[0] | |
| cur.execute(f"SELECT COUNT(*) FROM {t2}") | |
| n2 = cur.fetchone()[0] | |
| prod = n1 * n2 | |
| if len(df_result) == prod and prod > 0: | |
| return f"Result row count equals {n1}Γ{n2}={prod}. Likely cartesian product (missing join)." | |
| except Exception: | |
| return "Possible cartesian product: no join condition detected." | |
| return None | |
| def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool: | |
| if df_a.shape != df_b.shape: | |
| return False | |
| a = df_a.copy() | |
| b = df_b.copy() | |
| a.columns = [c.lower() for c in a.columns] | |
| b.columns = [c.lower() for c in b.columns] | |
| a = a.sort_values(list(a.columns)).reset_index(drop=True) | |
| b = b.sort_values(list(b.columns)).reset_index(drop=True) | |
| return a.equals(b) | |
| def aliases_present(sql: str, required_aliases: List[str]) -> bool: | |
| low = re.sub(r"\s+", " ", sql.lower()) | |
| for al in required_aliases: | |
| if f" {al}." not in low and f" as {al} " not in low: | |
| return False | |
| return True | |
| # -------------------- Question model -------------------- | |
| class SQLQuestion: | |
| id: str | |
| category: str | |
| difficulty: int | |
| prompt_md: str | |
| answer_sql: List[str] | |
| requires_aliases: bool = False | |
| required_aliases: List[str] = None | |
| def to_question_dict(q) -> Dict[str,Any]: | |
| d = dict(q) | |
| d.setdefault("requires_aliases", False) | |
| d.setdefault("required_aliases", []) | |
| return d | |
| def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]: | |
| out = [] | |
| for o in obj_list: | |
| out.append(to_question_dict(o)) | |
| return out | |
| # -------------------- Domain bootstrap -------------------- | |
| def bootstrap_domain_with_llm_or_fallback() -> Tuple[Dict[str,Any], List[Dict[str,Any]]]: | |
| obj = llm_generate_domain_and_questions() | |
| if obj is None: | |
| return FALLBACK_SCHEMA, FALLBACK_QUESTIONS | |
| # Guardrails: strip RIGHT/FULL joins from answers | |
| clean_qs = [] | |
| for q in obj.get("questions", []): | |
| answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()] | |
| if not answers: | |
| continue | |
| q["answer_sql"] = answers | |
| q.setdefault("requires_aliases", False) | |
| q.setdefault("required_aliases", []) | |
| clean_qs.append(q) | |
| obj["questions"] = clean_qs | |
| return obj, clean_qs | |
| def install_new_domain(): | |
| schema, questions = bootstrap_domain_with_llm_or_fallback() | |
| install_schema(CONN, schema) | |
| return schema, questions | |
| # -------------------- Session state -------------------- | |
| CURRENT_SCHEMA, CURRENT_QS = install_new_domain() | |
| # -------------------- Progress + mastery -------------------- | |
| def upsert_user(con: sqlite3.Connection, user_id: str, name: str): | |
| cur = con.cursor() | |
| cur.execute("SELECT user_id FROM users WHERE user_id = ?", (user_id,)) | |
| if cur.fetchone() is None: | |
| cur.execute("INSERT INTO users (user_id, name, created_at) VALUES (?, ?, ?)", | |
| (user_id, name, datetime.now(timezone.utc).isoformat())) | |
| else: | |
| cur.execute("UPDATE users SET name=? WHERE user_id=?", (name, user_id)) | |
| con.commit() | |
| CATEGORIES_ORDER = [ | |
| "SELECT *", "SELECT columns", "WHERE", "Aliases", | |
| "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO" | |
| ] | |
| def topic_stats(df_attempts: pd.DataFrame) -> pd.DataFrame: | |
| rows = [] | |
| for cat in CATEGORIES_ORDER: | |
| sub = df_attempts[df_attempts["category"] == cat] if not df_attempts.empty else pd.DataFrame() | |
| att = int(sub.shape[0]) if not sub.empty else 0 | |
| cor = int(sub["correct"].sum()) if not sub.empty else 0 | |
| acc = float(cor / max(att, 1)) | |
| rows.append({"category":cat,"attempts":att,"correct":cor,"accuracy":acc}) | |
| return pd.DataFrame(rows) | |
| def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame: | |
| return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,)) | |
| def pick_next_question(user_id: str) -> Dict[str,Any]: | |
| df = fetch_attempts(CONN, user_id) | |
| stats = topic_stats(df) | |
| stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True]) | |
| weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0] | |
| cands = [q for q in CURRENT_QS if q["category"] == weakest] or CURRENT_QS | |
| return dict(random.choice(cands)) | |
| # -------------------- Execution & feedback -------------------- | |
| def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]: | |
| if not sql_text or not sql_text.strip(): | |
| return None, "Enter a SQL statement.", None, None | |
| sql_raw = sql_text.strip().rstrip(";") | |
| sql_rew, created_tbl = rewrite_select_into(sql_raw) | |
| note = None | |
| if sql_rew != sql_raw: | |
| note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite." | |
| unsup = detect_unsupported_joins(sql_rew) | |
| if unsup: | |
| return None, unsup, None, note | |
| try: | |
| low = sql_rew.lower() | |
| if low.startswith("select"): | |
| df = run_df(CONN, sql_rew) | |
| warn = detect_cartesian(CONN, sql_rew, df) | |
| return df, None, warn, note | |
| else: | |
| cur = CONN.cursor() | |
| cur.execute(sql_rew) | |
| CONN.commit() | |
| # Preview newly created objects | |
| if low.startswith("create view"): | |
| m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low) | |
| name = m.group(2) if m else None | |
| if name: | |
| try: | |
| df = run_df(CONN, f"SELECT * FROM {name}") | |
| return df, None, None, note | |
| except Exception: | |
| return None, "View created but could not be queried.", None, note | |
| if low.startswith("create table"): | |
| tbl = created_tbl | |
| if not tbl: | |
| m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low) | |
| tbl = m.group(2) if m else None | |
| if tbl: | |
| try: | |
| df = run_df(CONN, f"SELECT * FROM {tbl}") | |
| return df, None, None, note | |
| except Exception: | |
| return None, "Table created but could not be queried.", None, note | |
| return pd.DataFrame(), None, None, note | |
| except Exception as e: | |
| # Tailored messages | |
| msg = str(e) | |
| if "no such table" in msg.lower(): | |
| return None, f"{msg}. Check table names for this randomized domain.", None, note | |
| if "no such column" in msg.lower(): | |
| return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note | |
| if "ambiguous column name" in msg.lower(): | |
| return None, f"{msg}. Qualify the column with a table alias.", None, note | |
| if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower(): | |
| return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note | |
| if "near \"into\"" in msg.lower() and "syntax error" in msg.lower(): | |
| return None, "SQLite doesnβt support `SELECT ... INTO`. I can rewrite it automaticallyβtry again.", None, note | |
| if "syntax error" in msg.lower(): | |
| return None, f"Syntax error. Check commas, keywords, and parentheses. Raw error: {msg}", None, note | |
| return None, f"SQL error: {msg}", None, note | |
| def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]: | |
| for sql in answer_sql: | |
| try: | |
| low = sql.strip().lower() | |
| if low.startswith("select"): | |
| return run_df(CONN, sql) | |
| if low.startswith("create view"): | |
| # temp preview | |
| m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low) | |
| view_name = m.group(2) if m else "vw_tmp" | |
| cur = CONN.cursor() | |
| cur.execute(f"DROP VIEW IF EXISTS {view_name}") | |
| cur.execute(sql) | |
| CONN.commit() | |
| return run_df(CONN, f"SELECT * FROM {view_name}") | |
| if low.startswith("create table"): | |
| m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low) | |
| tbl = m.group(2) if m else None | |
| cur = CONN.cursor() | |
| if tbl: | |
| cur.execute(f"DROP TABLE IF EXISTS {tbl}") | |
| cur.execute(sql) | |
| CONN.commit() | |
| if tbl: | |
| return run_df(CONN, f"SELECT * FROM {tbl}") | |
| except Exception: | |
| continue | |
| return None | |
| def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]: | |
| df_expected = answer_df(q["answer_sql"]) | |
| # If we can't build a canonical DF (e.g., DDL side effect), we accept any successful execution as correct | |
| if df_expected is None: | |
| return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task." | |
| if df_student is None: | |
| return False, f"**Explanation:** Expected data result differs." | |
| return results_equal(df_student, df_expected), f"**Explanation:** Compare your result to a canonical solution." | |
| def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str, | |
| time_taken: float, difficulty: int, source: str, notes: str): | |
| cur = CONN.cursor() | |
| cur.execute(""" | |
| INSERT INTO attempts (user_id, question_id, category, correct, sql_text, timestamp, time_taken, difficulty, source, notes) | |
| VALUES (?,?,?,?,?,?,?,?,?,?) | |
| """, (user_id, qid, category, int(correct), sql_text, datetime.now(timezone.utc).isoformat(), | |
| time_taken, difficulty, source, notes)) | |
| CONN.commit() | |
| # -------------------- UI callbacks -------------------- | |
| def start_session(name: str, session: dict): | |
| name = (name or "").strip() | |
| if not name: | |
| return (session, | |
| gr.update(value="Please enter your name to begin.", visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| None, | |
| gr.update(visible=False), | |
| pd.DataFrame(), | |
| pd.DataFrame()) | |
| slug = "-".join(name.lower().split()) | |
| user_id = slug[:64] if slug else f"user-{int(time.time())}" | |
| upsert_user(CONN, user_id, name) | |
| q = pick_next_question(user_id) | |
| session = {"user_id": user_id, "name": name, "qid": q["id"], "start_ts": time.time(), "q": q} | |
| prompt = q["prompt_md"] | |
| stats = topic_stats(fetch_attempts(CONN, user_id)) | |
| erd = draw_dynamic_erd(CURRENT_SCHEMA) | |
| return (session, | |
| gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True), | |
| gr.update(visible=True), # show SQL input | |
| gr.update(value="", visible=True), # preview block | |
| erd, | |
| gr.update(visible=False), # next btn hidden until submit | |
| stats, | |
| pd.DataFrame()) | |
| def render_preview_and_erd(sql_text: str, session: dict): | |
| if not session or "q" not in session: | |
| return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA) | |
| s = (sql_text or "").strip() | |
| if not s: | |
| return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA) | |
| return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), draw_dynamic_erd(CURRENT_SCHEMA) | |
| def submit_answer(sql_text: str, session: dict): | |
| if not session or "user_id" not in session or "q" not in session: | |
| return gr.update(value="Start a session first.", visible=True), pd.DataFrame(), gr.update(visible=False), pd.DataFrame() | |
| user_id = session["user_id"] | |
| q = session["q"] | |
| elapsed = max(0.0, time.time() - session.get("start_ts", time.time())) | |
| df, err, warn, note = exec_student_sql(sql_text) | |
| details = [] | |
| if note: details.append(f"βΉοΈ {note}") | |
| if err: | |
| fb = f"β **Did not run**\n\n{err}" | |
| if details: fb += "\n\n" + "\n".join(details) | |
| log_attempt(user_id, q["id"], q["category"], False, sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join([err] + details)) | |
| stats = topic_stats(fetch_attempts(CONN, user_id)) | |
| return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats | |
| # Validate correctness | |
| alias_msg = None | |
| if q.get("requires_aliases"): | |
| if not aliases_present(sql_text, q.get("required_aliases", [])): | |
| alias_msg = f"β οΈ This task asked for aliases {q.get('required_aliases', [])}. I didnβt detect them." | |
| is_correct, explanation = validate_answer(q, sql_text, df) | |
| if warn: details.append(f"β οΈ {warn}") | |
| if alias_msg: details.append(alias_msg) | |
| prefix = "β **Correct!**" if is_correct else "β **Not quite.**" | |
| feedback = prefix | |
| if details: | |
| feedback += "\n\n" + "\n".join(details) | |
| feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```" | |
| log_attempt(user_id, q["id"], q["category"], bool(is_correct), sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join(details)) | |
| stats = topic_stats(fetch_attempts(CONN, user_id)) | |
| return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats | |
| def next_question(session: dict): | |
| if not session or "user_id" not in session: | |
| return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False) | |
| user_id = session["user_id"] | |
| q = pick_next_question(user_id) | |
| session["qid"] = q["id"] | |
| session["q"] = q | |
| session["start_ts"] = time.time() | |
| return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False) | |
| def show_hint(session: dict): | |
| if not session or "q" not in session: | |
| return gr.update(value="Start a session first.", visible=True) | |
| # Lightweight hint policy: category-specific guidance | |
| cat = session["q"]["category"] | |
| hint = { | |
| "SELECT *": "Use `SELECT * FROM table_name`.", | |
| "SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.", | |
| "WHERE": "Filter with `WHERE` and combine conditions using AND/OR.", | |
| "Aliases": "Use `table_name t` and qualify: `t.col`.", | |
| "JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.", | |
| "JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.", | |
| "Aggregation": "Use aggregate functions and `GROUP BY` non-aggregated columns.", | |
| "VIEW": "`CREATE VIEW view_name AS SELECT ...`.", | |
| "CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`." | |
| }.get(cat, "Read the ER diagram and identify keys to join on.") | |
| return gr.update(value=f"**Hint:** {hint}", visible=True) | |
| def export_progress(user_name: str): | |
| slug = "-".join((user_name or "").lower().split()) | |
| if not slug: | |
| return None | |
| user_id = slug[:64] | |
| df = fetch_attempts(CONN, user_id) | |
| os.makedirs(EXPORT_DIR, exist_ok=True) | |
| path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv")) | |
| (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False) | |
| return path | |
| def regenerate_domain(): | |
| global CURRENT_SCHEMA, CURRENT_QS | |
| CURRENT_SCHEMA, CURRENT_QS = install_new_domain() | |
| erd = draw_dynamic_erd(CURRENT_SCHEMA) | |
| return gr.update(value="β Domain regenerated.", visible=True), erd | |
| def preview_table(tbl: str): | |
| try: | |
| return run_df(CONN, f"SELECT * FROM {tbl} LIMIT 20") | |
| except Exception as e: | |
| return pd.DataFrame([{"error": str(e)}]) | |
| def list_tables_for_preview(): | |
| df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name") | |
| if df.empty: | |
| return ["(no tables)"] | |
| return df["name"].tolist() | |
| # -------------------- UI -------------------- | |
| with gr.Blocks(title="Adaptive SQL Trainer β Randomized Domains") as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ͺ Adaptive SQL Trainer β Randomized Domains (SQLite) | |
| - Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler, | |
| sales tax, oil & gas wells, marketing), generate **3β4 tables** and **8β12** questions. | |
| - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**. | |
| - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**. | |
| > Set your `OPENAI_API_KEY` in the Space secrets to enable randomization. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # -------- Left column: controls + quick preview ---------- | |
| with gr.Column(scale=1): | |
| name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez") | |
| start_btn = gr.Button("Start / Resume Session", variant="primary") | |
| session_state = gr.State({"user_id": None, "name": None, "qid": None, "start_ts": None, "q": None}) | |
| gr.Markdown("---") | |
| gr.Markdown("### Dataset Controls") | |
| regen_btn = gr.Button("π Randomize Dataset (OpenAI)") | |
| regen_fb = gr.Markdown(visible=False) | |
| gr.Markdown("---") | |
| gr.Markdown("### Instructor Tools") | |
| export_name = gr.Textbox(label="Export a student's progress (enter name)") | |
| export_btn = gr.Button("Export CSV") | |
| export_file = gr.File(label="Download progress") | |
| gr.Markdown("---") | |
| gr.Markdown("### Quick Table/View Preview (top 20 rows)") | |
| tbl_dd = gr.Dropdown(choices=list_tables_for_preview(), label="Pick table/view", interactive=True) | |
| tbl_btn = gr.Button("Preview") | |
| preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False) | |
| # -------- Right column: task + feedback + mastery + results ---------- | |
| with gr.Column(scale=2): | |
| prompt_md = gr.Markdown(visible=False) | |
| sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False) | |
| preview_md = gr.Markdown(visible=False) | |
| er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Run & Submit", variant="primary") | |
| hint_btn = gr.Button("Hint") | |
| next_btn = gr.Button("Next Question βΆ", visible=False) | |
| feedback_md = gr.Markdown("") | |
| gr.Markdown("---") | |
| gr.Markdown("### Your Progress by Category") | |
| mastery_df = gr.Dataframe( | |
| headers=["category","attempts","correct","accuracy"], | |
| col_count=(4, "dynamic"), | |
| row_count=(0, "dynamic"), | |
| interactive=False | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Result Preview") | |
| result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False) | |
| # Wire events | |
| start_btn.click( | |
| start_session, | |
| inputs=[name_box, session_state], | |
| outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df], | |
| ) | |
| sql_input.change( | |
| render_preview_and_erd, | |
| inputs=[sql_input, session_state], | |
| outputs=[preview_md, er_image], | |
| ) | |
| submit_btn.click( | |
| submit_answer, | |
| inputs=[sql_input, session_state], | |
| outputs=[feedback_md, result_df, next_btn, mastery_df], | |
| ) | |
| next_btn.click( | |
| next_question, | |
| inputs=[session_state], | |
| outputs=[session_state, prompt_md, sql_input, er_image, next_btn], | |
| ) | |
| hint_btn.click( | |
| show_hint, | |
| inputs=[session_state], | |
| outputs=[feedback_md], | |
| ) | |
| export_btn.click( | |
| export_progress, | |
| inputs=[export_name], | |
| outputs=[export_file], | |
| ) | |
| regen_btn.click( | |
| regenerate_domain, | |
| inputs=[], | |
| outputs=[regen_fb, er_image], | |
| ) | |
| tbl_btn.click( | |
| lambda name: preview_table(name), | |
| inputs=[tbl_dd], | |
| outputs=[preview_df] | |
| ) | |
| # Keep dropdown fresh after regeneration | |
| regen_btn.click( | |
| lambda: gr.update(choices=list_tables_for_preview()), | |
| inputs=[], | |
| outputs=[tbl_dd] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |