dispatch_arena_v0 / client.py
Freakdivi's picture
Upload folder using huggingface_hub
c71bf62 verified
"""Typed client for the Dispatch Arena server API."""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional
from urllib.error import HTTPError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from dispatch_arena.models import Action, Config, Observation, State
class EnvClientError(RuntimeError):
"""Raised when the server returns a non-success response."""
@dataclass
class DispatchArenaClient:
"""Small typed wrapper around reset, step, state, replay, and health endpoints."""
base_url: str = "http://127.0.0.1:8080"
session_id: Optional[str] = None
timeout_seconds: int = 10
def create_session(self, mode: str = "mini", seed: Optional[int] = None, config: Optional[Dict[str, Any]] = None) -> Observation:
data = self._post("/api/sessions", {"mode": mode, "seed": seed, "config": config or {}})
self.session_id = data["session_id"]
return Observation.from_dict(data["observation"])
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
config: Optional[Config | Dict[str, Any]] = None,
) -> Observation:
payload: Dict[str, Any] = {
"seed": seed,
"episode_id": episode_id,
"session_id": self.session_id,
"config": config.to_dict() if isinstance(config, Config) else config or {},
}
data = self._post("/reset", payload)
self.session_id = data["session_id"]
return Observation.from_dict(data["observation"])
def step(self, action: Action | str | Dict[str, Any]) -> Observation:
if not self.session_id:
raise EnvClientError("Session not initialized. Call reset() first.")
data = self._post("/step", {"session_id": self.session_id, "action": self._action_payload(action)})
return Observation.from_dict(data["observation"])
def fetch_state(self) -> State:
if not self.session_id:
raise EnvClientError("Session not initialized. Call reset() first.")
data = self._get("/state", {"session_id": self.session_id})
return State.model_validate(data["state"])
def fetch_summary(self) -> Dict[str, Any]:
if not self.session_id:
raise EnvClientError("Session not initialized. Call reset() first.")
data = self._get("/summary", {"session_id": self.session_id})
return dict(data["episode_summary"])
def fetch_replay(self) -> list[dict]:
if not self.session_id:
raise EnvClientError("Session not initialized. Call reset() first.")
data = self._get(f"/api/sessions/{self.session_id}/replay")
return list(data["records"])
def health(self) -> Dict[str, Any]:
return self._get("/healthz")
def ready(self) -> Dict[str, Any]:
return self._get("/ready")
def state(self) -> State:
return self.fetch_state()
def _action_payload(self, action: Action | str | Dict[str, Any]) -> Any:
if isinstance(action, Action):
return action.to_dict()
if isinstance(action, str):
return action
if isinstance(action, dict):
return action
raise TypeError("action must be Action, str, or dict")
def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
body = json.dumps(payload).encode("utf-8")
req = Request(
self.base_url.rstrip("/") + path,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
return self._request_json(req)
def _get(self, path: str, query: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
url = self.base_url.rstrip("/") + path
if query:
cleaned = {key: value for key, value in query.items() if value is not None}
url += "?" + urlencode(cleaned)
req = Request(url, method="GET")
return self._request_json(req)
def _request_json(self, req: Request) -> Dict[str, Any]:
try:
with urlopen(req, timeout=self.timeout_seconds) as resp:
return json.loads(resp.read().decode("utf-8"))
except HTTPError as exc:
message = exc.read().decode("utf-8") if exc.fp else str(exc)
raise EnvClientError(f"HTTP {exc.code}: {message}") from exc
EnvClient = DispatchArenaClient