sql-query-optimizer / env /environment.py
jaivardhan2409's picture
Upload folder using huggingface_hub
126939a verified
from typing import Optional, Dict, Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from .models import Observation, Action, Reward
from .tasks import TASKS, grade_action, get_task
from .reward import compute_reward
class SQLEnv(Environment):
"""SQL Query Optimizer Environment following the OpenEnv interface."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self.current_task_id = None
self.task = None
self.step_number = 0
self.max_steps = 0
self.history = []
self.cumulative_score = 0.0
self.previous_grader_score = 0.0
self.final_grader_score = 0.0
self._state = State(episode_id=str(uuid4()), step_count=0)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: int = 1,
**kwargs: Any,
) -> Observation:
task = get_task(task_id)
if not task:
raise ValueError(f"Task {task_id} not found.")
self.current_task_id = task_id
self.task = task
self.step_number = 1
self.max_steps = task["max_steps"]
self.history = []
self.cumulative_score = 0.0
self.previous_grader_score = 0.0
self.final_grader_score = 0.0
self._state = State(
episode_id=episode_id or str(uuid4()),
step_count=0,
)
obs = Observation(
task_id=self.current_task_id,
query=self.task["initial_query"],
schema_context=self.task["schema_context"],
hint=self.task["hint"],
step_number=self.step_number,
max_steps=self.max_steps,
reward=0.0,
done=False,
)
self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
return obs
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
if not self.task:
raise RuntimeError("Environment not initialized. Call reset() first.")
grader_score, breakdown, feedback = grade_action(
self.current_task_id, action.rewritten_query
)
action_valid = len(action.rewritten_query.strip()) > 0
done = action.is_done or self.step_number >= self.max_steps
step_reward = compute_reward(
grader_score=grader_score,
previous_score=self.previous_grader_score,
step_number=self.step_number,
max_steps=self.max_steps,
is_done=done,
action_valid=action_valid,
)
self.cumulative_score += step_reward
self.previous_grader_score = grader_score
info = {
"cumulative_score": self.cumulative_score,
"grader_score": grader_score,
"breakdown": breakdown,
"feedback": feedback,
}
if done:
self.final_grader_score = grader_score
self._state.step_count += 1
obs = Observation(
task_id=self.current_task_id,
query=action.rewritten_query,
schema_context=self.task["schema_context"],
hint=self.task["hint"],
step_number=self.step_number + 1,
max_steps=self.max_steps,
reward=step_reward,
done=done,
metadata=info,
)
self.history.append({
"step": self.step_number,
"type": "step",
"action": action.model_dump(),
"reward": step_reward,
"done": done,
"info": info,
})
self.step_number += 1
return obs
@property
def state(self) -> State:
return self._state