Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Dispatch Triage Environment Client (synchronous-friendly wrapper).""" | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| # Support two run modes: | |
| # 1. Installed as a package → relative imports (.models) | |
| # 2. Imported directly from repo root (e.g. by inference.py) → absolute imports | |
| try: | |
| from .models import ( | |
| DispatchTriageAction, | |
| DispatchTriageObservation, | |
| DispatchTriageState, | |
| Incident, | |
| Unit, | |
| ) | |
| except ImportError: | |
| from models import ( # type: ignore[no-redef] | |
| DispatchTriageAction, | |
| DispatchTriageObservation, | |
| DispatchTriageState, | |
| Incident, | |
| Unit, | |
| ) | |
| class DispatchTriageEnv( | |
| EnvClient[DispatchTriageAction, DispatchTriageObservation, DispatchTriageState] | |
| ): | |
| """ | |
| Async client for the Dispatch Triage Environment. | |
| Maintains a persistent WebSocket connection to the environment server, | |
| enabling efficient multi-step interactions with lower latency. | |
| Each client instance has its own dedicated session on the server. | |
| For synchronous use (e.g. in inference.py), call `.sync()` to get a | |
| SyncEnvClient wrapper:: | |
| env = DispatchTriageEnv(base_url="http://localhost:8000").sync() | |
| with env: | |
| result = env.reset(difficulty="medium") | |
| result = env.step(DispatchTriageAction(incident_id=0, unit_id=1)) | |
| print(result.observation.message) | |
| For async use:: | |
| async with DispatchTriageEnv(base_url="http://localhost:8000") as env: | |
| result = await env.reset(difficulty="hard") | |
| result = await env.step(DispatchTriageAction(incident_id=1, unit_id=0)) | |
| """ | |
| def _step_payload(self, action: DispatchTriageAction) -> Dict: | |
| """Convert DispatchTriageAction to JSON-serialisable dict.""" | |
| return { | |
| "incident_id": action.incident_id, | |
| "unit_id": action.unit_id, | |
| } | |
| def _parse_result(self, payload: Dict) -> StepResult[DispatchTriageObservation]: | |
| """ | |
| Parse a raw server response dict into StepResult[DispatchTriageObservation]. | |
| The openenv HTTP server may send the observation fields either: | |
| • Nested under an "observation" key (StepResult envelope), or | |
| • Flat at the top level (direct Observation serialisation) | |
| We handle both cases. | |
| """ | |
| # Prefer the nested "observation" dict; fall back to top-level payload. | |
| obs_data: Dict = payload.get("observation") or payload | |
| incidents = [ | |
| Incident( | |
| id=i["id"], | |
| description=i["description"], | |
| location=i["location"], | |
| assigned_unit_id=i.get("assigned_unit_id"), | |
| dispatch_rank=i.get("dispatch_rank"), | |
| resolved=i.get("resolved", False), | |
| depends_on=i.get("depends_on", []), | |
| ) | |
| for i in obs_data.get("incidents", []) | |
| ] | |
| units = [ | |
| Unit(id=u["id"], type=u["type"], available=u.get("available", True)) | |
| for u in obs_data.get("units", []) | |
| ] | |
| # done / reward live at the top-level payload in the StepResult envelope | |
| done = payload.get("done", obs_data.get("done", False)) | |
| reward = payload.get("reward", obs_data.get("reward", 0.0)) | |
| observation = DispatchTriageObservation( | |
| done=done, | |
| reward=reward, | |
| incidents=incidents, | |
| units=units, | |
| dispatch_count=obs_data.get("dispatch_count", 0), | |
| score_so_far=obs_data.get("score_so_far", 0.0), | |
| message=obs_data.get("message", ""), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=reward, | |
| done=done, | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """Parse server response into DispatchTriageState.""" | |
| return DispatchTriageState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| difficulty=payload.get("difficulty", "easy"), | |
| total_incidents=payload.get("total_incidents", 0), | |
| total_units=payload.get("total_units", 0), | |
| max_possible_score=payload.get("max_possible_score", 1.0), | |
| ) | |