driftpa / client.py
rajv24's picture
Upload folder using huggingface_hub
2933389 verified
"""
DriftPA — client.py
====================
Typed client for connecting to a deployed DriftPA environment server.
Usage:
from driftpa.client import DriftPAClient
from driftpa.models import DriftPAAction
with DriftPAClient(base_url="ws://localhost:8000") as env:
result = env.reset(seed=0)
while not result.done:
action = DriftPAAction(tool_name="list_tools", payload={})
result = env.step(action)
print(result.observation)
# Or connect to a deployed HuggingFace Space:
env = DriftPAClient.from_env("rajashekarcs2023/driftpa")
"""
from typing import Any
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from driftpa.models import DriftPAAction, DriftPAObservation, DriftPAState
class DriftPAClient(EnvClient[DriftPAAction, DriftPAObservation, DriftPAState]):
"""
Typed WebSocket client for DriftPA.
Connects to a running DriftPA server (local or HuggingFace Spaces)
and exposes reset() / step() with full type annotations.
"""
def _step_payload(self, action: DriftPAAction) -> dict[str, Any]:
"""Serialize a DriftPAAction for the server."""
return action.model_dump()
def _parse_result(self, payload: dict[str, Any]) -> StepResult[DriftPAObservation]:
"""Deserialize server response into a typed StepResult."""
obs = DriftPAObservation(**payload.get("observation", {}))
return StepResult(
observation=obs,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: dict[str, Any]) -> DriftPAState:
"""Deserialize server state response."""
return DriftPAState(**payload)