File size: 11,849 Bytes
17674c2 0e9c140 17674c2 240383f 0e9c140 17674c2 0e9c140 17674c2 240383f 0e9c140 240383f 17674c2 0e9c140 240383f 0e9c140 17674c2 0e9c140 17674c2 4c667d9 17674c2 4c667d9 17674c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | """
data_engine.py — DuckDB lifecycle, schema introspection, and safe query execution.
Handles:
- Creating per-request in-memory DuckDB connections (thread-safe)
- Seeding schema + data (from seed.sql or programmatically)
- Schema introspection for prompt context
- extract_sql(): JSON envelope → ```sql``` block → raw fallback
- validate_sql(): forbidden-token check + schema-aware column validation via EXPLAIN
- execute_safe(): extraction, validation, timeout, subquery wrapping, execution
"""
from __future__ import annotations
import json
import re
import time
import threading
import duckdb
from pathlib import Path
# ── Forbidden SQL terms (case-insensitive) ─────────────────────────────
FORBIDDEN_TOKENS = [
"drop", "delete", "insert", "update", "alter", "truncate",
"create", "attach", "detach", "pragma",
]
# ── Execution limits ───────────────────────────────────────────────────
MAX_RESULT_ROWS = 1000
QUERY_TIMEOUT_SEC = 10
DATA_DIR = Path(__file__).parent / "data"
SEED_SQL_PATH = DATA_DIR / "seed.sql"
# Parquet files (pre-generated once via data/export_parquet.py).
# On HF Spaces, place them in the persistent /data/ directory.
_PARQUET_DIRS = [
Path("/data"), # HF Space persistent storage
DATA_DIR, # local dev (data/)
]
_ENROLLMENT_PQ = "enrollment.parquet"
_ATTENDANCE_PQ = "attendance.parquet"
# ── Connection factory ─────────────────────────────────────────────────
def get_connection(read_only: bool = False) -> duckdb.DuckDBPyConnection:
"""
Return a fresh in-memory DuckDB connection with safety defaults.
Each request gets its own connection for thread safety.
"""
conn = duckdb.connect(database=":memory:")
conn.execute("SET enable_progress_bar = false;")
conn.execute(f"SET max_memory = '256MB';")
conn.execute(f"SET threads = 2;")
return conn
# ── Database seeding ───────────────────────────────────────────────────
# All 5 tables (expanded schema for Day 1+)
_PARQUET_TABLES = ["enrollment", "attendance", "students", "discipline", "grades"]
def seed_database(
conn: duckdb.DuckDBPyConnection,
seed_sql_path: Path | None = None,
) -> None:
"""
Create tables and load seed data. Tries, in order:
1. Parquet files (fastest — pre-generated, ~260 KB total)
→ /data/*.parquet (HF Space persistent storage)
→ data/*.parquet (local dev)
2. data/seed.sql (custom overrides)
3. Python generator (slow fallback, ~20s)
"""
# ── 1. Parquet files ─────────────────────────────────────────
for base in _PARQUET_DIRS:
if all((base / f"{t}.parquet").exists() for t in _PARQUET_TABLES):
for table in _PARQUET_TABLES:
pq_path = base / f"{table}.parquet"
conn.execute(
f"CREATE TABLE {table} AS "
f"SELECT * FROM read_parquet('{pq_path}')"
)
return
# ── 2. seed.sql ──────────────────────────────────────────────
if seed_sql_path is None:
seed_sql_path = SEED_SQL_PATH
if seed_sql_path.exists():
with open(seed_sql_path) as f:
sql = f.read()
for statement in sql.split(";"):
statement = statement.strip()
if statement and not statement.startswith("--"):
conn.execute(statement)
return
# ── 3. Python generator (slow) ───────────────────────────────
from data.generate_seed import generate_seed_data
generate_seed_data(conn)
# ── Schema introspection ───────────────────────────────────────────────
def get_schema_info(conn: duckdb.DuckDBPyConnection) -> dict[str, list[tuple[str, str, str]]]:
"""
Introspect the database schema for prompt context.
Returns:
dict: table_name -> [(column_name, type, "")]
The description field is empty — we rely on the prompt's table docs.
"""
tables = conn.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'main' ORDER BY table_name"
).fetchall()
schema = {}
for (table_name,) in tables:
cols = conn.execute(
f"SELECT column_name, data_type FROM information_schema.columns "
f"WHERE table_name = '{table_name}' ORDER BY ordinal_position"
).fetchall()
schema[table_name] = [(name, dtype, "") for name, dtype in cols]
return schema
# ── JSON envelope parsing ──────────────────────────────────────────────
def _try_parse_json_envelope(text: str) -> str | None:
"""
Try to parse the LLM output as a JSON envelope like:
{"sql": "SELECT ...", "explanation": "..."}
Returns the SQL string if found, or None.
"""
# Try to find a JSON object anywhere in the text
json_match = re.search(r'\{[^{}]*"sql"\s*:\s*"[^"]+"[^{}]*\}', text, re.DOTALL)
if not json_match:
return None
try:
obj = json.loads(json_match.group(0))
if isinstance(obj, dict) and "sql" in obj:
return obj["sql"]
except (json.JSONDecodeError, KeyError):
pass
return None
# ── SQL extraction ─────────────────────────────────────────────────────
def extract_sql(raw_llm_output: str) -> str:
"""
Extract SQL from LLM output. Tries, in order:
1. JSON envelope: {"sql": "...", "explanation": "..."}
2. ```sql ... ``` markdown block
3. Generic ``` ... ``` code block
4. Raw text fallback
Always strips trailing semicolons (they break subquery wrapping).
"""
# 1. Try JSON envelope first
json_sql = _try_parse_json_envelope(raw_llm_output)
if json_sql:
return json_sql.strip().rstrip(";")
# 2. Try ```sql ... ```
sql_match = re.search(r"```sql\s*\n?(.*?)```", raw_llm_output, re.DOTALL | re.IGNORECASE)
if sql_match:
return sql_match.group(1).strip().rstrip(";")
# 3. Try generic ``` ... ```
code_match = re.search(r"```\s*\n?(.*?)```", raw_llm_output, re.DOTALL)
if code_match:
return code_match.group(1).strip().rstrip(";")
# 4. Fallback: return raw text, stripped
return raw_llm_output.strip().rstrip(";")
# ── SQL validation ─────────────────────────────────────────────────────
def validate_sql(sql: str, conn: duckdb.DuckDBPyConnection | None = None) -> None:
"""
Validate that the SQL is safe and refers to real columns.
Layer 1 — static checks (always run):
- Not empty
- Contains SELECT
- No forbidden tokens (DROP, DELETE, INSERT, etc.)
Layer 2 — schema-aware validation (if conn provided):
- Runs EXPLAIN against the actual schema to catch missing
columns, unknown tables, and syntax errors before execution.
Raises ValueError with a user-facing message on any failure.
"""
if not sql:
raise ValueError("Empty SQL query — nothing to execute.")
# Check forbidden tokens FIRST (before SELECT check — DROP/INSERT
# statements don't contain SELECT but are more dangerous)
sql_lower = sql.lower()
for token in FORBIDDEN_TOKENS:
if re.search(rf"\b{token}\b", sql_lower):
raise ValueError(
f"Forbidden operation detected: '{token}'. Only SELECT queries are allowed."
)
sql_upper = sql.upper()
if "SELECT" not in sql_upper:
raise ValueError("Only SELECT queries are allowed. No SELECT found.")
# Schema-aware validation via DuckDB EXPLAIN
if conn is not None:
try:
conn.execute(f"EXPLAIN {sql}")
except duckdb.Error as e:
# Surface the DuckDB error (e.g., "column 'foo' does not exist")
msg = str(e).strip()
# Clean up common DuckDB error prefixes for user-friendliness
for prefix in ["Parser Error: ", "Catalog Error: ", "Binder Error: "]:
if msg.startswith(prefix):
msg = msg[len(prefix):]
raise ValueError(f"SQL validation failed: {msg}") from e
# ── Timeout helper ─────────────────────────────────────────────────────
class QueryTimeoutError(TimeoutError):
"""Raised when a query exceeds the time budget."""
pass
def _execute_with_timeout(
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_sec: int,
):
"""
Execute SQL with a Python-level timeout via conn.interrupt().
DuckDB doesn't have a built-in SET query_timeout, so we use a
watchdog thread that calls conn.interrupt() after the deadline.
"""
result = {"df": None, "error": None}
done = threading.Event()
def run():
try:
result["df"] = conn.execute(sql).fetchdf()
except Exception as e:
result["error"] = e
finally:
done.set()
thread = threading.Thread(target=run, daemon=True)
thread.start()
if not done.wait(timeout=timeout_sec):
# Timed out — interrupt the DuckDB connection
conn.interrupt()
thread.join(timeout=2)
raise QueryTimeoutError(f"Query timed out after {timeout_sec}s.")
if result["error"]:
raise result["error"]
return result["df"]
# ── Safe SQL execution ─────────────────────────────────────────────────
def execute_safe(
conn: duckdb.DuckDBPyConnection,
raw_llm_output: str,
timeout_sec: int = QUERY_TIMEOUT_SEC,
) -> tuple[str, "DataFrame"]:
"""
Extract, validate, and execute LLM-generated SQL.
Pipeline:
1. extract_sql() — parse JSON / ```sql``` / raw
2. validate_sql() — static checks + schema-aware EXPLAIN
3. Wrap in SELECT * FROM (<query>) AS _safe LIMIT {MAX_RESULT_ROWS}
4. Execute directly (DuckDB in-memory is fast, no timeout needed)
5. Return (cleaned_sql, dataframe)
Returns:
(cleaned_sql, duckdb.DataFrame)
Raises:
ValueError: if SQL is invalid or references unknown columns/tables.
duckdb.Error: on database-level failures.
"""
sql = extract_sql(raw_llm_output)
validate_sql(sql, conn=conn)
# Safety wrap: SELECT * FROM (<user_query>) LIMIT MAX_RESULT_ROWS
safe_sql = f"SELECT * FROM (\n{sql}\n) AS _safe LIMIT {MAX_RESULT_ROWS}"
df = conn.execute(safe_sql).fetchdf()
return sql, df
# ── Full pipeline (for use in app.py) ──────────────────────────────────
def create_session() -> duckdb.DuckDBPyConnection:
"""Create a seeded DuckDB connection ready for queries."""
conn = get_connection()
seed_database(conn)
return conn
|