Spaces:
Running
Running
File size: 3,947 Bytes
30cf758 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """
SQLite in-memory database management.
Creates fresh DB instances per episode with deterministic seed data.
"""
import sqlite3
import time
from typing import Dict, Any, List
class EpisodeDatabase:
"""
Manages a single SQLite in-memory database for one episode.
Seeded with deterministic data per task.
"""
def __init__(self, task_id: str, schema_sql: str, seed_data_sql: str):
self.task_id = task_id
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA foreign_keys = ON")
self._setup(schema_sql, seed_data_sql)
def _setup(self, schema_sql: str, seed_data_sql: str):
"""Create schema and insert seed data."""
cursor = self.conn.cursor()
for statement in schema_sql.strip().split(";"):
stmt = statement.strip()
if stmt:
cursor.execute(stmt)
for statement in seed_data_sql.strip().split(";"):
stmt = statement.strip()
if stmt:
cursor.execute(stmt)
self.conn.commit()
def execute_query(self, query: str) -> Dict[str, Any]:
"""
Execute a read-only SQL query safely.
Returns rows or error. Enforces SELECT-only.
Execution timeout: 5 seconds.
"""
query_stripped = query.strip().upper()
# Block dangerous operations
blocked = ["DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
"TRUNCATE", "REPLACE", "ATTACH", "DETACH"]
for kw in blocked:
if query_stripped.startswith(kw) or f" {kw} " in query_stripped:
return {
"success": False,
"rows": None,
"row_count": None,
"error_message": f"BLOCKED: Only SELECT queries are allowed. '{kw}' is not permitted.",
"execution_time_ms": 0.0
}
start = time.time()
try:
cursor = self.conn.cursor()
cursor.execute(query)
rows = cursor.fetchall()
elapsed = (time.time() - start) * 1000
# Convert Row objects to dicts
result_rows = [dict(row) for row in rows]
return {
"success": True,
"rows": result_rows,
"row_count": len(result_rows),
"error_message": None,
"execution_time_ms": round(elapsed, 2)
}
except sqlite3.Error as e:
elapsed = (time.time() - start) * 1000
return {
"success": False,
"rows": None,
"row_count": None,
"error_message": str(e),
"execution_time_ms": round(elapsed, 2)
}
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
"""Return schema info: tables and their columns."""
schema = {}
cursor = self.conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
for table in tables:
cursor.execute(f"PRAGMA table_info({table})")
columns = []
for col in cursor.fetchall():
columns.append({
"name": col[1],
"type": col[2],
"nullable": "YES" if col[3] == 0 else "NO",
"primary_key": "YES" if col[5] > 0 else "NO"
})
schema[table] = columns
return schema
def get_sample_rows(self, table_name: str, limit: int = 3) -> List[Dict[str, Any]]:
"""Get sample rows from a table."""
result = self.execute_query(f"SELECT * FROM {table_name} LIMIT {limit}")
return result.get("rows", []) or []
def close(self):
self.conn.close()
|