0x960 / src /zero960_env /server /environment.py
qtzx06's picture
feat: finalize swarm tooling and submission artifacts
eac9d9f
from __future__ import annotations
from typing import Any, Optional
from uuid import uuid4
from openenv.core.env_server import Environment
from openenv.core.env_server.types import State
from zero960.runtime.episode import EpisodeConfig, Zero960EpisodeRuntime
from zero960.runtime.types import RuntimeAction
from zero960_env.models import Zero960Action, Zero960Observation
class Zero960Environment(Environment[Zero960Action, Zero960Observation, State]):
def __init__(self) -> None:
super().__init__()
self.runtime = Zero960EpisodeRuntime(EpisodeConfig())
self._state = State(episode_id=str(uuid4()), step_count=0)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> Zero960Observation:
eid = episode_id or str(uuid4())
self._state = State(episode_id=eid, step_count=0)
observation = self.runtime.reset(chess960_index=seed)
return Zero960Observation(
task=observation.task,
status_message=observation.status_message,
file_contents=observation.file_contents,
start_position=observation.start_position,
history=observation.history,
remaining_steps=observation.remaining_steps,
last_match_score=observation.last_match_score,
invalid_edit_count=observation.invalid_edit_count,
workflow_hint=observation.workflow_hint,
suggested_actions=observation.suggested_actions,
has_valid_edit=observation.has_valid_edit,
has_run_match=observation.has_run_match,
)
def step(
self,
action: Zero960Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Zero960Observation:
result = self.runtime.step(
RuntimeAction(
action_type=action.action_type,
path=action.path,
content=action.content,
)
)
self._state.step_count += 1
obs = result.observation
return Zero960Observation(
task=obs.task,
status_message=obs.status_message,
file_contents=obs.file_contents,
start_position=obs.start_position,
history=obs.history,
remaining_steps=obs.remaining_steps,
last_match_score=obs.last_match_score,
invalid_edit_count=obs.invalid_edit_count,
workflow_hint=obs.workflow_hint,
suggested_actions=obs.suggested_actions,
has_valid_edit=obs.has_valid_edit,
has_run_match=obs.has_run_match,
reward=obs.reward,
done=obs.done,
)
@property
def state(self) -> State:
return self._state