"""Action, observation, and tool-result models for the Pulse environment.""" from __future__ import annotations from typing import Any from pydantic import BaseModel, ConfigDict, Field try: from openenv.core.env_server.types import Action, Observation except ImportError: # pragma: no cover - allows local mock development without openenv class Action(BaseModel): """Fallback action base model for local development.""" class Observation(BaseModel): """Fallback observation base model for local development.""" from .patient_state import ( ArterialBloodGasResult, BasicMetabolicPanelResult, CompleteBloodCountResult, LactateTrend, MentalStatus, PatientState, ScenarioDifficulty, ) from .tool_catalog import EXTENDED_TOOL_NAMES, INITIAL_TOOL_NAMES, KNOWN_TOOL_NAMES class PulsePhysiologyAction(Action): """Tool invocation sent from the agent or client.""" tool_name: str = Field(..., description="Name of the tool to execute.") arguments: dict[str, Any] = Field( default_factory=dict, description="Structured arguments for the named tool.", ) reasoning: str | None = Field( default=None, description="Optional human-readable rationale for the action.", ) ToolAction = PulsePhysiologyAction class ToolResult(BaseModel): """Structured result returned for every handled tool call.""" model_config = ConfigDict(extra="forbid") tool_name: str success: bool message: str state_changed: bool changed_fields: list[str] = Field(default_factory=list) class ToolError(BaseModel): """Structured error attached to failed tool calls.""" model_config = ConfigDict(extra="forbid") code: str message: str retryable: bool class ObservationMetadata(BaseModel): """Non-state metadata returned alongside each step response.""" step_count: int = Field(default=0, ge=0, description="Episode step count") available_tools: list[str] = Field( default_factory=lambda: list(KNOWN_TOOL_NAMES), description="Tools the environment currently exposes", ) class PulsePhysiologyObservation(Observation): """Observation returned by the Pulse physiology environment.""" scenario_id: str = Field(default="baseline") scenario_difficulty: ScenarioDifficulty = Field(default="medium") patient_id: str = Field(default="standard_male") sim_time_s: float = Field(default=0.0) heart_rate_bpm: float | None = Field(default=None) systolic_bp_mmhg: float | None = Field(default=None) diastolic_bp_mmhg: float | None = Field(default=None) mean_arterial_pressure_mmhg: float | None = Field(default=None) cardiac_output_l_per_min: float | None = Field(default=None) spo2: float | None = Field(default=None) respiration_rate_bpm: float | None = Field(default=None) blood_volume_ml: float | None = Field(default=None) mental_status: MentalStatus = Field(default="alert") active_alerts: list[str] = Field(default_factory=list) etco2_mmhg: float | None = Field(default=None) tidal_volume_ml: float | None = Field(default=None) breath_sounds: str = Field(default="present bilateral") core_temperature_c: float | None = Field(default=None) shock_index: float | None = Field(default=None) lactate_trend: LactateTrend = Field(default="stable") position: str = Field(default="supine") oxygen_device: str | None = Field(default=None) oxygen_flow_lpm: float | None = Field(default=None) airway_support: str | None = Field(default=None) intubated: bool = Field(default=False) abg_result: ArterialBloodGasResult = Field(default_factory=ArterialBloodGasResult) cbc_result: CompleteBloodCountResult = Field(default_factory=CompleteBloodCountResult) bmp_result: BasicMetabolicPanelResult = Field(default_factory=BasicMetabolicPanelResult) pending_diagnostics: dict[str, int] = Field(default_factory=dict) ready_diagnostics: list[str] = Field(default_factory=list) active_infusions: dict[str, float] = Field(default_factory=dict) active_hemorrhages: dict[str, float] = Field(default_factory=dict) available_tools: list[str] = Field(default_factory=list) tool_result: ToolResult | None = Field(default=None) error: ToolError | None = Field(default=None) reward: float | None = Field(default=None) done: bool = Field(default=False) metadata: dict[str, Any] = Field(default_factory=dict) @classmethod def from_patient_state( cls, state: PatientState, *, reward: float | None = None, available_tools: list[str] | None = None, tool_result: ToolResult | None = None, error: ToolError | None = None, metadata: dict[str, Any] | None = None, ) -> "PulsePhysiologyObservation": payload = state.model_dump() payload.update( reward=reward, done=state.done, available_tools=available_tools or [], tool_result=tool_result, error=error, metadata=metadata or {}, ) return cls(**payload) class EnvironmentResponse(BaseModel): """Canonical step envelope for the mock and real adapters.""" observation: PulsePhysiologyObservation reward: float done: bool metadata: ObservationMetadata = Field(default_factory=ObservationMetadata) tool_result: ToolResult | None = Field(default=None) error: ToolError | None = Field(default=None)