| |
| |
| |
| |
| |
|
|
| """Vector Borne Disease Control Environment Client.""" |
|
|
| from typing import Dict |
|
|
| from openenv.core import EnvClient |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_server.types import State |
|
|
| from .models import ( |
| VectorBorneDiseaseControlAction, |
| VectorBorneDiseaseControlObservation, |
| ZoneState, |
| ) |
|
|
|
|
| class VectorBorneDiseaseControlEnv( |
| EnvClient[VectorBorneDiseaseControlAction, VectorBorneDiseaseControlObservation, State] |
| ): |
| """ |
| Client for the Vector Borne Disease Control Environment. |
| |
| This client maintains a persistent WebSocket connection to the environment server, |
| enabling efficient multi-step interactions with lower latency. |
| Each client instance has its own dedicated environment session on the server. |
| |
| Example: |
| >>> # Connect to a running server |
| >>> with VectorBorneDiseaseControlEnv(base_url="http://localhost:8000") as client: |
| ... result = client.reset() |
| ... obs = result.observation |
| ... print(f"Zones: {len(obs.zones)}") |
| ... |
| ... action = VectorBorneDiseaseControlAction(action_type=ActionType.SPRAY, zone_id=0) |
| ... result = client.step(action) |
| ... print(f"Reward: {result.reward}") |
| |
| Example with Docker: |
| >>> # Automatically start container and connect |
| >>> client = VectorBorneDiseaseControlEnv.from_docker_image("vector_borne_disease_control-env:latest") |
| >>> try: |
| ... result = client.reset() |
| ... action = VectorBorneDiseaseControlAction(action_type=ActionType.WAIT) |
| ... result = client.step(action) |
| ... finally: |
| ... client.close() |
| """ |
|
|
| def _step_payload(self, action: VectorBorneDiseaseControlAction) -> Dict: |
| """ |
| Convert VectorBorneDiseaseControlAction to JSON payload for step message. |
| |
| Args: |
| action: VectorBorneDiseaseControlAction instance |
| |
| Returns: |
| Dictionary representation suitable for JSON encoding |
| """ |
| return { |
| "action_type": action.action_type, |
| "zone_id": action.zone_id, |
| } |
|
|
| def _parse_result(self, payload: Dict) -> StepResult[VectorBorneDiseaseControlObservation]: |
| """ |
| Parse server response into StepResult[VectorBorneDiseaseControlObservation]. |
| |
| Args: |
| payload: JSON response data from server |
| |
| Returns: |
| StepResult with VectorBorneDiseaseControlObservation |
| """ |
| obs_data = payload.get("observation", {}) |
|
|
| |
| zones = {} |
| for zone_id, zone_data in obs_data.get("zones", {}).items(): |
| zones[int(zone_id)] = ZoneState(**zone_data) |
|
|
| |
| adjacency = {} |
| for zone_id, neighbors in obs_data.get("adjacency", {}).items(): |
| adjacency[int(zone_id)] = neighbors |
|
|
| observation = VectorBorneDiseaseControlObservation( |
| zones=zones, |
| adjacency=adjacency, |
| remaining_spray=obs_data.get("remaining_spray", 0), |
| remaining_traps=obs_data.get("remaining_traps", 0), |
| step=obs_data.get("step", 0), |
| cumulative_avg_infestation_rate=obs_data.get("cumulative_avg_infestation_rate", 0.0), |
| ) |
|
|
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: Dict) -> State: |
| """ |
| Parse server response into State object. |
| |
| Args: |
| payload: JSON response from state request |
| |
| Returns: |
| State object with episode_id and step_count |
| """ |
| return State( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| ) |
|
|