Spaces:
Sleeping
Sleeping
| """Env client adapters. | |
| Two implementations of the same contract: | |
| * :class:`HttpEnvClient` — talks to a running FastAPI server over HTTP | |
| (``localhost:8000`` during training). | |
| * :class:`InProcessEnvClient` — drives the same FastAPI app via | |
| ``fastapi.testclient.TestClient``, no socket required. Used by tests | |
| and by single-process training jobs. | |
| Both expose the same three operations: ``reset``, ``step``, ``close``. The | |
| rollout code only depends on the protocol, so swapping transports doesn't | |
| ripple through anything else. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Protocol, runtime_checkable | |
| class EnvClient(Protocol): | |
| """Minimal client surface the rollout depends on.""" | |
| def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]: ... | |
| def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]: ... | |
| def close(self, episode_id: str) -> dict[str, Any]: ... | |
| # ---- HTTP transport -------------------------------------------------- | |
| class HttpEnvClient: | |
| """Thin httpx wrapper. Use during training when the env server runs out-of-process.""" | |
| def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 30.0) -> None: | |
| # Defer the import so the dep is optional for users who only do | |
| # in-process drives in tests / notebooks. | |
| import httpx | |
| self._client = httpx.Client(base_url=base_url, timeout=timeout) | |
| def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]: | |
| body: dict[str, Any] = {} | |
| if task_id is not None: | |
| body["task_id"] = task_id | |
| if seed is not None: | |
| body["seed"] = seed | |
| r = self._client.post("/reset", json=body) | |
| r.raise_for_status() | |
| return r.json() | |
| def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]: | |
| r = self._client.post("/step", json={"episode_id": episode_id, "action": action}) | |
| # 422 = malformed action payload; surface as a structured response | |
| # rather than raising, because the agent's job is to learn from it. | |
| if r.status_code == 422: | |
| return { | |
| "observation": {}, | |
| "reward": 0.0, # caller will overlay with MALFORMED scoring | |
| "done": False, | |
| "info": {"error": "schema_rejection", "detail": r.json()}, | |
| } | |
| r.raise_for_status() | |
| return r.json() | |
| def close(self, episode_id: str) -> dict[str, Any]: | |
| r = self._client.post("/close", json={"episode_id": episode_id}) | |
| r.raise_for_status() | |
| return r.json() | |
| def __enter__(self) -> "HttpEnvClient": | |
| return self | |
| def __exit__(self, *_exc: object) -> None: | |
| self._client.close() | |
| # ---- in-process transport ------------------------------------------- | |
| class InProcessEnvClient: | |
| """Drive the FastAPI app via ``TestClient`` without a real socket.""" | |
| def __init__(self, app: object | None = None) -> None: | |
| from fastapi.testclient import TestClient | |
| if app is None: | |
| from graphforge.server.app import app as default_app | |
| app = default_app | |
| self._client = TestClient(app) # type: ignore[arg-type] | |
| def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]: | |
| body: dict[str, Any] = {} | |
| if task_id is not None: | |
| body["task_id"] = task_id | |
| if seed is not None: | |
| body["seed"] = seed | |
| r = self._client.post("/reset", json=body) | |
| r.raise_for_status() | |
| return r.json() | |
| def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]: | |
| r = self._client.post( | |
| "/step", json={"episode_id": episode_id, "action": action} | |
| ) | |
| if r.status_code == 422: | |
| return { | |
| "observation": {}, | |
| "reward": 0.0, | |
| "done": False, | |
| "info": {"error": "schema_rejection", "detail": r.json()}, | |
| } | |
| r.raise_for_status() | |
| return r.json() | |
| def close(self, episode_id: str) -> dict[str, Any]: | |
| r = self._client.post("/close", json={"episode_id": episode_id}) | |
| r.raise_for_status() | |
| return r.json() | |