Spaces:
Running
Running
File size: 7,977 Bytes
a39d8ef 46e0615 a39d8ef 2dbf1fe 46e0615 a39d8ef 46e0615 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """
nl2sql-bench/server/environment.py
====================================
NL2SQL-Bench core environment β implements the OpenEnv Environment interface.
Episode flow
------------
1. reset(task_name?) β picks a task + question, returns initial observation
2. step(action) β executes the SQL, grades it, returns observation + reward
3. state() β returns episode metadata
4. Episode ends when: exact_match OR step count reaches max_steps
The environment manages its own SQLite connection (in-memory, seeded
deterministically). One connection per Environment instance; the FastAPI
server creates one Environment per WebSocket session.
"""
from __future__ import annotations
import os
import sqlite3
import uuid
from pathlib import Path
from typing import Optional
from openenv.core.env_server import Environment
# Import after openenv so path is correct regardless of working directory
_HERE = Path(__file__).parent
# Lazy import of task registry (avoids circular imports)
from tasks import get_task, all_task_names, BaseTask
from tasks.base import TaskExample
from grader import (
GradeResult,
compute_ground_truth,
execute_query,
grade,
has_order_by,
)
# We import our models from one level up (models.py at project root)
import sys
sys.path.insert(0, str(_HERE.parent))
from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
# ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5"))
RESULT_LIMIT = 10 # Max rows shown to agent per step
class NL2SQLEnvironment(Environment):
"""
OpenEnv-compliant environment for NL-to-SQL query generation.
One instance per WebSocket session (created by create_fastapi_app).
"""
def __init__(self) -> None:
self._conn: Optional[sqlite3.Connection] = None
self._task: Optional[BaseTask] = None
self._example: Optional[TaskExample] = None
self._ground_truth: list = []
self._order_sensitive: bool = False
self._state = NL2SQLState(
episode_id=None,
step_count=0,
task_name="",
task_difficulty="",
question="",
best_reward=0.0,
cumulative_reward=0.0,
solved=False
)
self._last_obs = NL2SQLObservation(
question="",
schema_context="",
task_name="",
last_query="",
last_result=[],
last_error=None,
result_columns=[],
step=0,
max_steps=5,
done=False,
reward=None,
score=0.0
)
self._episode_rewards: list = []
self._setup_db()
# ββ DB lifecycle βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _setup_db(self) -> None:
"""Create in-memory SQLite DB and seed it."""
schema_path = _HERE / "db" / "schema.sql"
from db.seed import seed_database # local import after sys.path setup
conn = sqlite3.connect(":memory:", check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
conn.executescript(schema_path.read_text())
seed_database(conn)
self._conn = conn
# ββ OpenEnv interface ββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation:
"""
Start a new episode.
task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'.
Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'.
"""
task_name = task_name or DEFAULT_TASK
if task_name not in all_task_names():
task_name = DEFAULT_TASK
self._task = get_task(task_name)
self._example = self._task.next_example()
self._order_sensitive = has_order_by(self._example.sql)
# Pre-compute ground truth once per episode
self._ground_truth = compute_ground_truth(self._conn, self._example.sql)
self._episode_rewards = []
self._state = NL2SQLState(
episode_id=str(uuid.uuid4()),
step_count=0,
task_name=self._task.name,
task_difficulty=self._task.difficulty,
question=self._example.question,
best_reward=0.0,
cumulative_reward=0.0,
solved=False,
)
obs = NL2SQLObservation(
question=self._example.question,
schema_context=self._task.schema_context(),
task_name=self._task.name,
last_query="",
last_result=[],
last_error=None,
result_columns=[],
step=0,
max_steps=MAX_STEPS,
done=False,
reward=None,
score=0.0,
)
self._last_obs = obs
return obs
def step(self, action: NL2SQLAction) -> NL2SQLObservation:
"""Execute the agent's SQL and return graded observation."""
if self._task is None or self._example is None:
# Called before reset β auto-reset
self.reset()
self._state.step_count += 1
current_step = self._state.step_count
done = False
# Execute the query
rows, error = execute_query(self._conn, action.query)
# Grade it
result: GradeResult = grade(
actual_rows=rows,
ground_truth_rows=self._ground_truth,
error=error,
step=current_step,
order_sensitive=self._order_sensitive,
)
reward = result.reward
self._episode_rewards.append(reward)
self._state.cumulative_reward += reward
self._state.best_reward = max(self._state.best_reward, reward)
if result.exact_match:
self._state.solved = True
done = True
elif current_step >= MAX_STEPS:
done = True
# Prepare result rows for observation (truncated for agent readability)
display_rows = (rows or [])[:RESULT_LIMIT]
result_columns = list(display_rows[0].keys()) if display_rows else []
# Convert sqlite3.Row objects if needed
display_rows = [dict(r) for r in display_rows]
# Normalised cumulative score
n = len(self._episode_rewards)
score = self._state.cumulative_reward / max(n, 1) if n else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
obs = NL2SQLObservation(
question=self._example.question,
schema_context=self._task.schema_context(),
task_name=self._task.name,
last_query=action.query,
last_result=display_rows,
last_error=error,
result_columns=result_columns,
step=current_step,
max_steps=MAX_STEPS,
done=done,
reward=reward,
score=score,
)
self._last_obs = obs
# openenv-core expects ONLY the observation returned from step().
# The framework reads obs.reward and obs.done itself β do NOT return a tuple.
return obs
@property
def state(self) -> NL2SQLState:
return self._state
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def available_tasks(self) -> list:
return all_task_names() |