Spaces:
Sleeping
Sleeping
| """Typed client for the Dispatch Arena server API.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional | |
| from urllib.error import HTTPError | |
| from urllib.parse import urlencode | |
| from urllib.request import Request, urlopen | |
| from dispatch_arena.models import Action, Config, Observation, State | |
| class EnvClientError(RuntimeError): | |
| """Raised when the server returns a non-success response.""" | |
| class DispatchArenaClient: | |
| """Small typed wrapper around reset, step, state, replay, and health endpoints.""" | |
| base_url: str = "http://127.0.0.1:8080" | |
| session_id: Optional[str] = None | |
| timeout_seconds: int = 10 | |
| def create_session(self, mode: str = "mini", seed: Optional[int] = None, config: Optional[Dict[str, Any]] = None) -> Observation: | |
| data = self._post("/api/sessions", {"mode": mode, "seed": seed, "config": config or {}}) | |
| self.session_id = data["session_id"] | |
| return Observation.from_dict(data["observation"]) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| config: Optional[Config | Dict[str, Any]] = None, | |
| ) -> Observation: | |
| payload: Dict[str, Any] = { | |
| "seed": seed, | |
| "episode_id": episode_id, | |
| "session_id": self.session_id, | |
| "config": config.to_dict() if isinstance(config, Config) else config or {}, | |
| } | |
| data = self._post("/reset", payload) | |
| self.session_id = data["session_id"] | |
| return Observation.from_dict(data["observation"]) | |
| def step(self, action: Action | str | Dict[str, Any]) -> Observation: | |
| if not self.session_id: | |
| raise EnvClientError("Session not initialized. Call reset() first.") | |
| data = self._post("/step", {"session_id": self.session_id, "action": self._action_payload(action)}) | |
| return Observation.from_dict(data["observation"]) | |
| def fetch_state(self) -> State: | |
| if not self.session_id: | |
| raise EnvClientError("Session not initialized. Call reset() first.") | |
| data = self._get("/state", {"session_id": self.session_id}) | |
| return State.model_validate(data["state"]) | |
| def fetch_summary(self) -> Dict[str, Any]: | |
| if not self.session_id: | |
| raise EnvClientError("Session not initialized. Call reset() first.") | |
| data = self._get("/summary", {"session_id": self.session_id}) | |
| return dict(data["episode_summary"]) | |
| def fetch_replay(self) -> list[dict]: | |
| if not self.session_id: | |
| raise EnvClientError("Session not initialized. Call reset() first.") | |
| data = self._get(f"/api/sessions/{self.session_id}/replay") | |
| return list(data["records"]) | |
| def health(self) -> Dict[str, Any]: | |
| return self._get("/healthz") | |
| def ready(self) -> Dict[str, Any]: | |
| return self._get("/ready") | |
| def state(self) -> State: | |
| return self.fetch_state() | |
| def _action_payload(self, action: Action | str | Dict[str, Any]) -> Any: | |
| if isinstance(action, Action): | |
| return action.to_dict() | |
| if isinstance(action, str): | |
| return action | |
| if isinstance(action, dict): | |
| return action | |
| raise TypeError("action must be Action, str, or dict") | |
| def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| body = json.dumps(payload).encode("utf-8") | |
| req = Request( | |
| self.base_url.rstrip("/") + path, | |
| data=body, | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| return self._request_json(req) | |
| def _get(self, path: str, query: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| url = self.base_url.rstrip("/") + path | |
| if query: | |
| cleaned = {key: value for key, value in query.items() if value is not None} | |
| url += "?" + urlencode(cleaned) | |
| req = Request(url, method="GET") | |
| return self._request_json(req) | |
| def _request_json(self, req: Request) -> Dict[str, Any]: | |
| try: | |
| with urlopen(req, timeout=self.timeout_seconds) as resp: | |
| return json.loads(resp.read().decode("utf-8")) | |
| except HTTPError as exc: | |
| message = exc.read().decode("utf-8") if exc.fp else str(exc) | |
| raise EnvClientError(f"HTTP {exc.code}: {message}") from exc | |
| EnvClient = DispatchArenaClient | |