Spaces:
Sleeping
Sleeping
File size: 7,434 Bytes
08b82d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | """
SQL Query Writing Environment.
An AI agent receives a database schema and natural language question,
then writes SQL queries to answer the question. The environment grades
each query with partial-credit scoring and provides feedback.
"""
import json
import os
from pathlib import Path
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import SQLAction, SQLObservation
except ImportError:
from models import SQLAction, SQLObservation
from .database import Database
from .graders import grade_query
TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
# Default task can be overridden via environment variable
DEFAULT_TASK = os.getenv("SQL_ENV_TASK", "basic_select")
MAX_TOTAL_STEPS = int(os.getenv("SQL_ENV_MAX_STEPS", "15"))
STEP_PENALTY = float(os.getenv("SQL_ENV_STEP_PENALTY", "0.02"))
def _load_task(task_name: str) -> dict:
"""Load a task definition from JSON file."""
task_path = TASKS_DIR / f"{task_name}.json"
if not task_path.exists():
available = [f.stem for f in TASKS_DIR.glob("*.json")]
raise ValueError(
f"Task '{task_name}' not found. Available: {available}"
)
with open(task_path) as f:
return json.load(f)
class SQLEnvironment(Environment):
"""
SQL Query Writing Environment.
The agent interacts with an e-commerce SQLite database by submitting
SQL queries to answer natural language questions. Each query is graded
with a multi-component reward function providing partial credit.
Episode flow:
1. reset() → loads task, initializes DB, returns first question
2. step(SQLAction) → executes query, grades it, returns observation
3. Episode ends when all questions answered or max steps reached
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._db = Database()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._task: dict = {}
self._questions: list = []
self._current_q_index: int = 0
self._q_steps_used: int = 0
self._max_steps_per_q: int = 3
self._total_steps: int = 0
self._rewards: list = []
self._schema_cache: str = ""
self._done: bool = False
self._last_feedback: str = ""
def reset(self) -> SQLObservation:
"""
Reset the environment: initialize DB, load task, return first question.
"""
self._db.initialize()
self._state = State(episode_id=str(uuid4()), step_count=0)
task_name = os.getenv("SQL_ENV_TASK", DEFAULT_TASK)
self._task = _load_task(task_name)
self._questions = self._task["questions"]
self._max_steps_per_q = self._task.get("max_steps_per_question", 3)
self._current_q_index = 0
self._q_steps_used = 0
self._total_steps = 0
self._rewards = []
self._done = False
self._last_feedback = ""
self._schema_cache = self._db.get_schema_description()
return self._make_observation(
reward=0.0,
query_result="",
error="",
)
def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
"""
Execute the agent's SQL query, grade it, and return observation.
"""
# Auto-reset if step called before reset (HTTP stateless mode)
if not self._questions:
self.reset()
if self._done or self._current_q_index >= len(self._questions):
self._done = True
return self._make_observation(
reward=0.0,
query_result="Episode is over. Call reset() to start a new episode.",
error="",
)
self._state.step_count += 1
self._total_steps += 1
self._q_steps_used += 1
# Get current question
question = self._questions[self._current_q_index]
# Grade the query
grade_result = grade_query(
db=self._db,
agent_sql=action.query,
expected_columns=question["expected_columns"],
expected_rows=question["expected_rows"],
order_matters=question.get("order_matters", True),
)
raw_reward = grade_result["reward"]
# Apply step penalty (not on first attempt)
penalty = STEP_PENALTY * (self._q_steps_used - 1)
reward = max(raw_reward - penalty, 0.0)
reward = round(reward, 4)
self._rewards.append(reward)
self._last_feedback = grade_result["feedback"]
# Format query result for observation
query_result_str = grade_result["query_result"].to_display_string()
error_str = grade_result["query_result"].error or ""
# Check if we should move to next question
perfect = grade_result["exact_score"] == 1.0
out_of_attempts = self._q_steps_used >= self._max_steps_per_q
move_on = perfect or out_of_attempts
if move_on:
self._current_q_index += 1
self._q_steps_used = 0
# Check if episode is done
if self._current_q_index >= len(self._questions):
self._done = True
if self._total_steps >= MAX_TOTAL_STEPS:
self._done = True
return self._make_observation(
reward=reward,
query_result=query_result_str,
error=error_str,
)
@property
def state(self) -> State:
return self._state
def _make_observation(
self,
reward: float,
query_result: str,
error: str,
) -> SQLObservation:
"""Build an SQLObservation for the current state."""
if self._done or not self._questions or self._current_q_index >= len(self._questions):
# Episode finished or not started
return SQLObservation(
task_name=self._task.get("task_name", ""),
question="Episode complete. All questions answered.",
schema_description="",
query_result=query_result,
error=error,
steps_remaining=0,
question_index=len(self._questions),
total_questions=len(self._questions),
done=True,
reward=reward,
metadata={
"feedback": self._last_feedback,
"total_reward": round(sum(self._rewards), 4),
"rewards": [round(r, 4) for r in self._rewards],
},
)
question = self._questions[self._current_q_index]
steps_remaining = self._max_steps_per_q - self._q_steps_used
return SQLObservation(
task_name=self._task.get("task_name", ""),
question=question["question"],
schema_description=self._schema_cache,
query_result=query_result,
error=error,
steps_remaining=steps_remaining,
question_index=self._current_q_index + 1,
total_questions=len(self._questions),
done=False,
reward=reward,
metadata={
"feedback": self._last_feedback,
"question_id": question["id"],
"difficulty": self._task.get("difficulty", ""),
},
)
|