Spaces:
Sleeping
Sleeping
File size: 3,879 Bytes
126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a e4c32ce 126939a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from typing import Optional, Dict, Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from .models import Observation, Action, Reward
from .tasks import TASKS, grade_action, get_task
from .reward import compute_reward
class SQLEnv(Environment):
"""SQL Query Optimizer Environment following the OpenEnv interface."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self.current_task_id = None
self.task = None
self.step_number = 0
self.max_steps = 0
self.history = []
self.cumulative_score = 0.0
self.previous_grader_score = 0.0
self.final_grader_score = 0.0
self._state = State(episode_id=str(uuid4()), step_count=0)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: int = 1,
**kwargs: Any,
) -> Observation:
task = get_task(task_id)
if not task:
raise ValueError(f"Task {task_id} not found.")
self.current_task_id = task_id
self.task = task
self.step_number = 1
self.max_steps = task["max_steps"]
self.history = []
self.cumulative_score = 0.0
self.previous_grader_score = 0.0
self.final_grader_score = 0.0
self._state = State(
episode_id=episode_id or str(uuid4()),
step_count=0,
)
obs = Observation(
task_id=self.current_task_id,
query=self.task["initial_query"],
schema_context=self.task["schema_context"],
hint=self.task["hint"],
step_number=self.step_number,
max_steps=self.max_steps,
reward=0.0,
done=False,
)
self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
return obs
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
if not self.task:
raise RuntimeError("Environment not initialized. Call reset() first.")
grader_score, breakdown, feedback = grade_action(
self.current_task_id, action.rewritten_query
)
action_valid = len(action.rewritten_query.strip()) > 0
done = action.is_done or self.step_number >= self.max_steps
step_reward = compute_reward(
grader_score=grader_score,
previous_score=self.previous_grader_score,
step_number=self.step_number,
max_steps=self.max_steps,
is_done=done,
action_valid=action_valid,
)
self.cumulative_score += step_reward
self.previous_grader_score = grader_score
info = {
"cumulative_score": self.cumulative_score,
"grader_score": grader_score,
"breakdown": breakdown,
"feedback": feedback,
}
if done:
self.final_grader_score = grader_score
self._state.step_count += 1
obs = Observation(
task_id=self.current_task_id,
query=action.rewritten_query,
schema_context=self.task["schema_context"],
hint=self.task["hint"],
step_number=self.step_number + 1,
max_steps=self.max_steps,
reward=step_reward,
done=done,
metadata=info,
)
self.history.append({
"step": self.step_number,
"type": "step",
"action": action.model_dump(),
"reward": step_reward,
"done": done,
"info": info,
})
self.step_number += 1
return obs
@property
def state(self) -> State:
return self._state
|