""" SQLite database management for SQLEnv. Handles database initialization, query execution, and schema introspection. All operations use an in-memory SQLite database that is re-created on each environment reset, ensuring deterministic, isolated episodes. """ import os import sqlite3 from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple DATA_DIR = Path(__file__).resolve().parent.parent / "data" SCHEMA_PATH = DATA_DIR / "schema.sql" SEED_PATH = DATA_DIR / "seed.sql" @dataclass class QueryResult: """Result of executing a SQL query.""" columns: List[str] = field(default_factory=list) rows: List[Tuple] = field(default_factory=list) error: Optional[str] = None row_count: int = 0 @property def success(self) -> bool: return self.error is None def to_display_string(self, max_rows: int = 20) -> str: """Format result as a readable table string.""" if self.error: return f"ERROR: {self.error}" if not self.columns: return "(no results)" # Calculate column widths col_widths = [len(str(c)) for c in self.columns] display_rows = self.rows[:max_rows] for row in display_rows: for i, val in enumerate(row): col_widths[i] = max(col_widths[i], len(str(val))) # Build table header = " | ".join( str(c).ljust(col_widths[i]) for i, c in enumerate(self.columns) ) separator = "-+-".join("-" * w for w in col_widths) lines = [header, separator] for row in display_rows: line = " | ".join( str(val).ljust(col_widths[i]) for i, val in enumerate(row) ) lines.append(line) if len(self.rows) > max_rows: lines.append(f"... ({len(self.rows) - max_rows} more rows)") lines.append(f"\n({self.row_count} row{'s' if self.row_count != 1 else ''})") return "\n".join(lines) class Database: """ Manages an in-memory SQLite database for one episode. Each call to `initialize()` creates a fresh database with the schema and seed data, ensuring deterministic state across episodes. """ def __init__(self): self._conn: Optional[sqlite3.Connection] = None def initialize(self) -> None: """Create a fresh in-memory database with schema and seed data.""" self.close() self._conn = sqlite3.connect(":memory:") self._conn.execute("PRAGMA foreign_keys = ON") schema_sql = SCHEMA_PATH.read_text() self._conn.executescript(schema_sql) seed_sql = SEED_PATH.read_text() self._conn.executescript(seed_sql) self._conn.commit() def execute_query(self, sql: str, timeout_seconds: float = 5.0) -> QueryResult: """ Execute a SQL query and return the result. Only SELECT statements are allowed. Modification statements (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE) are rejected. Args: sql: The SQL query string to execute. timeout_seconds: Max execution time (unused for SQLite in-memory). Returns: QueryResult with columns, rows, and potential error. """ if self._conn is None: return QueryResult(error="Database not initialized. Call reset() first.") # Strip and normalize stripped = sql.strip().rstrip(";").strip() if not stripped: return QueryResult(error="Empty query.") # Block modification statements first_word = stripped.split()[0].upper() blocked = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE"} if first_word in blocked: return QueryResult( error=f"Only SELECT queries are allowed. Got: {first_word}" ) try: cursor = self._conn.execute(stripped) if cursor.description is None: return QueryResult(error="Query did not return results.") columns = [desc[0] for desc in cursor.description] rows = cursor.fetchall() return QueryResult( columns=columns, rows=rows, row_count=len(rows), ) except sqlite3.Error as e: return QueryResult(error=str(e)) def get_schema_description(self) -> str: """ Return a human-readable description of the database schema including table structures and sample data. """ schema_text = [] schema_text.append("=== DATABASE SCHEMA ===\n") tables = [ ("customers", "Customer information"), ("products", "Product catalog"), ("orders", "Customer orders"), ("order_items", "Items within each order"), ("reviews", "Product reviews by customers"), ] if self._conn is None: return "Database not initialized." for table_name, description in tables: schema_text.append(f"TABLE: {table_name} -- {description}") # Get column info cursor = self._conn.execute(f"PRAGMA table_info({table_name})") columns = cursor.fetchall() for col in columns: # col: (cid, name, type, notnull, default_value, pk) col_name = col[1] col_type = col[2] is_pk = " PRIMARY KEY" if col[5] else "" is_nn = " NOT NULL" if col[3] else "" schema_text.append(f" {col_name} {col_type}{is_pk}{is_nn}") # Get foreign keys cursor = self._conn.execute(f"PRAGMA foreign_key_list({table_name})") fks = cursor.fetchall() for fk in fks: schema_text.append(f" FOREIGN KEY ({fk[3]}) REFERENCES {fk[2]}({fk[4]})") # Show sample data (first 3 rows) result = self.execute_query(f"SELECT * FROM {table_name} LIMIT 3") if result.success and result.rows: schema_text.append(f" Sample data ({result.row_count} rows shown):") for row in result.rows: schema_text.append(f" {row}") # Show total count count_result = self.execute_query( f"SELECT COUNT(*) FROM {table_name}" ) if count_result.success and count_result.rows: total = count_result.rows[0][0] schema_text.append(f" Total rows: {total}") schema_text.append("") # Add relationship summary schema_text.append("=== RELATIONSHIPS ===") schema_text.append("orders.customer_id -> customers.id") schema_text.append("order_items.order_id -> orders.id") schema_text.append("order_items.product_id -> products.id") schema_text.append("reviews.product_id -> products.id") schema_text.append("reviews.customer_id -> customers.id") schema_text.append("") schema_text.append("=== NOTES ===") schema_text.append("- All dates are in ISO format (YYYY-MM-DD)") schema_text.append("- Prices are in INR (Indian Rupees)") schema_text.append("- Order status: pending, shipped, delivered, cancelled") schema_text.append("- Product categories: Electronics, Clothing, Books, Home") schema_text.append("- Ratings are integers from 1 to 5") return "\n".join(schema_text) def close(self) -> None: """Close the database connection.""" if self._conn is not None: self._conn.close() self._conn = None