Pulse_ER_env / tool_catalog.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Central tool registry and validation helpers for Pulse-ER."""
from __future__ import annotations
from dataclasses import dataclass
import re
from typing import Any
@dataclass(frozen=True)
class ToolArgumentSpec:
"""Schema metadata for one structured tool argument."""
name: str
description: str
required: bool = False
numeric: bool = False
boolean: bool = False
minimum: float | None = None
choices: tuple[str, ...] = ()
@dataclass(frozen=True)
class ToolSpec:
"""Single source of truth for a supported tool."""
tool_name: str
tier: str
description: str
read_only: bool
state_changing: bool
arguments: tuple[ToolArgumentSpec, ...] = ()
class ToolValidationError(ValueError):
"""Raised when a tool call payload violates the frozen contract."""
_NUMERIC_PREFIX_RE = re.compile(r"^\s*([-+]?\d+(?:\.\d+)?)")
_TRUE_TOKENS = {"1", "true", "yes", "y", "on"}
_FALSE_TOKENS = {"0", "false", "no", "n", "off"}
def normalize_contract_token(value: Any) -> str:
"""Normalize contract strings so harmless formatting differences do not fail closed."""
return str(value).strip().lower().replace("-", "_").replace(" ", "_")
def coerce_numeric_argument(value: Any) -> float:
"""Coerce model-emitted numeric strings like ``2LPM`` into floats.
The policy layer sometimes emits recoverable formatting artifacts such as
``"2LPM"`` or ``"500 ml"``. These should be treated as parsing noise
rather than hard clinical failures, as long as a leading numeric value is
still obvious and unambiguous.
"""
if isinstance(value, bool):
raise TypeError("boolean values are not valid numeric tool arguments")
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
compact = value.strip().replace(",", "")
try:
return float(compact)
except ValueError:
match = _NUMERIC_PREFIX_RE.match(compact)
if match is not None:
return float(match.group(1))
raise ValueError("numeric coercion failed")
def coerce_boolean_argument(value: Any) -> bool:
"""Coerce common boolean-like values emitted by models into real booleans."""
if isinstance(value, bool):
return value
if isinstance(value, (int, float)) and value in {0, 1}:
return bool(value)
if isinstance(value, str):
token = value.strip().lower()
if token in _TRUE_TOKENS:
return True
if token in _FALSE_TOKENS:
return False
raise ValueError("boolean coercion failed")
def normalize_semantic_argument(tool_name: str, arg_name: str, value: Any) -> Any:
"""Recover common non-clinical aliases before validation fails closed."""
if not isinstance(value, str):
return value
token = normalize_contract_token(value)
if tool_name == "airway_support" and arg_name in {"mode", "support_type"}:
if token in {"basic", "default", "standard", "support", "airway_support"}:
return "auto"
return value
TOOL_SPECS: tuple[ToolSpec, ...] = (
ToolSpec(
tool_name="get_vitals",
tier="tier_1",
description="Read the current core vitals without changing patient state.",
read_only=True,
state_changing=False,
),
ToolSpec(
tool_name="advance_time",
tier="tier_2",
description="Advance the simulation clock to observe ongoing physiology.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="seconds",
description="Number of seconds to advance the simulation.",
required=True,
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="give_oxygen",
tier="tier_2",
description="Provide supplemental oxygen to improve oxygenation.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="flow_lpm",
description="Oxygen flow rate in liters per minute.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="device",
description="Optional oxygen delivery device such as nasal_cannula or non_rebreather_mask.",
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after oxygen is applied.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="give_fluids",
tier="tier_2",
description="Administer IV fluids to support perfusion and blood pressure.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="volume_ml",
description="Fluid bolus volume in milliliters.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="fluid_type",
description="Optional fluid selection such as saline, blood, or packed_rbc.",
),
ToolArgumentSpec(
name="fluid",
description="Alias for fluid_type when a shorter key is used.",
),
ToolArgumentSpec(
name="bag_volume_ml",
description="Alias for total infused volume in milliliters.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="rate_ml_per_min",
description="Infusion rate in milliliters per minute.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after the infusion starts.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="control_bleeding",
tier="tier_2",
description="Apply bleeding control measures for active hemorrhage.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="site",
description="Optional hemorrhage site when more than one bleed is active.",
),
ToolArgumentSpec(
name="compartment",
description="Alias for the hemorrhage site or compartment.",
),
ToolArgumentSpec(
name="method",
description="Bleeding control technique to use.",
choices=("tourniquet", "pressure", "hemostatic_dressing"),
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after control is applied.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="position_patient",
tier="tier_2",
description="Reposition the patient to support breathing or perfusion.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="position",
description="Target position when explicitly specified, such as supine or upright.",
),
),
),
ToolSpec(
tool_name="airway_support",
tier="tier_2",
description="Provide airway support to improve ventilation.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="mode",
description="Optional airway support mode such as cpap, bag_valve_mask, or pressure_control_ventilation.",
),
ToolArgumentSpec(
name="support_type",
description="Alias for the airway support mode or airway intervention type.",
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after airway support is applied.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="fio2",
description="Fraction of inspired oxygen as a 0-1 value.",
numeric=True,
minimum=0.0,
),
ToolArgumentSpec(
name="fraction_inspired_oxygen",
description="Alias for fio2.",
numeric=True,
minimum=0.0,
),
ToolArgumentSpec(
name="peep_cmh2o",
description="Positive end-expiratory pressure in cmH2O.",
numeric=True,
),
ToolArgumentSpec(
name="peep",
description="Alias for peep_cmh2o.",
numeric=True,
),
ToolArgumentSpec(
name="respiration_rate_bpm",
description="Ventilatory rate in breaths per minute.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="rate_bpm",
description="Alias for respiration_rate_bpm.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="ie_ratio",
description="Inspiratory-to-expiratory ratio.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="inspiratory_expiratory_ratio",
description="Alias for ie_ratio.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="squeeze_pressure_cmh2o",
description="Bag-valve-mask squeeze pressure in cmH2O.",
numeric=True,
),
ToolArgumentSpec(
name="pressure_cmh2o",
description="Alias for squeeze or inspiratory pressure in cmH2O.",
numeric=True,
),
ToolArgumentSpec(
name="squeeze_volume_ml",
description="Bag-valve-mask squeeze volume in milliliters.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="tidal_volume_ml",
description="Alias for squeeze_volume_ml.",
numeric=True,
minimum=1.0,
),
ToolArgumentSpec(
name="airway_adjunct",
description="Optional airway adjunct such as oropharyngeal or nasopharyngeal.",
),
ToolArgumentSpec(
name="pressure_support_cmh2o",
description="Pressure support level in cmH2O.",
numeric=True,
),
ToolArgumentSpec(
name="pressure_support",
description="Alias for pressure_support_cmh2o.",
numeric=True,
),
ToolArgumentSpec(
name="inspiratory_pressure_cmh2o",
description="Inspiratory pressure target in cmH2O.",
numeric=True,
),
ToolArgumentSpec(
name="inspiratory_period_s",
description="Inspiratory period in seconds.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="summarize_state",
tier="tier_1",
description="Summarize the current patient state in concise clinical language.",
read_only=True,
state_changing=False,
),
ToolSpec(
tool_name="check_deterioration",
tier="tier_1",
description="Assess whether the patient is currently worsening.",
read_only=True,
state_changing=False,
),
ToolSpec(
tool_name="recommend_next_step",
tier="tier_3",
description="Recommend the most appropriate next intervention or assessment.",
read_only=True,
state_changing=False,
),
ToolSpec(
tool_name="give_pressor",
tier="tier_2",
description="Start, titrate, or stop a vasopressor infusion for refractory hypotension.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="pressor",
description="Pressor agent name such as norepinephrine or phenylephrine.",
),
ToolArgumentSpec(
name="agent",
description="Alias for the pressor agent name.",
),
ToolArgumentSpec(
name="rate_ml_per_min",
description="Infusion rate in milliliters per minute.",
numeric=True,
minimum=0.0,
),
ToolArgumentSpec(
name="concentration_ug_per_ml",
description="Pressor concentration in micrograms per milliliter.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="stop",
description="Set to true to stop the current infusion.",
boolean=True,
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after the pressor change.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="needle_decompression",
tier="tier_2",
description="Perform needle decompression when a tension pneumothorax is suspected.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="side",
description="Target side for decompression.",
choices=("left", "right"),
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after decompression.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="pericardiocentesis",
tier="tier_2",
description="Perform pericardiocentesis when tamponade physiology is suspected.",
read_only=False,
state_changing=True,
arguments=(
ToolArgumentSpec(
name="drain_rate_ml_per_min",
description="Drainage rate in milliliters per minute.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="rate_ml_per_min",
description="Alias for drain_rate_ml_per_min.",
numeric=True,
minimum=0.1,
),
ToolArgumentSpec(
name="monitor_seconds",
description="Optional observation window after pericardiocentesis.",
numeric=True,
minimum=0.1,
),
),
),
ToolSpec(
tool_name="get_respiratory_status",
tier="tier_1",
description="Read a respiratory-focused bedside summary including breath sounds and EtCO2.",
read_only=True,
state_changing=False,
),
ToolSpec(
tool_name="get_blood_gas",
tier="tier_1",
description="Order or review arterial blood gas results.",
read_only=False,
state_changing=True,
),
ToolSpec(
tool_name="get_cbc",
tier="tier_1",
description="Order or review complete blood count results.",
read_only=False,
state_changing=True,
),
ToolSpec(
tool_name="get_bmp",
tier="tier_1",
description="Order or review basic metabolic panel results.",
read_only=False,
state_changing=True,
),
)
TOOL_SPEC_BY_NAME: dict[str, ToolSpec] = {
spec.tool_name: spec for spec in TOOL_SPECS
}
INITIAL_TOOL_NAMES = [spec.tool_name for spec in TOOL_SPECS[:10]]
EXTENDED_TOOL_NAMES = [spec.tool_name for spec in TOOL_SPECS]
KNOWN_TOOL_NAMES = EXTENDED_TOOL_NAMES
def canonicalize_tool_name(tool_name: str, allowed_tools: list[str] | None = None) -> str:
"""Map harmless casing or separator differences onto the canonical tool name."""
valid_tools = allowed_tools or list(KNOWN_TOOL_NAMES)
if tool_name in valid_tools:
return tool_name
normalized = normalize_contract_token(tool_name)
matches = [candidate for candidate in valid_tools if normalize_contract_token(candidate) == normalized]
if len(matches) == 1:
return matches[0]
return tool_name
def get_tool_spec(tool_name: str) -> ToolSpec:
"""Return the registry entry for one tool name."""
tool_name = canonicalize_tool_name(tool_name)
try:
return TOOL_SPEC_BY_NAME[tool_name]
except KeyError as exc:
raise ToolValidationError(f"Unsupported tool_name '{tool_name}'.") from exc
def build_tool_catalog(available_tools: list[str] | None = None) -> list[dict[str, Any]]:
"""Build a prompt-safe catalog of supported tools and arguments."""
catalog: list[dict[str, Any]] = []
for tool_name in available_tools or KNOWN_TOOL_NAMES:
spec = get_tool_spec(tool_name)
catalog.append(
{
"tool_name": spec.tool_name,
"tier": spec.tier,
"description": spec.description,
"read_only": spec.read_only,
"state_changing": spec.state_changing,
"arguments": [
{
"name": arg.name,
"description": arg.description,
"required": arg.required,
"numeric": arg.numeric,
"boolean": arg.boolean,
"minimum": arg.minimum,
"choices": list(arg.choices),
}
for arg in spec.arguments
],
}
)
return catalog
def validate_tool_arguments(
tool_name: str,
arguments: dict[str, Any],
*,
allowed_tools: list[str] | None = None,
) -> dict[str, Any]:
"""Validate and normalize structured arguments for one tool call."""
tool_name = canonicalize_tool_name(tool_name, allowed_tools=allowed_tools)
if allowed_tools is not None and tool_name not in allowed_tools:
raise ToolValidationError(
f"Unsupported tool_name '{tool_name}'. Expected one of: {', '.join(allowed_tools)}"
)
if not isinstance(arguments, dict):
raise ToolValidationError("arguments must be a JSON object.")
spec = get_tool_spec(tool_name)
supported_args = {arg.name: arg for arg in spec.arguments}
unknown_args = sorted(set(arguments) - set(supported_args))
if unknown_args:
raise ToolValidationError(
f"{tool_name} received unsupported arguments: {', '.join(unknown_args)}"
)
normalized: dict[str, Any] = {}
for arg_spec in spec.arguments:
value = normalize_semantic_argument(
tool_name,
arg_spec.name,
arguments.get(arg_spec.name),
)
if value is None:
if arg_spec.required:
raise ToolValidationError(
f"{tool_name} requires argument '{arg_spec.name}'."
)
continue
if arg_spec.numeric:
try:
numeric_value = coerce_numeric_argument(value)
except (TypeError, ValueError) as exc:
raise ToolValidationError(
f"{tool_name}.{arg_spec.name} must be numeric."
) from exc
if arg_spec.minimum is not None and numeric_value < arg_spec.minimum:
raise ToolValidationError(
f"{tool_name}.{arg_spec.name} must be >= {arg_spec.minimum}."
)
normalized[arg_spec.name] = numeric_value
continue
if arg_spec.boolean:
try:
normalized[arg_spec.name] = coerce_boolean_argument(value)
except ValueError as exc:
raise ToolValidationError(
f"{tool_name}.{arg_spec.name} must be boolean."
) from exc
continue
if arg_spec.choices:
if not isinstance(value, str):
raise ToolValidationError(
f"{tool_name}.{arg_spec.name} must be a string."
)
normalized_choice = normalize_contract_token(value)
canonical_choices = {
normalize_contract_token(choice): choice for choice in arg_spec.choices
}
if normalized_choice not in canonical_choices:
raise ToolValidationError(
f"{tool_name}.{arg_spec.name} must be one of: {', '.join(arg_spec.choices)}"
)
normalized[arg_spec.name] = canonical_choices[normalized_choice]
continue
normalized[arg_spec.name] = value
return normalized