""" SepsisPilot — Unit Tests Run: pytest tests/ -v """ from __future__ import annotations import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import pytest from environment import SepsisPilotEnv, AVAILABLE_TASKS from environment.models import PatientVitals class TestPatientVitals: def test_is_stable_when_all_vitals_normal(self): v = PatientVitals( map_mmhg=72, lactate=1.5, wbc=8.0, temperature=37.0, heart_rate=80, creatinine=1.0, sofa_score=0.0, resistance=0.0, ) assert v.is_stable() is True def test_is_not_stable_when_map_low(self): v = PatientVitals( map_mmhg=60, lactate=1.5, wbc=8.0, temperature=37.0, heart_rate=80, creatinine=1.0, sofa_score=2.0, resistance=0.0, ) assert v.is_stable() is False def test_is_dead_when_map_critical(self): v = PatientVitals( map_mmhg=30, lactate=3.0, wbc=10.0, temperature=37.0, heart_rate=90, creatinine=1.2, sofa_score=10.0, resistance=0.0, ) assert v.is_dead() is True def test_is_dead_when_lactate_extreme(self): v = PatientVitals( map_mmhg=70, lactate=16.0, wbc=10.0, temperature=37.0, heart_rate=90, creatinine=1.2, sofa_score=10.0, resistance=0.0, ) assert v.is_dead() is True def test_to_list_returns_8_elements(self): v = PatientVitals( map_mmhg=70, lactate=2.0, wbc=10.0, temperature=37.0, heart_rate=80, creatinine=1.0, sofa_score=2.0, resistance=0.0, ) assert len(v.to_list()) == 8 class TestSepsisPilotEnv: def setup_method(self): self.env = SepsisPilotEnv() def test_available_tasks(self): assert set(AVAILABLE_TASKS) == {"mild_sepsis", "septic_shock", "severe_mods"} def test_reset_returns_valid_state(self): state = self.env.reset("mild_sepsis", seed=42) assert state.step == 0 assert state.done is False assert state.alive is True assert state.task == "mild_sepsis" assert state.vitals is not None @pytest.mark.parametrize("task", ["mild_sepsis", "septic_shock", "severe_mods"]) def test_reset_all_tasks(self, task): state = self.env.reset(task, seed=42) assert state.task == task assert state.step == 0 def test_step_increments_counter(self): self.env.reset("mild_sepsis", seed=42) result = self.env.step(5) assert result.state.step == 1 def test_step_returns_reward_and_done(self): self.env.reset("mild_sepsis", seed=42) result = self.env.step(1) assert isinstance(result.reward, float) assert isinstance(result.done, bool) assert isinstance(result.info, dict) def test_invalid_action_raises(self): self.env.reset("mild_sepsis", seed=42) with pytest.raises(ValueError): self.env.step(99) def test_step_before_reset_raises(self): fresh_env = SepsisPilotEnv() with pytest.raises(RuntimeError): fresh_env.step(0) def test_grade_before_done_raises(self): self.env.reset("mild_sepsis", seed=42) with pytest.raises(RuntimeError): self.env.grade() def test_full_episode_mild_sepsis(self): """Complete a full mild_sepsis episode and verify grade.""" state = self.env.reset("mild_sepsis", seed=42) steps = 0 while not state.done: result = self.env.step(5) # broad AB + low vaso state = result.state steps += 1 assert steps <= 30, "Episode exceeded max_steps guard" grade = self.env.grade() assert 0.0 <= grade.score <= 1.0 assert isinstance(grade.reason, str) assert isinstance(grade.passed, bool) def test_no_treatment_worsens_vitals(self): """No treatment should worsen MAP and lactate over time.""" state = self.env.reset("mild_sepsis", seed=42) initial_map = state.vitals.map_mmhg initial_lactate = state.vitals.lactate for _ in range(5): if state.done: break result = self.env.step(0) # no treatment state = result.state # MAP should trend down, lactate should trend up assert state.vitals.map_mmhg <= initial_map + 5 # allow noise assert state.vitals.lactate >= initial_lactate - 0.5 def test_reproducibility(self): """Same seed produces identical trajectories.""" def run(seed): env = SepsisPilotEnv() env.reset("septic_shock", seed=seed) rewards = [] for _ in range(6): r = env.step(7) rewards.append(r.reward) if r.done: break return rewards assert run(42) == run(42) def test_grader_score_varies_with_strategy(self): """Good strategy should score higher than bad strategy.""" def episode(task, actions_cycle, seed=42): env = SepsisPilotEnv() state = env.reset(task, seed=seed) i = 0 while not state.done: action = actions_cycle[i % len(actions_cycle)] result = env.step(action) state = result.state i += 1 return env.grade().score good = episode("mild_sepsis", [5, 1, 1, 5]) # broad AB + low vaso bad = episode("mild_sepsis", [0, 0, 0, 0]) # no treatment assert good > bad, f"Expected good > bad, got {good:.3f} <= {bad:.3f}" def test_task_list(self): tasks = SepsisPilotEnv.task_list() assert len(tasks) == 3 names = {t.name for t in tasks} assert names == {"mild_sepsis", "septic_shock", "severe_mods"} for t in tasks: assert t.difficulty in ("easy", "medium", "hard") assert t.max_steps > 0