Pulse_ER_env / episode_runner.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Episode runner and trajectory logging for Pulse-ER policies."""
from __future__ import annotations
import time
from dataclasses import dataclass
from enum import Enum
from .models import EnvironmentResponse, PulsePhysiologyObservation, ToolAction
from .policies import Policy
from .server.adapters import PatientBackend
from .tool_availability import ToolAvailabilityError, validate_tool_availability
class EpisodeTerminationReason(str, Enum):
"""Why an episode stopped from the consumer-side runner's perspective."""
PATIENT_DEATH = "patient_death"
MAX_TIMESTEPS = "max_timesteps"
FATAL_BACKEND_ERROR = "fatal_backend_error"
BUDGET_EXHAUSTED = "budget_exhausted"
@dataclass(frozen=True)
class EpisodeStep:
"""One action/result pair in a trajectory."""
step_index: int
action: ToolAction
reward: float
done: bool
observation: PulsePhysiologyObservation
tool_result: dict | None
error: dict | None
@dataclass(frozen=True)
class EpisodeTrace:
"""End-to-end record of one episode."""
scenario_id: str
policy_name: str
initial_observation: PulsePhysiologyObservation
steps: tuple[EpisodeStep, ...]
total_reward: float
final_observation: PulsePhysiologyObservation
termination_reason: EpisodeTerminationReason
action_budget_remaining: int | None = None
events: tuple[str, ...] = ()
@property
def num_steps(self) -> int:
return len(self.steps)
def summary(self) -> dict:
"""Compact episode summary for CLI tools and logging."""
mental_status = self.final_observation.mental_status
mental_status_value = getattr(mental_status, "value", mental_status)
return {
"scenario_id": self.scenario_id,
"policy_name": self.policy_name,
"num_steps": self.num_steps,
"total_reward": round(self.total_reward, 3),
"done": self.final_observation.done,
"termination_reason": self.termination_reason.value,
"action_budget_remaining": self.action_budget_remaining,
"sim_time_s": self.final_observation.sim_time_s,
"heart_rate_bpm": self.final_observation.heart_rate_bpm,
"systolic_bp_mmhg": self.final_observation.systolic_bp_mmhg,
"diastolic_bp_mmhg": self.final_observation.diastolic_bp_mmhg,
"spo2": self.final_observation.spo2,
"spo2_percent": round(self.final_observation.spo2 * 100, 1)
if self.final_observation.spo2 is not None
else None,
"respiration_rate_bpm": self.final_observation.respiration_rate_bpm,
"blood_volume_ml": self.final_observation.blood_volume_ml,
"mental_status": mental_status_value,
"active_alerts": self.final_observation.active_alerts,
}
@dataclass
class EpisodeRunner:
"""Reusable runner for executing policies against a backend."""
backend: PatientBackend
max_steps: int = 8
action_budget: int | None = None
max_retry_attempts: int = 3
retry_backoff_s: float = 0.01
@staticmethod
def _is_terminal_observation(observation: PulsePhysiologyObservation) -> bool:
"""Detect terminal physiology from the observation even before `done` is set.
The real Pulse runtime may surface cardiac-arrest state through alerts
before a subsequent tool call flips `done=True`. Detecting that here
avoids wasting retry attempts or extra advance-time calls in obviously
terminal states.
"""
if observation.done:
return True
active_alerts = set(observation.active_alerts or [])
return "cardiac_arrest" in active_alerts
def run(self, policy: Policy, scenario_id: str) -> EpisodeTrace:
"""Execute one episode and capture its trajectory."""
reset_result = self.backend.reset(scenario_id)
policy.reset(scenario_id)
current_observation = reset_result.observation
total_reward = reset_result.reward
steps: list[EpisodeStep] = []
events: list[str] = []
remaining_action_budget = self.action_budget
termination_reason = EpisodeTerminationReason.MAX_TIMESTEPS
if self._is_terminal_observation(current_observation):
termination_reason = EpisodeTerminationReason.PATIENT_DEATH
for step_index in range(self.max_steps):
if remaining_action_budget is not None and remaining_action_budget <= 0:
events.append("Action budget exhausted before the next decision could be made.")
termination_reason = EpisodeTerminationReason.BUDGET_EXHAUSTED
break
if self._is_terminal_observation(current_observation):
if "cardiac_arrest" in set(current_observation.active_alerts or []):
events.append(
"Terminal physiology detected from observation state; ending episode without another backend step."
)
termination_reason = EpisodeTerminationReason.PATIENT_DEATH
break
try:
validate_tool_availability(current_observation.available_tools)
action = policy.select_action(current_observation)
except ToolAvailabilityError as exc:
events.append(f"Fatal backend error: {exc}")
termination_reason = EpisodeTerminationReason.FATAL_BACKEND_ERROR
break
result = self._step_with_retry(action, events)
if remaining_action_budget is not None:
remaining_action_budget -= 1
total_reward += result.reward
observe_outcome = getattr(policy, "observe_outcome", None)
if callable(observe_outcome):
observe_outcome(action, result)
steps.append(self._to_step(step_index, action, result))
current_observation = result.observation
if self._is_terminal_observation(current_observation):
if "cardiac_arrest" in set(current_observation.active_alerts or []):
events.append(
"Terminal physiology detected after step result; ending episode without further retries."
)
termination_reason = EpisodeTerminationReason.PATIENT_DEATH
break
if result.error is not None:
if result.error.retryable:
events.append(
f"Retryable error persisted for {action.tool_name}; action skipped after "
f"{self.max_retry_attempts} attempts."
)
continue
events.append(
f"Fatal backend error from {action.tool_name}: {result.error.code} - {result.error.message}"
)
termination_reason = EpisodeTerminationReason.FATAL_BACKEND_ERROR
break
else:
termination_reason = EpisodeTerminationReason.MAX_TIMESTEPS
return EpisodeTrace(
scenario_id=reset_result.observation.scenario_id,
policy_name=policy.name,
initial_observation=reset_result.observation,
steps=tuple(steps),
total_reward=round(total_reward, 3),
final_observation=current_observation,
termination_reason=termination_reason,
action_budget_remaining=remaining_action_budget,
events=tuple(events),
)
def _step_with_retry(
self,
action: ToolAction,
events: list[str],
) -> EnvironmentResponse:
"""Execute one backend step with bounded retries for transient failures."""
result: EnvironmentResponse | None = None
for attempt_index in range(1, self.max_retry_attempts + 1):
result = self.backend.step(action)
if self._is_terminal_observation(result.observation):
events.append(
f"Terminal physiology detected after {action.tool_name}; skipping further retries."
)
return result
if result.error is None or not result.error.retryable:
return result
events.append(
f"Retryable error on {action.tool_name} attempt {attempt_index}/{self.max_retry_attempts}: "
f"{result.error.code} - {result.error.message}"
)
if attempt_index < self.max_retry_attempts and self.retry_backoff_s > 0:
time.sleep(self.retry_backoff_s)
assert result is not None
return result
def _to_step(
self,
step_index: int,
action: ToolAction,
result: EnvironmentResponse,
) -> EpisodeStep:
return EpisodeStep(
step_index=step_index,
action=action.model_copy(deep=True),
reward=result.reward,
done=result.done,
observation=result.observation,
tool_result=result.tool_result.model_dump() if result.tool_result else None,
error=result.error.model_dump() if result.error else None,
)