from __future__ import annotations from typing import Any, Optional from uuid import uuid4 from openenv.core.env_server import Environment from openenv.core.env_server.types import State from zero960.runtime.episode import EpisodeConfig, Zero960EpisodeRuntime from zero960.runtime.types import RuntimeAction from zero960_env.models import Zero960Action, Zero960Observation class Zero960Environment(Environment[Zero960Action, Zero960Observation, State]): def __init__(self) -> None: super().__init__() self.runtime = Zero960EpisodeRuntime(EpisodeConfig()) self._state = State(episode_id=str(uuid4()), step_count=0) def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> Zero960Observation: eid = episode_id or str(uuid4()) self._state = State(episode_id=eid, step_count=0) observation = self.runtime.reset(chess960_index=seed) return Zero960Observation( task=observation.task, status_message=observation.status_message, file_contents=observation.file_contents, start_position=observation.start_position, history=observation.history, remaining_steps=observation.remaining_steps, last_match_score=observation.last_match_score, invalid_edit_count=observation.invalid_edit_count, workflow_hint=observation.workflow_hint, suggested_actions=observation.suggested_actions, has_valid_edit=observation.has_valid_edit, has_run_match=observation.has_run_match, ) def step( self, action: Zero960Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Zero960Observation: result = self.runtime.step( RuntimeAction( action_type=action.action_type, path=action.path, content=action.content, ) ) self._state.step_count += 1 obs = result.observation return Zero960Observation( task=obs.task, status_message=obs.status_message, file_contents=obs.file_contents, start_position=obs.start_position, history=obs.history, remaining_steps=obs.remaining_steps, last_match_score=obs.last_match_score, invalid_edit_count=obs.invalid_edit_count, workflow_hint=obs.workflow_hint, suggested_actions=obs.suggested_actions, has_valid_edit=obs.has_valid_edit, has_run_match=obs.has_run_match, reward=obs.reward, done=obs.done, ) @property def state(self) -> State: return self._state