from __future__ import annotations from typing import Any import requests from models import SepsisAction, SepsisObservation, SepsisState from openenv_compat import EnvClient, OPENENV_AVAILABLE, StepResult from server.sepsis_environment import SepsisTreatmentEnvironment class SepsisTreatmentEnv(EnvClient): def __init__(self, base_url: str | None = None, task_id: str = "easy"): if base_url is not None: super().__init__(base_url=base_url) else: self.base_url = None if OPENENV_AVAILABLE: self._provider = None self._ws = None self.task_id = task_id self._local_env = None if base_url else SepsisTreatmentEnvironment(task_id=task_id) def _step_payload(self, action: SepsisAction) -> dict[str, Any]: return action.model_dump() def _parse_result(self, payload: dict[str, Any]) -> StepResult[SepsisObservation]: return StepResult( observation=SepsisObservation(**payload["observation"]), reward=payload.get("reward"), done=payload.get("done", False), info=payload.get("info", {}), ) def _parse_state(self, payload: dict[str, Any]) -> SepsisState: return SepsisState(**payload) def reset(self) -> StepResult[SepsisObservation]: if self._local_env is not None: observation = self._local_env.reset(task_id=self.task_id) return StepResult( observation=observation, reward=0.0, done=False, info={"tasks": self._local_env.available_tasks()}, ) response = requests.post(f"{self.base_url.rstrip('/')}/reset", json={"task_id": self.task_id}, timeout=30) response.raise_for_status() return self._parse_result(response.json()) def step(self, action: SepsisAction) -> StepResult[SepsisObservation]: if self._local_env is not None: observation = self._local_env.step(action) return StepResult( observation=observation, reward=observation.reward, done=observation.done, info={"metrics": self._local_env.current_metrics()}, ) action_payload = self._step_payload(action) response = requests.post(f"{self.base_url.rstrip('/')}/step", json=action_payload, timeout=30) if response.status_code == 422: response = requests.post( f"{self.base_url.rstrip('/')}/step", json={"action": action_payload}, timeout=30, ) response.raise_for_status() return self._parse_result(response.json()) def state(self) -> SepsisState: if self._local_env is not None: return self._local_env.state response = requests.get(f"{self.base_url.rstrip('/')}/state", timeout=30) response.raise_for_status() return self._parse_state(response.json()) def metadata(self) -> dict[str, Any]: if self._local_env is not None: return self._local_env.metadata() response = requests.get(f"{self.base_url.rstrip('/')}/metadata", timeout=30) response.raise_for_status() return response.json() def close(self) -> None: self._local_env = None