SepsisPilot / environment /models.py
coral-cyber
testing the environment
53d9f07
"""
SepsisPilot β€” Typed Models (OpenEnv Spec)
All state, action, step, and grader contracts live here.
"""
from __future__ import annotations
from enum import IntEnum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ──────────────────────────────────────────────
# Action Space (discrete, 9 actions)
# ──────────────────────────────────────────────
class Action(IntEnum):
NO_TREATMENT = 0 # watchful waiting
BROAD_ANTIBIOTICS = 1 # e.g. piperacillin-tazobactam (gram-negative coverage)
NARROW_ANTIBIOTICS = 2 # e.g. vancomycin (gram-positive coverage)
LOW_VASOPRESSOR = 3 # norepinephrine 0.1 mcg/kg/min
HIGH_VASOPRESSOR = 4 # norepinephrine 0.3 mcg/kg/min
BROAD_LOW_VASO = 5 # broad AB + low-dose vasopressor
BROAD_HIGH_VASO = 6 # broad AB + high-dose vasopressor
NARROW_LOW_VASO = 7 # narrow AB + low-dose vasopressor
NARROW_HIGH_VASO = 8 # narrow AB + high-dose vasopressor
ACTION_DESCRIPTIONS: Dict[int, str] = {
0: "No treatment β€” watchful waiting",
1: "Broad-spectrum antibiotics (piperacillin-tazobactam)",
2: "Narrow-spectrum antibiotics (vancomycin)",
3: "Low-dose vasopressor (norepinephrine 0.1 mcg/kg/min)",
4: "High-dose vasopressor (norepinephrine 0.3 mcg/kg/min)",
5: "Broad-spectrum antibiotics + low-dose vasopressor",
6: "Broad-spectrum antibiotics + high-dose vasopressor",
7: "Narrow-spectrum antibiotics + low-dose vasopressor",
8: "Narrow-spectrum antibiotics + high-dose vasopressor",
}
# ──────────────────────────────────────────────
# Patient State (observation space, shape=[8])
# ──────────────────────────────────────────────
class PatientVitals(BaseModel):
"""Continuous observation vector. Normal ranges noted inline."""
map_mmhg: float = Field(..., description="Mean Arterial Pressure mmHg. Normal 70-100; sepsis goal >65")
lactate: float = Field(..., description="Serum lactate mmol/L. Normal 0.5-2.0; crisis >4")
wbc: float = Field(..., description="White blood cell count k/uL. Normal 4-11; sepsis >12 or <4")
temperature: float = Field(..., description="Core temp Β°C. Normal 36.5-37.5; sepsis >38 or <36")
heart_rate: float = Field(..., description="Heart rate bpm. Normal 60-100; sepsis >90")
creatinine: float = Field(..., description="Serum creatinine mg/dL. Normal 0.6-1.2; AKI >1.5")
sofa_score: float = Field(..., description="SOFA score 0-24. >10 = high mortality")
resistance: float = Field(..., description="Antibiotic resistance index 0-1 (hard task only)")
def to_list(self) -> List[float]:
return [
self.map_mmhg, self.lactate, self.wbc, self.temperature,
self.heart_rate, self.creatinine, self.sofa_score, self.resistance,
]
def is_stable(self) -> bool:
"""All key vitals in target range."""
return (
self.map_mmhg >= 65
and self.lactate <= 2.0
and 4.0 <= self.wbc <= 12.0
and 36.0 <= self.temperature <= 38.0
and self.heart_rate <= 100
)
def is_dead(self) -> bool:
return (
self.map_mmhg < 35
or self.lactate > 15
or self.heart_rate > 165
or self.heart_rate < 25
)
class PatientState(BaseModel):
"""Full state exposed to the agent."""
vitals: PatientVitals
step: int
max_steps: int
done: bool
alive: bool
task: str
stabilized_at: Optional[int] = None # step when vitals first became stable
episode_reward: float = 0.0
def to_observation(self) -> List[float]:
"""Flat numeric vector for RL agents."""
return self.vitals.to_list() + [self.step / self.max_steps]
# ──────────────────────────────────────────────
# API Request / Response models
# ──────────────────────────────────────────────
class ResetRequest(BaseModel):
task: str = Field("mild_sepsis", description="Task name: mild_sepsis | septic_shock | severe_mods")
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
class ActionRequest(BaseModel):
action: int = Field(..., ge=0, le=8, description="Action index 0-8")
class StepResult(BaseModel):
state: PatientState
reward: float
done: bool
info: Dict[str, Any]
class GraderResult(BaseModel):
score: float = Field(..., ge=0.0, le=1.0)
reason: str
metrics: Dict[str, float]
passed: bool # score >= 0.5
class TaskInfo(BaseModel):
name: str
difficulty: str
description: str
max_steps: int
action_n: int = 9
obs_shape: List[int] = [9]