Spaces:
Sleeping
Sleeping
| """ | |
| SQLite Schema Executor β sandboxed DDL execution engine. | |
| Processes schema actions against an in-memory SQLite database, | |
| validates constraints, and runs test queries. Zero external deps. | |
| """ | |
| from __future__ import annotations | |
| import sqlite3 | |
| import time | |
| from typing import Optional | |
| from .models import ( | |
| ActionType, | |
| ColumnDef, | |
| SchemaAction, | |
| TableInfo, | |
| QueryResult, | |
| TestQuery, | |
| ) | |
| class SchemaExecutor: | |
| """Manages an in-memory SQLite database for a single episode.""" | |
| def __init__(self): | |
| self.conn = sqlite3.connect(":memory:") | |
| self.conn.execute("PRAGMA foreign_keys = ON") | |
| self.conn.row_factory = sqlite3.Row | |
| self._action_log: list[str] = [] | |
| def close(self): | |
| if self.conn: | |
| self.conn.close() | |
| self.conn = None | |
| # ββ Action Dispatch βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def execute_action(self, action: SchemaAction) -> tuple[bool, str]: | |
| """Execute a schema action. Returns (success, message).""" | |
| try: | |
| handler = { | |
| ActionType.CREATE_TABLE: self._create_table, | |
| ActionType.ADD_COLUMN: self._add_column, | |
| ActionType.DROP_TABLE: self._drop_table, | |
| ActionType.DROP_COLUMN: self._drop_column, | |
| ActionType.ADD_PRIMARY_KEY: self._add_primary_key, | |
| ActionType.ADD_FOREIGN_KEY: self._add_foreign_key, | |
| ActionType.ADD_UNIQUE: self._add_unique, | |
| ActionType.CREATE_INDEX: self._create_index, | |
| ActionType.DROP_INDEX: self._drop_index, | |
| ActionType.SUBMIT: self._submit, | |
| }.get(action.type) | |
| if handler is None: | |
| return False, f"Unknown action type: {action.type}" | |
| return handler(action) | |
| except sqlite3.Error as e: | |
| return False, f"SQL error: {e}" | |
| except Exception as e: | |
| return False, f"Error: {e}" | |
| # ββ Action Handlers βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _create_table(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name: | |
| return False, "table_name is required" | |
| if not action.columns or len(action.columns) == 0: | |
| return False, "columns list is required and must be non-empty" | |
| name = self._sanitize(action.table_name) | |
| col_defs = [] | |
| for col in action.columns: | |
| col_sql = f"{self._sanitize(col.name)} {col.col_type}" | |
| if col.primary_key: | |
| col_sql += " PRIMARY KEY" | |
| if not col.nullable and not col.primary_key: | |
| col_sql += " NOT NULL" | |
| if col.default_value is not None: | |
| col_sql += f" DEFAULT {col.default_value}" | |
| col_defs.append(col_sql) | |
| sql = f"CREATE TABLE {name} ({', '.join(col_defs)})" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Created table '{action.table_name}' with {len(action.columns)} columns" | |
| def _add_column(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name or not action.column_name or not action.column_type: | |
| return False, "table_name, column_name, and column_type are required" | |
| table = self._sanitize(action.table_name) | |
| col = self._sanitize(action.column_name) | |
| col_type = action.column_type | |
| sql = f"ALTER TABLE {table} ADD COLUMN {col} {col_type}" | |
| if action.nullable is False: | |
| sql += f" NOT NULL DEFAULT ''" | |
| if action.default_value is not None: | |
| sql += f" DEFAULT {action.default_value}" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Added column '{action.column_name}' to '{action.table_name}'" | |
| def _drop_table(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name: | |
| return False, "table_name is required" | |
| sql = f"DROP TABLE IF EXISTS {self._sanitize(action.table_name)}" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Dropped table '{action.table_name}'" | |
| def _drop_column(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name or not action.column_name: | |
| return False, "table_name and column_name are required" | |
| table = self._sanitize(action.table_name) | |
| col = self._sanitize(action.column_name) | |
| sql = f"ALTER TABLE {table} DROP COLUMN {col}" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Dropped column '{action.column_name}' from '{action.table_name}'" | |
| def _add_primary_key(self, action: SchemaAction) -> tuple[bool, str]: | |
| # SQLite doesn't support ALTER TABLE ADD PRIMARY KEY | |
| # We handle this by noting it β agent should define PK at table creation | |
| return False, "SQLite does not support adding primary keys after table creation. Define primary_key=true in create_table columns." | |
| def _add_foreign_key(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not all([action.table_name, action.fk_column, action.ref_table, action.ref_column]): | |
| return False, "table_name, fk_column, ref_table, and ref_column are required" | |
| # SQLite doesn't support ALTER TABLE ADD FOREIGN KEY | |
| # We simulate by tracking it and recreating the table | |
| # For simplicity in the hackathon, we track FK metadata separately | |
| table = action.table_name | |
| fk_col = action.fk_column | |
| ref_table = action.ref_table | |
| ref_col = action.ref_column | |
| # Verify the referenced table and column exist | |
| ref_tables = self._get_table_names() | |
| if ref_table not in ref_tables: | |
| return False, f"Referenced table '{ref_table}' does not exist" | |
| ref_cols = [c["name"] for c in self._get_columns(ref_table)] | |
| if ref_col not in ref_cols: | |
| return False, f"Referenced column '{ref_col}' does not exist in '{ref_table}'" | |
| # Verify source table and column exist | |
| if table not in ref_tables: | |
| return False, f"Table '{table}' does not exist" | |
| src_cols = [c["name"] for c in self._get_columns(table)] | |
| if fk_col not in src_cols: | |
| return False, f"Column '{fk_col}' does not exist in '{table}'" | |
| # Store FK as a tracked constraint (we check it during grading) | |
| if not hasattr(self, '_tracked_fks'): | |
| self._tracked_fks = [] | |
| self._tracked_fks.append({ | |
| "table": table, | |
| "column": fk_col, | |
| "ref_table": ref_table, | |
| "ref_column": ref_col, | |
| }) | |
| self._action_log.append( | |
| f"-- FK: {table}.{fk_col} -> {ref_table}.{ref_col}" | |
| ) | |
| return True, f"Added foreign key {table}.{fk_col} -> {ref_table}.{ref_col}" | |
| def _add_unique(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name or not action.unique_columns: | |
| return False, "table_name and unique_columns are required" | |
| table = self._sanitize(action.table_name) | |
| cols = ", ".join(self._sanitize(c) for c in action.unique_columns) | |
| idx_name = f"uq_{action.table_name}_{'_'.join(action.unique_columns)}" | |
| sql = f"CREATE UNIQUE INDEX {self._sanitize(idx_name)} ON {table} ({cols})" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Added unique constraint on {action.table_name}({', '.join(action.unique_columns)})" | |
| def _create_index(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.table_name or not action.index_columns: | |
| return False, "table_name and index_columns are required" | |
| table = self._sanitize(action.table_name) | |
| cols = ", ".join(self._sanitize(c) for c in action.index_columns) | |
| idx_name = action.index_name or f"idx_{action.table_name}_{'_'.join(action.index_columns)}" | |
| sql = f"CREATE INDEX {self._sanitize(idx_name)} ON {table} ({cols})" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Created index '{idx_name}' on {action.table_name}({', '.join(action.index_columns)})" | |
| def _drop_index(self, action: SchemaAction) -> tuple[bool, str]: | |
| if not action.index_name: | |
| return False, "index_name is required" | |
| sql = f"DROP INDEX IF EXISTS {self._sanitize(action.index_name)}" | |
| self.conn.execute(sql) | |
| self.conn.commit() | |
| self._action_log.append(sql) | |
| return True, f"Dropped index '{action.index_name}'" | |
| def _submit(self, action: SchemaAction) -> tuple[bool, str]: | |
| return True, "Schema submitted for evaluation" | |
| # ββ Schema Introspection ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_table_names(self) -> list[str]: | |
| cursor = self.conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" | |
| ) | |
| return [row[0] for row in cursor.fetchall()] | |
| def _get_columns(self, table: str) -> list[dict]: | |
| cursor = self.conn.execute(f"PRAGMA table_info({self._sanitize(table)})") | |
| return [ | |
| { | |
| "name": row[1], | |
| "type": row[2], | |
| "nullable": not bool(row[3]), | |
| "default": row[4], | |
| "pk": bool(row[5]), | |
| } | |
| for row in cursor.fetchall() | |
| ] | |
| def _get_indexes(self, table: str) -> list[str]: | |
| cursor = self.conn.execute(f"PRAGMA index_list({self._sanitize(table)})") | |
| return [row[1] for row in cursor.fetchall()] | |
| def _get_foreign_keys(self, table: str) -> list[dict]: | |
| # Combine SQLite's own FK info with our tracked FKs | |
| fks = [] | |
| try: | |
| cursor = self.conn.execute(f"PRAGMA foreign_key_list({self._sanitize(table)})") | |
| for row in cursor.fetchall(): | |
| fks.append({ | |
| "table": table, | |
| "column": row[3], | |
| "ref_table": row[2], | |
| "ref_column": row[4], | |
| }) | |
| except Exception: | |
| pass | |
| # Add tracked FKs | |
| if hasattr(self, '_tracked_fks'): | |
| for fk in self._tracked_fks: | |
| if fk["table"] == table: | |
| if fk not in fks: | |
| fks.append(fk) | |
| return fks | |
| def get_schema_info(self) -> list[TableInfo]: | |
| """Get full schema information for observation.""" | |
| tables = [] | |
| for table_name in self._get_table_names(): | |
| tables.append(TableInfo( | |
| name=table_name, | |
| columns=self._get_columns(table_name), | |
| indexes=self._get_indexes(table_name), | |
| foreign_keys=self._get_foreign_keys(table_name), | |
| )) | |
| return tables | |
| def get_all_foreign_keys(self) -> list[dict]: | |
| """Get all foreign keys across all tables.""" | |
| fks = [] | |
| for table_name in self._get_table_names(): | |
| fks.extend(self._get_foreign_keys(table_name)) | |
| return fks | |
| # ββ Query Execution βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_test_queries(self, queries: list[TestQuery]) -> list[QueryResult]: | |
| """Run all test queries against the current schema.""" | |
| results = [] | |
| for q in queries: | |
| result = self._run_single_query(q) | |
| results.append(result) | |
| return results | |
| def _run_single_query(self, query: TestQuery) -> QueryResult: | |
| """Run a single test query and return the result (SQL hidden from agent).""" | |
| try: | |
| start = time.perf_counter() | |
| cursor = self.conn.execute(query.sql) | |
| rows = cursor.fetchall() | |
| elapsed = (time.perf_counter() - start) * 1000 # ms | |
| # Check if expected columns are present in results | |
| if query.expected_columns and rows: | |
| result_cols = [desc[0] for desc in cursor.description] | |
| missing = set(query.expected_columns) - set(result_cols) | |
| if missing: | |
| return QueryResult( | |
| query_id=query.id, | |
| description=query.description, | |
| passed=False, | |
| error=f"Missing expected columns: {missing}", | |
| execution_time_ms=elapsed, | |
| ) | |
| passed = query.should_succeed | |
| return QueryResult( | |
| query_id=query.id, | |
| description=query.description, | |
| passed=passed, | |
| execution_time_ms=elapsed, | |
| ) | |
| except sqlite3.Error as e: | |
| # Sanitize error: keep useful schema hints but strip raw SQL | |
| error_msg = str(e) | |
| return QueryResult( | |
| query_id=query.id, | |
| description=query.description, | |
| passed=not query.should_succeed, | |
| error=error_msg, | |
| ) | |
| # ββ Utility βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sanitize(identifier: str) -> str: | |
| """Sanitize a SQL identifier to prevent injection.""" | |
| # Allow only alphanumeric and underscores | |
| clean = "".join(c for c in identifier if c.isalnum() or c == "_") | |
| if not clean: | |
| raise ValueError(f"Invalid identifier: {identifier}") | |
| return clean | |