Spaces:
Sleeping
Sleeping
| """Pyre Environment Client.""" | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from .models import PyreAction, PyreObservation, PyreState | |
| class PyreEnv(EnvClient[PyreAction, PyreObservation, PyreState]): | |
| """Client for the Pyre Environment. | |
| The environment is async by default; use .sync() for synchronous access: | |
| with PyreEnv(base_url="http://localhost:8000").sync() as env: | |
| result = env.reset() | |
| print(result.observation.narrative) | |
| result = env.step(PyreAction(action="move", direction="north")) | |
| print(f"Health: {result.observation.agent_health}") | |
| Or use async: | |
| async with PyreEnv(base_url="http://localhost:8000") as env: | |
| result = await env.reset() | |
| """ | |
| def _step_payload(self, action: PyreAction) -> Dict: | |
| return action.model_dump(exclude_none=True) | |
| def _parse_result(self, payload: Dict) -> StepResult[PyreObservation]: | |
| obs_data = payload.get("observation", payload) | |
| obs = PyreObservation( | |
| narrative=obs_data.get("narrative", ""), | |
| agent_evacuated=obs_data.get("agent_evacuated", False), | |
| location_label=obs_data.get("location_label", ""), | |
| smoke_level=obs_data.get("smoke_level", "none"), | |
| fire_visible=obs_data.get("fire_visible", False), | |
| fire_direction=obs_data.get("fire_direction"), | |
| agent_health=obs_data.get("agent_health", 100.0), | |
| health_status=obs_data.get("health_status", "Good"), | |
| wind_dir=obs_data.get("wind_dir", "CALM"), | |
| visible_objects=obs_data.get("visible_objects", []), | |
| blocked_exit_ids=obs_data.get("blocked_exit_ids", []), | |
| audible_signals=obs_data.get("audible_signals", []), | |
| elapsed_steps=obs_data.get("elapsed_steps", 0), | |
| last_action_feedback=obs_data.get("last_action_feedback", ""), | |
| available_actions_hint=obs_data.get("available_actions_hint", []), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward", 0.0), | |
| metadata=payload.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> PyreState: | |
| return PyreState(**payload) | |