"""WebSocket EnvClient for CORP-ENV.""" from __future__ import annotations from typing import Any, Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from corp_env.models import CorpAction, CorpObservation class CorpEnvClient(EnvClient[CorpAction, CorpObservation, State]): """Client for a running corp-env server (persistent WebSocket session).""" def _step_payload(self, action: CorpAction) -> Dict[str, Any]: return action.model_dump(mode="json", exclude_none=True) def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CorpObservation]: obs_data = dict(payload.get("observation", {})) meta = obs_data.pop("metadata", {}) or {} observation = CorpObservation( **obs_data, reward=payload.get("reward"), done=bool(payload.get("done", False)), metadata=meta, ) return StepResult( observation=observation, reward=payload.get("reward"), done=bool(payload.get("done", False)), ) def _parse_state(self, payload: Dict[str, Any]) -> State: return State( episode_id=payload.get("episode_id"), step_count=int(payload.get("step_count", 0)), )