sql-repair-env / sql_env /env_core.py
krishpotanwar's picture
feat: SQL Repair OpenEnv submission β€” Phase 1 validator passes
269f632
"""SQLite-backed environment state for SQL repair tasks.
The env exposes a minimal Gym-like API:
reset(task_id) -> observation dict
step(action) -> {observation, reward, done, info}
Per-task state is held in this single instance for simplicity. The
validator only needs one parallel run.
"""
from __future__ import annotations
import sqlite3
from typing import Any, Dict, List, Optional
from .tasks import TASKS, TASK_IDS
MAX_STEPS = 6
def _new_db(task_id: str) -> sqlite3.Connection:
"""Build a fresh in-memory DB for the given task."""
if task_id not in TASKS:
raise KeyError(f"Unknown task_id: {task_id}")
conn = sqlite3.connect(":memory:")
cur = conn.cursor()
for stmt in TASKS[task_id]["schema"]:
cur.execute(stmt)
conn.commit()
return conn
def _run_query(task_id: str, query: str) -> Dict[str, Any]:
"""Execute a query against a fresh DB; return rows or error info."""
conn = _new_db(task_id)
try:
cur = conn.execute(query)
rows = cur.fetchall()
col_names = [d[0] for d in cur.description] if cur.description else []
return {"ok": True, "rows": rows, "columns": col_names, "error": None}
except Exception as exc:
return {"ok": False, "rows": None, "columns": [], "error": str(exc)}
finally:
conn.close()
def _expected_rows(task_id: str) -> List[tuple]:
"""Compute the canonical (expected) result set for a task."""
res = _run_query(task_id, TASKS[task_id]["canonical_query"])
if not res["ok"]:
# Should never happen β€” canonical queries are vetted in tests.
raise RuntimeError(
f"Canonical query for {task_id} failed: {res['error']}"
)
return res["rows"]
class EnvState:
"""Mutable per-session env state. One instance handles all tasks."""
def __init__(self) -> None:
self.task_id: Optional[str] = None
self.step_count: int = 0
self.last_query: Optional[str] = None
self.last_error: Optional[str] = None
self.last_result: Optional[List[tuple]] = None
self.solved: bool = False
self.expected_rows: List[tuple] = []
self.expected_columns: int = 0
# ------------------------------------------------------------------
def reset(self, task_id: Optional[str] = None) -> Dict[str, Any]:
tid = task_id or "task_1"
if tid not in TASKS:
tid = "task_1"
task = TASKS[tid]
self.task_id = tid
self.step_count = 0
self.last_query = None
self.last_error = None
self.last_result = None
self.solved = False
self.expected_rows = _expected_rows(tid)
self.expected_columns = (
len(self.expected_rows[0]) if self.expected_rows else 0
)
# Surface what the broken query actually does, so the agent has
# an error message and a canonical "what went wrong" hint.
baseline = _run_query(tid, task["broken_query"])
return {
"task_id": tid,
"name": task["name"],
"difficulty": task["difficulty"],
"schema_sql": "\n".join(task["schema"]),
"broken_query": task["broken_query"],
"broken_query_error": baseline["error"],
"broken_query_executes": baseline["ok"],
"hint": task["hint"],
"expected_row_count": len(self.expected_rows),
"expected_column_count": self.expected_columns,
"step_count": 0,
"max_steps": MAX_STEPS,
"remaining_steps": MAX_STEPS,
}
# ------------------------------------------------------------------
def step(self, action: Dict[str, Any]) -> Dict[str, Any]:
if self.task_id is None:
return {
"observation": {"error": "No active task. Call /reset first."},
"reward": 0.0,
"done": True,
"info": {"solved": False, "no_active_task": True},
}
self.step_count += 1
action_type = (action or {}).get("action_type", "submit_query")
query = ((action or {}).get("query") or "").strip()
self.last_query = query
reward = 0.0
result_rows: Optional[List[tuple]] = None
error: Optional[str] = None
if action_type != "submit_query":
error = f"Unsupported action_type: {action_type}"
reward = -0.05
elif not query:
error = "Empty query string."
reward = -0.05
else:
res = _run_query(self.task_id, query)
if res["ok"]:
result_rows = res["rows"]
self.last_result = result_rows
self.last_error = None
if result_rows == self.expected_rows:
reward = 1.0
self.solved = True
else:
# executed but wrong rows β€” small positive reward
reward = 0.4
else:
error = res["error"]
self.last_error = error
self.last_result = None
reward = -0.10
done = self.solved or self.step_count >= MAX_STEPS
observation = {
"task_id": self.task_id,
"step_count": self.step_count,
"submitted_query": query,
"error": error,
"executed": error is None and result_rows is not None,
"matches_expected": (
result_rows == self.expected_rows if result_rows is not None else False
),
"result_row_count": len(result_rows) if result_rows is not None else 0,
"expected_row_count": len(self.expected_rows),
"result_preview": result_rows[:3] if result_rows else None,
"expected_preview": self.expected_rows[:3],
"remaining_steps": max(0, MAX_STEPS - self.step_count),
}
return {
"observation": observation,
"reward": float(reward),
"done": bool(done),
"info": {"solved": self.solved},
}