Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any, Literal | |
| from pydantic import Field, model_validator | |
| from openenv_compat import Action, Observation, State | |
| class SepsisAction(Action): | |
| action_type: Literal["request_lab", "request_treatment", "monitor"] = Field( | |
| ..., description="High-level agent action for the current clinical step." | |
| ) | |
| suspect_sepsis: bool = Field( | |
| default=False, | |
| description="Whether the agent believes the current presentation is consistent with sepsis.", | |
| ) | |
| lab_type: str | None = Field( | |
| default=None, | |
| description="Lab to request when action_type=request_lab. Supported values: lactate, wbc, creatinine, bicarbonate, platelets, bilirubin.", | |
| ) | |
| treatment_type: str | None = Field( | |
| default=None, | |
| description="Treatment plan when action_type=request_treatment. Supported values: monitor, fluids, vasopressors, combination.", | |
| ) | |
| rationale: str = Field(default="", description="Optional reasoning for the clinical decision.") | |
| def action_index(self) -> int: | |
| action_map = {"request_lab": 0, "request_treatment": 1, "monitor": 2} | |
| lab_map = { | |
| None: 0, | |
| "lactate": 1, | |
| "wbc": 2, | |
| "creatinine": 3, | |
| "bicarbonate": 4, | |
| "platelets": 5, | |
| "bilirubin": 6, | |
| } | |
| treatment_map = { | |
| None: 0, | |
| "monitor": 1, | |
| "fluids": 2, | |
| "vasopressors": 3, | |
| "combination": 4, | |
| } | |
| return int((action_map[self.action_type] * 100) + (lab_map.get(self.lab_type, 0) * 10) + treatment_map.get(self.treatment_type, 0)) | |
| def validate_payload(self) -> "SepsisAction": | |
| if self.action_type == "request_lab" and not self.lab_type: | |
| raise ValueError("lab_type is required when action_type='request_lab'.") | |
| if self.action_type == "request_treatment" and not self.treatment_type: | |
| raise ValueError("treatment_type is required when action_type='request_treatment'.") | |
| if self.action_type != "request_lab": | |
| object.__setattr__(self, "lab_type", None) | |
| if self.action_type != "request_treatment": | |
| object.__setattr__(self, "treatment_type", None) | |
| return self | |
| class SepsisObservation(Observation): | |
| episode_id: str = Field(..., description="Unique episode identifier.") | |
| task_id: str = Field(..., description="Current task id.") | |
| task_description: str = Field(..., description="Human-readable task goal.") | |
| patient_id: int = Field(..., description="ICU stay id for the current patient trajectory.") | |
| step_index: int = Field(..., description="Current step index within the episode.") | |
| max_steps: int = Field(..., description="Maximum number of steps in this task episode.") | |
| severity_proxy: float = Field(..., description="Proxy severity score used for reward shaping.") | |
| mortality_risk_flag: int = Field(..., description="Binary mortality flag from the logged stay.") | |
| demographics: dict[str, float] = Field(default_factory=dict, description="Static or slowly changing patient context.") | |
| vitals: dict[str, float] = Field(default_factory=dict, description="Always-visible bedside measurements.") | |
| context_features: dict[str, float] = Field(default_factory=dict, description="Visible non-lab features derived from the logged state.") | |
| visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Labs revealed so far by explicit requests.") | |
| requested_labs: list[str] = Field(default_factory=list, description="Ordered list of labs the agent has already requested.") | |
| available_lab_options: list[str] = Field(default_factory=list, description="Supported lab request types.") | |
| available_treatment_options: list[str] = Field(default_factory=list, description="Supported treatment request types.") | |
| cumulative_reward: float = Field(default=0.0, description="Running reward total.") | |
| last_reward: float = Field(default=0.0, description="Reward from the previous step.") | |
| done: bool = Field(default=False, description="Whether the episode has ended.") | |
| reward: float | None = Field(default=None, description="Current reward for OpenEnv compatibility.") | |
| class SepsisState(State): | |
| task_id: str = Field(..., description="Current task id.") | |
| current_stay_id: int = Field(..., description="Current ICU stay id.") | |
| step_count: int = Field(default=0, description="Number of actions taken in this episode.") | |
| max_steps: int = Field(default=0, description="Maximum steps in this episode.") | |
| cumulative_reward: float = Field(default=0.0, description="Running reward total.") | |
| safety_violations: int = Field(default=0, description="Count of unsafe actions.") | |
| terminal_outcome: str = Field(default="ongoing", description="survived, died, or ongoing") | |
| requested_labs: list[str] = Field(default_factory=list, description="Lab requests made in the episode.") | |
| visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Latest revealed lab values.") | |
| history: list[dict[str, Any]] = Field(default_factory=list, description="Per-step action trace.") | |