from __future__ import annotations from typing import Any, Literal from pydantic import Field, model_validator from openenv_compat import Action, Observation, State class SepsisAction(Action): action_type: Literal["request_lab", "request_treatment", "monitor"] = Field( ..., description="High-level agent action for the current clinical step." ) suspect_sepsis: bool = Field( default=False, description="Whether the agent believes the current presentation is consistent with sepsis.", ) lab_type: str | None = Field( default=None, description="Lab to request when action_type=request_lab. Supported values: lactate, wbc, creatinine, bicarbonate, platelets, bilirubin.", ) treatment_type: str | None = Field( default=None, description="Treatment plan when action_type=request_treatment. Supported values: monitor, fluids, vasopressors, combination.", ) rationale: str = Field(default="", description="Optional reasoning for the clinical decision.") @property def action_index(self) -> int: action_map = {"request_lab": 0, "request_treatment": 1, "monitor": 2} lab_map = { None: 0, "lactate": 1, "wbc": 2, "creatinine": 3, "bicarbonate": 4, "platelets": 5, "bilirubin": 6, } treatment_map = { None: 0, "monitor": 1, "fluids": 2, "vasopressors": 3, "combination": 4, } return int((action_map[self.action_type] * 100) + (lab_map.get(self.lab_type, 0) * 10) + treatment_map.get(self.treatment_type, 0)) @model_validator(mode="after") def validate_payload(self) -> "SepsisAction": if self.action_type == "request_lab" and not self.lab_type: raise ValueError("lab_type is required when action_type='request_lab'.") if self.action_type == "request_treatment" and not self.treatment_type: raise ValueError("treatment_type is required when action_type='request_treatment'.") if self.action_type != "request_lab": object.__setattr__(self, "lab_type", None) if self.action_type != "request_treatment": object.__setattr__(self, "treatment_type", None) return self class SepsisObservation(Observation): episode_id: str = Field(..., description="Unique episode identifier.") task_id: str = Field(..., description="Current task id.") task_description: str = Field(..., description="Human-readable task goal.") patient_id: int = Field(..., description="ICU stay id for the current patient trajectory.") step_index: int = Field(..., description="Current step index within the episode.") max_steps: int = Field(..., description="Maximum number of steps in this task episode.") severity_proxy: float = Field(..., description="Proxy severity score used for reward shaping.") mortality_risk_flag: int = Field(..., description="Binary mortality flag from the logged stay.") demographics: dict[str, float] = Field(default_factory=dict, description="Static or slowly changing patient context.") vitals: dict[str, float] = Field(default_factory=dict, description="Always-visible bedside measurements.") context_features: dict[str, float] = Field(default_factory=dict, description="Visible non-lab features derived from the logged state.") visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Labs revealed so far by explicit requests.") requested_labs: list[str] = Field(default_factory=list, description="Ordered list of labs the agent has already requested.") available_lab_options: list[str] = Field(default_factory=list, description="Supported lab request types.") available_treatment_options: list[str] = Field(default_factory=list, description="Supported treatment request types.") cumulative_reward: float = Field(default=0.0, description="Running reward total.") last_reward: float = Field(default=0.0, description="Reward from the previous step.") done: bool = Field(default=False, description="Whether the episode has ended.") reward: float | None = Field(default=None, description="Current reward for OpenEnv compatibility.") class SepsisState(State): task_id: str = Field(..., description="Current task id.") current_stay_id: int = Field(..., description="Current ICU stay id.") step_count: int = Field(default=0, description="Number of actions taken in this episode.") max_steps: int = Field(default=0, description="Maximum steps in this episode.") cumulative_reward: float = Field(default=0.0, description="Running reward total.") safety_violations: int = Field(default=0, description="Count of unsafe actions.") terminal_outcome: str = Field(default="ongoing", description="survived, died, or ongoing") requested_labs: list[str] = Field(default_factory=list, description="Lab requests made in the episode.") visible_labs: dict[str, float | None] = Field(default_factory=dict, description="Latest revealed lab values.") history: list[dict[str, Any]] = Field(default_factory=list, description="Per-step action trace.")