File size: 5,249 Bytes
ef41468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

from typing import Any, Literal

from pydantic import Field, model_validator

from openenv_compat import Action, Observation, State


class SepsisAction(Action):
    action_type: Literal["request_lab", "request_treatment", "monitor"] = Field(
        ..., description="High-level agent action for the current clinical step."
    )
    suspect_sepsis: bool = Field(
        default=False,
        description="Whether the agent believes the current presentation is consistent with sepsis.",
    )
    lab_type: str | None = Field(
        default=None,
        description="Lab to request when action_type=request_lab. Supported values: lactate, wbc, creatinine, bicarbonate, platelets, bilirubin.",
    )
    treatment_type: str | None = Field(
        default=None,
        description="Treatment plan when action_type=request_treatment. Supported values: monitor, fluids, vasopressors, combination.",
    )
    rationale: str = Field(default="", description="Optional reasoning for the clinical decision.")

    @property
    def action_index(self) -> int:
        action_map = {"request_lab": 0, "request_treatment": 1, "monitor": 2}
        lab_map = {
            None: 0,
            "lactate": 1,
            "wbc": 2,
            "creatinine": 3,
            "bicarbonate": 4,
            "platelets": 5,
            "bilirubin": 6,
        }
        treatment_map = {
            None: 0,
            "monitor": 1,
            "fluids": 2,
            "vasopressors": 3,
            "combination": 4,
        }
        return int((action_map[self.action_type] * 100) + (lab_map.get(self.lab_type, 0) * 10) + treatment_map.get(self.treatment_type, 0))

    @model_validator(mode="after")
    def validate_payload(self) -> "SepsisAction":
        if self.action_type == "request_lab" and not self.lab_type:
            raise ValueError("lab_type is required when action_type='request_lab'.")
        if self.action_type == "request_treatment" and not self.treatment_type:
            raise ValueError("treatment_type is required when action_type='request_treatment'.")
        if self.action_type != "request_lab":
            object.__setattr__(self, "lab_type", None)
        if self.action_type != "request_treatment":
            object.__setattr__(self, "treatment_type", None)
        return self


class SepsisObservation(Observation):
    episode_id: str = Field(..., description="Unique episode identifier.")
    task_id: str = Field(..., description="Current task id.")
    task_description: str = Field(..., description="Human-readable task goal.")
    patient_id: int = Field(..., description="ICU stay id for the current patient trajectory.")
    step_index: int = Field(..., description="Current step index within the episode.")
    max_steps: int = Field(..., description="Maximum number of steps in this task episode.")
    severity_proxy: float = Field(..., description="Proxy severity score used for reward shaping.")
    mortality_risk_flag: int = Field(..., description="Binary mortality flag from the logged stay.")
    demographics: dict[str, float] = Field(default_factory=dict, description="Static or slowly changing patient context.")
    vitals: dict[str, float] = Field(default_factory=dict, description="Always-visible bedside measurements.")
    context_features: dict[str, float] = Field(default_factory=dict, description="Visible non-lab features derived from the logged state.")
    visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Labs revealed so far by explicit requests.")
    requested_labs: list[str] = Field(default_factory=list, description="Ordered list of labs the agent has already requested.")
    available_lab_options: list[str] = Field(default_factory=list, description="Supported lab request types.")
    available_treatment_options: list[str] = Field(default_factory=list, description="Supported treatment request types.")
    cumulative_reward: float = Field(default=0.0, description="Running reward total.")
    last_reward: float = Field(default=0.0, description="Reward from the previous step.")
    done: bool = Field(default=False, description="Whether the episode has ended.")
    reward: float | None = Field(default=None, description="Current reward for OpenEnv compatibility.")


class SepsisState(State):
    task_id: str = Field(..., description="Current task id.")
    current_stay_id: int = Field(..., description="Current ICU stay id.")
    step_count: int = Field(default=0, description="Number of actions taken in this episode.")
    max_steps: int = Field(default=0, description="Maximum steps in this episode.")
    cumulative_reward: float = Field(default=0.0, description="Running reward total.")
    safety_violations: int = Field(default=0, description="Count of unsafe actions.")
    terminal_outcome: str = Field(default="ongoing", description="survived, died, or ongoing")
    requested_labs: list[str] = Field(default_factory=list, description="Lab requests made in the episode.")
    visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Latest revealed lab values.")
    history: list[dict[str, Any]] = Field(default_factory=list, description="Per-step action trace.")