""" SQL Query Writing Environment. An AI agent receives a database schema and natural language question, then writes SQL queries to answer the question. The environment grades each query with partial-credit scoring and provides feedback. """ import json import os from pathlib import Path from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..models import SQLAction, SQLObservation except ImportError: from models import SQLAction, SQLObservation from .database import Database from .graders import grade_query TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks" # Default task can be overridden via environment variable DEFAULT_TASK = os.getenv("SQL_ENV_TASK", "basic_select") MAX_TOTAL_STEPS = int(os.getenv("SQL_ENV_MAX_STEPS", "15")) STEP_PENALTY = float(os.getenv("SQL_ENV_STEP_PENALTY", "0.02")) def _load_task(task_name: str) -> dict: """Load a task definition from JSON file.""" task_path = TASKS_DIR / f"{task_name}.json" if not task_path.exists(): available = [f.stem for f in TASKS_DIR.glob("*.json")] raise ValueError( f"Task '{task_name}' not found. Available: {available}" ) with open(task_path) as f: return json.load(f) class SQLEnvironment(Environment): """ SQL Query Writing Environment. The agent interacts with an e-commerce SQLite database by submitting SQL queries to answer natural language questions. Each query is graded with a multi-component reward function providing partial credit. Episode flow: 1. reset() → loads task, initializes DB, returns first question 2. step(SQLAction) → executes query, grades it, returns observation 3. Episode ends when all questions answered or max steps reached """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): self._db = Database() self._state = State(episode_id=str(uuid4()), step_count=0) self._task: dict = {} self._questions: list = [] self._current_q_index: int = 0 self._q_steps_used: int = 0 self._max_steps_per_q: int = 3 self._total_steps: int = 0 self._rewards: list = [] self._schema_cache: str = "" self._done: bool = False self._last_feedback: str = "" def reset(self) -> SQLObservation: """ Reset the environment: initialize DB, load task, return first question. """ self._db.initialize() self._state = State(episode_id=str(uuid4()), step_count=0) task_name = os.getenv("SQL_ENV_TASK", DEFAULT_TASK) self._task = _load_task(task_name) self._questions = self._task["questions"] self._max_steps_per_q = self._task.get("max_steps_per_question", 3) self._current_q_index = 0 self._q_steps_used = 0 self._total_steps = 0 self._rewards = [] self._done = False self._last_feedback = "" self._schema_cache = self._db.get_schema_description() return self._make_observation( reward=0.0, query_result="", error="", ) def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override] """ Execute the agent's SQL query, grade it, and return observation. """ # Auto-reset if step called before reset (HTTP stateless mode) if not self._questions: self.reset() if self._done or self._current_q_index >= len(self._questions): self._done = True return self._make_observation( reward=0.0, query_result="Episode is over. Call reset() to start a new episode.", error="", ) self._state.step_count += 1 self._total_steps += 1 self._q_steps_used += 1 # Get current question question = self._questions[self._current_q_index] # Grade the query grade_result = grade_query( db=self._db, agent_sql=action.query, expected_columns=question["expected_columns"], expected_rows=question["expected_rows"], order_matters=question.get("order_matters", True), ) raw_reward = grade_result["reward"] # Apply step penalty (not on first attempt) penalty = STEP_PENALTY * (self._q_steps_used - 1) reward = max(raw_reward - penalty, 0.0) reward = round(reward, 4) self._rewards.append(reward) self._last_feedback = grade_result["feedback"] # Format query result for observation query_result_str = grade_result["query_result"].to_display_string() error_str = grade_result["query_result"].error or "" # Check if we should move to next question perfect = grade_result["exact_score"] == 1.0 out_of_attempts = self._q_steps_used >= self._max_steps_per_q move_on = perfect or out_of_attempts if move_on: self._current_q_index += 1 self._q_steps_used = 0 # Check if episode is done if self._current_q_index >= len(self._questions): self._done = True if self._total_steps >= MAX_TOTAL_STEPS: self._done = True return self._make_observation( reward=reward, query_result=query_result_str, error=error_str, ) @property def state(self) -> State: return self._state def _make_observation( self, reward: float, query_result: str, error: str, ) -> SQLObservation: """Build an SQLObservation for the current state.""" if self._done or not self._questions or self._current_q_index >= len(self._questions): # Episode finished or not started return SQLObservation( task_name=self._task.get("task_name", ""), question="Episode complete. All questions answered.", schema_description="", query_result=query_result, error=error, steps_remaining=0, question_index=len(self._questions), total_questions=len(self._questions), done=True, reward=reward, metadata={ "feedback": self._last_feedback, "total_reward": round(sum(self._rewards), 4), "rewards": [round(r, 4) for r in self._rewards], }, ) question = self._questions[self._current_q_index] steps_remaining = self._max_steps_per_q - self._q_steps_used return SQLObservation( task_name=self._task.get("task_name", ""), question=question["question"], schema_description=self._schema_cache, query_result=query_result, error=error, steps_remaining=steps_remaining, question_index=self._current_q_index + 1, total_questions=len(self._questions), done=False, reward=reward, metadata={ "feedback": self._last_feedback, "question_id": question["id"], "difficulty": self._task.get("difficulty", ""), }, )