File size: 7,718 Bytes
08b82d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
SQLite database management for SQLEnv.

Handles database initialization, query execution, and schema introspection.
All operations use an in-memory SQLite database that is re-created on each
environment reset, ensuring deterministic, isolated episodes.
"""

import os
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple

DATA_DIR = Path(__file__).resolve().parent.parent / "data"
SCHEMA_PATH = DATA_DIR / "schema.sql"
SEED_PATH = DATA_DIR / "seed.sql"


@dataclass
class QueryResult:
    """Result of executing a SQL query."""

    columns: List[str] = field(default_factory=list)
    rows: List[Tuple] = field(default_factory=list)
    error: Optional[str] = None
    row_count: int = 0

    @property
    def success(self) -> bool:
        return self.error is None

    def to_display_string(self, max_rows: int = 20) -> str:
        """Format result as a readable table string."""
        if self.error:
            return f"ERROR: {self.error}"
        if not self.columns:
            return "(no results)"

        # Calculate column widths
        col_widths = [len(str(c)) for c in self.columns]
        display_rows = self.rows[:max_rows]
        for row in display_rows:
            for i, val in enumerate(row):
                col_widths[i] = max(col_widths[i], len(str(val)))

        # Build table
        header = " | ".join(
            str(c).ljust(col_widths[i]) for i, c in enumerate(self.columns)
        )
        separator = "-+-".join("-" * w for w in col_widths)
        lines = [header, separator]

        for row in display_rows:
            line = " | ".join(
                str(val).ljust(col_widths[i]) for i, val in enumerate(row)
            )
            lines.append(line)

        if len(self.rows) > max_rows:
            lines.append(f"... ({len(self.rows) - max_rows} more rows)")

        lines.append(f"\n({self.row_count} row{'s' if self.row_count != 1 else ''})")
        return "\n".join(lines)


class Database:
    """
    Manages an in-memory SQLite database for one episode.

    Each call to `initialize()` creates a fresh database with the schema
    and seed data, ensuring deterministic state across episodes.
    """

    def __init__(self):
        self._conn: Optional[sqlite3.Connection] = None

    def initialize(self) -> None:
        """Create a fresh in-memory database with schema and seed data."""
        self.close()
        self._conn = sqlite3.connect(":memory:")
        self._conn.execute("PRAGMA foreign_keys = ON")

        schema_sql = SCHEMA_PATH.read_text()
        self._conn.executescript(schema_sql)

        seed_sql = SEED_PATH.read_text()
        self._conn.executescript(seed_sql)

        self._conn.commit()

    def execute_query(self, sql: str, timeout_seconds: float = 5.0) -> QueryResult:
        """
        Execute a SQL query and return the result.

        Only SELECT statements are allowed. Modification statements
        (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE) are rejected.

        Args:
            sql: The SQL query string to execute.
            timeout_seconds: Max execution time (unused for SQLite in-memory).

        Returns:
            QueryResult with columns, rows, and potential error.
        """
        if self._conn is None:
            return QueryResult(error="Database not initialized. Call reset() first.")

        # Strip and normalize
        stripped = sql.strip().rstrip(";").strip()
        if not stripped:
            return QueryResult(error="Empty query.")

        # Block modification statements
        first_word = stripped.split()[0].upper()
        blocked = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE"}
        if first_word in blocked:
            return QueryResult(
                error=f"Only SELECT queries are allowed. Got: {first_word}"
            )

        try:
            cursor = self._conn.execute(stripped)
            if cursor.description is None:
                return QueryResult(error="Query did not return results.")

            columns = [desc[0] for desc in cursor.description]
            rows = cursor.fetchall()
            return QueryResult(
                columns=columns,
                rows=rows,
                row_count=len(rows),
            )
        except sqlite3.Error as e:
            return QueryResult(error=str(e))

    def get_schema_description(self) -> str:
        """
        Return a human-readable description of the database schema
        including table structures and sample data.
        """
        schema_text = []
        schema_text.append("=== DATABASE SCHEMA ===\n")

        tables = [
            ("customers", "Customer information"),
            ("products", "Product catalog"),
            ("orders", "Customer orders"),
            ("order_items", "Items within each order"),
            ("reviews", "Product reviews by customers"),
        ]

        if self._conn is None:
            return "Database not initialized."

        for table_name, description in tables:
            schema_text.append(f"TABLE: {table_name} -- {description}")

            # Get column info
            cursor = self._conn.execute(f"PRAGMA table_info({table_name})")
            columns = cursor.fetchall()
            for col in columns:
                # col: (cid, name, type, notnull, default_value, pk)
                col_name = col[1]
                col_type = col[2]
                is_pk = " PRIMARY KEY" if col[5] else ""
                is_nn = " NOT NULL" if col[3] else ""
                schema_text.append(f"  {col_name} {col_type}{is_pk}{is_nn}")

            # Get foreign keys
            cursor = self._conn.execute(f"PRAGMA foreign_key_list({table_name})")
            fks = cursor.fetchall()
            for fk in fks:
                schema_text.append(f"  FOREIGN KEY ({fk[3]}) REFERENCES {fk[2]}({fk[4]})")

            # Show sample data (first 3 rows)
            result = self.execute_query(f"SELECT * FROM {table_name} LIMIT 3")
            if result.success and result.rows:
                schema_text.append(f"  Sample data ({result.row_count} rows shown):")
                for row in result.rows:
                    schema_text.append(f"    {row}")

            # Show total count
            count_result = self.execute_query(
                f"SELECT COUNT(*) FROM {table_name}"
            )
            if count_result.success and count_result.rows:
                total = count_result.rows[0][0]
                schema_text.append(f"  Total rows: {total}")

            schema_text.append("")

        # Add relationship summary
        schema_text.append("=== RELATIONSHIPS ===")
        schema_text.append("orders.customer_id -> customers.id")
        schema_text.append("order_items.order_id -> orders.id")
        schema_text.append("order_items.product_id -> products.id")
        schema_text.append("reviews.product_id -> products.id")
        schema_text.append("reviews.customer_id -> customers.id")
        schema_text.append("")
        schema_text.append("=== NOTES ===")
        schema_text.append("- All dates are in ISO format (YYYY-MM-DD)")
        schema_text.append("- Prices are in INR (Indian Rupees)")
        schema_text.append("- Order status: pending, shipped, delivered, cancelled")
        schema_text.append("- Product categories: Electronics, Clothing, Books, Home")
        schema_text.append("- Ratings are integers from 1 to 5")

        return "\n".join(schema_text)

    def close(self) -> None:
        """Close the database connection."""
        if self._conn is not None:
            self._conn.close()
            self._conn = None