| """ |
| SepsisPilot β Typed Models (OpenEnv Spec) |
| All state, action, step, and grader contracts live here. |
| """ |
|
|
| from __future__ import annotations |
| from enum import IntEnum |
| from typing import Any, Dict, List, Optional |
| from pydantic import BaseModel, Field |
|
|
|
|
| |
| |
| |
|
|
| class Action(IntEnum): |
| NO_TREATMENT = 0 |
| BROAD_ANTIBIOTICS = 1 |
| NARROW_ANTIBIOTICS = 2 |
| LOW_VASOPRESSOR = 3 |
| HIGH_VASOPRESSOR = 4 |
| BROAD_LOW_VASO = 5 |
| BROAD_HIGH_VASO = 6 |
| NARROW_LOW_VASO = 7 |
| NARROW_HIGH_VASO = 8 |
|
|
| ACTION_DESCRIPTIONS: Dict[int, str] = { |
| 0: "No treatment β watchful waiting", |
| 1: "Broad-spectrum antibiotics (piperacillin-tazobactam)", |
| 2: "Narrow-spectrum antibiotics (vancomycin)", |
| 3: "Low-dose vasopressor (norepinephrine 0.1 mcg/kg/min)", |
| 4: "High-dose vasopressor (norepinephrine 0.3 mcg/kg/min)", |
| 5: "Broad-spectrum antibiotics + low-dose vasopressor", |
| 6: "Broad-spectrum antibiotics + high-dose vasopressor", |
| 7: "Narrow-spectrum antibiotics + low-dose vasopressor", |
| 8: "Narrow-spectrum antibiotics + high-dose vasopressor", |
| } |
|
|
| |
| |
| |
|
|
| class PatientVitals(BaseModel): |
| """Continuous observation vector. Normal ranges noted inline.""" |
| map_mmhg: float = Field(..., description="Mean Arterial Pressure mmHg. Normal 70-100; sepsis goal >65") |
| lactate: float = Field(..., description="Serum lactate mmol/L. Normal 0.5-2.0; crisis >4") |
| wbc: float = Field(..., description="White blood cell count k/uL. Normal 4-11; sepsis >12 or <4") |
| temperature: float = Field(..., description="Core temp Β°C. Normal 36.5-37.5; sepsis >38 or <36") |
| heart_rate: float = Field(..., description="Heart rate bpm. Normal 60-100; sepsis >90") |
| creatinine: float = Field(..., description="Serum creatinine mg/dL. Normal 0.6-1.2; AKI >1.5") |
| sofa_score: float = Field(..., description="SOFA score 0-24. >10 = high mortality") |
| resistance: float = Field(..., description="Antibiotic resistance index 0-1 (hard task only)") |
|
|
| def to_list(self) -> List[float]: |
| return [ |
| self.map_mmhg, self.lactate, self.wbc, self.temperature, |
| self.heart_rate, self.creatinine, self.sofa_score, self.resistance, |
| ] |
|
|
| def is_stable(self) -> bool: |
| """All key vitals in target range.""" |
| return ( |
| self.map_mmhg >= 65 |
| and self.lactate <= 2.0 |
| and 4.0 <= self.wbc <= 12.0 |
| and 36.0 <= self.temperature <= 38.0 |
| and self.heart_rate <= 100 |
| ) |
|
|
| def is_dead(self) -> bool: |
| return ( |
| self.map_mmhg < 35 |
| or self.lactate > 15 |
| or self.heart_rate > 165 |
| or self.heart_rate < 25 |
| ) |
|
|
|
|
| class PatientState(BaseModel): |
| """Full state exposed to the agent.""" |
| vitals: PatientVitals |
| step: int |
| max_steps: int |
| done: bool |
| alive: bool |
| task: str |
| stabilized_at: Optional[int] = None |
| episode_reward: float = 0.0 |
|
|
| def to_observation(self) -> List[float]: |
| """Flat numeric vector for RL agents.""" |
| return self.vitals.to_list() + [self.step / self.max_steps] |
|
|
|
|
| |
| |
| |
|
|
| class ResetRequest(BaseModel): |
| task: str = Field("mild_sepsis", description="Task name: mild_sepsis | septic_shock | severe_mods") |
| seed: Optional[int] = Field(None, description="Random seed for reproducibility") |
|
|
| class ActionRequest(BaseModel): |
| action: int = Field(..., ge=0, le=8, description="Action index 0-8") |
|
|
| class StepResult(BaseModel): |
| state: PatientState |
| reward: float |
| done: bool |
| info: Dict[str, Any] |
|
|
| class GraderResult(BaseModel): |
| score: float = Field(..., ge=0.0, le=1.0) |
| reason: str |
| metrics: Dict[str, float] |
| passed: bool |
|
|
| class TaskInfo(BaseModel): |
| name: str |
| difficulty: str |
| description: str |
| max_steps: int |
| action_n: int = 9 |
| obs_shape: List[int] = [9] |
|
|