AMR-Guard / src /utils.py
ghitaben's picture
Refactor prompt templates and RAG module
ad9e267
"""
Utility functions for clinical calculations and data parsing.
- Creatinine Clearance (CrCl) via Cockcroft-Gault
- MIC trend analysis and creep detection
- Prescription card formatter
- JSON parsing and data normalization helpers
"""
import json
import math
import re
from typing import Any, Dict, List, Literal, Optional, Tuple
# --- CrCl calculator ---
def calculate_crcl(
age_years: float,
weight_kg: float,
serum_creatinine_mg_dl: float,
sex: Literal["male", "female"],
use_ibw: bool = False,
height_cm: Optional[float] = None,
) -> float:
"""
Cockcroft-Gault equation.
CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
When use_ibw=True and height is given, uses Ideal Body Weight.
For obese patients (actual > 1.3 × IBW), switches to Adjusted Body Weight.
Returns CrCl in mL/min.
"""
if serum_creatinine_mg_dl <= 0:
raise ValueError("Serum creatinine must be positive")
if age_years <= 0 or weight_kg <= 0:
raise ValueError("Age and weight must be positive")
weight = weight_kg
if use_ibw and height_cm:
ibw = calculate_ibw(height_cm, sex)
weight = calculate_adjusted_bw(ibw, weight_kg) if weight_kg > ibw * 1.3 else ibw
crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
if sex == "female":
crcl *= 0.85
return round(crcl, 1)
def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
"""
Devine formula for Ideal Body Weight.
Male: 50 kg + 2.3 kg per inch over 5 feet
Female: 45.5 kg + 2.3 kg per inch over 5 feet
"""
height_over_60_inches = max(0, height_cm / 2.54 - 60)
base = 50 if sex == "male" else 45.5
return round(base + 2.3 * height_over_60_inches, 1)
def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
"""
Adjusted Body Weight for obese patients.
AdjBW = IBW + 0.4 × (Actual - IBW)
"""
return round(ibw + 0.4 * (actual_weight - ibw), 1)
def get_renal_dose_category(crcl: float) -> str:
"""Map CrCl value to a dosing category string."""
if crcl >= 90:
return "normal"
elif crcl >= 60:
return "mild_impairment"
elif crcl >= 30:
return "moderate_impairment"
elif crcl >= 15:
return "severe_impairment"
else:
return "esrd"
# --- MIC trend analysis ---
def calculate_mic_trend(
mic_values: List[Dict[str, Any]],
susceptible_breakpoint: Optional[float] = None,
resistant_breakpoint: Optional[float] = None,
) -> Dict[str, Any]:
"""
Analyze a list of MIC readings over time.
Requires at least 2 readings. Uses linear regression slope for trend
direction when >= 3 points are available; falls back to ratio comparison
for exactly 2 points.
"""
if len(mic_values) < 2:
return {
"trend": "insufficient_data",
"risk_level": "UNKNOWN",
"alert": "Need at least 2 MIC values for trend analysis",
}
mics = [float(v["mic_value"]) for v in mic_values]
baseline_mic = mics[0]
current_mic = mics[-1]
fold_change = (current_mic / baseline_mic) if baseline_mic > 0 else float("inf")
if len(mics) >= 3:
n = len(mics)
x_mean = (n - 1) / 2
y_mean = sum(mics) / n
numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
denominator = sum((i - x_mean) ** 2 for i in range(n))
slope = numerator / denominator if denominator != 0 else 0
trend = "increasing" if slope > 0.5 else "decreasing" if slope < -0.5 else "stable"
else:
trend = "increasing" if current_mic > baseline_mic * 1.5 else "decreasing" if current_mic < baseline_mic * 0.67 else "stable"
# Fold change per time step (geometric rate of change)
velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
risk_level, alert = _assess_mic_risk(
current_mic, baseline_mic, fold_change, trend,
susceptible_breakpoint, resistant_breakpoint,
)
return {
"baseline_mic": baseline_mic,
"current_mic": current_mic,
"ratio": round(fold_change, 2),
"trend": trend,
"velocity": round(velocity, 3),
"risk_level": risk_level,
"alert": alert,
"n_readings": len(mics),
}
def _assess_mic_risk(
current_mic: float,
baseline_mic: float,
fold_change: float,
trend: str,
s_breakpoint: Optional[float],
r_breakpoint: Optional[float],
) -> Tuple[str, str]:
"""
Assign a risk level (LOW/MODERATE/HIGH/CRITICAL) based on breakpoints and fold change.
Prefers breakpoint-based assessment when breakpoints are available.
Falls back to fold-change thresholds otherwise.
"""
if s_breakpoint is not None and r_breakpoint is not None:
margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
if current_mic > r_breakpoint:
return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
if current_mic > s_breakpoint:
return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
if margin < 2:
if trend == "increasing":
return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
if margin < 4:
if trend == "increasing":
return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
return "LOW", "MIC stable with adequate margin to breakpoint."
return "LOW", "MIC well below breakpoint with good safety margin."
# No breakpoints — use fold change thresholds from EUCAST MIC creep criteria
if fold_change >= 8:
return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
if fold_change >= 4:
return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
if fold_change >= 2:
if trend == "increasing":
return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
if trend == "increasing":
return "MODERATE", "MIC showing upward trend. Continue monitoring."
return "LOW", "MIC stable or decreasing. Current therapy appropriate."
def detect_mic_creep(
organism: str,
antibiotic: str,
mic_history: List[Dict[str, Any]],
breakpoints: Dict[str, float],
) -> Dict[str, Any]:
"""
Detect MIC creep for a specific organism-antibiotic pair.
Augments calculate_mic_trend with a time-to-resistance estimate
when the MIC is rising and a susceptible breakpoint is available.
"""
result = calculate_mic_trend(
mic_history,
susceptible_breakpoint=breakpoints.get("susceptible"),
resistant_breakpoint=breakpoints.get("resistant"),
)
result["organism"] = organism
result["antibiotic"] = antibiotic
result["breakpoint_susceptible"] = breakpoints.get("susceptible")
result["breakpoint_resistant"] = breakpoints.get("resistant")
# Estimate how many more time-points until MIC reaches the susceptible breakpoint
if result["trend"] == "increasing" and result["velocity"] > 1.0:
current = result["current_mic"]
s_bp = breakpoints.get("susceptible")
if s_bp and current < s_bp:
doublings_needed = math.log2(s_bp / current) if current > 0 else 0
log_velocity = math.log(result["velocity"]) / math.log(2)
if log_velocity > 0:
result["estimated_readings_to_resistance"] = round(doublings_needed / log_velocity, 1)
return result
# --- Prescription formatter ---
def format_prescription_card(recommendation: Dict[str, Any]) -> str:
"""Format a recommendation dict as a plain-text prescription card."""
lines = []
lines.append("=" * 50)
lines.append("ANTIBIOTIC PRESCRIPTION")
lines.append("=" * 50)
primary = recommendation.get("primary_recommendation", recommendation)
lines.append(f"\nDRUG: {primary.get('antibiotic', 'N/A')}")
lines.append(f"DOSE: {primary.get('dose', 'N/A')}")
lines.append(f"ROUTE: {primary.get('route', 'N/A')}")
lines.append(f"FREQUENCY: {primary.get('frequency', 'N/A')}")
lines.append(f"DURATION: {primary.get('duration', 'N/A')}")
if primary.get("aware_category"):
lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
adjustments = recommendation.get("dose_adjustments", {})
if adjustments.get("renal") and adjustments["renal"] != "None needed":
lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
alerts = recommendation.get("safety_alerts", [])
if alerts:
lines.append("\n" + "-" * 50)
lines.append("SAFETY ALERTS:")
for alert in alerts:
level = alert.get("level", "INFO")
marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
lines.append(f" {marker} {alert.get('message', '')}")
monitoring = recommendation.get("monitoring_parameters", [])
if monitoring:
lines.append("\n" + "-" * 50)
lines.append("MONITORING:")
for param in monitoring:
lines.append(f" - {param}")
if recommendation.get("rationale"):
lines.append("\n" + "-" * 50)
lines.append("RATIONALE:")
lines.append(f" {recommendation['rationale']}")
lines.append("\n" + "=" * 50)
return "\n".join(lines)
# --- JSON parsing ---
def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
"""
Extract and parse the first JSON object from a string.
Handles model output that may wrap JSON in markdown code fences.
Returns None if no valid JSON is found.
"""
if not text:
return None
try:
return json.loads(text)
except json.JSONDecodeError:
pass
for pattern in [r"```json\s*\n?(.*?)\n?```", r"```\s*\n?(.*?)\n?```", r"\{[\s\S]*\}"]:
match = re.search(pattern, text, re.DOTALL)
if match:
try:
json_str = match.group(1) if match.lastindex else match.group(0)
return json.loads(json_str)
except (json.JSONDecodeError, IndexError):
continue
return None
def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
"""Return (is_valid, missing_fields) for an agent output dict."""
missing = [f for f in required_fields if f not in output]
return len(missing) == 0, missing
# --- Name normalization ---
def normalize_antibiotic_name(name: str) -> str:
"""Map common abbreviations and brand names to standard antibiotic names."""
mappings = {
"amox": "amoxicillin",
"amox/clav": "amoxicillin-clavulanate",
"augmentin": "amoxicillin-clavulanate",
"pip/tazo": "piperacillin-tazobactam",
"zosyn": "piperacillin-tazobactam",
"tmp/smx": "trimethoprim-sulfamethoxazole",
"bactrim": "trimethoprim-sulfamethoxazole",
"cipro": "ciprofloxacin",
"levo": "levofloxacin",
"moxi": "moxifloxacin",
"vanc": "vancomycin",
"vanco": "vancomycin",
"mero": "meropenem",
"imi": "imipenem",
"gent": "gentamicin",
"tobra": "tobramycin",
"ceftriax": "ceftriaxone",
"rocephin": "ceftriaxone",
"cefepime": "cefepime",
"maxipime": "cefepime",
}
return mappings.get(name.lower().strip(), name.lower().strip())
def normalize_organism_name(name: str) -> str:
"""Map common abbreviations to full organism names."""
abbreviations = {
"e. coli": "Escherichia coli",
"e.coli": "Escherichia coli",
"k. pneumoniae": "Klebsiella pneumoniae",
"k.pneumoniae": "Klebsiella pneumoniae",
"p. aeruginosa": "Pseudomonas aeruginosa",
"p.aeruginosa": "Pseudomonas aeruginosa",
"s. aureus": "Staphylococcus aureus",
"s.aureus": "Staphylococcus aureus",
"mrsa": "Staphylococcus aureus (MRSA)",
"mssa": "Staphylococcus aureus (MSSA)",
"enterococcus": "Enterococcus species",
"vre": "Enterococcus (VRE)",
}
return abbreviations.get(name.strip().lower(), name.strip())