Spaces:
Sleeping
Sleeping
File size: 4,337 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """Env client adapters.
Two implementations of the same contract:
* :class:`HttpEnvClient` — talks to a running FastAPI server over HTTP
(``localhost:8000`` during training).
* :class:`InProcessEnvClient` — drives the same FastAPI app via
``fastapi.testclient.TestClient``, no socket required. Used by tests
and by single-process training jobs.
Both expose the same three operations: ``reset``, ``step``, ``close``. The
rollout code only depends on the protocol, so swapping transports doesn't
ripple through anything else.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class EnvClient(Protocol):
"""Minimal client surface the rollout depends on."""
def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]: ...
def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]: ...
def close(self, episode_id: str) -> dict[str, Any]: ...
# ---- HTTP transport --------------------------------------------------
class HttpEnvClient:
"""Thin httpx wrapper. Use during training when the env server runs out-of-process."""
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 30.0) -> None:
# Defer the import so the dep is optional for users who only do
# in-process drives in tests / notebooks.
import httpx
self._client = httpx.Client(base_url=base_url, timeout=timeout)
def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]:
body: dict[str, Any] = {}
if task_id is not None:
body["task_id"] = task_id
if seed is not None:
body["seed"] = seed
r = self._client.post("/reset", json=body)
r.raise_for_status()
return r.json()
def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]:
r = self._client.post("/step", json={"episode_id": episode_id, "action": action})
# 422 = malformed action payload; surface as a structured response
# rather than raising, because the agent's job is to learn from it.
if r.status_code == 422:
return {
"observation": {},
"reward": 0.0, # caller will overlay with MALFORMED scoring
"done": False,
"info": {"error": "schema_rejection", "detail": r.json()},
}
r.raise_for_status()
return r.json()
def close(self, episode_id: str) -> dict[str, Any]:
r = self._client.post("/close", json={"episode_id": episode_id})
r.raise_for_status()
return r.json()
def __enter__(self) -> "HttpEnvClient":
return self
def __exit__(self, *_exc: object) -> None:
self._client.close()
# ---- in-process transport -------------------------------------------
class InProcessEnvClient:
"""Drive the FastAPI app via ``TestClient`` without a real socket."""
def __init__(self, app: object | None = None) -> None:
from fastapi.testclient import TestClient
if app is None:
from graphforge.server.app import app as default_app
app = default_app
self._client = TestClient(app) # type: ignore[arg-type]
def reset(self, task_id: str | None = None, seed: int | None = None) -> dict[str, Any]:
body: dict[str, Any] = {}
if task_id is not None:
body["task_id"] = task_id
if seed is not None:
body["seed"] = seed
r = self._client.post("/reset", json=body)
r.raise_for_status()
return r.json()
def step(self, episode_id: str, action: dict[str, Any]) -> dict[str, Any]:
r = self._client.post(
"/step", json={"episode_id": episode_id, "action": action}
)
if r.status_code == 422:
return {
"observation": {},
"reward": 0.0,
"done": False,
"info": {"error": "schema_rejection", "detail": r.json()},
}
r.raise_for_status()
return r.json()
def close(self, episode_id: str) -> dict[str, Any]:
r = self._client.post("/close", json={"episode_id": episode_id})
r.raise_for_status()
return r.json()
|