Sepsis-OpenEnv / server /sepsis_environment.py
BAIBHAV1234's picture
Upload folder using huggingface_hub
4740eda verified
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from graders import grade_episode, summarize_episode
from models import SepsisAction, SepsisObservation, SepsisState
from openenv_compat import Environment
from tasks import TaskConfig, build_task_catalog
ROOT = Path(__file__).resolve().parent.parent
ENV_DATA_DIR = ROOT / "env_data"
DATASET_PATH = ENV_DATA_DIR / "processed_demo_dataset.pkl"
FEATURES_PATH = ENV_DATA_DIR / "selected_features.json"
LAB_FIELDS = {
"lactate": "Arterial_lactate",
"wbc": "WBC_count",
"creatinine": "Creatinine",
"bicarbonate": "Bicarbonate",
"platelets": "Platelets_count",
"bilirubin": "Total_bili",
}
LAB_OPTIONS = list(LAB_FIELDS.keys())
TREATMENT_OPTIONS = ["monitor", "fluids", "vasopressors", "combination"]
VITAL_FIELDS = ["HR", "MeanBP", "RR", "Temp_C", "SpO2", "Shock_Index"]
DEMOGRAPHIC_FIELDS = ["age", "is_male"]
def load_processed_assets() -> tuple[pd.DataFrame, list[str]]:
if DATASET_PATH.exists() and FEATURES_PATH.exists():
dataset = pd.read_pickle(DATASET_PATH)
with open(FEATURES_PATH, "r", encoding="utf-8") as handle:
selected_features = json.load(handle)
return dataset, selected_features
raise FileNotFoundError(
"Missing env_data assets. Expected env_data/processed_demo_dataset.pkl "
"and env_data/selected_features.json."
)
def build_summary(dataset: pd.DataFrame) -> pd.DataFrame:
grouped = (
dataset.groupby("icustay_id")
.agg(
length=("bin_idx", "count"),
mean_severity=("severity_proxy", "mean"),
max_severity=("severity_proxy", "max"),
mortality=("mortality", "max"),
)
.reset_index()
.sort_values("icustay_id")
.reset_index(drop=True)
)
return grouped
class SepsisTreatmentEnvironment(Environment):
def __init__(self, task_id: str = "easy"):
self.dataset, self.selected_features = load_processed_assets()
self.summary = build_summary(self.dataset)
self.task_catalog = build_task_catalog(self.summary)
self._task_cycle = {task_key: 0 for task_key in self.task_catalog}
self.default_task_id = task_id if task_id in self.task_catalog else "easy"
self._episode = pd.DataFrame()
self._cursor = 0
self._task: TaskConfig = self.task_catalog[self.default_task_id]
self._metrics: dict[str, Any] = {}
self._visited_state_actions: set[tuple[str, str]] = set()
self._state = SepsisState(
episode_id=str(uuid.uuid4()),
task_id=self.default_task_id,
current_stay_id=-1,
step_count=0,
max_steps=0,
cumulative_reward=0.0,
safety_violations=0,
terminal_outcome="ongoing",
requested_labs=[],
visible_labs={},
history=[],
)
def available_tasks(self) -> list[dict[str, Any]]:
return [
{
"id": task.task_id,
"title": task.title,
"description": task.description,
"min_steps": task.min_steps,
"max_steps": task.max_steps,
"num_episodes": len(task.preferred_stay_ids),
}
for task in self.task_catalog.values()
]
def metadata(self) -> dict[str, Any]:
return {
"name": "sepsi-gym",
"description": "A real-world clinical RL environment for sepsis diagnosis and treatment management using MIMIC-III trajectories.",
"tasks": self.available_tasks(),
"selected_features": self.selected_features,
"action_space": {
"action_type": ["request_lab", "request_treatment", "monitor"],
"lab_type": LAB_OPTIONS,
"treatment_type": TREATMENT_OPTIONS,
"suspect_sepsis": "boolean",
},
"observation_space": {
"vitals": VITAL_FIELDS,
"visible_labs": LAB_OPTIONS,
"demographics": DEMOGRAPHIC_FIELDS,
},
}
def _select_episode(self, task: TaskConfig) -> pd.DataFrame:
stay_ids = task.preferred_stay_ids
if not stay_ids:
stay_ids = tuple(int(stay_id) for stay_id in self.summary["icustay_id"].tolist())
index = self._task_cycle[task.task_id] % len(stay_ids)
self._task_cycle[task.task_id] += 1
stay_id = stay_ids[index]
episode = self.dataset[self.dataset["icustay_id"] == stay_id].sort_values("bin_idx").reset_index(drop=True)
episode = episode.head(task.max_steps).copy()
if len(episode) < task.min_steps:
episode = self.dataset[self.dataset["icustay_id"] == stay_id].sort_values("bin_idx").reset_index(drop=True)
return episode
def _row_float(self, row: pd.Series, key: str) -> float:
value = row.get(key, 0.0)
if pd.isna(value):
return 0.0
return float(value)
def _make_observation(self, row: pd.Series, *, reward: float, done: bool) -> SepsisObservation:
hidden_lab_columns = set(LAB_FIELDS.values())
vitals = {name: self._row_float(row, name) for name in VITAL_FIELDS}
demographics = {name: self._row_float(row, name) for name in DEMOGRAPHIC_FIELDS}
context_features = {
name: self._row_float(row, name)
for name in self.selected_features
if name not in hidden_lab_columns and name not in VITAL_FIELDS
}
return SepsisObservation(
episode_id=self._state.episode_id,
task_id=self._task.task_id,
task_description=self._task.description,
patient_id=int(row["icustay_id"]),
step_index=int(self._cursor),
max_steps=int(len(self._episode)),
severity_proxy=self._row_float(row, "severity_proxy"),
mortality_risk_flag=int(row["mortality"]),
demographics=demographics,
vitals=vitals,
context_features=context_features,
visible_labs=dict(self._state.visible_labs),
requested_labs=list(self._state.requested_labs),
available_lab_options=LAB_OPTIONS,
available_treatment_options=TREATMENT_OPTIONS,
cumulative_reward=float(self._state.cumulative_reward),
last_reward=float(reward),
done=done,
reward=float(reward),
)
def _priority_labs(self, row: pd.Series) -> set[str]:
priority: set[str] = set()
severity = self._row_float(row, "severity_proxy")
shock = self._row_float(row, "Shock_Index")
mean_bp = self._row_float(row, "MeanBP")
lactate = self._row_float(row, "Arterial_lactate")
wbc = self._row_float(row, "WBC_count")
creatinine = self._row_float(row, "Creatinine")
bicarbonate = self._row_float(row, "Bicarbonate")
platelets = self._row_float(row, "Platelets_count")
bilirubin = self._row_float(row, "Total_bili")
if severity >= 1.0 or shock > 0.1 or mean_bp < 0.0 or lactate > 0.25:
priority.update(["lactate", "wbc"])
if creatinine > 0.15 or self._row_float(row, "BUN_Creatinine_Ratio") > 0.2:
priority.add("creatinine")
if bicarbonate < -0.15 or self._row_float(row, "Arterial_pH") < -0.15:
priority.add("bicarbonate")
if platelets < -0.2:
priority.add("platelets")
if bilirubin > 0.15:
priority.add("bilirubin")
if not priority:
priority.add("lactate")
return priority
def _sepsis_signal(self, row: pd.Series) -> bool:
return bool(
self._row_float(row, "severity_proxy") >= 1.0
or self._row_float(row, "Shock_Index") > 0.1
or self._row_float(row, "MeanBP") < 0.0
or self._row_float(row, "Arterial_lactate") > 0.25
or self._row_float(row, "WBC_count") > 0.2
)
def _target_treatment(self, row: pd.Series) -> str:
clinician_fluid = int(row["fluid_bin"])
clinician_pressor = int(row["pressor_bin"])
severity = self._row_float(row, "severity_proxy")
if clinician_pressor >= 2 and clinician_fluid >= 2:
return "combination"
if clinician_pressor >= 2:
return "vasopressors"
if clinician_fluid >= 2 or severity >= 1.25 or self._row_float(row, "Shock_Index") > 0.15:
return "fluids"
return "monitor"
def _treatment_match(self, chosen: str, target: str) -> float:
if chosen == target:
return 1.0
close_pairs = {
("fluids", "combination"),
("combination", "fluids"),
("vasopressors", "combination"),
("combination", "vasopressors"),
("monitor", "fluids"),
("fluids", "monitor"),
}
if (chosen, target) in close_pairs:
return 0.5
return 0.0
def _compute_reward(
self,
row: pd.Series,
next_row: pd.Series | None,
action: SepsisAction,
is_terminal: bool,
) -> tuple[float, dict[str, Any]]:
reward = 0.0
unsafe = False
sepsis_signal = self._sepsis_signal(row)
priority_labs = self._priority_labs(row)
target_treatment = self._target_treatment(row)
severity_now = self._row_float(row, "severity_proxy")
severity_next = self._row_float(next_row, "severity_proxy") if next_row is not None else severity_now
progress = float(0.15 * np.tanh(severity_now - severity_next))
stability_score = 1.0 if severity_next <= severity_now else max(0.0, 1.0 - min(severity_next - severity_now, 1.0))
detection_credit = 0.0
lab_score = 0.0
treatment_score = 0.0
inefficiency_penalty = 0.0
safety_penalty = 0.0
novelty_bonus = 0.0
if action.suspect_sepsis:
detection_credit = 1.0 if sepsis_signal else 0.25
reward += 0.15 if sepsis_signal else -0.05
elif sepsis_signal and self._cursor <= 1:
reward -= 0.05
if action.action_type == "request_lab":
if action.lab_type in LAB_FIELDS:
duplicate_lab = action.lab_type in self._state.requested_labs
revealed_value = self._row_float(row, LAB_FIELDS[action.lab_type])
self._state.visible_labs[action.lab_type] = revealed_value
self._state.requested_labs.append(action.lab_type)
if duplicate_lab:
inefficiency_penalty += 0.08
reward -= 0.08
elif action.lab_type in priority_labs:
lab_score = 1.0
reward += 0.20
else:
lab_score = 0.4
reward += 0.05
if action.lab_type in {"lactate", "wbc"} and sepsis_signal:
detection_credit = max(detection_credit, 0.8)
else:
inefficiency_penalty += 0.05
reward -= 0.05
elif action.action_type == "request_treatment":
if action.treatment_type in TREATMENT_OPTIONS:
treatment_score = self._treatment_match(action.treatment_type, target_treatment)
reward += 0.30 * treatment_score
if severity_now <= 0.8 and action.treatment_type in {"vasopressors", "combination"}:
safety_penalty += 0.25
unsafe = True
if severity_now >= 2.0 and action.treatment_type == "monitor":
safety_penalty += 0.30
unsafe = True
if self._row_float(row, "MeanBP") < -0.2 and action.treatment_type == "monitor":
safety_penalty += 0.10
unsafe = True
else:
inefficiency_penalty += 0.05
reward -= 0.05
else:
if severity_now >= 1.75:
inefficiency_penalty += 0.05
reward -= 0.05
state_id = f"{int(row['icustay_id'])}_{self._cursor}"
action_id = f"{action.action_type}_{action.lab_type}_{action.treatment_type}"
state_action_key = (state_id, action_id)
if state_action_key not in self._visited_state_actions:
novelty_bonus = 0.03
reward += novelty_bonus
self._visited_state_actions.add(state_action_key)
else:
inefficiency_penalty += 0.01
reward -= 0.01
if sepsis_signal and self._cursor <= 2 and action.suspect_sepsis:
detection_credit = max(detection_credit, 1.0)
reward += 0.05
if severity_now > 1.5 and action.action_type == "monitor":
safety_penalty += 0.05
reward -= 0.05
reward += progress
reward -= safety_penalty
reward -= inefficiency_penalty
if is_terminal:
reward += 0.30 if int(row["mortality"]) == 0 else -0.30
reward = float(np.clip(reward, -1.0, 1.0))
return reward, {
"detection_credit": round(detection_credit, 4),
"lab_score": round(lab_score, 4),
"treatment_score": round(treatment_score, 4),
"stability_score": round(stability_score, 4),
"progress_score": round(progress, 4),
"novelty_bonus": round(novelty_bonus, 4),
"safety_penalty": round(safety_penalty, 4),
"inefficiency_penalty": round(inefficiency_penalty, 4),
"target_treatment": target_treatment,
"priority_labs": sorted(priority_labs),
"unsafe": unsafe,
}
def reset(self, task_id: str | None = None) -> SepsisObservation:
chosen_task = task_id or self.default_task_id
if chosen_task not in self.task_catalog:
chosen_task = self.default_task_id
self._task = self.task_catalog[chosen_task]
self._episode = self._select_episode(self._task)
self._cursor = 0
current_row = self._episode.iloc[self._cursor]
self._state = SepsisState(
episode_id=str(uuid.uuid4()),
task_id=self._task.task_id,
current_stay_id=int(current_row["icustay_id"]),
step_count=0,
max_steps=int(len(self._episode)),
cumulative_reward=0.0,
safety_violations=0,
terminal_outcome="ongoing",
requested_labs=[],
visible_labs={},
history=[],
)
self._metrics = {}
self._visited_state_actions = set()
return self._make_observation(current_row, reward=0.0, done=False)
def step(self, action: SepsisAction) -> SepsisObservation:
if self._episode.empty:
return self.reset(self.default_task_id)
row = self._episode.iloc[self._cursor]
next_index = self._cursor + 1
next_row = self._episode.iloc[next_index] if next_index < len(self._episode) else None
done = next_row is None
reward, details = self._compute_reward(row, next_row, action, done)
self._state.step_count += 1
self._state.cumulative_reward += reward
if details["unsafe"]:
self._state.safety_violations += 1
if done:
self._state.terminal_outcome = "survived" if int(row["mortality"]) == 0 else "died"
history_row = {
"step": int(self._cursor),
"action_index": action.action_index,
"action_type": action.action_type,
"suspect_sepsis": action.suspect_sepsis,
"lab_type": action.lab_type,
"treatment_type": action.treatment_type,
"detection_credit": details["detection_credit"],
"lab_score": details["lab_score"],
"treatment_score": details["treatment_score"],
"stability_score": details["stability_score"],
"progress_score": details["progress_score"],
"novelty_bonus": details["novelty_bonus"],
"safety_penalty": details["safety_penalty"],
"inefficiency_penalty": details["inefficiency_penalty"],
"unsafe": details["unsafe"],
"severity_proxy": round(self._row_float(row, "severity_proxy"), 4),
"target_treatment": details["target_treatment"],
"priority_labs": details["priority_labs"],
"reward": round(reward, 4),
}
self._state.history.append(history_row)
if not done:
self._cursor = next_index
observation_row = self._episode.iloc[self._cursor]
else:
observation_row = row
episode_metrics = summarize_episode(
total_reward=float(self._state.cumulative_reward),
state_history=self._state.history,
terminal_outcome=self._state.terminal_outcome,
)
episode_metrics["score"] = grade_episode(self._task, episode_metrics)
self._metrics = episode_metrics
return self._make_observation(observation_row, reward=reward, done=done)
def current_metrics(self) -> dict[str, Any]:
return {
**self._metrics,
"task_id": self._task.task_id,
"current_stay_id": self._state.current_stay_id,
"cumulative_reward": round(float(self._state.cumulative_reward), 4),
"requested_labs": list(self._state.requested_labs),
"safety_violations": int(self._state.safety_violations),
}
@property
def state(self) -> SepsisState:
return self._state