from typing import Optional, Dict, Any from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from .models import Observation, Action, Reward from .tasks import TASKS, grade_action, get_task from .reward import compute_reward class SQLEnv(Environment): """SQL Query Optimizer Environment following the OpenEnv interface.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): self.current_task_id = None self.task = None self.step_number = 0 self.max_steps = 0 self.history = [] self.cumulative_score = 0.0 self.previous_grader_score = 0.0 self.final_grader_score = 0.0 self._state = State(episode_id=str(uuid4()), step_count=0) def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: int = 1, **kwargs: Any, ) -> Observation: task = get_task(task_id) if not task: raise ValueError(f"Task {task_id} not found.") self.current_task_id = task_id self.task = task self.step_number = 1 self.max_steps = task["max_steps"] self.history = [] self.cumulative_score = 0.0 self.previous_grader_score = 0.0 self.final_grader_score = 0.0 self._state = State( episode_id=episode_id or str(uuid4()), step_count=0, ) obs = Observation( task_id=self.current_task_id, query=self.task["initial_query"], schema_context=self.task["schema_context"], hint=self.task["hint"], step_number=self.step_number, max_steps=self.max_steps, reward=0.0, done=False, ) self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()}) return obs def step( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: if not self.task: raise RuntimeError("Environment not initialized. Call reset() first.") grader_score, breakdown, feedback = grade_action( self.current_task_id, action.rewritten_query ) action_valid = len(action.rewritten_query.strip()) > 0 done = action.is_done or self.step_number >= self.max_steps step_reward = compute_reward( grader_score=grader_score, previous_score=self.previous_grader_score, step_number=self.step_number, max_steps=self.max_steps, is_done=done, action_valid=action_valid, ) self.cumulative_score += step_reward self.previous_grader_score = grader_score info = { "cumulative_score": self.cumulative_score, "grader_score": grader_score, "breakdown": breakdown, "feedback": feedback, } if done: self.final_grader_score = grader_score self._state.step_count += 1 obs = Observation( task_id=self.current_task_id, query=action.rewritten_query, schema_context=self.task["schema_context"], hint=self.task["hint"], step_number=self.step_number + 1, max_steps=self.max_steps, reward=step_reward, done=done, metadata=info, ) self.history.append({ "step": self.step_number, "type": "step", "action": action.model_dump(), "reward": step_reward, "done": done, "info": info, }) self.step_number += 1 return obs @property def state(self) -> State: return self._state