File size: 5,306 Bytes
53d9f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
SepsisPilot β€” Typed Models (OpenEnv Spec)
All state, action, step, and grader contracts live here.
"""

from __future__ import annotations
from enum import IntEnum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field


# ──────────────────────────────────────────────
# Action Space  (discrete, 9 actions)
# ──────────────────────────────────────────────

class Action(IntEnum):
    NO_TREATMENT       = 0   # watchful waiting
    BROAD_ANTIBIOTICS  = 1   # e.g. piperacillin-tazobactam (gram-negative coverage)
    NARROW_ANTIBIOTICS = 2   # e.g. vancomycin (gram-positive coverage)
    LOW_VASOPRESSOR    = 3   # norepinephrine 0.1 mcg/kg/min
    HIGH_VASOPRESSOR   = 4   # norepinephrine 0.3 mcg/kg/min
    BROAD_LOW_VASO     = 5   # broad AB + low-dose vasopressor
    BROAD_HIGH_VASO    = 6   # broad AB + high-dose vasopressor
    NARROW_LOW_VASO    = 7   # narrow AB + low-dose vasopressor
    NARROW_HIGH_VASO   = 8   # narrow AB + high-dose vasopressor

ACTION_DESCRIPTIONS: Dict[int, str] = {
    0: "No treatment β€” watchful waiting",
    1: "Broad-spectrum antibiotics (piperacillin-tazobactam)",
    2: "Narrow-spectrum antibiotics (vancomycin)",
    3: "Low-dose vasopressor (norepinephrine 0.1 mcg/kg/min)",
    4: "High-dose vasopressor (norepinephrine 0.3 mcg/kg/min)",
    5: "Broad-spectrum antibiotics + low-dose vasopressor",
    6: "Broad-spectrum antibiotics + high-dose vasopressor",
    7: "Narrow-spectrum antibiotics + low-dose vasopressor",
    8: "Narrow-spectrum antibiotics + high-dose vasopressor",
}

# ──────────────────────────────────────────────
# Patient State  (observation space, shape=[8])
# ──────────────────────────────────────────────

class PatientVitals(BaseModel):
    """Continuous observation vector. Normal ranges noted inline."""
    map_mmhg:    float = Field(..., description="Mean Arterial Pressure mmHg. Normal 70-100; sepsis goal >65")
    lactate:     float = Field(..., description="Serum lactate mmol/L. Normal 0.5-2.0; crisis >4")
    wbc:         float = Field(..., description="White blood cell count k/uL. Normal 4-11; sepsis >12 or <4")
    temperature: float = Field(..., description="Core temp Β°C. Normal 36.5-37.5; sepsis >38 or <36")
    heart_rate:  float = Field(..., description="Heart rate bpm. Normal 60-100; sepsis >90")
    creatinine:  float = Field(..., description="Serum creatinine mg/dL. Normal 0.6-1.2; AKI >1.5")
    sofa_score:  float = Field(..., description="SOFA score 0-24. >10 = high mortality")
    resistance:  float = Field(..., description="Antibiotic resistance index 0-1 (hard task only)")

    def to_list(self) -> List[float]:
        return [
            self.map_mmhg, self.lactate, self.wbc, self.temperature,
            self.heart_rate, self.creatinine, self.sofa_score, self.resistance,
        ]

    def is_stable(self) -> bool:
        """All key vitals in target range."""
        return (
            self.map_mmhg >= 65
            and self.lactate <= 2.0
            and 4.0 <= self.wbc <= 12.0
            and 36.0 <= self.temperature <= 38.0
            and self.heart_rate <= 100
        )

    def is_dead(self) -> bool:
        return (
            self.map_mmhg < 35
            or self.lactate > 15
            or self.heart_rate > 165
            or self.heart_rate < 25
        )


class PatientState(BaseModel):
    """Full state exposed to the agent."""
    vitals:         PatientVitals
    step:           int
    max_steps:      int
    done:           bool
    alive:          bool
    task:           str
    stabilized_at:  Optional[int] = None   # step when vitals first became stable
    episode_reward: float = 0.0

    def to_observation(self) -> List[float]:
        """Flat numeric vector for RL agents."""
        return self.vitals.to_list() + [self.step / self.max_steps]


# ──────────────────────────────────────────────
# API Request / Response models
# ──────────────────────────────────────────────

class ResetRequest(BaseModel):
    task: str = Field("mild_sepsis", description="Task name: mild_sepsis | septic_shock | severe_mods")
    seed: Optional[int] = Field(None, description="Random seed for reproducibility")

class ActionRequest(BaseModel):
    action: int = Field(..., ge=0, le=8, description="Action index 0-8")

class StepResult(BaseModel):
    state:  PatientState
    reward: float
    done:   bool
    info:   Dict[str, Any]

class GraderResult(BaseModel):
    score:   float = Field(..., ge=0.0, le=1.0)
    reason:  str
    metrics: Dict[str, float]
    passed:  bool   # score >= 0.5

class TaskInfo(BaseModel):
    name:        str
    difficulty:  str
    description: str
    max_steps:   int
    action_n:    int = 9
    obs_shape:   List[int] = [9]