Spaces:
Sleeping
Sleeping
File size: 7,718 Bytes
08b82d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
SQLite database management for SQLEnv.
Handles database initialization, query execution, and schema introspection.
All operations use an in-memory SQLite database that is re-created on each
environment reset, ensuring deterministic, isolated episodes.
"""
import os
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
SCHEMA_PATH = DATA_DIR / "schema.sql"
SEED_PATH = DATA_DIR / "seed.sql"
@dataclass
class QueryResult:
"""Result of executing a SQL query."""
columns: List[str] = field(default_factory=list)
rows: List[Tuple] = field(default_factory=list)
error: Optional[str] = None
row_count: int = 0
@property
def success(self) -> bool:
return self.error is None
def to_display_string(self, max_rows: int = 20) -> str:
"""Format result as a readable table string."""
if self.error:
return f"ERROR: {self.error}"
if not self.columns:
return "(no results)"
# Calculate column widths
col_widths = [len(str(c)) for c in self.columns]
display_rows = self.rows[:max_rows]
for row in display_rows:
for i, val in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(val)))
# Build table
header = " | ".join(
str(c).ljust(col_widths[i]) for i, c in enumerate(self.columns)
)
separator = "-+-".join("-" * w for w in col_widths)
lines = [header, separator]
for row in display_rows:
line = " | ".join(
str(val).ljust(col_widths[i]) for i, val in enumerate(row)
)
lines.append(line)
if len(self.rows) > max_rows:
lines.append(f"... ({len(self.rows) - max_rows} more rows)")
lines.append(f"\n({self.row_count} row{'s' if self.row_count != 1 else ''})")
return "\n".join(lines)
class Database:
"""
Manages an in-memory SQLite database for one episode.
Each call to `initialize()` creates a fresh database with the schema
and seed data, ensuring deterministic state across episodes.
"""
def __init__(self):
self._conn: Optional[sqlite3.Connection] = None
def initialize(self) -> None:
"""Create a fresh in-memory database with schema and seed data."""
self.close()
self._conn = sqlite3.connect(":memory:")
self._conn.execute("PRAGMA foreign_keys = ON")
schema_sql = SCHEMA_PATH.read_text()
self._conn.executescript(schema_sql)
seed_sql = SEED_PATH.read_text()
self._conn.executescript(seed_sql)
self._conn.commit()
def execute_query(self, sql: str, timeout_seconds: float = 5.0) -> QueryResult:
"""
Execute a SQL query and return the result.
Only SELECT statements are allowed. Modification statements
(INSERT, UPDATE, DELETE, DROP, ALTER, CREATE) are rejected.
Args:
sql: The SQL query string to execute.
timeout_seconds: Max execution time (unused for SQLite in-memory).
Returns:
QueryResult with columns, rows, and potential error.
"""
if self._conn is None:
return QueryResult(error="Database not initialized. Call reset() first.")
# Strip and normalize
stripped = sql.strip().rstrip(";").strip()
if not stripped:
return QueryResult(error="Empty query.")
# Block modification statements
first_word = stripped.split()[0].upper()
blocked = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE"}
if first_word in blocked:
return QueryResult(
error=f"Only SELECT queries are allowed. Got: {first_word}"
)
try:
cursor = self._conn.execute(stripped)
if cursor.description is None:
return QueryResult(error="Query did not return results.")
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return QueryResult(
columns=columns,
rows=rows,
row_count=len(rows),
)
except sqlite3.Error as e:
return QueryResult(error=str(e))
def get_schema_description(self) -> str:
"""
Return a human-readable description of the database schema
including table structures and sample data.
"""
schema_text = []
schema_text.append("=== DATABASE SCHEMA ===\n")
tables = [
("customers", "Customer information"),
("products", "Product catalog"),
("orders", "Customer orders"),
("order_items", "Items within each order"),
("reviews", "Product reviews by customers"),
]
if self._conn is None:
return "Database not initialized."
for table_name, description in tables:
schema_text.append(f"TABLE: {table_name} -- {description}")
# Get column info
cursor = self._conn.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
for col in columns:
# col: (cid, name, type, notnull, default_value, pk)
col_name = col[1]
col_type = col[2]
is_pk = " PRIMARY KEY" if col[5] else ""
is_nn = " NOT NULL" if col[3] else ""
schema_text.append(f" {col_name} {col_type}{is_pk}{is_nn}")
# Get foreign keys
cursor = self._conn.execute(f"PRAGMA foreign_key_list({table_name})")
fks = cursor.fetchall()
for fk in fks:
schema_text.append(f" FOREIGN KEY ({fk[3]}) REFERENCES {fk[2]}({fk[4]})")
# Show sample data (first 3 rows)
result = self.execute_query(f"SELECT * FROM {table_name} LIMIT 3")
if result.success and result.rows:
schema_text.append(f" Sample data ({result.row_count} rows shown):")
for row in result.rows:
schema_text.append(f" {row}")
# Show total count
count_result = self.execute_query(
f"SELECT COUNT(*) FROM {table_name}"
)
if count_result.success and count_result.rows:
total = count_result.rows[0][0]
schema_text.append(f" Total rows: {total}")
schema_text.append("")
# Add relationship summary
schema_text.append("=== RELATIONSHIPS ===")
schema_text.append("orders.customer_id -> customers.id")
schema_text.append("order_items.order_id -> orders.id")
schema_text.append("order_items.product_id -> products.id")
schema_text.append("reviews.product_id -> products.id")
schema_text.append("reviews.customer_id -> customers.id")
schema_text.append("")
schema_text.append("=== NOTES ===")
schema_text.append("- All dates are in ISO format (YYYY-MM-DD)")
schema_text.append("- Prices are in INR (Indian Rupees)")
schema_text.append("- Order status: pending, shipped, delivered, cancelled")
schema_text.append("- Product categories: Electronics, Clothing, Books, Home")
schema_text.append("- Ratings are integers from 1 to 5")
return "\n".join(schema_text)
def close(self) -> None:
"""Close the database connection."""
if self._conn is not None:
self._conn.close()
self._conn = None
|