LFED / data_engine.py
Kasualdad's picture
Day 1 expanded schema: 5 tables, 14B fine-tuned model
240383f
Raw
History Blame Contribute Delete
11.8 kB
"""
data_engine.py — DuckDB lifecycle, schema introspection, and safe query execution.
Handles:
- Creating per-request in-memory DuckDB connections (thread-safe)
- Seeding schema + data (from seed.sql or programmatically)
- Schema introspection for prompt context
- extract_sql(): JSON envelope → ```sql``` block → raw fallback
- validate_sql(): forbidden-token check + schema-aware column validation via EXPLAIN
- execute_safe(): extraction, validation, timeout, subquery wrapping, execution
"""
from __future__ import annotations
import json
import re
import time
import threading
import duckdb
from pathlib import Path
# ── Forbidden SQL terms (case-insensitive) ─────────────────────────────
FORBIDDEN_TOKENS = [
"drop", "delete", "insert", "update", "alter", "truncate",
"create", "attach", "detach", "pragma",
]
# ── Execution limits ───────────────────────────────────────────────────
MAX_RESULT_ROWS = 1000
QUERY_TIMEOUT_SEC = 10
DATA_DIR = Path(__file__).parent / "data"
SEED_SQL_PATH = DATA_DIR / "seed.sql"
# Parquet files (pre-generated once via data/export_parquet.py).
# On HF Spaces, place them in the persistent /data/ directory.
_PARQUET_DIRS = [
Path("/data"), # HF Space persistent storage
DATA_DIR, # local dev (data/)
]
_ENROLLMENT_PQ = "enrollment.parquet"
_ATTENDANCE_PQ = "attendance.parquet"
# ── Connection factory ─────────────────────────────────────────────────
def get_connection(read_only: bool = False) -> duckdb.DuckDBPyConnection:
"""
Return a fresh in-memory DuckDB connection with safety defaults.
Each request gets its own connection for thread safety.
"""
conn = duckdb.connect(database=":memory:")
conn.execute("SET enable_progress_bar = false;")
conn.execute(f"SET max_memory = '256MB';")
conn.execute(f"SET threads = 2;")
return conn
# ── Database seeding ───────────────────────────────────────────────────
# All 5 tables (expanded schema for Day 1+)
_PARQUET_TABLES = ["enrollment", "attendance", "students", "discipline", "grades"]
def seed_database(
conn: duckdb.DuckDBPyConnection,
seed_sql_path: Path | None = None,
) -> None:
"""
Create tables and load seed data. Tries, in order:
1. Parquet files (fastest — pre-generated, ~260 KB total)
→ /data/*.parquet (HF Space persistent storage)
→ data/*.parquet (local dev)
2. data/seed.sql (custom overrides)
3. Python generator (slow fallback, ~20s)
"""
# ── 1. Parquet files ─────────────────────────────────────────
for base in _PARQUET_DIRS:
if all((base / f"{t}.parquet").exists() for t in _PARQUET_TABLES):
for table in _PARQUET_TABLES:
pq_path = base / f"{table}.parquet"
conn.execute(
f"CREATE TABLE {table} AS "
f"SELECT * FROM read_parquet('{pq_path}')"
)
return
# ── 2. seed.sql ──────────────────────────────────────────────
if seed_sql_path is None:
seed_sql_path = SEED_SQL_PATH
if seed_sql_path.exists():
with open(seed_sql_path) as f:
sql = f.read()
for statement in sql.split(";"):
statement = statement.strip()
if statement and not statement.startswith("--"):
conn.execute(statement)
return
# ── 3. Python generator (slow) ───────────────────────────────
from data.generate_seed import generate_seed_data
generate_seed_data(conn)
# ── Schema introspection ───────────────────────────────────────────────
def get_schema_info(conn: duckdb.DuckDBPyConnection) -> dict[str, list[tuple[str, str, str]]]:
"""
Introspect the database schema for prompt context.
Returns:
dict: table_name -> [(column_name, type, "")]
The description field is empty — we rely on the prompt's table docs.
"""
tables = conn.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'main' ORDER BY table_name"
).fetchall()
schema = {}
for (table_name,) in tables:
cols = conn.execute(
f"SELECT column_name, data_type FROM information_schema.columns "
f"WHERE table_name = '{table_name}' ORDER BY ordinal_position"
).fetchall()
schema[table_name] = [(name, dtype, "") for name, dtype in cols]
return schema
# ── JSON envelope parsing ──────────────────────────────────────────────
def _try_parse_json_envelope(text: str) -> str | None:
"""
Try to parse the LLM output as a JSON envelope like:
{"sql": "SELECT ...", "explanation": "..."}
Returns the SQL string if found, or None.
"""
# Try to find a JSON object anywhere in the text
json_match = re.search(r'\{[^{}]*"sql"\s*:\s*"[^"]+"[^{}]*\}', text, re.DOTALL)
if not json_match:
return None
try:
obj = json.loads(json_match.group(0))
if isinstance(obj, dict) and "sql" in obj:
return obj["sql"]
except (json.JSONDecodeError, KeyError):
pass
return None
# ── SQL extraction ─────────────────────────────────────────────────────
def extract_sql(raw_llm_output: str) -> str:
"""
Extract SQL from LLM output. Tries, in order:
1. JSON envelope: {"sql": "...", "explanation": "..."}
2. ```sql ... ``` markdown block
3. Generic ``` ... ``` code block
4. Raw text fallback
Always strips trailing semicolons (they break subquery wrapping).
"""
# 1. Try JSON envelope first
json_sql = _try_parse_json_envelope(raw_llm_output)
if json_sql:
return json_sql.strip().rstrip(";")
# 2. Try ```sql ... ```
sql_match = re.search(r"```sql\s*\n?(.*?)```", raw_llm_output, re.DOTALL | re.IGNORECASE)
if sql_match:
return sql_match.group(1).strip().rstrip(";")
# 3. Try generic ``` ... ```
code_match = re.search(r"```\s*\n?(.*?)```", raw_llm_output, re.DOTALL)
if code_match:
return code_match.group(1).strip().rstrip(";")
# 4. Fallback: return raw text, stripped
return raw_llm_output.strip().rstrip(";")
# ── SQL validation ─────────────────────────────────────────────────────
def validate_sql(sql: str, conn: duckdb.DuckDBPyConnection | None = None) -> None:
"""
Validate that the SQL is safe and refers to real columns.
Layer 1 — static checks (always run):
- Not empty
- Contains SELECT
- No forbidden tokens (DROP, DELETE, INSERT, etc.)
Layer 2 — schema-aware validation (if conn provided):
- Runs EXPLAIN against the actual schema to catch missing
columns, unknown tables, and syntax errors before execution.
Raises ValueError with a user-facing message on any failure.
"""
if not sql:
raise ValueError("Empty SQL query — nothing to execute.")
# Check forbidden tokens FIRST (before SELECT check — DROP/INSERT
# statements don't contain SELECT but are more dangerous)
sql_lower = sql.lower()
for token in FORBIDDEN_TOKENS:
if re.search(rf"\b{token}\b", sql_lower):
raise ValueError(
f"Forbidden operation detected: '{token}'. Only SELECT queries are allowed."
)
sql_upper = sql.upper()
if "SELECT" not in sql_upper:
raise ValueError("Only SELECT queries are allowed. No SELECT found.")
# Schema-aware validation via DuckDB EXPLAIN
if conn is not None:
try:
conn.execute(f"EXPLAIN {sql}")
except duckdb.Error as e:
# Surface the DuckDB error (e.g., "column 'foo' does not exist")
msg = str(e).strip()
# Clean up common DuckDB error prefixes for user-friendliness
for prefix in ["Parser Error: ", "Catalog Error: ", "Binder Error: "]:
if msg.startswith(prefix):
msg = msg[len(prefix):]
raise ValueError(f"SQL validation failed: {msg}") from e
# ── Timeout helper ─────────────────────────────────────────────────────
class QueryTimeoutError(TimeoutError):
"""Raised when a query exceeds the time budget."""
pass
def _execute_with_timeout(
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_sec: int,
):
"""
Execute SQL with a Python-level timeout via conn.interrupt().
DuckDB doesn't have a built-in SET query_timeout, so we use a
watchdog thread that calls conn.interrupt() after the deadline.
"""
result = {"df": None, "error": None}
done = threading.Event()
def run():
try:
result["df"] = conn.execute(sql).fetchdf()
except Exception as e:
result["error"] = e
finally:
done.set()
thread = threading.Thread(target=run, daemon=True)
thread.start()
if not done.wait(timeout=timeout_sec):
# Timed out — interrupt the DuckDB connection
conn.interrupt()
thread.join(timeout=2)
raise QueryTimeoutError(f"Query timed out after {timeout_sec}s.")
if result["error"]:
raise result["error"]
return result["df"]
# ── Safe SQL execution ─────────────────────────────────────────────────
def execute_safe(
conn: duckdb.DuckDBPyConnection,
raw_llm_output: str,
timeout_sec: int = QUERY_TIMEOUT_SEC,
) -> tuple[str, "DataFrame"]:
"""
Extract, validate, and execute LLM-generated SQL.
Pipeline:
1. extract_sql() — parse JSON / ```sql``` / raw
2. validate_sql() — static checks + schema-aware EXPLAIN
3. Wrap in SELECT * FROM (<query>) AS _safe LIMIT {MAX_RESULT_ROWS}
4. Execute directly (DuckDB in-memory is fast, no timeout needed)
5. Return (cleaned_sql, dataframe)
Returns:
(cleaned_sql, duckdb.DataFrame)
Raises:
ValueError: if SQL is invalid or references unknown columns/tables.
duckdb.Error: on database-level failures.
"""
sql = extract_sql(raw_llm_output)
validate_sql(sql, conn=conn)
# Safety wrap: SELECT * FROM (<user_query>) LIMIT MAX_RESULT_ROWS
safe_sql = f"SELECT * FROM (\n{sql}\n) AS _safe LIMIT {MAX_RESULT_ROWS}"
df = conn.execute(safe_sql).fetchdf()
return sql, df
# ── Full pipeline (for use in app.py) ──────────────────────────────────
def create_session() -> duckdb.DuckDBPyConnection:
"""Create a seeded DuckDB connection ready for queries."""
conn = get_connection()
seed_database(conn)
return conn