Spaces:
Sleeping
Sleeping
File size: 2,071 Bytes
807d5cc 868d431 807d5cc 868d431 807d5cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Async HTTP client mirroring the OpenEnv env interface.
Used by inference.py to interact with the running FastAPI server (local or
HF Space deployment).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
import httpx
from .models import Action, Observation
@dataclass
class StepResponse:
observation: Observation
reward: float
reward_detail: Dict[str, Any]
done: bool
info: Dict[str, Any]
@dataclass
class ResetResponse:
observation: Observation
info: Dict[str, Any]
class ESCHttpClient:
"""Thin async client for the ESC OpenEnv server."""
def __init__(self, base_url: str, timeout: float = 30.0):
self.base_url = base_url.rstrip("/")
self._client = httpx.AsyncClient(base_url=self.base_url, timeout=timeout)
@classmethod
def from_url(cls, base_url: str) -> "ESCHttpClient":
return cls(base_url=base_url)
async def reset(self, task_id: Optional[str] = None) -> ResetResponse:
payload: Dict[str, Any] = {}
if task_id is not None:
payload["task_id"] = task_id
r = await self._client.post("/reset", json=payload)
r.raise_for_status()
data = r.json()
return ResetResponse(
observation=Observation(**data["observation"]),
info=data.get("info", {}),
)
async def step(self, action: Action) -> StepResponse:
r = await self._client.post("/step", json={"action": action.model_dump()})
r.raise_for_status()
data = r.json()
return StepResponse(
observation=Observation(**data["observation"]),
reward=float(data["reward"]),
reward_detail=data.get("reward_detail", {}),
done=bool(data["done"]),
info=data.get("info", {}),
)
async def state(self) -> Dict[str, Any]:
r = await self._client.get("/state")
r.raise_for_status()
return r.json()
async def close(self) -> None:
await self._client.aclose()
|