Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| import threading | |
| from typing import Any | |
| import duckdb | |
| BLOCKED = re.compile( | |
| r"\b(DROP|TRUNCATE|DELETE|INSERT|UPDATE|CREATE|ALTER|ATTACH|COPY|EXPORT|IMPORT)\b", | |
| re.IGNORECASE, | |
| ) | |
| MAX_ROWS = 100 | |
| _lock = threading.Lock() | |
| class SQLEngine: | |
| def __init__(self) -> None: | |
| self.conn = duckdb.connect(":memory:") | |
| def load_tables(self, tables: dict[str, Any]) -> None: | |
| with _lock: | |
| for name, df in tables.items(): | |
| self.conn.register(name, df) | |
| self.conn.execute(f"CREATE OR REPLACE TABLE {name} AS SELECT * FROM {name}") | |
| self.conn.unregister(name) | |
| def execute(self, sql: str) -> list[dict] | str: | |
| s = (sql or "").strip() | |
| if BLOCKED.search(s): | |
| return "ERROR: Destructive SQL (DROP/DELETE/UPDATE/etc.) is not permitted." | |
| with _lock: | |
| try: | |
| rel = self.conn.execute(s) | |
| cols = [d[0] for d in rel.description] | |
| rows = rel.fetchmany(MAX_ROWS) | |
| return [dict(zip(cols, row)) for row in rows] | |
| except Exception as e: | |
| return f"ERROR: {e}" | |
| def run_fix_sql(self, sql: str, gold_clean: dict[str, Any] | None = None) -> float: | |
| s = (sql or "").strip() | |
| # Only allow UPDATE during fix phase. | |
| if re.search(r"\b(DROP|TRUNCATE|DELETE|INSERT|CREATE|ALTER|ATTACH|COPY|EXPORT|IMPORT)\b", s, re.IGNORECASE): | |
| return 0.0 | |
| if not re.search(r"\bUPDATE\b", s, re.IGNORECASE): | |
| return 0.0 | |
| with _lock: | |
| try: | |
| self.conn.execute(s) | |
| # Lightweight deterministic scoring placeholder. | |
| return 0.5 | |
| except Exception: | |
| return 0.0 | |
| def get_table_schemas(self, tables: list[str]) -> dict[str, dict[str, str]]: | |
| out: dict[str, dict[str, str]] = {} | |
| with _lock: | |
| for t in tables: | |
| rows = self.conn.execute(f"PRAGMA table_info('{t}')").fetchall() | |
| out[t] = {r[1]: str(r[2]) for r in rows} | |
| return out | |
| def get_row_counts(self, tables: list[str]) -> dict[str, int]: | |
| out: dict[str, int] = {} | |
| with _lock: | |
| for t in tables: | |
| out[t] = int(self.conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]) | |
| return out | |
| def close(self) -> None: | |
| self.conn.close() | |