Spaces:
Sleeping
Sleeping
File size: 2,368 Bytes
240b29f 4fc9bec 240b29f 4fc9bec 240b29f 4fc9bec 240b29f 4fc9bec 240b29f 4fc9bec 240b29f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """DispatchPulse client.
Subclasses ``openenv.core.env_client.EnvClient`` and implements the three
required hooks: ``_step_payload``, ``_parse_result``, ``_parse_state``.
The client speaks WebSocket to the env server (a FastAPI app created via
``create_fastapi_app``). Use ``DispatchPulseEnv.from_docker_image(image)``
to spin up a local container, or ``DispatchPulseEnv(base_url=...)`` to
connect to an already-running server (e.g. a Hugging Face Space).
"""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.env_client import EnvClient, StepResult
from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState
class DispatchPulseEnv(
EnvClient[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState]
):
"""Async client for the DispatchPulse OpenEnv environment.
Example (Docker image)::
env = await DispatchPulseEnv.from_docker_image("dispatchpulse:latest")
result = await env.reset(task_name="easy", seed=42)
while not result.done:
action = DispatchPulseAction(action_type="wait", minutes=1)
result = await env.step(action)
await env.close()
Example (remote URL)::
async with DispatchPulseEnv(base_url="https://...hf.space") as env:
result = await env.reset(task_name="easy", seed=42)
"""
def _step_payload(self, action: DispatchPulseAction) -> Dict[str, Any]:
return action.model_dump(mode="json")
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[DispatchPulseObservation]:
obs_data = payload.get("observation", payload) or {}
# Drop unknown keys defensively (model_config is extra=forbid)
allowed = set(DispatchPulseObservation.model_fields.keys())
obs_clean = {k: v for k, v in obs_data.items() if k in allowed}
observation = DispatchPulseObservation(**obs_clean)
return StepResult(
observation=observation,
reward=payload.get("reward", observation.reward),
done=payload.get("done", observation.done),
)
def _parse_state(self, payload: Dict[str, Any]) -> DispatchPulseState:
allowed = set(DispatchPulseState.model_fields.keys())
state_clean = {k: v for k, v in payload.items() if k in allowed}
return DispatchPulseState(**state_clean)
|