coral-cyber
testing the environment
53d9f07
"""
SepsisPilot β€” OpenEnv Environment
Implements: reset() / step() / state() / grade()
This class is the single source of truth for episode state.
"""
from __future__ import annotations
from typing import Optional, List
from .models import (
Action, PatientState, PatientVitals, StepResult, GraderResult, TaskInfo, ResetRequest,
)
from .patient_sim import PatientSimulator, TASK_PROFILES
from .graders import grade_mild_sepsis, grade_septic_shock, grade_severe_mods
AVAILABLE_TASKS = list(TASK_PROFILES.keys())
class SepsisPilotEnv:
"""
OpenEnv-compliant environment for sepsis treatment sequencing.
Usage:
env = SepsisPilotEnv()
state = env.reset("mild_sepsis")
while not state.done:
result = env.step(action_int)
state = result.state
grade = env.grade()
"""
def __init__(self):
self._sim: Optional[PatientSimulator] = None
self._task: Optional[str] = None
self._step_count: int = 0
self._alive: bool = True
self._done: bool = False
self._episode_reward: float = 0.0
self._stabilized_at: Optional[int] = None
self._trajectory: List[PatientVitals] = []
self._current_vitals: Optional[PatientVitals] = None
# Grader tracking metadata
self._used_narrow_ab: bool = False
self._used_vasopressor: bool = False
self._used_broad_first: bool = False
self._switched_to_narrow: bool = False
self._peak_resistance: float = 0.0
self._min_vp_dose: str = "none"
self._first_ab_step: Optional[int] = None
self._narrow_after_broad: bool = False
# ─────────────────────────────────────────────
# OpenEnv API
# ─────────────────────────────────────────────
def reset(self, task: str = "mild_sepsis", seed: Optional[int] = None) -> PatientState:
"""Reset environment to start a new episode."""
if task not in TASK_PROFILES:
raise ValueError(f"Unknown task '{task}'. Available: {AVAILABLE_TASKS}")
profile = TASK_PROFILES[task]
self._sim = PatientSimulator(profile, seed=seed)
self._task = task
self._step_count = 0
self._alive = True
self._done = False
self._episode_reward = 0.0
self._stabilized_at = None
self._trajectory = []
self._current_vitals = self._sim.reset(seed=seed)
self._trajectory.append(self._current_vitals)
# Reset grader metadata
self._used_narrow_ab = False
self._used_vasopressor = False
self._used_broad_first = False
self._switched_to_narrow = False
self._peak_resistance = self._current_vitals.resistance
self._min_vp_dose = "none"
self._first_ab_step = None
self._narrow_after_broad = False
return self._make_state()
def step(self, action: int) -> StepResult:
"""Apply action, advance one timestep, return result."""
if self._sim is None or self._task is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode done. Call reset() to start a new episode.")
if not (0 <= action <= 8):
raise ValueError(f"Invalid action {action}. Must be 0-8.")
profile = TASK_PROFILES[self._task]
self._step_count += 1
# Track grader metadata before sim step
self._update_grader_metadata(action)
# Advance simulation
vitals, reward, sim_done, info = self._sim.step(action)
self._current_vitals = vitals
self._trajectory.append(vitals)
self._episode_reward += reward
# Determine episode termination
self._alive = not vitals.is_dead()
if vitals.is_stable() and self._stabilized_at is None:
self._stabilized_at = self._step_count
self._done = (
sim_done
or self._step_count >= profile.max_steps
)
# Update resistance peak
self._peak_resistance = max(self._peak_resistance, vitals.resistance)
state = self._make_state()
return StepResult(state=state, reward=reward, done=self._done, info=info)
def state(self) -> PatientState:
"""Return current state without advancing the simulation."""
if self._sim is None:
raise RuntimeError("Call reset() first.")
return self._make_state()
def grade(self) -> GraderResult:
"""Grade the completed episode. Returns score in [0.0, 1.0]."""
if not self._done:
raise RuntimeError("Episode not done yet. Cannot grade.")
profile = TASK_PROFILES[self._task]
if self._task == "mild_sepsis":
return grade_mild_sepsis(
trajectory=self._trajectory,
alive=self._alive,
max_steps=profile.max_steps,
stabilized_at=self._stabilized_at,
)
elif self._task == "septic_shock":
return grade_septic_shock(
trajectory=self._trajectory,
alive=self._alive,
max_steps=profile.max_steps,
stabilized_at=self._stabilized_at,
used_narrow_ab=self._used_narrow_ab,
used_vasopressor=self._used_vasopressor,
)
elif self._task == "severe_mods":
return grade_severe_mods(
trajectory=self._trajectory,
alive=self._alive,
max_steps=profile.max_steps,
stabilized_at=self._stabilized_at,
used_broad_first=self._used_broad_first,
switched_to_narrow=self._switched_to_narrow,
peak_resistance=self._peak_resistance,
min_vasopressor_dose=self._min_vp_dose,
)
else:
raise ValueError(f"No grader for task '{self._task}'")
# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────
def _make_state(self) -> PatientState:
profile = TASK_PROFILES[self._task]
return PatientState(
vitals=self._current_vitals,
step=self._step_count,
max_steps=profile.max_steps,
done=self._done,
alive=self._alive,
task=self._task,
stabilized_at=self._stabilized_at,
episode_reward=round(self._episode_reward, 4),
)
def _update_grader_metadata(self, action: int):
has_broad = action in (1, 5, 6)
has_narrow = action in (2, 7, 8)
has_low_vp = action in (3, 5, 7)
has_high_vp = action in (4, 6, 8)
if has_narrow:
self._used_narrow_ab = True
if has_low_vp or has_high_vp:
self._used_vasopressor = True
# Vasopressor dose tracking (prefer lowest dose used)
if has_low_vp and self._min_vp_dose == "none":
self._min_vp_dose = "low"
if has_high_vp:
self._min_vp_dose = "high" if self._min_vp_dose == "none" else self._min_vp_dose
# Antibiotic sequencing (broad β†’ narrow is optimal for severe MODS)
if has_broad and self._first_ab_step is None:
self._first_ab_step = self._step_count
self._used_broad_first = True
if has_narrow and self._used_broad_first and not self._switched_to_narrow:
self._switched_to_narrow = True
@staticmethod
def task_list() -> List[TaskInfo]:
from .patient_sim import TASK_PROFILES
return [
TaskInfo(
name=p.name,
difficulty=p.difficulty,
description=p.description,
max_steps=p.max_steps,
)
for p in TASK_PROFILES.values()
]