| """ |
| data_engine.py — DuckDB lifecycle, schema introspection, and safe query execution. |
| |
| Handles: |
| - Creating per-request in-memory DuckDB connections (thread-safe) |
| - Seeding schema + data (from seed.sql or programmatically) |
| - Schema introspection for prompt context |
| - extract_sql(): JSON envelope → ```sql``` block → raw fallback |
| - validate_sql(): forbidden-token check + schema-aware column validation via EXPLAIN |
| - execute_safe(): extraction, validation, timeout, subquery wrapping, execution |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| import time |
| import threading |
| import duckdb |
| from pathlib import Path |
|
|
|
|
| |
|
|
| FORBIDDEN_TOKENS = [ |
| "drop", "delete", "insert", "update", "alter", "truncate", |
| "create", "attach", "detach", "pragma", |
| ] |
|
|
| |
|
|
| MAX_RESULT_ROWS = 1000 |
| QUERY_TIMEOUT_SEC = 10 |
| DATA_DIR = Path(__file__).parent / "data" |
| SEED_SQL_PATH = DATA_DIR / "seed.sql" |
|
|
| |
| |
| _PARQUET_DIRS = [ |
| Path("/data"), |
| DATA_DIR, |
| ] |
| _ENROLLMENT_PQ = "enrollment.parquet" |
| _ATTENDANCE_PQ = "attendance.parquet" |
|
|
|
|
| |
|
|
| def get_connection(read_only: bool = False) -> duckdb.DuckDBPyConnection: |
| """ |
| Return a fresh in-memory DuckDB connection with safety defaults. |
| |
| Each request gets its own connection for thread safety. |
| """ |
| conn = duckdb.connect(database=":memory:") |
| conn.execute("SET enable_progress_bar = false;") |
| conn.execute(f"SET max_memory = '256MB';") |
| conn.execute(f"SET threads = 2;") |
| return conn |
|
|
|
|
| |
|
|
| |
| _PARQUET_TABLES = ["enrollment", "attendance", "students", "discipline", "grades"] |
|
|
|
|
| def seed_database( |
| conn: duckdb.DuckDBPyConnection, |
| seed_sql_path: Path | None = None, |
| ) -> None: |
| """ |
| Create tables and load seed data. Tries, in order: |
| |
| 1. Parquet files (fastest — pre-generated, ~260 KB total) |
| → /data/*.parquet (HF Space persistent storage) |
| → data/*.parquet (local dev) |
| 2. data/seed.sql (custom overrides) |
| 3. Python generator (slow fallback, ~20s) |
| """ |
| |
| for base in _PARQUET_DIRS: |
| if all((base / f"{t}.parquet").exists() for t in _PARQUET_TABLES): |
| for table in _PARQUET_TABLES: |
| pq_path = base / f"{table}.parquet" |
| conn.execute( |
| f"CREATE TABLE {table} AS " |
| f"SELECT * FROM read_parquet('{pq_path}')" |
| ) |
| return |
|
|
| |
| if seed_sql_path is None: |
| seed_sql_path = SEED_SQL_PATH |
|
|
| if seed_sql_path.exists(): |
| with open(seed_sql_path) as f: |
| sql = f.read() |
| for statement in sql.split(";"): |
| statement = statement.strip() |
| if statement and not statement.startswith("--"): |
| conn.execute(statement) |
| return |
|
|
| |
| from data.generate_seed import generate_seed_data |
| generate_seed_data(conn) |
|
|
|
|
| |
|
|
| def get_schema_info(conn: duckdb.DuckDBPyConnection) -> dict[str, list[tuple[str, str, str]]]: |
| """ |
| Introspect the database schema for prompt context. |
| |
| Returns: |
| dict: table_name -> [(column_name, type, "")] |
| The description field is empty — we rely on the prompt's table docs. |
| """ |
| tables = conn.execute( |
| "SELECT table_name FROM information_schema.tables " |
| "WHERE table_schema = 'main' ORDER BY table_name" |
| ).fetchall() |
|
|
| schema = {} |
| for (table_name,) in tables: |
| cols = conn.execute( |
| f"SELECT column_name, data_type FROM information_schema.columns " |
| f"WHERE table_name = '{table_name}' ORDER BY ordinal_position" |
| ).fetchall() |
| schema[table_name] = [(name, dtype, "") for name, dtype in cols] |
| return schema |
|
|
|
|
| |
|
|
| def _try_parse_json_envelope(text: str) -> str | None: |
| """ |
| Try to parse the LLM output as a JSON envelope like: |
| {"sql": "SELECT ...", "explanation": "..."} |
| Returns the SQL string if found, or None. |
| """ |
| |
| json_match = re.search(r'\{[^{}]*"sql"\s*:\s*"[^"]+"[^{}]*\}', text, re.DOTALL) |
| if not json_match: |
| return None |
| try: |
| obj = json.loads(json_match.group(0)) |
| if isinstance(obj, dict) and "sql" in obj: |
| return obj["sql"] |
| except (json.JSONDecodeError, KeyError): |
| pass |
| return None |
|
|
|
|
| |
|
|
| def extract_sql(raw_llm_output: str) -> str: |
| """ |
| Extract SQL from LLM output. Tries, in order: |
| 1. JSON envelope: {"sql": "...", "explanation": "..."} |
| 2. ```sql ... ``` markdown block |
| 3. Generic ``` ... ``` code block |
| 4. Raw text fallback |
| Always strips trailing semicolons (they break subquery wrapping). |
| """ |
| |
| json_sql = _try_parse_json_envelope(raw_llm_output) |
| if json_sql: |
| return json_sql.strip().rstrip(";") |
|
|
| |
| sql_match = re.search(r"```sql\s*\n?(.*?)```", raw_llm_output, re.DOTALL | re.IGNORECASE) |
| if sql_match: |
| return sql_match.group(1).strip().rstrip(";") |
|
|
| |
| code_match = re.search(r"```\s*\n?(.*?)```", raw_llm_output, re.DOTALL) |
| if code_match: |
| return code_match.group(1).strip().rstrip(";") |
|
|
| |
| return raw_llm_output.strip().rstrip(";") |
|
|
|
|
| |
|
|
| def validate_sql(sql: str, conn: duckdb.DuckDBPyConnection | None = None) -> None: |
| """ |
| Validate that the SQL is safe and refers to real columns. |
| |
| Layer 1 — static checks (always run): |
| - Not empty |
| - Contains SELECT |
| - No forbidden tokens (DROP, DELETE, INSERT, etc.) |
| |
| Layer 2 — schema-aware validation (if conn provided): |
| - Runs EXPLAIN against the actual schema to catch missing |
| columns, unknown tables, and syntax errors before execution. |
| |
| Raises ValueError with a user-facing message on any failure. |
| """ |
| if not sql: |
| raise ValueError("Empty SQL query — nothing to execute.") |
|
|
| |
| |
| sql_lower = sql.lower() |
| for token in FORBIDDEN_TOKENS: |
| if re.search(rf"\b{token}\b", sql_lower): |
| raise ValueError( |
| f"Forbidden operation detected: '{token}'. Only SELECT queries are allowed." |
| ) |
|
|
| sql_upper = sql.upper() |
|
|
| if "SELECT" not in sql_upper: |
| raise ValueError("Only SELECT queries are allowed. No SELECT found.") |
|
|
| |
| if conn is not None: |
| try: |
| conn.execute(f"EXPLAIN {sql}") |
| except duckdb.Error as e: |
| |
| msg = str(e).strip() |
| |
| for prefix in ["Parser Error: ", "Catalog Error: ", "Binder Error: "]: |
| if msg.startswith(prefix): |
| msg = msg[len(prefix):] |
| raise ValueError(f"SQL validation failed: {msg}") from e |
|
|
|
|
| |
|
|
| class QueryTimeoutError(TimeoutError): |
| """Raised when a query exceeds the time budget.""" |
| pass |
|
|
|
|
| def _execute_with_timeout( |
| conn: duckdb.DuckDBPyConnection, |
| sql: str, |
| timeout_sec: int, |
| ): |
| """ |
| Execute SQL with a Python-level timeout via conn.interrupt(). |
| |
| DuckDB doesn't have a built-in SET query_timeout, so we use a |
| watchdog thread that calls conn.interrupt() after the deadline. |
| """ |
| result = {"df": None, "error": None} |
| done = threading.Event() |
|
|
| def run(): |
| try: |
| result["df"] = conn.execute(sql).fetchdf() |
| except Exception as e: |
| result["error"] = e |
| finally: |
| done.set() |
|
|
| thread = threading.Thread(target=run, daemon=True) |
| thread.start() |
|
|
| if not done.wait(timeout=timeout_sec): |
| |
| conn.interrupt() |
| thread.join(timeout=2) |
| raise QueryTimeoutError(f"Query timed out after {timeout_sec}s.") |
|
|
| if result["error"]: |
| raise result["error"] |
|
|
| return result["df"] |
|
|
|
|
| |
|
|
| def execute_safe( |
| conn: duckdb.DuckDBPyConnection, |
| raw_llm_output: str, |
| timeout_sec: int = QUERY_TIMEOUT_SEC, |
| ) -> tuple[str, "DataFrame"]: |
| """ |
| Extract, validate, and execute LLM-generated SQL. |
| |
| Pipeline: |
| 1. extract_sql() — parse JSON / ```sql``` / raw |
| 2. validate_sql() — static checks + schema-aware EXPLAIN |
| 3. Wrap in SELECT * FROM (<query>) AS _safe LIMIT {MAX_RESULT_ROWS} |
| 4. Execute directly (DuckDB in-memory is fast, no timeout needed) |
| 5. Return (cleaned_sql, dataframe) |
| |
| Returns: |
| (cleaned_sql, duckdb.DataFrame) |
| |
| Raises: |
| ValueError: if SQL is invalid or references unknown columns/tables. |
| duckdb.Error: on database-level failures. |
| """ |
| sql = extract_sql(raw_llm_output) |
| validate_sql(sql, conn=conn) |
|
|
| |
| safe_sql = f"SELECT * FROM (\n{sql}\n) AS _safe LIMIT {MAX_RESULT_ROWS}" |
|
|
| df = conn.execute(safe_sql).fetchdf() |
|
|
| return sql, df |
|
|
|
|
| |
|
|
| def create_session() -> duckdb.DuckDBPyConnection: |
| """Create a seeded DuckDB connection ready for queries.""" |
| conn = get_connection() |
| seed_database(conn) |
| return conn |
|
|