Spaces:
Running
Running
| """Central tool registry and validation helpers for Pulse-ER.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import re | |
| from typing import Any | |
| class ToolArgumentSpec: | |
| """Schema metadata for one structured tool argument.""" | |
| name: str | |
| description: str | |
| required: bool = False | |
| numeric: bool = False | |
| boolean: bool = False | |
| minimum: float | None = None | |
| choices: tuple[str, ...] = () | |
| class ToolSpec: | |
| """Single source of truth for a supported tool.""" | |
| tool_name: str | |
| tier: str | |
| description: str | |
| read_only: bool | |
| state_changing: bool | |
| arguments: tuple[ToolArgumentSpec, ...] = () | |
| class ToolValidationError(ValueError): | |
| """Raised when a tool call payload violates the frozen contract.""" | |
| _NUMERIC_PREFIX_RE = re.compile(r"^\s*([-+]?\d+(?:\.\d+)?)") | |
| _TRUE_TOKENS = {"1", "true", "yes", "y", "on"} | |
| _FALSE_TOKENS = {"0", "false", "no", "n", "off"} | |
| def normalize_contract_token(value: Any) -> str: | |
| """Normalize contract strings so harmless formatting differences do not fail closed.""" | |
| return str(value).strip().lower().replace("-", "_").replace(" ", "_") | |
| def coerce_numeric_argument(value: Any) -> float: | |
| """Coerce model-emitted numeric strings like ``2LPM`` into floats. | |
| The policy layer sometimes emits recoverable formatting artifacts such as | |
| ``"2LPM"`` or ``"500 ml"``. These should be treated as parsing noise | |
| rather than hard clinical failures, as long as a leading numeric value is | |
| still obvious and unambiguous. | |
| """ | |
| if isinstance(value, bool): | |
| raise TypeError("boolean values are not valid numeric tool arguments") | |
| if isinstance(value, (int, float)): | |
| return float(value) | |
| if isinstance(value, str): | |
| compact = value.strip().replace(",", "") | |
| try: | |
| return float(compact) | |
| except ValueError: | |
| match = _NUMERIC_PREFIX_RE.match(compact) | |
| if match is not None: | |
| return float(match.group(1)) | |
| raise ValueError("numeric coercion failed") | |
| def coerce_boolean_argument(value: Any) -> bool: | |
| """Coerce common boolean-like values emitted by models into real booleans.""" | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, (int, float)) and value in {0, 1}: | |
| return bool(value) | |
| if isinstance(value, str): | |
| token = value.strip().lower() | |
| if token in _TRUE_TOKENS: | |
| return True | |
| if token in _FALSE_TOKENS: | |
| return False | |
| raise ValueError("boolean coercion failed") | |
| def normalize_semantic_argument(tool_name: str, arg_name: str, value: Any) -> Any: | |
| """Recover common non-clinical aliases before validation fails closed.""" | |
| if not isinstance(value, str): | |
| return value | |
| token = normalize_contract_token(value) | |
| if tool_name == "airway_support" and arg_name in {"mode", "support_type"}: | |
| if token in {"basic", "default", "standard", "support", "airway_support"}: | |
| return "auto" | |
| return value | |
| TOOL_SPECS: tuple[ToolSpec, ...] = ( | |
| ToolSpec( | |
| tool_name="get_vitals", | |
| tier="tier_1", | |
| description="Read the current core vitals without changing patient state.", | |
| read_only=True, | |
| state_changing=False, | |
| ), | |
| ToolSpec( | |
| tool_name="advance_time", | |
| tier="tier_2", | |
| description="Advance the simulation clock to observe ongoing physiology.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="seconds", | |
| description="Number of seconds to advance the simulation.", | |
| required=True, | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="give_oxygen", | |
| tier="tier_2", | |
| description="Provide supplemental oxygen to improve oxygenation.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="flow_lpm", | |
| description="Oxygen flow rate in liters per minute.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="device", | |
| description="Optional oxygen delivery device such as nasal_cannula or non_rebreather_mask.", | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after oxygen is applied.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="give_fluids", | |
| tier="tier_2", | |
| description="Administer IV fluids to support perfusion and blood pressure.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="volume_ml", | |
| description="Fluid bolus volume in milliliters.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="fluid_type", | |
| description="Optional fluid selection such as saline, blood, or packed_rbc.", | |
| ), | |
| ToolArgumentSpec( | |
| name="fluid", | |
| description="Alias for fluid_type when a shorter key is used.", | |
| ), | |
| ToolArgumentSpec( | |
| name="bag_volume_ml", | |
| description="Alias for total infused volume in milliliters.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="rate_ml_per_min", | |
| description="Infusion rate in milliliters per minute.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after the infusion starts.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="control_bleeding", | |
| tier="tier_2", | |
| description="Apply bleeding control measures for active hemorrhage.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="site", | |
| description="Optional hemorrhage site when more than one bleed is active.", | |
| ), | |
| ToolArgumentSpec( | |
| name="compartment", | |
| description="Alias for the hemorrhage site or compartment.", | |
| ), | |
| ToolArgumentSpec( | |
| name="method", | |
| description="Bleeding control technique to use.", | |
| choices=("tourniquet", "pressure", "hemostatic_dressing"), | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after control is applied.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="position_patient", | |
| tier="tier_2", | |
| description="Reposition the patient to support breathing or perfusion.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="position", | |
| description="Target position when explicitly specified, such as supine or upright.", | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="airway_support", | |
| tier="tier_2", | |
| description="Provide airway support to improve ventilation.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="mode", | |
| description="Optional airway support mode such as cpap, bag_valve_mask, or pressure_control_ventilation.", | |
| ), | |
| ToolArgumentSpec( | |
| name="support_type", | |
| description="Alias for the airway support mode or airway intervention type.", | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after airway support is applied.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="fio2", | |
| description="Fraction of inspired oxygen as a 0-1 value.", | |
| numeric=True, | |
| minimum=0.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="fraction_inspired_oxygen", | |
| description="Alias for fio2.", | |
| numeric=True, | |
| minimum=0.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="peep_cmh2o", | |
| description="Positive end-expiratory pressure in cmH2O.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="peep", | |
| description="Alias for peep_cmh2o.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="respiration_rate_bpm", | |
| description="Ventilatory rate in breaths per minute.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="rate_bpm", | |
| description="Alias for respiration_rate_bpm.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="ie_ratio", | |
| description="Inspiratory-to-expiratory ratio.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="inspiratory_expiratory_ratio", | |
| description="Alias for ie_ratio.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="squeeze_pressure_cmh2o", | |
| description="Bag-valve-mask squeeze pressure in cmH2O.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="pressure_cmh2o", | |
| description="Alias for squeeze or inspiratory pressure in cmH2O.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="squeeze_volume_ml", | |
| description="Bag-valve-mask squeeze volume in milliliters.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="tidal_volume_ml", | |
| description="Alias for squeeze_volume_ml.", | |
| numeric=True, | |
| minimum=1.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="airway_adjunct", | |
| description="Optional airway adjunct such as oropharyngeal or nasopharyngeal.", | |
| ), | |
| ToolArgumentSpec( | |
| name="pressure_support_cmh2o", | |
| description="Pressure support level in cmH2O.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="pressure_support", | |
| description="Alias for pressure_support_cmh2o.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="inspiratory_pressure_cmh2o", | |
| description="Inspiratory pressure target in cmH2O.", | |
| numeric=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="inspiratory_period_s", | |
| description="Inspiratory period in seconds.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="summarize_state", | |
| tier="tier_1", | |
| description="Summarize the current patient state in concise clinical language.", | |
| read_only=True, | |
| state_changing=False, | |
| ), | |
| ToolSpec( | |
| tool_name="check_deterioration", | |
| tier="tier_1", | |
| description="Assess whether the patient is currently worsening.", | |
| read_only=True, | |
| state_changing=False, | |
| ), | |
| ToolSpec( | |
| tool_name="recommend_next_step", | |
| tier="tier_3", | |
| description="Recommend the most appropriate next intervention or assessment.", | |
| read_only=True, | |
| state_changing=False, | |
| ), | |
| ToolSpec( | |
| tool_name="give_pressor", | |
| tier="tier_2", | |
| description="Start, titrate, or stop a vasopressor infusion for refractory hypotension.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="pressor", | |
| description="Pressor agent name such as norepinephrine or phenylephrine.", | |
| ), | |
| ToolArgumentSpec( | |
| name="agent", | |
| description="Alias for the pressor agent name.", | |
| ), | |
| ToolArgumentSpec( | |
| name="rate_ml_per_min", | |
| description="Infusion rate in milliliters per minute.", | |
| numeric=True, | |
| minimum=0.0, | |
| ), | |
| ToolArgumentSpec( | |
| name="concentration_ug_per_ml", | |
| description="Pressor concentration in micrograms per milliliter.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="stop", | |
| description="Set to true to stop the current infusion.", | |
| boolean=True, | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after the pressor change.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="needle_decompression", | |
| tier="tier_2", | |
| description="Perform needle decompression when a tension pneumothorax is suspected.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="side", | |
| description="Target side for decompression.", | |
| choices=("left", "right"), | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after decompression.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="pericardiocentesis", | |
| tier="tier_2", | |
| description="Perform pericardiocentesis when tamponade physiology is suspected.", | |
| read_only=False, | |
| state_changing=True, | |
| arguments=( | |
| ToolArgumentSpec( | |
| name="drain_rate_ml_per_min", | |
| description="Drainage rate in milliliters per minute.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="rate_ml_per_min", | |
| description="Alias for drain_rate_ml_per_min.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ToolArgumentSpec( | |
| name="monitor_seconds", | |
| description="Optional observation window after pericardiocentesis.", | |
| numeric=True, | |
| minimum=0.1, | |
| ), | |
| ), | |
| ), | |
| ToolSpec( | |
| tool_name="get_respiratory_status", | |
| tier="tier_1", | |
| description="Read a respiratory-focused bedside summary including breath sounds and EtCO2.", | |
| read_only=True, | |
| state_changing=False, | |
| ), | |
| ToolSpec( | |
| tool_name="get_blood_gas", | |
| tier="tier_1", | |
| description="Order or review arterial blood gas results.", | |
| read_only=False, | |
| state_changing=True, | |
| ), | |
| ToolSpec( | |
| tool_name="get_cbc", | |
| tier="tier_1", | |
| description="Order or review complete blood count results.", | |
| read_only=False, | |
| state_changing=True, | |
| ), | |
| ToolSpec( | |
| tool_name="get_bmp", | |
| tier="tier_1", | |
| description="Order or review basic metabolic panel results.", | |
| read_only=False, | |
| state_changing=True, | |
| ), | |
| ) | |
| TOOL_SPEC_BY_NAME: dict[str, ToolSpec] = { | |
| spec.tool_name: spec for spec in TOOL_SPECS | |
| } | |
| INITIAL_TOOL_NAMES = [spec.tool_name for spec in TOOL_SPECS[:10]] | |
| EXTENDED_TOOL_NAMES = [spec.tool_name for spec in TOOL_SPECS] | |
| KNOWN_TOOL_NAMES = EXTENDED_TOOL_NAMES | |
| def canonicalize_tool_name(tool_name: str, allowed_tools: list[str] | None = None) -> str: | |
| """Map harmless casing or separator differences onto the canonical tool name.""" | |
| valid_tools = allowed_tools or list(KNOWN_TOOL_NAMES) | |
| if tool_name in valid_tools: | |
| return tool_name | |
| normalized = normalize_contract_token(tool_name) | |
| matches = [candidate for candidate in valid_tools if normalize_contract_token(candidate) == normalized] | |
| if len(matches) == 1: | |
| return matches[0] | |
| return tool_name | |
| def get_tool_spec(tool_name: str) -> ToolSpec: | |
| """Return the registry entry for one tool name.""" | |
| tool_name = canonicalize_tool_name(tool_name) | |
| try: | |
| return TOOL_SPEC_BY_NAME[tool_name] | |
| except KeyError as exc: | |
| raise ToolValidationError(f"Unsupported tool_name '{tool_name}'.") from exc | |
| def build_tool_catalog(available_tools: list[str] | None = None) -> list[dict[str, Any]]: | |
| """Build a prompt-safe catalog of supported tools and arguments.""" | |
| catalog: list[dict[str, Any]] = [] | |
| for tool_name in available_tools or KNOWN_TOOL_NAMES: | |
| spec = get_tool_spec(tool_name) | |
| catalog.append( | |
| { | |
| "tool_name": spec.tool_name, | |
| "tier": spec.tier, | |
| "description": spec.description, | |
| "read_only": spec.read_only, | |
| "state_changing": spec.state_changing, | |
| "arguments": [ | |
| { | |
| "name": arg.name, | |
| "description": arg.description, | |
| "required": arg.required, | |
| "numeric": arg.numeric, | |
| "boolean": arg.boolean, | |
| "minimum": arg.minimum, | |
| "choices": list(arg.choices), | |
| } | |
| for arg in spec.arguments | |
| ], | |
| } | |
| ) | |
| return catalog | |
| def validate_tool_arguments( | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| *, | |
| allowed_tools: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """Validate and normalize structured arguments for one tool call.""" | |
| tool_name = canonicalize_tool_name(tool_name, allowed_tools=allowed_tools) | |
| if allowed_tools is not None and tool_name not in allowed_tools: | |
| raise ToolValidationError( | |
| f"Unsupported tool_name '{tool_name}'. Expected one of: {', '.join(allowed_tools)}" | |
| ) | |
| if not isinstance(arguments, dict): | |
| raise ToolValidationError("arguments must be a JSON object.") | |
| spec = get_tool_spec(tool_name) | |
| supported_args = {arg.name: arg for arg in spec.arguments} | |
| unknown_args = sorted(set(arguments) - set(supported_args)) | |
| if unknown_args: | |
| raise ToolValidationError( | |
| f"{tool_name} received unsupported arguments: {', '.join(unknown_args)}" | |
| ) | |
| normalized: dict[str, Any] = {} | |
| for arg_spec in spec.arguments: | |
| value = normalize_semantic_argument( | |
| tool_name, | |
| arg_spec.name, | |
| arguments.get(arg_spec.name), | |
| ) | |
| if value is None: | |
| if arg_spec.required: | |
| raise ToolValidationError( | |
| f"{tool_name} requires argument '{arg_spec.name}'." | |
| ) | |
| continue | |
| if arg_spec.numeric: | |
| try: | |
| numeric_value = coerce_numeric_argument(value) | |
| except (TypeError, ValueError) as exc: | |
| raise ToolValidationError( | |
| f"{tool_name}.{arg_spec.name} must be numeric." | |
| ) from exc | |
| if arg_spec.minimum is not None and numeric_value < arg_spec.minimum: | |
| raise ToolValidationError( | |
| f"{tool_name}.{arg_spec.name} must be >= {arg_spec.minimum}." | |
| ) | |
| normalized[arg_spec.name] = numeric_value | |
| continue | |
| if arg_spec.boolean: | |
| try: | |
| normalized[arg_spec.name] = coerce_boolean_argument(value) | |
| except ValueError as exc: | |
| raise ToolValidationError( | |
| f"{tool_name}.{arg_spec.name} must be boolean." | |
| ) from exc | |
| continue | |
| if arg_spec.choices: | |
| if not isinstance(value, str): | |
| raise ToolValidationError( | |
| f"{tool_name}.{arg_spec.name} must be a string." | |
| ) | |
| normalized_choice = normalize_contract_token(value) | |
| canonical_choices = { | |
| normalize_contract_token(choice): choice for choice in arg_spec.choices | |
| } | |
| if normalized_choice not in canonical_choices: | |
| raise ToolValidationError( | |
| f"{tool_name}.{arg_spec.name} must be one of: {', '.join(arg_spec.choices)}" | |
| ) | |
| normalized[arg_spec.name] = canonical_choices[normalized_choice] | |
| continue | |
| normalized[arg_spec.name] = value | |
| return normalized | |