Spaces:
Sleeping
Sleeping
| """Typed WebSocket client for CacheInvalidationEnvironment.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from env.models import CacheAction, CacheObservation, CacheState | |
| class CacheInvalidationEnvClient(EnvClient[CacheAction, CacheObservation, CacheState]): | |
| def _step_payload(self, action: CacheAction | Dict[str, Any]) -> Dict[str, Any]: | |
| if isinstance(action, CacheAction): | |
| return action.model_dump() | |
| return CacheAction.model_validate(action).model_dump() | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CacheObservation]: | |
| obs_inner = payload.get("observation", {}) | |
| return StepResult( | |
| observation=CacheObservation.model_validate( | |
| {**obs_inner, "reward": payload.get("reward"), "done": payload.get("done", False)} | |
| ), | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> CacheState: | |
| return CacheState.model_validate(payload) | |