SepsisPilot / tests /test_env.py
coral-cyber
testing the environment
53d9f07
"""
SepsisPilot — Unit Tests
Run: pytest tests/ -v
"""
from __future__ import annotations
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from environment import SepsisPilotEnv, AVAILABLE_TASKS
from environment.models import PatientVitals
class TestPatientVitals:
def test_is_stable_when_all_vitals_normal(self):
v = PatientVitals(
map_mmhg=72, lactate=1.5, wbc=8.0, temperature=37.0,
heart_rate=80, creatinine=1.0, sofa_score=0.0, resistance=0.0,
)
assert v.is_stable() is True
def test_is_not_stable_when_map_low(self):
v = PatientVitals(
map_mmhg=60, lactate=1.5, wbc=8.0, temperature=37.0,
heart_rate=80, creatinine=1.0, sofa_score=2.0, resistance=0.0,
)
assert v.is_stable() is False
def test_is_dead_when_map_critical(self):
v = PatientVitals(
map_mmhg=30, lactate=3.0, wbc=10.0, temperature=37.0,
heart_rate=90, creatinine=1.2, sofa_score=10.0, resistance=0.0,
)
assert v.is_dead() is True
def test_is_dead_when_lactate_extreme(self):
v = PatientVitals(
map_mmhg=70, lactate=16.0, wbc=10.0, temperature=37.0,
heart_rate=90, creatinine=1.2, sofa_score=10.0, resistance=0.0,
)
assert v.is_dead() is True
def test_to_list_returns_8_elements(self):
v = PatientVitals(
map_mmhg=70, lactate=2.0, wbc=10.0, temperature=37.0,
heart_rate=80, creatinine=1.0, sofa_score=2.0, resistance=0.0,
)
assert len(v.to_list()) == 8
class TestSepsisPilotEnv:
def setup_method(self):
self.env = SepsisPilotEnv()
def test_available_tasks(self):
assert set(AVAILABLE_TASKS) == {"mild_sepsis", "septic_shock", "severe_mods"}
def test_reset_returns_valid_state(self):
state = self.env.reset("mild_sepsis", seed=42)
assert state.step == 0
assert state.done is False
assert state.alive is True
assert state.task == "mild_sepsis"
assert state.vitals is not None
@pytest.mark.parametrize("task", ["mild_sepsis", "septic_shock", "severe_mods"])
def test_reset_all_tasks(self, task):
state = self.env.reset(task, seed=42)
assert state.task == task
assert state.step == 0
def test_step_increments_counter(self):
self.env.reset("mild_sepsis", seed=42)
result = self.env.step(5)
assert result.state.step == 1
def test_step_returns_reward_and_done(self):
self.env.reset("mild_sepsis", seed=42)
result = self.env.step(1)
assert isinstance(result.reward, float)
assert isinstance(result.done, bool)
assert isinstance(result.info, dict)
def test_invalid_action_raises(self):
self.env.reset("mild_sepsis", seed=42)
with pytest.raises(ValueError):
self.env.step(99)
def test_step_before_reset_raises(self):
fresh_env = SepsisPilotEnv()
with pytest.raises(RuntimeError):
fresh_env.step(0)
def test_grade_before_done_raises(self):
self.env.reset("mild_sepsis", seed=42)
with pytest.raises(RuntimeError):
self.env.grade()
def test_full_episode_mild_sepsis(self):
"""Complete a full mild_sepsis episode and verify grade."""
state = self.env.reset("mild_sepsis", seed=42)
steps = 0
while not state.done:
result = self.env.step(5) # broad AB + low vaso
state = result.state
steps += 1
assert steps <= 30, "Episode exceeded max_steps guard"
grade = self.env.grade()
assert 0.0 <= grade.score <= 1.0
assert isinstance(grade.reason, str)
assert isinstance(grade.passed, bool)
def test_no_treatment_worsens_vitals(self):
"""No treatment should worsen MAP and lactate over time."""
state = self.env.reset("mild_sepsis", seed=42)
initial_map = state.vitals.map_mmhg
initial_lactate = state.vitals.lactate
for _ in range(5):
if state.done:
break
result = self.env.step(0) # no treatment
state = result.state
# MAP should trend down, lactate should trend up
assert state.vitals.map_mmhg <= initial_map + 5 # allow noise
assert state.vitals.lactate >= initial_lactate - 0.5
def test_reproducibility(self):
"""Same seed produces identical trajectories."""
def run(seed):
env = SepsisPilotEnv()
env.reset("septic_shock", seed=seed)
rewards = []
for _ in range(6):
r = env.step(7)
rewards.append(r.reward)
if r.done:
break
return rewards
assert run(42) == run(42)
def test_grader_score_varies_with_strategy(self):
"""Good strategy should score higher than bad strategy."""
def episode(task, actions_cycle, seed=42):
env = SepsisPilotEnv()
state = env.reset(task, seed=seed)
i = 0
while not state.done:
action = actions_cycle[i % len(actions_cycle)]
result = env.step(action)
state = result.state
i += 1
return env.grade().score
good = episode("mild_sepsis", [5, 1, 1, 5]) # broad AB + low vaso
bad = episode("mild_sepsis", [0, 0, 0, 0]) # no treatment
assert good > bad, f"Expected good > bad, got {good:.3f} <= {bad:.3f}"
def test_task_list(self):
tasks = SepsisPilotEnv.task_list()
assert len(tasks) == 3
names = {t.name for t in tasks}
assert names == {"mild_sepsis", "septic_shock", "severe_mods"}
for t in tasks:
assert t.difficulty in ("easy", "medium", "hard")
assert t.max_steps > 0