| """ |
| SepsisPilot β OpenEnv Environment |
| Implements: reset() / step() / state() / grade() |
| This class is the single source of truth for episode state. |
| """ |
|
|
| from __future__ import annotations |
| from typing import Optional, List |
|
|
| from .models import ( |
| Action, PatientState, PatientVitals, StepResult, GraderResult, TaskInfo, ResetRequest, |
| ) |
| from .patient_sim import PatientSimulator, TASK_PROFILES |
| from .graders import grade_mild_sepsis, grade_septic_shock, grade_severe_mods |
|
|
|
|
| AVAILABLE_TASKS = list(TASK_PROFILES.keys()) |
|
|
|
|
| class SepsisPilotEnv: |
| """ |
| OpenEnv-compliant environment for sepsis treatment sequencing. |
| |
| Usage: |
| env = SepsisPilotEnv() |
| state = env.reset("mild_sepsis") |
| while not state.done: |
| result = env.step(action_int) |
| state = result.state |
| grade = env.grade() |
| """ |
|
|
| def __init__(self): |
| self._sim: Optional[PatientSimulator] = None |
| self._task: Optional[str] = None |
| self._step_count: int = 0 |
| self._alive: bool = True |
| self._done: bool = False |
| self._episode_reward: float = 0.0 |
| self._stabilized_at: Optional[int] = None |
| self._trajectory: List[PatientVitals] = [] |
| self._current_vitals: Optional[PatientVitals] = None |
|
|
| |
| self._used_narrow_ab: bool = False |
| self._used_vasopressor: bool = False |
| self._used_broad_first: bool = False |
| self._switched_to_narrow: bool = False |
| self._peak_resistance: float = 0.0 |
| self._min_vp_dose: str = "none" |
| self._first_ab_step: Optional[int] = None |
| self._narrow_after_broad: bool = False |
|
|
| |
| |
| |
|
|
| def reset(self, task: str = "mild_sepsis", seed: Optional[int] = None) -> PatientState: |
| """Reset environment to start a new episode.""" |
| if task not in TASK_PROFILES: |
| raise ValueError(f"Unknown task '{task}'. Available: {AVAILABLE_TASKS}") |
|
|
| profile = TASK_PROFILES[task] |
| self._sim = PatientSimulator(profile, seed=seed) |
| self._task = task |
| self._step_count = 0 |
| self._alive = True |
| self._done = False |
| self._episode_reward = 0.0 |
| self._stabilized_at = None |
| self._trajectory = [] |
| self._current_vitals = self._sim.reset(seed=seed) |
| self._trajectory.append(self._current_vitals) |
|
|
| |
| self._used_narrow_ab = False |
| self._used_vasopressor = False |
| self._used_broad_first = False |
| self._switched_to_narrow = False |
| self._peak_resistance = self._current_vitals.resistance |
| self._min_vp_dose = "none" |
| self._first_ab_step = None |
| self._narrow_after_broad = False |
|
|
| return self._make_state() |
|
|
| def step(self, action: int) -> StepResult: |
| """Apply action, advance one timestep, return result.""" |
| if self._sim is None or self._task is None: |
| raise RuntimeError("Call reset() before step().") |
| if self._done: |
| raise RuntimeError("Episode done. Call reset() to start a new episode.") |
| if not (0 <= action <= 8): |
| raise ValueError(f"Invalid action {action}. Must be 0-8.") |
|
|
| profile = TASK_PROFILES[self._task] |
| self._step_count += 1 |
|
|
| |
| self._update_grader_metadata(action) |
|
|
| |
| vitals, reward, sim_done, info = self._sim.step(action) |
| self._current_vitals = vitals |
| self._trajectory.append(vitals) |
| self._episode_reward += reward |
|
|
| |
| self._alive = not vitals.is_dead() |
| if vitals.is_stable() and self._stabilized_at is None: |
| self._stabilized_at = self._step_count |
|
|
| self._done = ( |
| sim_done |
| or self._step_count >= profile.max_steps |
| ) |
|
|
| |
| self._peak_resistance = max(self._peak_resistance, vitals.resistance) |
|
|
| state = self._make_state() |
| return StepResult(state=state, reward=reward, done=self._done, info=info) |
|
|
| def state(self) -> PatientState: |
| """Return current state without advancing the simulation.""" |
| if self._sim is None: |
| raise RuntimeError("Call reset() first.") |
| return self._make_state() |
|
|
| def grade(self) -> GraderResult: |
| """Grade the completed episode. Returns score in [0.0, 1.0].""" |
| if not self._done: |
| raise RuntimeError("Episode not done yet. Cannot grade.") |
|
|
| profile = TASK_PROFILES[self._task] |
|
|
| if self._task == "mild_sepsis": |
| return grade_mild_sepsis( |
| trajectory=self._trajectory, |
| alive=self._alive, |
| max_steps=profile.max_steps, |
| stabilized_at=self._stabilized_at, |
| ) |
| elif self._task == "septic_shock": |
| return grade_septic_shock( |
| trajectory=self._trajectory, |
| alive=self._alive, |
| max_steps=profile.max_steps, |
| stabilized_at=self._stabilized_at, |
| used_narrow_ab=self._used_narrow_ab, |
| used_vasopressor=self._used_vasopressor, |
| ) |
| elif self._task == "severe_mods": |
| return grade_severe_mods( |
| trajectory=self._trajectory, |
| alive=self._alive, |
| max_steps=profile.max_steps, |
| stabilized_at=self._stabilized_at, |
| used_broad_first=self._used_broad_first, |
| switched_to_narrow=self._switched_to_narrow, |
| peak_resistance=self._peak_resistance, |
| min_vasopressor_dose=self._min_vp_dose, |
| ) |
| else: |
| raise ValueError(f"No grader for task '{self._task}'") |
|
|
| |
| |
| |
|
|
| def _make_state(self) -> PatientState: |
| profile = TASK_PROFILES[self._task] |
| return PatientState( |
| vitals=self._current_vitals, |
| step=self._step_count, |
| max_steps=profile.max_steps, |
| done=self._done, |
| alive=self._alive, |
| task=self._task, |
| stabilized_at=self._stabilized_at, |
| episode_reward=round(self._episode_reward, 4), |
| ) |
|
|
| def _update_grader_metadata(self, action: int): |
| has_broad = action in (1, 5, 6) |
| has_narrow = action in (2, 7, 8) |
| has_low_vp = action in (3, 5, 7) |
| has_high_vp = action in (4, 6, 8) |
|
|
| if has_narrow: |
| self._used_narrow_ab = True |
| if has_low_vp or has_high_vp: |
| self._used_vasopressor = True |
|
|
| |
| if has_low_vp and self._min_vp_dose == "none": |
| self._min_vp_dose = "low" |
| if has_high_vp: |
| self._min_vp_dose = "high" if self._min_vp_dose == "none" else self._min_vp_dose |
|
|
| |
| if has_broad and self._first_ab_step is None: |
| self._first_ab_step = self._step_count |
| self._used_broad_first = True |
| if has_narrow and self._used_broad_first and not self._switched_to_narrow: |
| self._switched_to_narrow = True |
|
|
| @staticmethod |
| def task_list() -> List[TaskInfo]: |
| from .patient_sim import TASK_PROFILES |
| return [ |
| TaskInfo( |
| name=p.name, |
| difficulty=p.difficulty, |
| description=p.description, |
| max_steps=p.max_steps, |
| ) |
| for p in TASK_PROFILES.values() |
| ] |
|
|