File size: 3,357 Bytes
ef41468 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | from __future__ import annotations
from typing import Any
import requests
from models import SepsisAction, SepsisObservation, SepsisState
from openenv_compat import EnvClient, OPENENV_AVAILABLE, StepResult
from server.sepsis_environment import SepsisTreatmentEnvironment
class SepsisTreatmentEnv(EnvClient):
def __init__(self, base_url: str | None = None, task_id: str = "easy"):
if base_url is not None:
super().__init__(base_url=base_url)
else:
self.base_url = None
if OPENENV_AVAILABLE:
self._provider = None
self._ws = None
self.task_id = task_id
self._local_env = None if base_url else SepsisTreatmentEnvironment(task_id=task_id)
def _step_payload(self, action: SepsisAction) -> dict[str, Any]:
return action.model_dump()
def _parse_result(self, payload: dict[str, Any]) -> StepResult[SepsisObservation]:
return StepResult(
observation=SepsisObservation(**payload["observation"]),
reward=payload.get("reward"),
done=payload.get("done", False),
info=payload.get("info", {}),
)
def _parse_state(self, payload: dict[str, Any]) -> SepsisState:
return SepsisState(**payload)
def reset(self) -> StepResult[SepsisObservation]:
if self._local_env is not None:
observation = self._local_env.reset(task_id=self.task_id)
return StepResult(
observation=observation,
reward=0.0,
done=False,
info={"tasks": self._local_env.available_tasks()},
)
response = requests.post(f"{self.base_url.rstrip('/')}/reset", json={"task_id": self.task_id}, timeout=30)
response.raise_for_status()
return self._parse_result(response.json())
def step(self, action: SepsisAction) -> StepResult[SepsisObservation]:
if self._local_env is not None:
observation = self._local_env.step(action)
return StepResult(
observation=observation,
reward=observation.reward,
done=observation.done,
info={"metrics": self._local_env.current_metrics()},
)
action_payload = self._step_payload(action)
response = requests.post(f"{self.base_url.rstrip('/')}/step", json=action_payload, timeout=30)
if response.status_code == 422:
response = requests.post(
f"{self.base_url.rstrip('/')}/step",
json={"action": action_payload},
timeout=30,
)
response.raise_for_status()
return self._parse_result(response.json())
def state(self) -> SepsisState:
if self._local_env is not None:
return self._local_env.state
response = requests.get(f"{self.base_url.rstrip('/')}/state", timeout=30)
response.raise_for_status()
return self._parse_state(response.json())
def metadata(self) -> dict[str, Any]:
if self._local_env is not None:
return self._local_env.metadata()
response = requests.get(f"{self.base_url.rstrip('/')}/metadata", timeout=30)
response.raise_for_status()
return response.json()
def close(self) -> None:
self._local_env = None
|