Spaces:
Sleeping
Sleeping
File size: 2,925 Bytes
4904e85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """REST/API client wrapper for the 911 dispatch environment."""
from __future__ import annotations
from typing import Any
import httpx
from src.models import Action, Observation, State
class APIError(Exception):
"""Raised on non-2xx HTTP responses from the API."""
def __init__(self, status_code: int, detail: str) -> None:
self.status_code = status_code
self.detail = detail
super().__init__(f"APIError({status_code}): {detail}")
class DispatchAPI:
"""Async HTTP client for the 911 dispatch environment API."""
def __init__(self, base_url: str = "http://localhost:8000") -> None:
self.base_url = base_url
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
return self._client
async def _close(self) -> None:
if self._client is not None:
await self._client.aclose()
self._client = None
async def reset(self, task_id: str, seed: int | None) -> Observation:
"""Reset the environment and return initial observation."""
client = self._get_client()
response = await client.post(
"/reset",
json={"task_id": task_id, "seed": seed},
)
if response.status_code != 200:
raise APIError(status_code=response.status_code, detail=response.text)
data = response.json()
return Observation.model_validate(data)
async def step(self, action: Action) -> tuple[Observation, float, bool]:
"""Execute an action and return (observation, reward, done)."""
client = self._get_client()
response = await client.post(
"/step",
json={"action": action.model_dump()},
)
if response.status_code != 200:
raise APIError(status_code=response.status_code, detail=response.text)
data = response.json()
observation = Observation.model_validate(data["observation"])
reward: float = data["reward"]
done: bool = data["done"]
return observation, reward, done
async def state(self) -> State:
"""Get current environment state."""
client = self._get_client()
response = await client.get("/state")
if response.status_code != 200:
raise APIError(status_code=response.status_code, detail=response.text)
data = response.json()
return State.model_validate(data)
async def health(self) -> bool:
"""Check if the API server is healthy."""
client = self._get_client()
try:
response = await client.get("/health")
return response.status_code == 200
except Exception:
return False
# Backwards-compatible alias (old ATC name).
ATCAircraftAPI = DispatchAPI
|