Pulse_ER_env / server /adapters.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Backend adapters for swapping mock and real Pulse runtimes."""
from __future__ import annotations
from abc import ABC, abstractmethod
import random
from ..models import (
INITIAL_TOOL_NAMES,
EnvironmentResponse,
ObservationMetadata,
PatientState,
PulsePhysiologyObservation,
ToolAction,
ToolError,
ToolResult,
)
from ..runtime_effects import NoisyObservation, TimePressureMechanic
from ..rewards import compute_reward
from ..tool_catalog import (
KNOWN_TOOL_NAMES,
ToolValidationError,
canonicalize_tool_name,
coerce_boolean_argument,
validate_tool_arguments,
)
from .mock_scenarios import DEFAULT_MOCK_SCENARIO_ID, MOCK_SCENARIOS, MockScenarioDefinition
MOCK_READ_ONLY_TOOLS = {
"get_vitals",
"summarize_state",
"check_deterioration",
"recommend_next_step",
"get_respiratory_status",
"get_blood_gas",
"get_cbc",
"get_bmp",
}
MOCK_DIAGNOSTIC_DELAYS = {
"get_blood_gas": 120,
"get_cbc": 240,
"get_bmp": 300,
}
MOCK_EXTENDED_INTERVENTION_EFFECTS = {
"give_oxygen": {
"baseline_stable": {},
"respiratory_distress": {"spo2": 0.04, "respiration_rate_bpm": -3.0, "heart_rate_bpm": -2.0},
"hemorrhagic_shock": {"spo2": 0.02, "respiration_rate_bpm": -1.0, "heart_rate_bpm": -1.0},
},
"give_fluids": {
"baseline_stable": {"heart_rate_bpm": 2.0},
"respiratory_distress": {"systolic_bp_mmhg": 0.5, "heart_rate_bpm": 1.0},
"hemorrhagic_shock": {"systolic_bp_mmhg": 8.0, "diastolic_bp_mmhg": 4.0, "heart_rate_bpm": -4.0, "blood_volume_ml": 300.0},
},
"control_bleeding": {
"baseline_stable": {},
"respiratory_distress": {},
"hemorrhagic_shock": {"systolic_bp_mmhg": 5.0, "diastolic_bp_mmhg": 2.5, "heart_rate_bpm": -3.0, "blood_volume_ml": 150.0},
},
"position_patient": {
"baseline_stable": {},
"respiratory_distress": {"spo2": 0.01, "respiration_rate_bpm": -1.0},
"hemorrhagic_shock": {"systolic_bp_mmhg": 2.0, "diastolic_bp_mmhg": 1.0, "respiration_rate_bpm": -1.0},
},
"airway_support": {
"baseline_stable": {"heart_rate_bpm": 1.0},
"respiratory_distress": {"spo2": 0.04, "respiration_rate_bpm": -4.0, "heart_rate_bpm": -2.0},
"hemorrhagic_shock": {"spo2": 0.01, "respiration_rate_bpm": -1.0, "heart_rate_bpm": -0.5},
},
"give_pressor": {
"baseline_stable": {"systolic_bp_mmhg": 0.5, "diastolic_bp_mmhg": 0.2, "heart_rate_bpm": 3.0},
"respiratory_distress": {"systolic_bp_mmhg": 1.0, "diastolic_bp_mmhg": 0.5, "heart_rate_bpm": 2.0},
"hemorrhagic_shock": {"systolic_bp_mmhg": 4.0, "diastolic_bp_mmhg": 2.0, "heart_rate_bpm": -1.0},
},
"needle_decompression": {
"baseline_stable": {"systolic_bp_mmhg": -1.0, "heart_rate_bpm": 1.0},
"respiratory_distress": {"spo2": 0.06, "respiration_rate_bpm": -6.0, "heart_rate_bpm": -4.0},
"hemorrhagic_shock": {"systolic_bp_mmhg": -1.0, "heart_rate_bpm": 0.5},
},
"pericardiocentesis": {
"baseline_stable": {"systolic_bp_mmhg": -2.0, "diastolic_bp_mmhg": -1.0, "heart_rate_bpm": 2.0},
"respiratory_distress": {"systolic_bp_mmhg": -1.0, "diastolic_bp_mmhg": -0.5, "heart_rate_bpm": 1.0},
"hemorrhagic_shock": {"systolic_bp_mmhg": -1.0, "diastolic_bp_mmhg": -0.5, "heart_rate_bpm": 1.0},
},
}
class PatientBackend(ABC):
"""Stable interface between Person 2's stack and the backend runtime."""
@abstractmethod
def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse:
"""Reset the environment and return the initial response."""
@abstractmethod
def step(self, action: ToolAction) -> EnvironmentResponse:
"""Apply one action and return the next response."""
@abstractmethod
def get_state(self) -> PatientState:
"""Return the latest patient state."""
class MockPulseAdapter(PatientBackend):
"""Deterministic backend used by Person 2 before real Pulse integration exists."""
def __init__(
self,
default_scenario_id: str = DEFAULT_MOCK_SCENARIO_ID,
*,
observation_noise_level: float = 0.0,
time_pressure_enabled: bool = False,
time_pressure_onset_s: float = 180.0,
time_pressure_escalation_per_minute: float = 0.15,
seed: int | None = None,
):
self._default_scenario_id = default_scenario_id
self._scenario: MockScenarioDefinition | None = None
self._state: PatientState | None = None
self._step_count = 0
self._active_supports: set[str] = set()
self._tool_counts: dict[str, int] = {}
self._last_tool_name: str | None = None
self._same_tool_called_consecutively = 0
self._rng = random.Random(seed)
self._observation_noise = NoisyObservation(observation_noise_level)
self._time_pressure = TimePressureMechanic(
enabled=time_pressure_enabled,
onset_s=time_pressure_onset_s,
escalation_per_minute=time_pressure_escalation_per_minute,
)
self._episode_observation_rng = random.Random(self._rng.random())
def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse:
selected_scenario_id = scenario_id or self._default_scenario_id
if selected_scenario_id not in MOCK_SCENARIOS:
valid = ", ".join(sorted(MOCK_SCENARIOS))
raise ValueError(
f"Unknown mock scenario_id '{selected_scenario_id}'. Expected one of: {valid}"
)
scenario = MOCK_SCENARIOS[selected_scenario_id]
self._scenario = scenario
self._configure_runtime_effects(kwargs)
self._state = self._refresh_state(scenario.initial_state.model_copy(deep=True))
self._step_count = 0
self._active_supports = set()
self._tool_counts = {}
self._last_tool_name = None
self._same_tool_called_consecutively = 0
self._episode_observation_rng = random.Random(self._rng.random())
return self._build_response(
reward=0.0,
tool_result=ToolResult(
tool_name="load_scenario",
success=True,
message=f"Scenario '{scenario.scenario_id}' loaded.",
state_changed=True,
changed_fields=list(self._state.model_dump().keys()),
),
)
def step(self, action: ToolAction) -> EnvironmentResponse:
if self._state is None or self._scenario is None:
return self._error_response(
code="NOT_INITIALIZED",
message="Call reset() before step().",
retryable=True,
tool_name=action.tool_name,
)
canonical_tool_name = canonicalize_tool_name(action.tool_name, allowed_tools=list(KNOWN_TOOL_NAMES))
action = action.model_copy(update={"tool_name": canonical_tool_name})
if action.tool_name not in KNOWN_TOOL_NAMES:
return self._error_response(
code="UNKNOWN_TOOL",
message=f"Unsupported tool '{action.tool_name}'.",
retryable=False,
tool_name=action.tool_name,
)
try:
normalized_arguments = validate_tool_arguments(
action.tool_name,
action.arguments,
allowed_tools=KNOWN_TOOL_NAMES,
)
except ToolValidationError as exc:
return self._error_response(
code="INVALID_ARGUMENT",
message=str(exc),
retryable=False,
tool_name=action.tool_name,
)
action = action.model_copy(update={"arguments": normalized_arguments})
previous_state = self._state.model_copy(deep=True)
self._step_count += 1
if action.tool_name == "advance_time":
result = self._advance_time(action)
elif action.tool_name in MOCK_READ_ONLY_TOOLS:
result = self._read_only_tool(action.tool_name)
else:
result = self._apply_intervention(action)
if result.error is not None:
return result
self._state = self._refresh_state(self._state)
changed_fields = self._changed_fields(previous_state, self._state)
tool_usage_count = self._tool_counts.get(action.tool_name, 0) + 1
if self._last_tool_name == action.tool_name:
self._same_tool_called_consecutively += 1
else:
self._last_tool_name = action.tool_name
self._same_tool_called_consecutively = 1
self._tool_counts[action.tool_name] = tool_usage_count
reward = compute_reward(
previous_state,
self._state,
action.tool_name,
self._scenario.recommended_actions,
tool_usage_count=tool_usage_count,
same_tool_called_consecutively=self._same_tool_called_consecutively,
state_changed=bool(changed_fields),
time_pressure_multiplier=self._current_time_pressure_multiplier(self._state),
).total
tool_result = result.tool_result or ToolResult(
tool_name=action.tool_name,
success=True,
message=f"{action.tool_name} executed.",
state_changed=bool(changed_fields),
changed_fields=changed_fields,
)
tool_result.changed_fields = changed_fields
tool_result.state_changed = bool(changed_fields)
return self._build_response(reward=reward, tool_result=tool_result)
def get_state(self) -> PatientState:
if self._state is None:
raise RuntimeError("MockPulseAdapter has not been reset yet.")
return self._state.model_copy(deep=True)
def _advance_time(self, action: ToolAction) -> EnvironmentResponse:
assert self._state is not None
assert self._scenario is not None
seconds = float(action.arguments.get("seconds", 30))
if seconds <= 0:
return self._error_response(
code="INVALID_ARGUMENT",
message="seconds must be greater than 0",
retryable=False,
tool_name=action.tool_name,
)
scale = seconds / 30.0
updates = self._state.model_dump()
deterioration_multiplier = self._current_time_pressure_multiplier(self._state)
for field_name, delta in self._scenario.deterioration_per_30s.items():
adjusted_delta = self._deterioration_delta(field_name, delta) * deterioration_multiplier
current_value = updates.get(field_name)
if current_value is None:
continue
updates[field_name] = current_value + adjusted_delta * scale
updates["sim_time_s"] = self._state.sim_time_s + seconds
pending_diagnostics = dict(self._state.pending_diagnostics)
ready_diagnostics = list(self._state.ready_diagnostics)
for tool_name, remaining_seconds in list(pending_diagnostics.items()):
remaining_after_step = max(0, int(remaining_seconds - seconds))
if remaining_after_step <= 0:
pending_diagnostics.pop(tool_name, None)
if tool_name not in ready_diagnostics:
ready_diagnostics.append(tool_name)
else:
pending_diagnostics[tool_name] = remaining_after_step
updates["pending_diagnostics"] = pending_diagnostics
updates["ready_diagnostics"] = ready_diagnostics
self._state = PatientState(**updates)
return self._build_response(
reward=0.0,
tool_result=ToolResult(
tool_name=action.tool_name,
success=True,
message=f"Advanced simulation by {seconds:.0f} seconds.",
state_changed=True,
changed_fields=[],
),
)
def _apply_intervention(self, action: ToolAction) -> EnvironmentResponse:
assert self._state is not None
assert self._scenario is not None
if action.tool_name == "give_pressor" and action.arguments.get("stop") is True:
updates = self._state.model_dump()
self._apply_tool_side_effects(action, updates, effect_scale=0.0)
self._active_supports.discard("give_pressor")
self._state = PatientState(**updates)
return self._build_response(
reward=0.0,
tool_result=ToolResult(
tool_name=action.tool_name,
success=True,
message="Vasopressor support stopped.",
state_changed=True,
changed_fields=[],
),
)
effects = self._scenario.tool_effects.get(action.tool_name)
if effects is None:
effects = MOCK_EXTENDED_INTERVENTION_EFFECTS.get(action.tool_name, {}).get(self._scenario.scenario_id)
if effects is None:
return self._error_response(
code="UNSUPPORTED_IN_SCENARIO",
message=f"{action.tool_name} is not modeled for scenario '{self._scenario.scenario_id}'.",
retryable=False,
tool_name=action.tool_name,
)
updates = self._state.model_dump()
effect_scale = self._intervention_scale(action.tool_name)
for field_name, delta in effects.items():
current_value = updates.get(field_name)
if current_value is None:
continue
updates[field_name] = current_value + (delta * effect_scale)
self._apply_tool_side_effects(action, updates, effect_scale)
self._active_supports.add(action.tool_name)
self._state = PatientState(**updates)
return self._build_response(
reward=0.0,
tool_result=ToolResult(
tool_name=action.tool_name,
success=True,
message=self._tool_message(action.tool_name),
state_changed=True,
changed_fields=[],
),
)
def _intervention_scale(self, tool_name: str) -> float:
assert self._state is not None
alerts = set(self._state.active_alerts)
scale = 1.0
intervention_multiplier = self._time_pressure.intervention_effectiveness_multiplier(
sim_time_s=self._state.sim_time_s,
injury_severity=self._scenario.injury_severity,
unstable=self._is_state_unstable(self._state),
)
if tool_name == "control_bleeding":
if "blood_loss" in alerts:
scale = 1.0
elif "hypotension" in alerts:
scale = 0.5
else:
scale = 0.15
elif tool_name == "give_fluids":
if {"hypotension", "blood_loss"} & alerts:
scale = 1.0
elif "tachycardia" in alerts:
scale = 0.6
else:
scale = 0.2
elif tool_name == "give_oxygen":
if {"hypoxemia", "tachypnea"} & alerts:
scale = 1.0
elif "tachycardia" in alerts:
scale = 0.4
else:
scale = 0.15
elif tool_name == "position_patient":
if {"tachypnea", "hypotension"} & alerts:
scale = 1.0
else:
scale = 0.25
elif tool_name == "airway_support":
if {"hypoxemia", "tachypnea"} & alerts:
scale = 1.0
else:
scale = 0.2
if tool_name in self._active_supports:
scale *= 0.7
return scale * intervention_multiplier
def _read_only_tool(self, tool_name: str) -> EnvironmentResponse:
assert self._state is not None
assert self._scenario is not None
if tool_name == "summarize_state":
message = (
f"{self._scenario.description} HR {self._state.heart_rate_bpm:.0f}, "
f"BP {self._state.systolic_bp_mmhg:.0f}/{self._state.diastolic_bp_mmhg:.0f}, "
f"SpO2 {self._state.spo2:.2f}."
)
elif tool_name == "check_deterioration":
message = "Deterioration ongoing." if self._state.active_alerts else "Patient currently stable."
elif tool_name == "recommend_next_step":
message = f"Recommended next step: {self._scenario.recommended_actions[0]}."
elif tool_name == "get_respiratory_status":
message = (
f"Breath sounds {self._state.breath_sounds}, SpO2 {self._state.spo2:.2f}, "
f"RR {self._state.respiration_rate_bpm:.0f}, airway support {self._state.airway_support or 'none'}."
)
elif tool_name in MOCK_DIAGNOSTIC_DELAYS:
message = self._handle_diagnostic_read(tool_name)
else:
message = "Current vitals retrieved."
return self._build_response(
reward=0.0,
tool_result=ToolResult(
tool_name=tool_name,
success=True,
message=message,
state_changed=False,
changed_fields=[],
),
)
def _build_response(
self,
reward: float,
tool_result: ToolResult | None = None,
error: ToolError | None = None,
) -> EnvironmentResponse:
assert self._state is not None
available_tools = self._available_tools()
observed_state, runtime_metadata = self._build_observed_state()
return EnvironmentResponse(
observation=PulsePhysiologyObservation.from_patient_state(
observed_state,
reward=reward,
available_tools=available_tools,
tool_result=tool_result,
error=error,
metadata={
"step_count": self._step_count,
**runtime_metadata,
},
),
reward=reward,
done=observed_state.done,
metadata=ObservationMetadata(
step_count=self._step_count,
available_tools=available_tools,
),
tool_result=tool_result,
error=error,
)
def _error_response(
self,
code: str,
message: str,
retryable: bool,
tool_name: str,
) -> EnvironmentResponse:
state = self._state or MOCK_SCENARIOS[self._default_scenario_id].initial_state
available_tools = self._available_tools()
if self._state is not None:
observed_state, runtime_metadata = self._build_observed_state()
else:
observed_state = state
runtime_metadata = {
"observation_noise": {
"enabled": self._observation_noise.config.enabled,
"noise_level": round(self._observation_noise.config.normalized_level, 3),
"masked_fields": [],
"perturbed_fields": [],
},
"time_pressure": self._time_pressure.as_metadata(
sim_time_s=float(state.sim_time_s or 0.0),
injury_severity=self._scenario.injury_severity if self._scenario is not None else 0.0,
unstable=self._is_state_unstable(state),
),
}
return EnvironmentResponse(
observation=PulsePhysiologyObservation.from_patient_state(
observed_state,
reward=-1.0,
available_tools=available_tools,
error=ToolError(code=code, message=message, retryable=retryable),
metadata={
"step_count": self._step_count,
**runtime_metadata,
},
),
reward=-1.0,
done=observed_state.done,
metadata=ObservationMetadata(
step_count=self._step_count,
available_tools=available_tools,
),
tool_result=ToolResult(
tool_name=tool_name,
success=False,
message=message,
state_changed=False,
changed_fields=[],
),
error=ToolError(code=code, message=message, retryable=retryable),
)
def _deterioration_delta(self, field_name: str, delta: float) -> float:
if self._scenario is None:
return delta
if self._scenario.scenario_id == "respiratory_distress":
if field_name == "spo2" and "give_oxygen" in self._active_supports:
return delta * 0.25
if field_name == "respiration_rate_bpm" and "airway_support" in self._active_supports:
return delta * 0.4
if self._scenario.scenario_id == "hemorrhagic_shock":
if field_name == "blood_volume_ml" and "control_bleeding" in self._active_supports:
return delta * 0.1
if field_name in {"systolic_bp_mmhg", "diastolic_bp_mmhg"} and "give_fluids" in self._active_supports:
if "control_bleeding" in self._active_supports or "position_patient" in self._active_supports:
return delta * 0.15
return delta * 0.4
if field_name in {"systolic_bp_mmhg", "diastolic_bp_mmhg"} and "control_bleeding" in self._active_supports:
return delta * 0.6
if field_name == "heart_rate_bpm":
if {"control_bleeding", "give_fluids"} <= self._active_supports:
return delta * 0.25
if "control_bleeding" in self._active_supports or "give_fluids" in self._active_supports:
return delta * 0.5
if field_name == "respiration_rate_bpm":
if "give_oxygen" in self._active_supports and "position_patient" in self._active_supports:
return delta * 0.35
if "give_oxygen" in self._active_supports or "position_patient" in self._active_supports:
return delta * 0.6
if field_name == "spo2" and "give_oxygen" in self._active_supports:
return delta * 0.3
return delta
def _apply_tool_side_effects(
self,
action: ToolAction,
updates: dict,
effect_scale: float,
) -> None:
tool_name = action.tool_name
arguments = action.arguments
if tool_name == "give_oxygen":
updates["oxygen_device"] = str(arguments.get("device") or "nasal_cannula")
updates["oxygen_flow_lpm"] = float(arguments.get("flow_lpm", 15))
elif tool_name == "position_patient":
updates["position"] = str(arguments.get("position") or updates.get("position") or "upright")
elif tool_name == "airway_support":
requested_mode = str(arguments.get("mode") or arguments.get("support_type") or "auto")
normalized_mode = requested_mode.strip().lower().replace("-", "_").replace(" ", "_")
if normalized_mode in {"auto", "basic", "default", "standard", "support", "airway_support"}:
updates["airway_support"] = self._suggest_airway_support_mode()
else:
updates["airway_support"] = normalized_mode
elif tool_name == "give_fluids":
active_infusions = dict(updates.get("active_infusions") or {})
fluid_name = str(arguments.get("fluid_type") or arguments.get("fluid") or "saline")
active_infusions[fluid_name] = float(arguments.get("rate_ml_per_min", 100))
updates["active_infusions"] = active_infusions
elif tool_name == "give_pressor":
active_infusions = dict(updates.get("active_infusions") or {})
pressor_name = str(arguments.get("pressor") or arguments.get("agent") or "norepinephrine")
if arguments.get("stop") is True:
active_infusions.pop(pressor_name, None)
else:
active_infusions[pressor_name] = float(arguments.get("rate_ml_per_min", 5))
updates["active_infusions"] = active_infusions
elif tool_name == "needle_decompression":
updates["breath_sounds"] = "present bilateral"
elif tool_name == "pericardiocentesis":
updates["active_alerts"] = [
alert
for alert in updates.get("active_alerts", [])
if alert != "tamponade"
]
if tool_name == "needle_decompression" and self._scenario is not None:
if self._scenario.scenario_id == "respiratory_distress":
updates["spo2"] = min(1.0, float(updates.get("spo2") or 0.0) + (0.02 * effect_scale))
def _handle_diagnostic_read(self, tool_name: str) -> str:
assert self._state is not None
if tool_name in self._state.ready_diagnostics:
return self._diagnostic_result_message(tool_name)
if tool_name in self._state.pending_diagnostics:
remaining = self._state.pending_diagnostics[tool_name]
return f"{tool_name} is pending. {remaining} simulated seconds remaining before results are ready."
updates = self._state.model_dump()
pending_diagnostics = dict(self._state.pending_diagnostics)
pending_diagnostics[tool_name] = MOCK_DIAGNOSTIC_DELAYS[tool_name]
updates["pending_diagnostics"] = pending_diagnostics
self._state = PatientState(**updates)
return (
f"Ordered {tool_name}. Results will be ready after about "
f"{MOCK_DIAGNOSTIC_DELAYS[tool_name]} simulated seconds."
)
def _diagnostic_result_message(self, tool_name: str) -> str:
assert self._state is not None
updates = self._state.model_dump()
if tool_name == "get_blood_gas":
abg_result = self._build_abg_result()
updates["abg_result"] = abg_result.model_dump()
self._state = PatientState(**updates)
return (
f"ABG pH {abg_result.ph:.3f}, PaO2 {abg_result.partial_pressure_of_oxygen_mmhg:.1f} mmHg, "
f"PaCO2 {abg_result.partial_pressure_of_carbon_dioxide_mmhg:.1f} mmHg, "
f"lactate {abg_result.lactate_mg_per_dl:.1f} mg/dL."
)
if tool_name == "get_cbc":
cbc_result = self._build_cbc_result()
updates["cbc_result"] = cbc_result.model_dump()
self._state = PatientState(**updates)
return (
f"CBC hemoglobin {cbc_result.hemoglobin_g_per_dl:.1f} g/dL, "
f"hematocrit {cbc_result.hematocrit_fraction:.3f}, "
f"WBC {cbc_result.white_blood_cell_count_per_u_l:.0f} /uL."
)
bmp_result = self._build_bmp_result()
updates["bmp_result"] = bmp_result.model_dump()
self._state = PatientState(**updates)
return (
f"BMP sodium {bmp_result.sodium_mmol_per_l:.1f} mmol/L, "
f"potassium {bmp_result.potassium_mmol_per_l:.1f} mmol/L, "
f"creatinine {bmp_result.creatinine_mg_per_dl:.1f} mg/dL, "
f"glucose {bmp_result.glucose_mg_per_dl:.1f} mg/dL."
)
def _build_abg_result(self):
assert self._state is not None
from ..patient_state import ArterialBloodGasResult
spo2 = float(self._state.spo2 or 0.95)
systolic = float(self._state.systolic_bp_mmhg or 110.0)
ph = max(7.10, min(7.45, 7.40 - max(0.0, (95.0 - systolic) / 200.0)))
pao2 = max(45.0, min(110.0, 40.0 + spo2 * 60.0))
paco2 = max(28.0, min(60.0, 40.0 + max(0.0, (24.0 - float(self._state.respiration_rate_bpm or 16.0)) * 0.8)))
lactate = max(8.0, min(40.0, 10.0 + max(0.0, (100.0 - systolic) * 0.15)))
return ArterialBloodGasResult(
ph=round(ph, 3),
partial_pressure_of_oxygen_mmhg=round(pao2, 1),
partial_pressure_of_carbon_dioxide_mmhg=round(paco2, 1),
oxygen_saturation=round(spo2, 3),
bicarbonate_meq_per_l=24.0,
lactate_mg_per_dl=round(lactate, 1),
base_excess_meq_per_l=-2.0 if systolic < 95.0 else 0.0,
base_deficit_meq_per_l=2.0 if systolic < 95.0 else 0.0,
)
def _build_cbc_result(self):
assert self._state is not None
from ..patient_state import CompleteBloodCountResult
blood_volume = float(self._state.blood_volume_ml or 5400.0)
hemoglobin = max(7.5, min(15.0, 15.0 - max(0.0, (5400.0 - blood_volume) / 350.0)))
hematocrit = max(0.24, min(0.45, hemoglobin / 33.0))
return CompleteBloodCountResult(
hemoglobin_g_per_dl=round(hemoglobin, 1),
hematocrit_fraction=round(hematocrit, 3),
white_blood_cell_count_per_u_l=9000.0 if self._state.active_alerts else 7000.0,
platelet_count_per_u_l=250000.0,
red_blood_cell_count_per_u_l=4800000.0,
)
def _build_bmp_result(self):
assert self._state is not None
from ..patient_state import BasicMetabolicPanelResult
systolic = float(self._state.systolic_bp_mmhg or 110.0)
return BasicMetabolicPanelResult(
sodium_mmol_per_l=142.0 if systolic >= 95.0 else 145.0,
potassium_mmol_per_l=3.8 if systolic >= 95.0 else 4.2,
calcium_mmol_per_l=2.2,
creatinine_mg_per_dl=1.0 if systolic >= 95.0 else 1.4,
glucose_mg_per_dl=105.0 if not self._state.active_alerts else 128.0,
)
def _changed_fields(self, previous_state: PatientState, new_state: PatientState) -> list[str]:
changed_fields: list[str] = []
for field_name in new_state.model_fields:
if getattr(previous_state, field_name) != getattr(new_state, field_name):
changed_fields.append(field_name)
return changed_fields
def _refresh_state(self, state: PatientState) -> PatientState:
updates = state.model_dump()
updates["spo2"] = max(0.5, min(1.0, updates["spo2"]))
updates["heart_rate_bpm"] = max(20.0, updates["heart_rate_bpm"])
updates["systolic_bp_mmhg"] = max(40.0, updates["systolic_bp_mmhg"])
updates["diastolic_bp_mmhg"] = max(20.0, updates["diastolic_bp_mmhg"])
updates["respiration_rate_bpm"] = max(4.0, updates["respiration_rate_bpm"])
if updates["blood_volume_ml"] is not None:
updates["blood_volume_ml"] = max(2500.0, updates["blood_volume_ml"])
if updates["systolic_bp_mmhg"] is not None and updates["diastolic_bp_mmhg"] is not None:
updates["mean_arterial_pressure_mmhg"] = (
updates["systolic_bp_mmhg"] + 2 * updates["diastolic_bp_mmhg"]
) / 3.0
if updates["heart_rate_bpm"] is not None and updates["systolic_bp_mmhg"] not in (None, 0):
updates["shock_index"] = updates["heart_rate_bpm"] / updates["systolic_bp_mmhg"]
if self._scenario is not None and self._scenario.scenario_id == "respiratory_distress":
updates["breath_sounds"] = (
"present bilateral" if "needle_decompression" in self._active_supports else "diminished bilateral"
)
elif updates.get("breath_sounds") in (None, ""):
updates["breath_sounds"] = "present bilateral"
if self._scenario is not None and self._scenario.scenario_id == "hemorrhagic_shock":
if "control_bleeding" in self._active_supports:
flow_rate = 25.0
else:
blood_volume = float(updates["blood_volume_ml"] or 4700.0)
flow_rate = max(0.0, min(180.0, (5200.0 - blood_volume) * 0.4 + 60.0))
updates["active_hemorrhages"] = {"right_leg": round(flow_rate, 1)} if flow_rate > 5.0 else {}
else:
updates["active_hemorrhages"] = {}
if updates.get("systolic_bp_mmhg", 110.0) < 95.0 or updates.get("blood_volume_ml", 5500.0) < 5000.0:
updates["lactate_trend"] = "worsening"
elif self._active_supports & {"give_fluids", "control_bleeding", "give_oxygen", "needle_decompression"}:
updates["lactate_trend"] = "improving"
else:
updates["lactate_trend"] = "stable"
alerts: list[str] = []
if updates["spo2"] < 0.92:
alerts.append("hypoxemia")
if updates["heart_rate_bpm"] > 110:
alerts.append("tachycardia")
if updates["systolic_bp_mmhg"] < 95:
alerts.append("hypotension")
if updates["blood_volume_ml"] is not None and updates["blood_volume_ml"] < 5000:
alerts.append("blood_loss")
if updates["respiration_rate_bpm"] >= 24:
alerts.append("tachypnea")
if updates.get("shock_index") is not None and updates["shock_index"] >= 0.9:
alerts.append("shock_index_elevated")
if updates["systolic_bp_mmhg"] < 70 or updates["spo2"] < 0.75:
alerts.append("cardiovascular_collapse")
updates["active_alerts"] = alerts
updates["mental_status"] = self._derive_mental_status(updates["spo2"], updates["systolic_bp_mmhg"])
updates["done"] = "cardiovascular_collapse" in alerts
return PatientState(**updates)
def _available_tools(self) -> list[str]:
return list(KNOWN_TOOL_NAMES)
def _suggest_airway_support_mode(self) -> str:
assert self._state is not None
if self._state.spo2 is not None and self._state.spo2 < 0.85:
return "bag_valve_mask"
if self._state.spo2 is not None and self._state.spo2 < 0.9:
return "cpap"
if self._state.mental_status in {"pain", "unresponsive"}:
return "tracheal"
if self._state.mental_status == "verbal":
return "oropharyngeal"
return "nasopharyngeal"
def _derive_mental_status(self, spo2: float, systolic_bp_mmhg: float) -> str:
if spo2 < 0.75 or systolic_bp_mmhg < 60:
return "unresponsive"
if spo2 < 0.82 or systolic_bp_mmhg < 70:
return "pain"
if spo2 < 0.88 or systolic_bp_mmhg < 85:
return "verbal"
return "alert"
def _tool_message(self, tool_name: str) -> str:
if tool_name == "give_oxygen":
return "Supplemental oxygen started."
if tool_name == "give_fluids":
return "Fluid resuscitation initiated."
if tool_name == "control_bleeding":
return "Bleeding control measures applied."
if tool_name == "position_patient":
return "Patient repositioned for support."
if tool_name == "airway_support":
return "Airway support applied."
if tool_name == "give_pressor":
return "Vasopressor support initiated."
if tool_name == "needle_decompression":
return "Needle decompression performed."
if tool_name == "pericardiocentesis":
return "Pericardiocentesis performed."
return f"{tool_name} executed."
def _configure_runtime_effects(self, kwargs: dict[str, object]) -> None:
observation_noise_level = float(kwargs.get("observation_noise_level", self._observation_noise.config.noise_level))
raw_time_pressure_enabled = kwargs.get("time_pressure_enabled", self._time_pressure.config.enabled)
if isinstance(raw_time_pressure_enabled, str):
time_pressure_enabled = coerce_boolean_argument(raw_time_pressure_enabled)
else:
time_pressure_enabled = bool(raw_time_pressure_enabled)
time_pressure_onset_s = float(kwargs.get("time_pressure_onset_s", self._time_pressure.config.onset_s))
time_pressure_escalation_per_minute = float(
kwargs.get(
"time_pressure_escalation_per_minute",
self._time_pressure.config.escalation_per_minute,
)
)
self._observation_noise = NoisyObservation(observation_noise_level)
self._time_pressure = TimePressureMechanic(
enabled=time_pressure_enabled,
onset_s=time_pressure_onset_s,
escalation_per_minute=time_pressure_escalation_per_minute,
)
def _build_observed_state(self) -> tuple[PatientState, dict[str, object]]:
assert self._state is not None
observed_state, noise_metadata = self._observation_noise.apply(
self._state,
rng=self._episode_observation_rng,
)
time_pressure_metadata = self._time_pressure.as_metadata(
sim_time_s=self._state.sim_time_s,
injury_severity=self._scenario.injury_severity if self._scenario is not None else 0.0,
unstable=self._is_state_unstable(self._state),
)
return observed_state, {
"observation_noise": noise_metadata,
"time_pressure": time_pressure_metadata,
}
def _current_time_pressure_multiplier(self, state: PatientState) -> float:
assert self._scenario is not None
return self._time_pressure.deterioration_multiplier(
sim_time_s=state.sim_time_s,
injury_severity=self._scenario.injury_severity,
unstable=self._is_state_unstable(state),
)
@staticmethod
def _is_state_unstable(state: PatientState) -> bool:
systolic = state.systolic_bp_mmhg if state.systolic_bp_mmhg is not None else 120.0
spo2 = state.spo2 if state.spo2 is not None else 1.0
return bool(state.active_alerts) or systolic < 95.0 or spo2 < 0.92 or state.mental_status != "alert"