Spaces:
Sleeping
Sleeping
| from typing import Optional, Dict, Any | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from .models import Observation, Action, Reward | |
| from .tasks import TASKS, grade_action, get_task | |
| from .reward import compute_reward | |
| class SQLEnv(Environment): | |
| """SQL Query Optimizer Environment following the OpenEnv interface.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self.current_task_id = None | |
| self.task = None | |
| self.step_number = 0 | |
| self.max_steps = 0 | |
| self.history = [] | |
| self.cumulative_score = 0.0 | |
| self.previous_grader_score = 0.0 | |
| self.final_grader_score = 0.0 | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: int = 1, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| task = get_task(task_id) | |
| if not task: | |
| raise ValueError(f"Task {task_id} not found.") | |
| self.current_task_id = task_id | |
| self.task = task | |
| self.step_number = 1 | |
| self.max_steps = task["max_steps"] | |
| self.history = [] | |
| self.cumulative_score = 0.0 | |
| self.previous_grader_score = 0.0 | |
| self.final_grader_score = 0.0 | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| ) | |
| obs = Observation( | |
| task_id=self.current_task_id, | |
| query=self.task["initial_query"], | |
| schema_context=self.task["schema_context"], | |
| hint=self.task["hint"], | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()}) | |
| return obs | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| if not self.task: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| grader_score, breakdown, feedback = grade_action( | |
| self.current_task_id, action.rewritten_query | |
| ) | |
| action_valid = len(action.rewritten_query.strip()) > 0 | |
| done = action.is_done or self.step_number >= self.max_steps | |
| step_reward = compute_reward( | |
| grader_score=grader_score, | |
| previous_score=self.previous_grader_score, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| is_done=done, | |
| action_valid=action_valid, | |
| ) | |
| self.cumulative_score += step_reward | |
| self.previous_grader_score = grader_score | |
| info = { | |
| "cumulative_score": self.cumulative_score, | |
| "grader_score": grader_score, | |
| "breakdown": breakdown, | |
| "feedback": feedback, | |
| } | |
| if done: | |
| self.final_grader_score = grader_score | |
| self._state.step_count += 1 | |
| obs = Observation( | |
| task_id=self.current_task_id, | |
| query=action.rewritten_query, | |
| schema_context=self.task["schema_context"], | |
| hint=self.task["hint"], | |
| step_number=self.step_number + 1, | |
| max_steps=self.max_steps, | |
| reward=step_reward, | |
| done=done, | |
| metadata=info, | |
| ) | |
| self.history.append({ | |
| "step": self.step_number, | |
| "type": "step", | |
| "action": action.model_dump(), | |
| "reward": step_reward, | |
| "done": done, | |
| "info": info, | |
| }) | |
| self.step_number += 1 | |
| return obs | |
| def state(self) -> State: | |
| return self._state | |