Add prompt templates and utility functions for Med-I-C multi-agent system
Browse files- app.py +449 -219
- src/agents.py +572 -14
- src/agents/__init__.py +0 -0
- src/graph.py +300 -0
- src/prompts.py +355 -0
- src/rag.py +482 -0
- src/utils.py +505 -0
app.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
"""
|
| 2 |
Med-I-C: AMR-Guard Demo Application
|
| 3 |
Infection Lifecycle Orchestrator - Streamlit Interface
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
import sys
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
# Add project root to path
|
|
@@ -12,19 +15,14 @@ PROJECT_ROOT = Path(__file__).parent
|
|
| 12 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 13 |
|
| 14 |
from src.tools import (
|
| 15 |
-
query_antibiotic_info,
|
| 16 |
-
get_antibiotics_by_category,
|
| 17 |
interpret_mic_value,
|
| 18 |
-
get_breakpoints_for_pathogen,
|
| 19 |
-
query_resistance_pattern,
|
| 20 |
get_most_effective_antibiotics,
|
| 21 |
calculate_mic_trend,
|
| 22 |
-
check_drug_interactions,
|
| 23 |
screen_antibiotic_safety,
|
| 24 |
search_clinical_guidelines,
|
| 25 |
-
get_treatment_recommendation,
|
| 26 |
get_empirical_therapy_guidance,
|
| 27 |
)
|
|
|
|
| 28 |
|
| 29 |
# Page configuration
|
| 30 |
st.set_page_config(
|
|
@@ -48,6 +46,21 @@ st.markdown("""
|
|
| 48 |
color: #666;
|
| 49 |
margin-top: 0;
|
| 50 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
.risk-high {
|
| 52 |
background-color: #FFCDD2;
|
| 53 |
padding: 10px;
|
|
@@ -66,6 +79,13 @@ st.markdown("""
|
|
| 66 |
border-radius: 5px;
|
| 67 |
border-left: 4px solid #388E3C;
|
| 68 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
.info-box {
|
| 70 |
background-color: #E3F2FD;
|
| 71 |
padding: 15px;
|
|
@@ -79,7 +99,7 @@ st.markdown("""
|
|
| 79 |
def main():
|
| 80 |
# Header
|
| 81 |
st.markdown('<p class="main-header">🦠 Med-I-C: AMR-Guard</p>', unsafe_allow_html=True)
|
| 82 |
-
st.markdown('<p class="sub-header">Infection Lifecycle Orchestrator
|
| 83 |
|
| 84 |
# Sidebar navigation
|
| 85 |
st.sidebar.title("Navigation")
|
|
@@ -87,85 +107,416 @@ def main():
|
|
| 87 |
"Select Module",
|
| 88 |
[
|
| 89 |
"🏠 Overview",
|
| 90 |
-
"
|
| 91 |
-
"
|
|
|
|
| 92 |
"📊 MIC Trend Analysis",
|
| 93 |
"⚠️ Drug Safety Check",
|
| 94 |
-
"📚 Clinical Guidelines
|
| 95 |
]
|
| 96 |
)
|
| 97 |
|
| 98 |
if page == "🏠 Overview":
|
| 99 |
show_overview()
|
| 100 |
-
elif page == "
|
|
|
|
|
|
|
| 101 |
show_empirical_advisor()
|
| 102 |
-
elif page == "🔬
|
| 103 |
show_lab_interpretation()
|
| 104 |
elif page == "📊 MIC Trend Analysis":
|
| 105 |
show_mic_trend_analysis()
|
| 106 |
elif page == "⚠️ Drug Safety Check":
|
| 107 |
show_drug_safety()
|
| 108 |
-
elif page == "📚 Clinical Guidelines
|
| 109 |
show_guidelines_search()
|
| 110 |
|
| 111 |
|
| 112 |
def show_overview():
|
| 113 |
st.header("System Overview")
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
col1, col2 = st.columns(2)
|
| 116 |
|
| 117 |
with col1:
|
| 118 |
-
st.subheader("Stage 1: Empirical Phase")
|
| 119 |
st.markdown("""
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
""")
|
| 130 |
|
| 131 |
with col2:
|
| 132 |
-
st.subheader("Stage 2: Targeted Phase")
|
| 133 |
st.markdown("""
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
""")
|
| 142 |
|
| 143 |
st.divider()
|
| 144 |
|
|
|
|
| 145 |
st.subheader("Knowledge Sources")
|
| 146 |
|
| 147 |
col1, col2, col3, col4 = st.columns(4)
|
| 148 |
|
| 149 |
with col1:
|
| 150 |
-
st.metric("WHO
|
| 151 |
with col2:
|
| 152 |
-
st.metric("
|
| 153 |
with col3:
|
| 154 |
-
st.metric("
|
| 155 |
with col4:
|
| 156 |
-
st.metric("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def show_empirical_advisor():
|
| 160 |
-
st.header("💊
|
| 161 |
-
st.markdown("*
|
| 162 |
|
| 163 |
col1, col2 = st.columns([2, 1])
|
| 164 |
|
| 165 |
with col1:
|
| 166 |
infection_type = st.selectbox(
|
| 167 |
"Infection Type",
|
| 168 |
-
["Urinary Tract Infection
|
| 169 |
"Skin/Soft Tissue", "Intra-abdominal", "Meningitis"]
|
| 170 |
)
|
| 171 |
|
|
@@ -182,269 +533,148 @@ def show_empirical_advisor():
|
|
| 182 |
)
|
| 183 |
|
| 184 |
with col2:
|
| 185 |
-
st.markdown("**WHO
|
| 186 |
st.markdown("""
|
| 187 |
- **ACCESS**: First-line, low resistance
|
| 188 |
- **WATCH**: Higher resistance potential
|
| 189 |
- **RESERVE**: Last resort antibiotics
|
| 190 |
""")
|
| 191 |
|
| 192 |
-
if st.button("Get
|
| 193 |
-
with st.spinner("Searching guidelines
|
| 194 |
-
# Get recommendations from guidelines
|
| 195 |
guidance = get_empirical_therapy_guidance(
|
| 196 |
-
infection_type
|
| 197 |
risk_factors
|
| 198 |
)
|
| 199 |
|
| 200 |
-
st.subheader("Recommendations")
|
| 201 |
|
| 202 |
if guidance.get("recommendations"):
|
| 203 |
for i, rec in enumerate(guidance["recommendations"][:3], 1):
|
| 204 |
-
with st.expander(f"
|
| 205 |
st.markdown(rec.get("content", ""))
|
| 206 |
st.caption(f"Source: {rec.get('source', 'IDSA Guidelines')}")
|
| 207 |
|
| 208 |
-
# If pathogen specified, show resistance patterns
|
| 209 |
if suspected_pathogen:
|
| 210 |
-
st.subheader(f"Resistance
|
| 211 |
-
|
| 212 |
effective = get_most_effective_antibiotics(suspected_pathogen, min_susceptibility=70)
|
| 213 |
|
| 214 |
if effective:
|
| 215 |
-
st.markdown("**Most Effective Antibiotics (>70% susceptibility)**")
|
| 216 |
for ab in effective[:5]:
|
| 217 |
st.write(f"- **{ab.get('antibiotic')}**: {ab.get('avg_susceptibility', 0):.1f}% susceptible")
|
| 218 |
else:
|
| 219 |
-
st.info("No resistance data found
|
| 220 |
|
| 221 |
|
| 222 |
def show_lab_interpretation():
|
| 223 |
-
st.header("🔬
|
| 224 |
st.markdown("*Interpret antibiogram MIC values*")
|
| 225 |
|
| 226 |
col1, col2 = st.columns(2)
|
| 227 |
|
| 228 |
with col1:
|
| 229 |
-
pathogen = st.text_input(
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
antibiotic = st.text_input(
|
| 235 |
-
"Antibiotic",
|
| 236 |
-
placeholder="e.g., Ciprofloxacin, Meropenem"
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
mic_value = st.number_input(
|
| 240 |
-
"MIC Value (mg/L)",
|
| 241 |
-
min_value=0.001,
|
| 242 |
-
max_value=1024.0,
|
| 243 |
-
value=1.0,
|
| 244 |
-
step=0.5
|
| 245 |
-
)
|
| 246 |
|
| 247 |
with col2:
|
| 248 |
-
st.markdown("**
|
| 249 |
st.markdown("""
|
| 250 |
-
- **S
|
| 251 |
-
- **I
|
| 252 |
-
- **R
|
| 253 |
""")
|
| 254 |
|
| 255 |
-
if st.button("Interpret
|
| 256 |
if pathogen and antibiotic:
|
| 257 |
-
|
| 258 |
-
|
| 259 |
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
st.success(f"✅ **{interpretation}**")
|
| 264 |
-
elif interpretation == "RESISTANT":
|
| 265 |
-
st.error(f"❌ **{interpretation}**")
|
| 266 |
-
elif interpretation == "INTERMEDIATE":
|
| 267 |
-
st.warning(f"⚠️ **{interpretation}**")
|
| 268 |
-
else:
|
| 269 |
-
st.info(f"❓ **{interpretation}**")
|
| 270 |
-
|
| 271 |
-
st.markdown(f"**Details:** {result.get('message', '')}")
|
| 272 |
-
|
| 273 |
-
if result.get("breakpoints"):
|
| 274 |
-
bp = result["breakpoints"]
|
| 275 |
-
st.markdown(f"""
|
| 276 |
-
**Breakpoints:**
|
| 277 |
-
- S ≤ {bp.get('susceptible', 'N/A')} mg/L
|
| 278 |
-
- R > {bp.get('resistant', 'N/A')} mg/L
|
| 279 |
-
""")
|
| 280 |
-
|
| 281 |
-
if result.get("notes"):
|
| 282 |
-
st.info(f"**Note:** {result.get('notes')}")
|
| 283 |
-
else:
|
| 284 |
-
st.warning("Please enter both pathogen and antibiotic names.")
|
| 285 |
|
| 286 |
|
| 287 |
def show_mic_trend_analysis():
|
| 288 |
st.header("📊 MIC Trend Analysis")
|
| 289 |
st.markdown("*Detect MIC creep over time*")
|
| 290 |
|
| 291 |
-
st.
|
| 292 |
-
Enter historical MIC values to detect resistance velocity.
|
| 293 |
-
**MIC Creep**: A gradual increase in MIC that may predict treatment failure
|
| 294 |
-
even when the organism is still classified as "Susceptible".
|
| 295 |
-
""")
|
| 296 |
-
|
| 297 |
-
# Input for historical MICs
|
| 298 |
-
num_readings = st.slider("Number of historical readings", 2, 6, 3)
|
| 299 |
|
| 300 |
mic_values = []
|
| 301 |
cols = st.columns(num_readings)
|
| 302 |
|
| 303 |
for i, col in enumerate(cols):
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
f"MIC {i+1}",
|
| 307 |
-
min_value=0.001,
|
| 308 |
-
max_value=256.0,
|
| 309 |
-
value=float(2 ** i), # Default: 1, 2, 4, ...
|
| 310 |
-
key=f"mic_{i}"
|
| 311 |
-
)
|
| 312 |
-
mic_values.append({"date": f"T{i}", "mic_value": mic})
|
| 313 |
|
| 314 |
-
if st.button("Analyze
|
| 315 |
result = calculate_mic_trend(mic_values)
|
| 316 |
-
|
| 317 |
risk_level = result.get("risk_level", "UNKNOWN")
|
| 318 |
|
| 319 |
if risk_level == "HIGH":
|
| 320 |
-
st.markdown(f'<div class="risk-high">
|
| 321 |
-
unsafe_allow_html=True)
|
| 322 |
elif risk_level == "MODERATE":
|
| 323 |
-
st.markdown(f'<div class="risk-moderate">
|
| 324 |
-
unsafe_allow_html=True)
|
| 325 |
else:
|
| 326 |
-
st.markdown(f'<div class="risk-low">
|
| 327 |
-
unsafe_allow_html=True)
|
| 328 |
-
|
| 329 |
-
st.divider()
|
| 330 |
|
| 331 |
col1, col2, col3 = st.columns(3)
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
with col2:
|
| 336 |
-
st.metric("Current MIC", f"{result.get('current_mic', 'N/A')} mg/L")
|
| 337 |
-
with col3:
|
| 338 |
-
st.metric("Fold Change", f"{result.get('ratio', 'N/A')}x")
|
| 339 |
-
|
| 340 |
-
st.markdown(f"**Trend:** {result.get('trend', 'N/A')}")
|
| 341 |
-
st.markdown(f"**Resistance Velocity:** {result.get('velocity', 'N/A')}x per time point")
|
| 342 |
|
| 343 |
|
| 344 |
def show_drug_safety():
|
| 345 |
st.header("⚠️ Drug Safety Check")
|
| 346 |
-
st.markdown("*Screen for drug interactions*")
|
| 347 |
|
| 348 |
col1, col2 = st.columns(2)
|
| 349 |
|
| 350 |
with col1:
|
| 351 |
-
antibiotic = st.text_input(
|
| 352 |
-
|
| 353 |
-
placeholder="e.g., Ciprofloxacin"
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
current_meds = st.text_area(
|
| 357 |
-
"Current Medications (one per line)",
|
| 358 |
-
placeholder="Warfarin\nMetformin\nAmlodipine",
|
| 359 |
-
height=150
|
| 360 |
-
)
|
| 361 |
|
| 362 |
with col2:
|
| 363 |
-
allergies = st.text_area(
|
| 364 |
-
"Known Allergies (one per line)",
|
| 365 |
-
placeholder="Penicillin\nSulfa",
|
| 366 |
-
height=100
|
| 367 |
-
)
|
| 368 |
|
| 369 |
if st.button("Check Safety", type="primary"):
|
| 370 |
if antibiotic:
|
| 371 |
medications = [m.strip() for m in current_meds.split("\n") if m.strip()]
|
| 372 |
allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
|
| 373 |
|
| 374 |
-
|
| 375 |
-
result = screen_antibiotic_safety(antibiotic, medications, allergy_list)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
st.subheader("Alerts")
|
| 385 |
-
for alert in result["alerts"]:
|
| 386 |
-
level = alert.get("level", "WARNING")
|
| 387 |
-
if level == "CRITICAL":
|
| 388 |
-
st.error(f"🚨 {alert.get('message', '')}")
|
| 389 |
-
else:
|
| 390 |
-
st.warning(f"⚠️ {alert.get('message', '')}")
|
| 391 |
-
|
| 392 |
-
# Show allergy warnings
|
| 393 |
-
if result.get("allergy_warnings"):
|
| 394 |
-
st.subheader("Allergy Warnings")
|
| 395 |
-
for warn in result["allergy_warnings"]:
|
| 396 |
-
st.error(f"🚫 {warn.get('message', '')}")
|
| 397 |
-
|
| 398 |
-
# Show interactions
|
| 399 |
-
if result.get("interactions"):
|
| 400 |
-
st.subheader("Drug Interactions Found")
|
| 401 |
-
for interaction in result["interactions"][:5]:
|
| 402 |
-
severity = interaction.get("severity", "unknown")
|
| 403 |
-
icon = "🔴" if severity == "major" else "🟡" if severity == "moderate" else "🟢"
|
| 404 |
-
st.markdown(f"""
|
| 405 |
-
{icon} **{interaction.get('drug_1')}** ↔ **{interaction.get('drug_2')}**
|
| 406 |
-
- Severity: {severity.upper()}
|
| 407 |
-
- {interaction.get('interaction_description', '')}
|
| 408 |
-
""")
|
| 409 |
-
else:
|
| 410 |
-
st.warning("Please enter an antibiotic name.")
|
| 411 |
|
| 412 |
|
| 413 |
def show_guidelines_search():
|
| 414 |
-
st.header("📚 Clinical Guidelines
|
| 415 |
-
st.markdown("*Search IDSA treatment guidelines*")
|
| 416 |
|
| 417 |
-
query = st.text_input(
|
| 418 |
-
|
| 419 |
-
placeholder="e.g., treatment for ESBL E. coli UTI"
|
| 420 |
-
)
|
| 421 |
|
| 422 |
-
|
| 423 |
-
"Filter by Pathogen Type (optional)",
|
| 424 |
-
["All", "ESBL-E", "CRE", "CRAB", "DTR-PA", "S.maltophilia", "AmpC-E"]
|
| 425 |
-
)
|
| 426 |
-
|
| 427 |
-
if st.button("Search Guidelines", type="primary"):
|
| 428 |
if query:
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
with st.expander(
|
| 439 |
-
f"Result {i} - {result.get('pathogen_type', 'General')} "
|
| 440 |
-
f"(Relevance: {result.get('relevance_score', 0):.2f})"
|
| 441 |
-
):
|
| 442 |
-
st.markdown(result.get("content", ""))
|
| 443 |
-
st.caption(f"Source: {result.get('source', 'IDSA Guidelines')}")
|
| 444 |
-
else:
|
| 445 |
-
st.info("No results found. Try a different query or remove the filter.")
|
| 446 |
-
else:
|
| 447 |
-
st.warning("Please enter a search query.")
|
| 448 |
|
| 449 |
|
| 450 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
Med-I-C: AMR-Guard Demo Application
|
| 3 |
Infection Lifecycle Orchestrator - Streamlit Interface
|
| 4 |
+
|
| 5 |
+
Multi-Agent Architecture powered by MedGemma via LangGraph
|
| 6 |
"""
|
| 7 |
|
| 8 |
import streamlit as st
|
| 9 |
import sys
|
| 10 |
+
import json
|
| 11 |
from pathlib import Path
|
| 12 |
|
| 13 |
# Add project root to path
|
|
|
|
| 15 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
|
| 17 |
from src.tools import (
|
|
|
|
|
|
|
| 18 |
interpret_mic_value,
|
|
|
|
|
|
|
| 19 |
get_most_effective_antibiotics,
|
| 20 |
calculate_mic_trend,
|
|
|
|
| 21 |
screen_antibiotic_safety,
|
| 22 |
search_clinical_guidelines,
|
|
|
|
| 23 |
get_empirical_therapy_guidance,
|
| 24 |
)
|
| 25 |
+
from src.utils import format_prescription_card
|
| 26 |
|
| 27 |
# Page configuration
|
| 28 |
st.set_page_config(
|
|
|
|
| 46 |
color: #666;
|
| 47 |
margin-top: 0;
|
| 48 |
}
|
| 49 |
+
.agent-card {
|
| 50 |
+
background-color: #F5F5F5;
|
| 51 |
+
padding: 15px;
|
| 52 |
+
border-radius: 8px;
|
| 53 |
+
margin: 10px 0;
|
| 54 |
+
border-left: 4px solid #1E88E5;
|
| 55 |
+
}
|
| 56 |
+
.agent-active {
|
| 57 |
+
border-left-color: #4CAF50;
|
| 58 |
+
background-color: #E8F5E9;
|
| 59 |
+
}
|
| 60 |
+
.agent-complete {
|
| 61 |
+
border-left-color: #9E9E9E;
|
| 62 |
+
background-color: #FAFAFA;
|
| 63 |
+
}
|
| 64 |
.risk-high {
|
| 65 |
background-color: #FFCDD2;
|
| 66 |
padding: 10px;
|
|
|
|
| 79 |
border-radius: 5px;
|
| 80 |
border-left: 4px solid #388E3C;
|
| 81 |
}
|
| 82 |
+
.prescription-card {
|
| 83 |
+
background-color: #E3F2FD;
|
| 84 |
+
padding: 20px;
|
| 85 |
+
border-radius: 10px;
|
| 86 |
+
font-family: monospace;
|
| 87 |
+
white-space: pre-wrap;
|
| 88 |
+
}
|
| 89 |
.info-box {
|
| 90 |
background-color: #E3F2FD;
|
| 91 |
padding: 15px;
|
|
|
|
| 99 |
def main():
|
| 100 |
# Header
|
| 101 |
st.markdown('<p class="main-header">🦠 Med-I-C: AMR-Guard</p>', unsafe_allow_html=True)
|
| 102 |
+
st.markdown('<p class="sub-header">Infection Lifecycle Orchestrator - Multi-Agent System</p>', unsafe_allow_html=True)
|
| 103 |
|
| 104 |
# Sidebar navigation
|
| 105 |
st.sidebar.title("Navigation")
|
|
|
|
| 107 |
"Select Module",
|
| 108 |
[
|
| 109 |
"🏠 Overview",
|
| 110 |
+
"🤖 Agent Pipeline",
|
| 111 |
+
"💊 Empirical Advisor",
|
| 112 |
+
"🔬 Lab Interpretation",
|
| 113 |
"📊 MIC Trend Analysis",
|
| 114 |
"⚠️ Drug Safety Check",
|
| 115 |
+
"📚 Clinical Guidelines"
|
| 116 |
]
|
| 117 |
)
|
| 118 |
|
| 119 |
if page == "🏠 Overview":
|
| 120 |
show_overview()
|
| 121 |
+
elif page == "🤖 Agent Pipeline":
|
| 122 |
+
show_agent_pipeline()
|
| 123 |
+
elif page == "💊 Empirical Advisor":
|
| 124 |
show_empirical_advisor()
|
| 125 |
+
elif page == "🔬 Lab Interpretation":
|
| 126 |
show_lab_interpretation()
|
| 127 |
elif page == "📊 MIC Trend Analysis":
|
| 128 |
show_mic_trend_analysis()
|
| 129 |
elif page == "⚠️ Drug Safety Check":
|
| 130 |
show_drug_safety()
|
| 131 |
+
elif page == "📚 Clinical Guidelines":
|
| 132 |
show_guidelines_search()
|
| 133 |
|
| 134 |
|
| 135 |
def show_overview():
|
| 136 |
st.header("System Overview")
|
| 137 |
|
| 138 |
+
st.markdown("""
|
| 139 |
+
**AMR-Guard** is a multi-agent AI system that orchestrates the complete infection treatment lifecycle,
|
| 140 |
+
from initial empirical therapy to targeted treatment based on lab results.
|
| 141 |
+
""")
|
| 142 |
+
|
| 143 |
+
# Architecture diagram
|
| 144 |
+
st.subheader("Multi-Agent Architecture")
|
| 145 |
+
|
| 146 |
col1, col2 = st.columns(2)
|
| 147 |
|
| 148 |
with col1:
|
|
|
|
| 149 |
st.markdown("""
|
| 150 |
+
### Stage 1: Empirical Phase
|
| 151 |
+
**Path:** Agent 1 → Agent 4
|
| 152 |
+
|
| 153 |
+
*Before lab results are available*
|
| 154 |
+
|
| 155 |
+
1. **Intake Historian** (Agent 1)
|
| 156 |
+
- Parses patient demographics & history
|
| 157 |
+
- Calculates CrCl for renal dosing
|
| 158 |
+
- Identifies risk factors for MDR
|
| 159 |
+
|
| 160 |
+
2. **Clinical Pharmacologist** (Agent 4)
|
| 161 |
+
- Recommends empirical antibiotics
|
| 162 |
+
- Applies WHO AWaRe principles
|
| 163 |
+
- Performs safety checks
|
| 164 |
""")
|
| 165 |
|
| 166 |
with col2:
|
|
|
|
| 167 |
st.markdown("""
|
| 168 |
+
### Stage 2: Targeted Phase
|
| 169 |
+
**Path:** Agent 1 → Agent 2 → Agent 3 → Agent 4
|
| 170 |
+
|
| 171 |
+
*When lab/culture results are available*
|
| 172 |
+
|
| 173 |
+
1. **Intake Historian** (Agent 1)
|
| 174 |
+
2. **Vision Specialist** (Agent 2)
|
| 175 |
+
- Extracts data from lab reports
|
| 176 |
+
- Supports any language/format
|
| 177 |
+
3. **Trend Analyst** (Agent 3)
|
| 178 |
+
- Detects MIC creep patterns
|
| 179 |
+
- Calculates resistance velocity
|
| 180 |
+
4. **Clinical Pharmacologist** (Agent 4)
|
| 181 |
""")
|
| 182 |
|
| 183 |
st.divider()
|
| 184 |
|
| 185 |
+
# Knowledge sources
|
| 186 |
st.subheader("Knowledge Sources")
|
| 187 |
|
| 188 |
col1, col2, col3, col4 = st.columns(4)
|
| 189 |
|
| 190 |
with col1:
|
| 191 |
+
st.metric("WHO AWaRe", "264", "antibiotics classified")
|
| 192 |
with col2:
|
| 193 |
+
st.metric("EUCAST", "v16.0", "breakpoint tables")
|
| 194 |
with col3:
|
| 195 |
+
st.metric("IDSA", "2024", "treatment guidelines")
|
| 196 |
with col4:
|
| 197 |
+
st.metric("DDInter", "191K+", "drug interactions")
|
| 198 |
+
|
| 199 |
+
# Model info
|
| 200 |
+
st.subheader("AI Models")
|
| 201 |
+
st.markdown("""
|
| 202 |
+
| Agent | Primary Model | Fallback |
|
| 203 |
+
|-------|---------------|----------|
|
| 204 |
+
| Intake Historian | MedGemma 4B IT | Vertex AI API |
|
| 205 |
+
| Vision Specialist | MedGemma 4B IT (multimodal) | Vertex AI API |
|
| 206 |
+
| Trend Analyst | MedGemma 4B IT | Vertex AI API |
|
| 207 |
+
| Clinical Pharmacologist | MedGemma 4B + TxGemma 2B (safety) | Vertex AI API |
|
| 208 |
+
""")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def show_agent_pipeline():
|
| 212 |
+
st.header("🤖 Multi-Agent Pipeline")
|
| 213 |
+
st.markdown("*Run the complete infection lifecycle workflow*")
|
| 214 |
+
|
| 215 |
+
# Initialize session state
|
| 216 |
+
if "pipeline_result" not in st.session_state:
|
| 217 |
+
st.session_state.pipeline_result = None
|
| 218 |
+
|
| 219 |
+
# Patient Information Form
|
| 220 |
+
with st.expander("Patient Information", expanded=True):
|
| 221 |
+
col1, col2, col3 = st.columns(3)
|
| 222 |
+
|
| 223 |
+
with col1:
|
| 224 |
+
age = st.number_input("Age (years)", min_value=0, max_value=120, value=65)
|
| 225 |
+
weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0)
|
| 226 |
+
height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
|
| 227 |
+
|
| 228 |
+
with col2:
|
| 229 |
+
sex = st.selectbox("Sex", ["male", "female"])
|
| 230 |
+
creatinine = st.number_input("Serum Creatinine (mg/dL)", min_value=0.1, max_value=20.0, value=1.2)
|
| 231 |
+
|
| 232 |
+
with col3:
|
| 233 |
+
infection_site = st.selectbox(
|
| 234 |
+
"Infection Site",
|
| 235 |
+
["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"]
|
| 236 |
+
)
|
| 237 |
+
suspected_source = st.text_input(
|
| 238 |
+
"Suspected Source",
|
| 239 |
+
placeholder="e.g., community UTI, hospital-acquired pneumonia"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
with st.expander("Medical History"):
|
| 243 |
+
col1, col2 = st.columns(2)
|
| 244 |
+
|
| 245 |
+
with col1:
|
| 246 |
+
medications = st.text_area(
|
| 247 |
+
"Current Medications (one per line)",
|
| 248 |
+
placeholder="Metformin\nLisinopril\nAspirin",
|
| 249 |
+
height=100
|
| 250 |
+
)
|
| 251 |
+
allergies = st.text_area(
|
| 252 |
+
"Allergies (one per line)",
|
| 253 |
+
placeholder="Penicillin\nSulfa",
|
| 254 |
+
height=100
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
with col2:
|
| 258 |
+
comorbidities = st.multiselect(
|
| 259 |
+
"Comorbidities",
|
| 260 |
+
["Diabetes", "CKD", "Heart Failure", "COPD", "Immunocompromised",
|
| 261 |
+
"Recent Surgery", "Malignancy", "Liver Disease"]
|
| 262 |
+
)
|
| 263 |
+
risk_factors = st.multiselect(
|
| 264 |
+
"MDR Risk Factors",
|
| 265 |
+
["Prior MRSA infection", "Recent antibiotic use (<90 days)",
|
| 266 |
+
"Healthcare-associated", "Recent hospitalization",
|
| 267 |
+
"Nursing home resident", "Prior MDR infection"]
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Lab Data (Optional - triggers Stage 2)
|
| 271 |
+
with st.expander("Lab Results (Optional - triggers targeted pathway)"):
|
| 272 |
+
lab_input_method = st.radio(
|
| 273 |
+
"Input Method",
|
| 274 |
+
["None (Empirical only)", "Paste Lab Text", "Upload File"],
|
| 275 |
+
horizontal=True
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
labs_raw_text = None
|
| 279 |
+
|
| 280 |
+
if lab_input_method == "Paste Lab Text":
|
| 281 |
+
labs_raw_text = st.text_area(
|
| 282 |
+
"Lab Report Text",
|
| 283 |
+
placeholder="""Example:
|
| 284 |
+
Culture: Urine
|
| 285 |
+
Organism: Escherichia coli
|
| 286 |
+
Colony Count: >100,000 CFU/mL
|
| 287 |
+
|
| 288 |
+
Susceptibility:
|
| 289 |
+
Ampicillin: R (MIC >32)
|
| 290 |
+
Ciprofloxacin: S (MIC 0.25)
|
| 291 |
+
Nitrofurantoin: S (MIC 16)
|
| 292 |
+
Trimethoprim-Sulfamethoxazole: R (MIC >4)""",
|
| 293 |
+
height=200
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
elif lab_input_method == "Upload File":
|
| 297 |
+
uploaded_file = st.file_uploader(
|
| 298 |
+
"Upload Lab Report (PDF or Image)",
|
| 299 |
+
type=["pdf", "png", "jpg", "jpeg"]
|
| 300 |
+
)
|
| 301 |
+
if uploaded_file:
|
| 302 |
+
st.info("File uploaded. Text extraction will be performed by the Vision Specialist agent.")
|
| 303 |
+
# In production, would extract text here
|
| 304 |
+
labs_raw_text = f"[Uploaded file: {uploaded_file.name}]"
|
| 305 |
+
|
| 306 |
+
# Run Pipeline Button
|
| 307 |
+
st.divider()
|
| 308 |
+
|
| 309 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 310 |
+
with col2:
|
| 311 |
+
run_pipeline_btn = st.button(
|
| 312 |
+
"🚀 Run Agent Pipeline",
|
| 313 |
+
type="primary",
|
| 314 |
+
use_container_width=True
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if run_pipeline_btn:
|
| 318 |
+
# Build patient data
|
| 319 |
+
patient_data = {
|
| 320 |
+
"age_years": age,
|
| 321 |
+
"weight_kg": weight,
|
| 322 |
+
"height_cm": height,
|
| 323 |
+
"sex": sex,
|
| 324 |
+
"serum_creatinine_mg_dl": creatinine,
|
| 325 |
+
"infection_site": infection_site,
|
| 326 |
+
"suspected_source": suspected_source or f"{infection_site} infection",
|
| 327 |
+
"medications": [m.strip() for m in medications.split("\n") if m.strip()],
|
| 328 |
+
"allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
|
| 329 |
+
"comorbidities": list(comorbidities) + list(risk_factors),
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
# Show pipeline progress
|
| 333 |
+
st.subheader("Pipeline Execution")
|
| 334 |
+
|
| 335 |
+
# Agent progress indicators
|
| 336 |
+
agents = [
|
| 337 |
+
("Intake Historian", "Analyzing patient data..."),
|
| 338 |
+
("Vision Specialist", "Processing lab results...") if labs_raw_text else None,
|
| 339 |
+
("Trend Analyst", "Analyzing MIC trends...") if labs_raw_text else None,
|
| 340 |
+
("Clinical Pharmacologist", "Generating recommendations..."),
|
| 341 |
+
]
|
| 342 |
+
agents = [a for a in agents if a is not None]
|
| 343 |
+
|
| 344 |
+
progress_bar = st.progress(0)
|
| 345 |
+
status_text = st.empty()
|
| 346 |
+
|
| 347 |
+
# Simulate pipeline execution (in production, would call actual pipeline)
|
| 348 |
+
try:
|
| 349 |
+
# Try to import and run the actual pipeline
|
| 350 |
+
from src.graph import run_pipeline
|
| 351 |
+
|
| 352 |
+
for i, (agent_name, status_msg) in enumerate(agents):
|
| 353 |
+
status_text.text(f"Agent {i+1}/{len(agents)}: {agent_name} - {status_msg}")
|
| 354 |
+
progress_bar.progress((i + 1) / len(agents))
|
| 355 |
+
|
| 356 |
+
# Run the actual pipeline
|
| 357 |
+
result = run_pipeline(patient_data, labs_raw_text)
|
| 358 |
+
st.session_state.pipeline_result = result
|
| 359 |
+
|
| 360 |
+
except Exception as e:
|
| 361 |
+
st.error(f"Pipeline execution error: {e}")
|
| 362 |
+
st.info("Running in demo mode with simulated output...")
|
| 363 |
+
|
| 364 |
+
# Demo mode - simulate results
|
| 365 |
+
st.session_state.pipeline_result = _generate_demo_result(patient_data, labs_raw_text)
|
| 366 |
+
|
| 367 |
+
progress_bar.progress(100)
|
| 368 |
+
status_text.text("Pipeline complete!")
|
| 369 |
+
|
| 370 |
+
# Display Results
|
| 371 |
+
if st.session_state.pipeline_result:
|
| 372 |
+
result = st.session_state.pipeline_result
|
| 373 |
+
|
| 374 |
+
st.divider()
|
| 375 |
+
st.subheader("Pipeline Results")
|
| 376 |
+
|
| 377 |
+
# Tabs for different result sections
|
| 378 |
+
tab1, tab2, tab3, tab4 = st.tabs([
|
| 379 |
+
"📋 Recommendation",
|
| 380 |
+
"👤 Patient Summary",
|
| 381 |
+
"🔬 Lab Analysis",
|
| 382 |
+
"⚠️ Safety Alerts"
|
| 383 |
+
])
|
| 384 |
+
|
| 385 |
+
with tab1:
|
| 386 |
+
rec = result.get("recommendation", {})
|
| 387 |
+
if rec:
|
| 388 |
+
st.markdown("### Antibiotic Recommendation")
|
| 389 |
+
|
| 390 |
+
col1, col2 = st.columns(2)
|
| 391 |
+
|
| 392 |
+
with col1:
|
| 393 |
+
st.markdown(f"**Primary:** {rec.get('primary_antibiotic', 'N/A')}")
|
| 394 |
+
st.markdown(f"**Dose:** {rec.get('dose', 'N/A')}")
|
| 395 |
+
st.markdown(f"**Route:** {rec.get('route', 'N/A')}")
|
| 396 |
+
st.markdown(f"**Frequency:** {rec.get('frequency', 'N/A')}")
|
| 397 |
+
st.markdown(f"**Duration:** {rec.get('duration', 'N/A')}")
|
| 398 |
+
|
| 399 |
+
with col2:
|
| 400 |
+
if rec.get("backup_antibiotic"):
|
| 401 |
+
st.markdown(f"**Alternative:** {rec.get('backup_antibiotic')}")
|
| 402 |
+
|
| 403 |
+
st.markdown("---")
|
| 404 |
+
st.markdown("**Rationale:**")
|
| 405 |
+
st.markdown(rec.get("rationale", "No rationale provided"))
|
| 406 |
+
|
| 407 |
+
if rec.get("references"):
|
| 408 |
+
st.markdown("**References:**")
|
| 409 |
+
for ref in rec["references"]:
|
| 410 |
+
st.markdown(f"- {ref}")
|
| 411 |
+
|
| 412 |
+
with tab2:
|
| 413 |
+
st.markdown("### Patient Assessment")
|
| 414 |
+
intake_notes = result.get("intake_notes", "")
|
| 415 |
+
if intake_notes:
|
| 416 |
+
try:
|
| 417 |
+
intake_data = json.loads(intake_notes) if isinstance(intake_notes, str) else intake_notes
|
| 418 |
+
st.json(intake_data)
|
| 419 |
+
except:
|
| 420 |
+
st.text(intake_notes)
|
| 421 |
+
|
| 422 |
+
if result.get("creatinine_clearance_ml_min"):
|
| 423 |
+
st.metric("Calculated CrCl", f"{result['creatinine_clearance_ml_min']} mL/min")
|
| 424 |
+
|
| 425 |
+
with tab3:
|
| 426 |
+
st.markdown("### Laboratory Analysis")
|
| 427 |
+
|
| 428 |
+
vision_notes = result.get("vision_notes", "No lab data processed")
|
| 429 |
+
if vision_notes and vision_notes != "No lab data provided":
|
| 430 |
+
try:
|
| 431 |
+
vision_data = json.loads(vision_notes) if isinstance(vision_notes, str) else vision_notes
|
| 432 |
+
st.json(vision_data)
|
| 433 |
+
except:
|
| 434 |
+
st.text(vision_notes)
|
| 435 |
+
|
| 436 |
+
trend_notes = result.get("trend_notes", "")
|
| 437 |
+
if trend_notes and trend_notes != "No MIC data available for trend analysis":
|
| 438 |
+
st.markdown("#### MIC Trend Analysis")
|
| 439 |
+
try:
|
| 440 |
+
trend_data = json.loads(trend_notes) if isinstance(trend_notes, str) else trend_notes
|
| 441 |
+
st.json(trend_data)
|
| 442 |
+
except:
|
| 443 |
+
st.text(trend_notes)
|
| 444 |
+
|
| 445 |
+
with tab4:
|
| 446 |
+
st.markdown("### Safety Alerts")
|
| 447 |
+
|
| 448 |
+
warnings = result.get("safety_warnings", [])
|
| 449 |
+
if warnings:
|
| 450 |
+
for warning in warnings:
|
| 451 |
+
st.warning(f"⚠️ {warning}")
|
| 452 |
+
else:
|
| 453 |
+
st.success("No safety concerns identified")
|
| 454 |
+
|
| 455 |
+
errors = result.get("errors", [])
|
| 456 |
+
if errors:
|
| 457 |
+
st.markdown("#### Errors")
|
| 458 |
+
for error in errors:
|
| 459 |
+
st.error(error)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict:
|
| 463 |
+
"""Generate demo result when actual pipeline is not available."""
|
| 464 |
+
result = {
|
| 465 |
+
"stage": "targeted" if labs_raw_text else "empirical",
|
| 466 |
+
"creatinine_clearance_ml_min": 58.3,
|
| 467 |
+
"intake_notes": json.dumps({
|
| 468 |
+
"patient_summary": f"65-year-old male with {patient_data.get('suspected_source', 'infection')}",
|
| 469 |
+
"creatinine_clearance_ml_min": 58.3,
|
| 470 |
+
"renal_dose_adjustment_needed": True,
|
| 471 |
+
"identified_risk_factors": patient_data.get("comorbidities", []),
|
| 472 |
+
"infection_severity": "moderate",
|
| 473 |
+
"recommended_stage": "targeted" if labs_raw_text else "empirical",
|
| 474 |
+
}),
|
| 475 |
+
"recommendation": {
|
| 476 |
+
"primary_antibiotic": "Ciprofloxacin",
|
| 477 |
+
"dose": "500mg",
|
| 478 |
+
"route": "PO",
|
| 479 |
+
"frequency": "Every 12 hours",
|
| 480 |
+
"duration": "7 days",
|
| 481 |
+
"backup_antibiotic": "Nitrofurantoin",
|
| 482 |
+
"rationale": "Based on suspected community-acquired UTI with moderate renal impairment. Ciprofloxacin provides good coverage for common uropathogens. Dose adjusted for CrCl 58 mL/min.",
|
| 483 |
+
"references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
|
| 484 |
+
"safety_alerts": [],
|
| 485 |
+
},
|
| 486 |
+
"safety_warnings": [],
|
| 487 |
+
"errors": [],
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
if labs_raw_text:
|
| 491 |
+
result["vision_notes"] = json.dumps({
|
| 492 |
+
"specimen_type": "urine",
|
| 493 |
+
"identified_organisms": [{"organism_name": "Escherichia coli", "significance": "pathogen"}],
|
| 494 |
+
"susceptibility_results": [
|
| 495 |
+
{"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
|
| 496 |
+
{"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
|
| 497 |
+
],
|
| 498 |
+
"extraction_confidence": 0.95,
|
| 499 |
+
})
|
| 500 |
+
result["trend_notes"] = json.dumps([{
|
| 501 |
+
"organism": "E. coli",
|
| 502 |
+
"antibiotic": "Ciprofloxacin",
|
| 503 |
+
"risk_level": "LOW",
|
| 504 |
+
"recommendation": "Continue current therapy",
|
| 505 |
+
}])
|
| 506 |
+
|
| 507 |
+
return result
|
| 508 |
|
| 509 |
|
| 510 |
def show_empirical_advisor():
|
| 511 |
+
st.header("💊 Empirical Advisor")
|
| 512 |
+
st.markdown("*Get empirical therapy recommendations before lab results*")
|
| 513 |
|
| 514 |
col1, col2 = st.columns([2, 1])
|
| 515 |
|
| 516 |
with col1:
|
| 517 |
infection_type = st.selectbox(
|
| 518 |
"Infection Type",
|
| 519 |
+
["Urinary Tract Infection", "Pneumonia", "Sepsis",
|
| 520 |
"Skin/Soft Tissue", "Intra-abdominal", "Meningitis"]
|
| 521 |
)
|
| 522 |
|
|
|
|
| 533 |
)
|
| 534 |
|
| 535 |
with col2:
|
| 536 |
+
st.markdown("**WHO AWaRe Categories**")
|
| 537 |
st.markdown("""
|
| 538 |
- **ACCESS**: First-line, low resistance
|
| 539 |
- **WATCH**: Higher resistance potential
|
| 540 |
- **RESERVE**: Last resort antibiotics
|
| 541 |
""")
|
| 542 |
|
| 543 |
+
if st.button("Get Recommendation", type="primary"):
|
| 544 |
+
with st.spinner("Searching guidelines..."):
|
|
|
|
| 545 |
guidance = get_empirical_therapy_guidance(
|
| 546 |
+
infection_type,
|
| 547 |
risk_factors
|
| 548 |
)
|
| 549 |
|
| 550 |
+
st.subheader("Guideline Recommendations")
|
| 551 |
|
| 552 |
if guidance.get("recommendations"):
|
| 553 |
for i, rec in enumerate(guidance["recommendations"][:3], 1):
|
| 554 |
+
with st.expander(f"Excerpt {i} (Relevance: {rec.get('relevance_score', 0):.2f})"):
|
| 555 |
st.markdown(rec.get("content", ""))
|
| 556 |
st.caption(f"Source: {rec.get('source', 'IDSA Guidelines')}")
|
| 557 |
|
|
|
|
| 558 |
if suspected_pathogen:
|
| 559 |
+
st.subheader(f"Resistance Data: {suspected_pathogen}")
|
|
|
|
| 560 |
effective = get_most_effective_antibiotics(suspected_pathogen, min_susceptibility=70)
|
| 561 |
|
| 562 |
if effective:
|
|
|
|
| 563 |
for ab in effective[:5]:
|
| 564 |
st.write(f"- **{ab.get('antibiotic')}**: {ab.get('avg_susceptibility', 0):.1f}% susceptible")
|
| 565 |
else:
|
| 566 |
+
st.info("No resistance data found.")
|
| 567 |
|
| 568 |
|
| 569 |
def show_lab_interpretation():
|
| 570 |
+
st.header("🔬 Lab Interpretation")
|
| 571 |
st.markdown("*Interpret antibiogram MIC values*")
|
| 572 |
|
| 573 |
col1, col2 = st.columns(2)
|
| 574 |
|
| 575 |
with col1:
|
| 576 |
+
pathogen = st.text_input("Pathogen", placeholder="e.g., Escherichia coli")
|
| 577 |
+
antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
|
| 578 |
+
mic_value = st.number_input("MIC (mg/L)", min_value=0.001, max_value=1024.0, value=1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
with col2:
|
| 581 |
+
st.markdown("**Interpretation Guide**")
|
| 582 |
st.markdown("""
|
| 583 |
+
- **S**: Susceptible - antibiotic effective
|
| 584 |
+
- **I**: Intermediate - may work at higher doses
|
| 585 |
+
- **R**: Resistant - do not use
|
| 586 |
""")
|
| 587 |
|
| 588 |
+
if st.button("Interpret", type="primary"):
|
| 589 |
if pathogen and antibiotic:
|
| 590 |
+
result = interpret_mic_value(pathogen, antibiotic, mic_value)
|
| 591 |
+
interpretation = result.get("interpretation", "UNKNOWN")
|
| 592 |
|
| 593 |
+
if interpretation == "SUSCEPTIBLE":
|
| 594 |
+
st.success(f"✅ {interpretation}")
|
| 595 |
+
elif interpretation == "RESISTANT":
|
| 596 |
+
st.error(f"❌ {interpretation}")
|
| 597 |
+
else:
|
| 598 |
+
st.warning(f"⚠️ {interpretation}")
|
| 599 |
|
| 600 |
+
st.markdown(f"**Details:** {result.get('message', '')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
|
| 603 |
def show_mic_trend_analysis():
|
| 604 |
st.header("📊 MIC Trend Analysis")
|
| 605 |
st.markdown("*Detect MIC creep over time*")
|
| 606 |
|
| 607 |
+
num_readings = st.slider("Historical readings", 2, 6, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
mic_values = []
|
| 610 |
cols = st.columns(num_readings)
|
| 611 |
|
| 612 |
for i, col in enumerate(cols):
|
| 613 |
+
mic = col.number_input(f"MIC {i+1}", min_value=0.001, max_value=256.0, value=float(2 ** i), key=f"mic_{i}")
|
| 614 |
+
mic_values.append({"date": f"T{i}", "mic_value": mic})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
+
if st.button("Analyze", type="primary"):
|
| 617 |
result = calculate_mic_trend(mic_values)
|
|
|
|
| 618 |
risk_level = result.get("risk_level", "UNKNOWN")
|
| 619 |
|
| 620 |
if risk_level == "HIGH":
|
| 621 |
+
st.markdown(f'<div class="risk-high">🚨 HIGH RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
|
|
|
|
| 622 |
elif risk_level == "MODERATE":
|
| 623 |
+
st.markdown(f'<div class="risk-moderate">⚠️ MODERATE: {result.get("alert", "")}</div>', unsafe_allow_html=True)
|
|
|
|
| 624 |
else:
|
| 625 |
+
st.markdown(f'<div class="risk-low">✅ LOW RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
col1, col2, col3 = st.columns(3)
|
| 628 |
+
col1.metric("Baseline", f"{result.get('baseline_mic', 'N/A')} mg/L")
|
| 629 |
+
col2.metric("Current", f"{result.get('current_mic', 'N/A')} mg/L")
|
| 630 |
+
col3.metric("Fold Change", f"{result.get('ratio', 'N/A')}x")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
|
| 633 |
def show_drug_safety():
|
| 634 |
st.header("⚠️ Drug Safety Check")
|
|
|
|
| 635 |
|
| 636 |
col1, col2 = st.columns(2)
|
| 637 |
|
| 638 |
with col1:
|
| 639 |
+
antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
|
| 640 |
+
current_meds = st.text_area("Current Medications", placeholder="Warfarin\nMetformin", height=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
with col2:
|
| 643 |
+
allergies = st.text_area("Allergies", placeholder="Penicillin\nSulfa", height=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
if st.button("Check Safety", type="primary"):
|
| 646 |
if antibiotic:
|
| 647 |
medications = [m.strip() for m in current_meds.split("\n") if m.strip()]
|
| 648 |
allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
|
| 649 |
|
| 650 |
+
result = screen_antibiotic_safety(antibiotic, medications, allergy_list)
|
|
|
|
| 651 |
|
| 652 |
+
if result.get("safe_to_use"):
|
| 653 |
+
st.success("✅ No critical safety concerns")
|
| 654 |
+
else:
|
| 655 |
+
st.error("❌ Safety concerns identified")
|
| 656 |
+
|
| 657 |
+
for alert in result.get("alerts", []):
|
| 658 |
+
st.warning(f"⚠️ {alert.get('message', '')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
|
| 661 |
def show_guidelines_search():
|
| 662 |
+
st.header("📚 Clinical Guidelines")
|
|
|
|
| 663 |
|
| 664 |
+
query = st.text_input("Search", placeholder="e.g., ESBL E. coli UTI treatment")
|
| 665 |
+
pathogen_filter = st.selectbox("Pathogen Filter", ["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"])
|
|
|
|
|
|
|
| 666 |
|
| 667 |
+
if st.button("Search", type="primary"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
if query:
|
| 669 |
+
filter_val = None if pathogen_filter == "All" else pathogen_filter
|
| 670 |
+
results = search_clinical_guidelines(query, pathogen_filter=filter_val, n_results=5)
|
| 671 |
+
|
| 672 |
+
if results:
|
| 673 |
+
for i, r in enumerate(results, 1):
|
| 674 |
+
with st.expander(f"Result {i} (Relevance: {r.get('relevance_score', 0):.2f})"):
|
| 675 |
+
st.markdown(r.get("content", ""))
|
| 676 |
+
else:
|
| 677 |
+
st.info("No results found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
|
| 680 |
if __name__ == "__main__":
|
src/agents.py
CHANGED
|
@@ -1,16 +1,574 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Med-I-C Multi-Agent System.
|
| 3 |
+
|
| 4 |
+
Implements the 4 specialized agents for the infection lifecycle workflow:
|
| 5 |
+
- Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
|
| 6 |
+
- Agent 2: Vision Specialist - Extract structured data from lab reports
|
| 7 |
+
- Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
|
| 8 |
+
- Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
from typing import Any, Dict, Optional
|
| 16 |
+
|
| 17 |
+
from .config import get_settings
|
| 18 |
+
from .loader import run_inference, TextModelName
|
| 19 |
+
from .prompts import (
|
| 20 |
+
INTAKE_HISTORIAN_SYSTEM,
|
| 21 |
+
INTAKE_HISTORIAN_PROMPT,
|
| 22 |
+
VISION_SPECIALIST_SYSTEM,
|
| 23 |
+
VISION_SPECIALIST_PROMPT,
|
| 24 |
+
TREND_ANALYST_SYSTEM,
|
| 25 |
+
TREND_ANALYST_PROMPT,
|
| 26 |
+
CLINICAL_PHARMACOLOGIST_SYSTEM,
|
| 27 |
+
CLINICAL_PHARMACOLOGIST_PROMPT,
|
| 28 |
+
TXGEMMA_SAFETY_PROMPT,
|
| 29 |
)
|
| 30 |
+
from .rag import get_context_for_agent
|
| 31 |
+
from .state import InfectionState
|
| 32 |
+
from .utils import (
|
| 33 |
+
calculate_crcl,
|
| 34 |
+
get_renal_dose_category,
|
| 35 |
+
safe_json_parse,
|
| 36 |
+
normalize_organism_name,
|
| 37 |
+
normalize_antibiotic_name,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# =============================================================================
|
| 44 |
+
# AGENT 1: INTAKE HISTORIAN
|
| 45 |
+
# =============================================================================
|
| 46 |
+
|
| 47 |
+
def run_intake_historian(state: InfectionState) -> InfectionState:
|
| 48 |
+
"""
|
| 49 |
+
Agent 1: Parse patient data, calculate CrCl, identify risk factors.
|
| 50 |
+
|
| 51 |
+
Input state fields used:
|
| 52 |
+
- age_years, weight_kg, height_cm, sex
|
| 53 |
+
- serum_creatinine_mg_dl
|
| 54 |
+
- medications, allergies, comorbidities
|
| 55 |
+
- suspected_source, infection_site
|
| 56 |
+
|
| 57 |
+
Output state fields updated:
|
| 58 |
+
- creatinine_clearance_ml_min
|
| 59 |
+
- intake_notes
|
| 60 |
+
- stage (empirical/targeted)
|
| 61 |
+
- route_to_vision
|
| 62 |
+
"""
|
| 63 |
+
logger.info("Running Intake Historian agent...")
|
| 64 |
+
|
| 65 |
+
# Calculate CrCl if we have the required data
|
| 66 |
+
crcl = None
|
| 67 |
+
if all([
|
| 68 |
+
state.get("age_years"),
|
| 69 |
+
state.get("weight_kg"),
|
| 70 |
+
state.get("serum_creatinine_mg_dl"),
|
| 71 |
+
state.get("sex"),
|
| 72 |
+
]):
|
| 73 |
+
try:
|
| 74 |
+
crcl = calculate_crcl(
|
| 75 |
+
age_years=state["age_years"],
|
| 76 |
+
weight_kg=state["weight_kg"],
|
| 77 |
+
serum_creatinine_mg_dl=state["serum_creatinine_mg_dl"],
|
| 78 |
+
sex=state["sex"],
|
| 79 |
+
use_ibw=True,
|
| 80 |
+
height_cm=state.get("height_cm"),
|
| 81 |
+
)
|
| 82 |
+
state["creatinine_clearance_ml_min"] = crcl
|
| 83 |
+
logger.info(f"Calculated CrCl: {crcl} mL/min")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Could not calculate CrCl: {e}")
|
| 86 |
+
state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
|
| 87 |
+
|
| 88 |
+
# Build patient data string for prompt
|
| 89 |
+
patient_data = _format_patient_data(state)
|
| 90 |
+
|
| 91 |
+
# Get RAG context
|
| 92 |
+
query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
|
| 93 |
+
rag_context = get_context_for_agent(
|
| 94 |
+
agent_name="intake_historian",
|
| 95 |
+
query=query,
|
| 96 |
+
patient_context={
|
| 97 |
+
"pathogen_type": state.get("suspected_source"),
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Format the prompt
|
| 102 |
+
prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
|
| 103 |
+
patient_data=patient_data,
|
| 104 |
+
medications=', '.join(state.get('medications', [])) or 'None reported',
|
| 105 |
+
allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
|
| 106 |
+
infection_site=state.get('infection_site', 'Unknown'),
|
| 107 |
+
suspected_source=state.get('suspected_source', 'Unknown'),
|
| 108 |
+
rag_context=rag_context,
|
| 109 |
+
)}"
|
| 110 |
+
|
| 111 |
+
# Run inference
|
| 112 |
+
try:
|
| 113 |
+
response = run_inference(
|
| 114 |
+
prompt=prompt,
|
| 115 |
+
model_name="medgemma_4b",
|
| 116 |
+
max_new_tokens=1024,
|
| 117 |
+
temperature=0.2,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Parse response
|
| 121 |
+
parsed = safe_json_parse(response)
|
| 122 |
+
if parsed:
|
| 123 |
+
state["intake_notes"] = json.dumps(parsed, indent=2)
|
| 124 |
+
|
| 125 |
+
# Update state from parsed response
|
| 126 |
+
if parsed.get("creatinine_clearance_ml_min") and crcl is None:
|
| 127 |
+
state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
|
| 128 |
+
|
| 129 |
+
# Determine stage
|
| 130 |
+
recommended_stage = parsed.get("recommended_stage", "empirical")
|
| 131 |
+
state["stage"] = recommended_stage
|
| 132 |
+
|
| 133 |
+
# Route to vision if we have lab data to process
|
| 134 |
+
state["route_to_vision"] = bool(state.get("labs_raw_text"))
|
| 135 |
+
else:
|
| 136 |
+
state["intake_notes"] = response
|
| 137 |
+
state["stage"] = "empirical"
|
| 138 |
+
state["route_to_vision"] = bool(state.get("labs_raw_text"))
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Intake Historian error: {e}")
|
| 142 |
+
state.setdefault("errors", []).append(f"Intake Historian error: {e}")
|
| 143 |
+
state["intake_notes"] = f"Error: {e}"
|
| 144 |
+
state["stage"] = "empirical"
|
| 145 |
+
|
| 146 |
+
logger.info(f"Intake Historian complete. Stage: {state.get('stage')}")
|
| 147 |
+
return state
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# =============================================================================
|
| 151 |
+
# AGENT 2: VISION SPECIALIST
|
| 152 |
+
# =============================================================================
|
| 153 |
+
|
| 154 |
+
def run_vision_specialist(state: InfectionState) -> InfectionState:
|
| 155 |
+
"""
|
| 156 |
+
Agent 2: Extract structured data from lab reports (text, images, PDFs).
|
| 157 |
+
|
| 158 |
+
Input state fields used:
|
| 159 |
+
- labs_raw_text (extracted text from lab report)
|
| 160 |
+
|
| 161 |
+
Output state fields updated:
|
| 162 |
+
- labs_parsed
|
| 163 |
+
- mic_data
|
| 164 |
+
- vision_notes
|
| 165 |
+
- route_to_trend_analyst
|
| 166 |
+
"""
|
| 167 |
+
logger.info("Running Vision Specialist agent...")
|
| 168 |
+
|
| 169 |
+
labs_raw = state.get("labs_raw_text", "")
|
| 170 |
+
if not labs_raw:
|
| 171 |
+
logger.info("No lab data to process, skipping Vision Specialist")
|
| 172 |
+
state["vision_notes"] = "No lab data provided"
|
| 173 |
+
state["route_to_trend_analyst"] = False
|
| 174 |
+
return state
|
| 175 |
+
|
| 176 |
+
# Detect language (simplified - in production would use langdetect)
|
| 177 |
+
language = "English (assumed)"
|
| 178 |
+
|
| 179 |
+
# Get RAG context for lab interpretation
|
| 180 |
+
rag_context = get_context_for_agent(
|
| 181 |
+
agent_name="vision_specialist",
|
| 182 |
+
query="culture sensitivity susceptibility interpretation",
|
| 183 |
+
patient_context={},
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Format the prompt
|
| 187 |
+
prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
|
| 188 |
+
report_content=labs_raw,
|
| 189 |
+
source_format='text',
|
| 190 |
+
language=language,
|
| 191 |
+
)}"
|
| 192 |
+
|
| 193 |
+
# Run inference
|
| 194 |
+
try:
|
| 195 |
+
response = run_inference(
|
| 196 |
+
prompt=prompt,
|
| 197 |
+
model_name="medgemma_4b",
|
| 198 |
+
max_new_tokens=2048,
|
| 199 |
+
temperature=0.1,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Parse response
|
| 203 |
+
parsed = safe_json_parse(response)
|
| 204 |
+
if parsed:
|
| 205 |
+
state["vision_notes"] = json.dumps(parsed, indent=2)
|
| 206 |
+
|
| 207 |
+
# Extract organisms and susceptibility data
|
| 208 |
+
organisms = parsed.get("identified_organisms", [])
|
| 209 |
+
susceptibility = parsed.get("susceptibility_results", [])
|
| 210 |
+
|
| 211 |
+
# Convert to MICDatum format
|
| 212 |
+
mic_data = []
|
| 213 |
+
for result in susceptibility:
|
| 214 |
+
mic_datum = {
|
| 215 |
+
"organism": normalize_organism_name(result.get("organism", "")),
|
| 216 |
+
"antibiotic": normalize_antibiotic_name(result.get("antibiotic", "")),
|
| 217 |
+
"mic_value": str(result.get("mic_value", "")),
|
| 218 |
+
"mic_unit": result.get("mic_unit", "mg/L"),
|
| 219 |
+
"interpretation": result.get("interpretation"),
|
| 220 |
+
}
|
| 221 |
+
mic_data.append(mic_datum)
|
| 222 |
+
|
| 223 |
+
state["mic_data"] = mic_data
|
| 224 |
+
state["labs_parsed"] = [{
|
| 225 |
+
"name": org.get("organism_name", "Unknown"),
|
| 226 |
+
"value": org.get("colony_count", ""),
|
| 227 |
+
"flag": "pathogen" if org.get("significance") == "pathogen" else None,
|
| 228 |
+
} for org in organisms]
|
| 229 |
+
|
| 230 |
+
# Route to trend analyst if we have MIC data
|
| 231 |
+
state["route_to_trend_analyst"] = len(mic_data) > 0
|
| 232 |
+
|
| 233 |
+
# Check for critical findings
|
| 234 |
+
critical = parsed.get("critical_findings", [])
|
| 235 |
+
if critical:
|
| 236 |
+
state.setdefault("safety_warnings", []).extend(critical)
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
state["vision_notes"] = response
|
| 240 |
+
state["route_to_trend_analyst"] = False
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.error(f"Vision Specialist error: {e}")
|
| 244 |
+
state.setdefault("errors", []).append(f"Vision Specialist error: {e}")
|
| 245 |
+
state["vision_notes"] = f"Error: {e}"
|
| 246 |
+
state["route_to_trend_analyst"] = False
|
| 247 |
+
|
| 248 |
+
logger.info(f"Vision Specialist complete. MIC data points: {len(state.get('mic_data', []))}")
|
| 249 |
+
return state
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# =============================================================================
|
| 253 |
+
# AGENT 3: TREND ANALYST
|
| 254 |
+
# =============================================================================
|
| 255 |
+
|
| 256 |
+
def run_trend_analyst(state: InfectionState) -> InfectionState:
|
| 257 |
+
"""
|
| 258 |
+
Agent 3: Analyze MIC trends and detect resistance velocity.
|
| 259 |
+
|
| 260 |
+
Input state fields used:
|
| 261 |
+
- mic_data (current MIC readings)
|
| 262 |
+
- Historical MIC data (if available)
|
| 263 |
+
|
| 264 |
+
Output state fields updated:
|
| 265 |
+
- mic_trend_summary
|
| 266 |
+
- trend_notes
|
| 267 |
+
- safety_warnings (if high risk detected)
|
| 268 |
+
"""
|
| 269 |
+
logger.info("Running Trend Analyst agent...")
|
| 270 |
+
|
| 271 |
+
mic_data = state.get("mic_data", [])
|
| 272 |
+
if not mic_data:
|
| 273 |
+
logger.info("No MIC data to analyze, skipping Trend Analyst")
|
| 274 |
+
state["trend_notes"] = "No MIC data available for trend analysis"
|
| 275 |
+
return state
|
| 276 |
+
|
| 277 |
+
# For each organism-antibiotic pair, analyze trends
|
| 278 |
+
trend_results = []
|
| 279 |
+
|
| 280 |
+
for mic in mic_data:
|
| 281 |
+
organism = mic.get("organism", "Unknown")
|
| 282 |
+
antibiotic = mic.get("antibiotic", "Unknown")
|
| 283 |
+
|
| 284 |
+
# Get RAG context for breakpoints
|
| 285 |
+
rag_context = get_context_for_agent(
|
| 286 |
+
agent_name="trend_analyst",
|
| 287 |
+
query=f"breakpoint {organism} {antibiotic}",
|
| 288 |
+
patient_context={
|
| 289 |
+
"organism": organism,
|
| 290 |
+
"antibiotic": antibiotic,
|
| 291 |
+
"region": state.get("country_or_region"),
|
| 292 |
+
},
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Format MIC history (in production, would pull from database)
|
| 296 |
+
mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
|
| 297 |
+
|
| 298 |
+
# Format prompt
|
| 299 |
+
prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
|
| 300 |
+
organism=organism,
|
| 301 |
+
antibiotic=antibiotic,
|
| 302 |
+
mic_history=json.dumps(mic_history, indent=2),
|
| 303 |
+
breakpoint_data=rag_context,
|
| 304 |
+
resistance_context='Regional data not available',
|
| 305 |
+
)}"
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
response = run_inference(
|
| 309 |
+
prompt=prompt,
|
| 310 |
+
model_name="medgemma_4b",
|
| 311 |
+
max_new_tokens=1024,
|
| 312 |
+
temperature=0.2,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
parsed = safe_json_parse(response)
|
| 316 |
+
if parsed:
|
| 317 |
+
trend_results.append(parsed)
|
| 318 |
+
|
| 319 |
+
# Add safety warning if high/critical risk
|
| 320 |
+
risk_level = parsed.get("risk_level", "LOW")
|
| 321 |
+
if risk_level in ["HIGH", "CRITICAL"]:
|
| 322 |
+
warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
|
| 323 |
+
state.setdefault("safety_warnings", []).append(warning)
|
| 324 |
+
else:
|
| 325 |
+
trend_results.append({"raw_response": response})
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
|
| 329 |
+
trend_results.append({"error": str(e)})
|
| 330 |
+
|
| 331 |
+
# Summarize trends
|
| 332 |
+
state["trend_notes"] = json.dumps(trend_results, indent=2)
|
| 333 |
+
|
| 334 |
+
# Create summary
|
| 335 |
+
high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
|
| 336 |
+
state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
|
| 337 |
+
|
| 338 |
+
logger.info(f"Trend Analyst complete. {state['mic_trend_summary']}")
|
| 339 |
+
return state
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# =============================================================================
|
| 343 |
+
# AGENT 4: CLINICAL PHARMACOLOGIST
|
| 344 |
+
# =============================================================================
|
| 345 |
+
|
| 346 |
+
def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
| 347 |
+
"""
|
| 348 |
+
Agent 4: Generate final antibiotic recommendation with safety checks.
|
| 349 |
+
|
| 350 |
+
Input state fields used:
|
| 351 |
+
- intake_notes, vision_notes, trend_notes
|
| 352 |
+
- age_years, weight_kg, creatinine_clearance_ml_min
|
| 353 |
+
- allergies, medications
|
| 354 |
+
- infection_site, suspected_source
|
| 355 |
+
|
| 356 |
+
Output state fields updated:
|
| 357 |
+
- recommendation
|
| 358 |
+
- pharmacology_notes
|
| 359 |
+
- safety_warnings (additional alerts)
|
| 360 |
+
"""
|
| 361 |
+
logger.info("Running Clinical Pharmacologist agent...")
|
| 362 |
+
|
| 363 |
+
# Gather all previous agent outputs
|
| 364 |
+
intake_summary = state.get("intake_notes", "No intake data")
|
| 365 |
+
lab_results = state.get("vision_notes", "No lab data")
|
| 366 |
+
trend_analysis = state.get("trend_notes", "No trend data")
|
| 367 |
+
|
| 368 |
+
# Get RAG context
|
| 369 |
+
query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
|
| 370 |
+
rag_context = get_context_for_agent(
|
| 371 |
+
agent_name="clinical_pharmacologist",
|
| 372 |
+
query=query,
|
| 373 |
+
patient_context={
|
| 374 |
+
"proposed_antibiotic": None, # Will be determined by agent
|
| 375 |
+
},
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Format prompt
|
| 379 |
+
prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
|
| 380 |
+
intake_summary=intake_summary,
|
| 381 |
+
lab_results=lab_results,
|
| 382 |
+
trend_analysis=trend_analysis,
|
| 383 |
+
age=state.get('age_years', 'Unknown'),
|
| 384 |
+
weight=state.get('weight_kg', 'Unknown'),
|
| 385 |
+
crcl=state.get('creatinine_clearance_ml_min', 'Unknown'),
|
| 386 |
+
allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
|
| 387 |
+
current_medications=', '.join(state.get('medications', [])) or 'None reported',
|
| 388 |
+
infection_site=state.get('infection_site', 'Unknown'),
|
| 389 |
+
suspected_source=state.get('suspected_source', 'Unknown'),
|
| 390 |
+
severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
|
| 391 |
+
rag_context=rag_context,
|
| 392 |
+
)}"
|
| 393 |
+
|
| 394 |
+
try:
|
| 395 |
+
response = run_inference(
|
| 396 |
+
prompt=prompt,
|
| 397 |
+
model_name="medgemma_4b",
|
| 398 |
+
max_new_tokens=2048,
|
| 399 |
+
temperature=0.2,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
parsed = safe_json_parse(response)
|
| 403 |
+
if parsed:
|
| 404 |
+
state["pharmacology_notes"] = json.dumps(parsed, indent=2)
|
| 405 |
+
|
| 406 |
+
# Build recommendation
|
| 407 |
+
primary = parsed.get("primary_recommendation", {})
|
| 408 |
+
recommendation = {
|
| 409 |
+
"primary_antibiotic": primary.get("antibiotic"),
|
| 410 |
+
"dose": primary.get("dose"),
|
| 411 |
+
"route": primary.get("route"),
|
| 412 |
+
"frequency": primary.get("frequency"),
|
| 413 |
+
"duration": primary.get("duration"),
|
| 414 |
+
"rationale": parsed.get("rationale"),
|
| 415 |
+
"references": parsed.get("guideline_references", []),
|
| 416 |
+
"safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
# Add alternative if provided
|
| 420 |
+
alt = parsed.get("alternative_recommendation", {})
|
| 421 |
+
if alt.get("antibiotic"):
|
| 422 |
+
recommendation["backup_antibiotic"] = alt.get("antibiotic")
|
| 423 |
+
|
| 424 |
+
state["recommendation"] = recommendation
|
| 425 |
+
|
| 426 |
+
# Add safety alerts to state
|
| 427 |
+
for alert in parsed.get("safety_alerts", []):
|
| 428 |
+
if alert.get("level") in ["WARNING", "CRITICAL"]:
|
| 429 |
+
state.setdefault("safety_warnings", []).append(alert.get("message"))
|
| 430 |
+
|
| 431 |
+
# Run TxGemma safety check (optional)
|
| 432 |
+
if primary.get("antibiotic"):
|
| 433 |
+
safety_result = _run_txgemma_safety_check(
|
| 434 |
+
antibiotic=primary.get("antibiotic"),
|
| 435 |
+
dose=primary.get("dose"),
|
| 436 |
+
route=primary.get("route"),
|
| 437 |
+
duration=primary.get("duration"),
|
| 438 |
+
age=state.get("age_years"),
|
| 439 |
+
crcl=state.get("creatinine_clearance_ml_min"),
|
| 440 |
+
medications=state.get("medications", []),
|
| 441 |
+
)
|
| 442 |
+
if safety_result:
|
| 443 |
+
state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
|
| 444 |
+
|
| 445 |
+
else:
|
| 446 |
+
state["pharmacology_notes"] = response
|
| 447 |
+
state["recommendation"] = {"rationale": response}
|
| 448 |
+
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Clinical Pharmacologist error: {e}")
|
| 451 |
+
state.setdefault("errors", []).append(f"Clinical Pharmacologist error: {e}")
|
| 452 |
+
state["pharmacology_notes"] = f"Error: {e}"
|
| 453 |
+
|
| 454 |
+
logger.info("Clinical Pharmacologist complete.")
|
| 455 |
+
return state
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# =============================================================================
|
| 459 |
+
# HELPER FUNCTIONS
|
| 460 |
+
# =============================================================================
|
| 461 |
+
|
| 462 |
+
def _format_patient_data(state: InfectionState) -> str:
|
| 463 |
+
"""Format patient data for prompt injection."""
|
| 464 |
+
lines = []
|
| 465 |
+
|
| 466 |
+
if state.get("patient_id"):
|
| 467 |
+
lines.append(f"Patient ID: {state['patient_id']}")
|
| 468 |
+
|
| 469 |
+
demographics = []
|
| 470 |
+
if state.get("age_years"):
|
| 471 |
+
demographics.append(f"{state['age_years']} years old")
|
| 472 |
+
if state.get("sex"):
|
| 473 |
+
demographics.append(state["sex"])
|
| 474 |
+
if demographics:
|
| 475 |
+
lines.append(f"Demographics: {', '.join(demographics)}")
|
| 476 |
+
|
| 477 |
+
if state.get("weight_kg"):
|
| 478 |
+
lines.append(f"Weight: {state['weight_kg']} kg")
|
| 479 |
+
if state.get("height_cm"):
|
| 480 |
+
lines.append(f"Height: {state['height_cm']} cm")
|
| 481 |
+
|
| 482 |
+
if state.get("serum_creatinine_mg_dl"):
|
| 483 |
+
lines.append(f"Serum Creatinine: {state['serum_creatinine_mg_dl']} mg/dL")
|
| 484 |
+
if state.get("creatinine_clearance_ml_min"):
|
| 485 |
+
crcl = state["creatinine_clearance_ml_min"]
|
| 486 |
+
category = get_renal_dose_category(crcl)
|
| 487 |
+
lines.append(f"CrCl: {crcl} mL/min ({category})")
|
| 488 |
+
|
| 489 |
+
if state.get("comorbidities"):
|
| 490 |
+
lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
|
| 491 |
+
|
| 492 |
+
if state.get("vitals"):
|
| 493 |
+
vitals_str = ", ".join(f"{k}: {v}" for k, v in state["vitals"].items())
|
| 494 |
+
lines.append(f"Vitals: {vitals_str}")
|
| 495 |
+
|
| 496 |
+
return "\n".join(lines) if lines else "No patient data available"
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _run_txgemma_safety_check(
|
| 500 |
+
antibiotic: str,
|
| 501 |
+
dose: Optional[str],
|
| 502 |
+
route: Optional[str],
|
| 503 |
+
duration: Optional[str],
|
| 504 |
+
age: Optional[float],
|
| 505 |
+
crcl: Optional[float],
|
| 506 |
+
medications: list,
|
| 507 |
+
) -> Optional[str]:
|
| 508 |
+
"""
|
| 509 |
+
Run TxGemma safety check (supplementary).
|
| 510 |
+
|
| 511 |
+
TxGemma is used only for safety validation, not primary recommendations.
|
| 512 |
+
"""
|
| 513 |
+
try:
|
| 514 |
+
prompt = TXGEMMA_SAFETY_PROMPT.format(
|
| 515 |
+
antibiotic=antibiotic,
|
| 516 |
+
dose=dose or "Not specified",
|
| 517 |
+
route=route or "Not specified",
|
| 518 |
+
duration=duration or "Not specified",
|
| 519 |
+
age=age or "Unknown",
|
| 520 |
+
crcl=crcl or "Unknown",
|
| 521 |
+
medications=", ".join(medications) if medications else "None",
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
response = run_inference(
|
| 525 |
+
prompt=prompt,
|
| 526 |
+
model_name="txgemma_2b", # Use smaller TxGemma for safety check
|
| 527 |
+
max_new_tokens=256,
|
| 528 |
+
temperature=0.1,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return response
|
| 532 |
+
|
| 533 |
+
except Exception as e:
|
| 534 |
+
logger.warning(f"TxGemma safety check failed: {e}")
|
| 535 |
+
return None
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
# =============================================================================
|
| 539 |
+
# AGENT REGISTRY
|
| 540 |
+
# =============================================================================
|
| 541 |
+
|
| 542 |
+
AGENTS = {
|
| 543 |
+
"intake_historian": run_intake_historian,
|
| 544 |
+
"vision_specialist": run_vision_specialist,
|
| 545 |
+
"trend_analyst": run_trend_analyst,
|
| 546 |
+
"clinical_pharmacologist": run_clinical_pharmacologist,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
|
| 551 |
+
"""
|
| 552 |
+
Run a specific agent by name.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
agent_name: Name of the agent to run
|
| 556 |
+
state: Current infection state
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
Updated infection state
|
| 560 |
+
"""
|
| 561 |
+
if agent_name not in AGENTS:
|
| 562 |
+
raise ValueError(f"Unknown agent: {agent_name}")
|
| 563 |
+
|
| 564 |
+
return AGENTS[agent_name](state)
|
| 565 |
+
|
| 566 |
|
| 567 |
+
__all__ = [
|
| 568 |
+
"run_intake_historian",
|
| 569 |
+
"run_vision_specialist",
|
| 570 |
+
"run_trend_analyst",
|
| 571 |
+
"run_clinical_pharmacologist",
|
| 572 |
+
"run_agent",
|
| 573 |
+
"AGENTS",
|
| 574 |
+
]
|
src/agents/__init__.py
DELETED
|
File without changes
|
src/graph.py
CHANGED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LangGraph Orchestrator for Med-I-C Multi-Agent System.
|
| 3 |
+
|
| 4 |
+
Implements the infection lifecycle workflow with conditional routing:
|
| 5 |
+
|
| 6 |
+
Stage 1 (Empirical - no lab results):
|
| 7 |
+
Intake Historian -> Clinical Pharmacologist
|
| 8 |
+
|
| 9 |
+
Stage 2 (Targeted - lab results available):
|
| 10 |
+
Intake Historian -> Vision Specialist -> Trend Analyst -> Clinical Pharmacologist
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from typing import Literal
|
| 17 |
+
|
| 18 |
+
from langgraph.graph import StateGraph, END
|
| 19 |
+
|
| 20 |
+
from .agents import (
|
| 21 |
+
run_intake_historian,
|
| 22 |
+
run_vision_specialist,
|
| 23 |
+
run_trend_analyst,
|
| 24 |
+
run_clinical_pharmacologist,
|
| 25 |
+
)
|
| 26 |
+
from .state import InfectionState
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# =============================================================================
|
| 32 |
+
# NODE FUNCTIONS (Wrapper for agents)
|
| 33 |
+
# =============================================================================
|
| 34 |
+
|
| 35 |
+
def intake_historian_node(state: InfectionState) -> InfectionState:
|
| 36 |
+
"""Node 1: Run Intake Historian agent."""
|
| 37 |
+
logger.info("Graph: Executing Intake Historian node")
|
| 38 |
+
return run_intake_historian(state)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def vision_specialist_node(state: InfectionState) -> InfectionState:
|
| 42 |
+
"""Node 2: Run Vision Specialist agent."""
|
| 43 |
+
logger.info("Graph: Executing Vision Specialist node")
|
| 44 |
+
return run_vision_specialist(state)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def trend_analyst_node(state: InfectionState) -> InfectionState:
|
| 48 |
+
"""Node 3: Run Trend Analyst agent."""
|
| 49 |
+
logger.info("Graph: Executing Trend Analyst node")
|
| 50 |
+
return run_trend_analyst(state)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def clinical_pharmacologist_node(state: InfectionState) -> InfectionState:
|
| 54 |
+
"""Node 4: Run Clinical Pharmacologist agent."""
|
| 55 |
+
logger.info("Graph: Executing Clinical Pharmacologist node")
|
| 56 |
+
return run_clinical_pharmacologist(state)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# CONDITIONAL ROUTING FUNCTIONS
|
| 61 |
+
# =============================================================================
|
| 62 |
+
|
| 63 |
+
def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
|
| 64 |
+
"""
|
| 65 |
+
Determine routing after Intake Historian.
|
| 66 |
+
|
| 67 |
+
Routes to Vision Specialist if:
|
| 68 |
+
- stage is "targeted" AND
|
| 69 |
+
- route_to_vision is True (i.e., we have lab data to process)
|
| 70 |
+
|
| 71 |
+
Otherwise routes directly to Clinical Pharmacologist (empirical path).
|
| 72 |
+
"""
|
| 73 |
+
stage = state.get("stage", "empirical")
|
| 74 |
+
has_lab_data = state.get("route_to_vision", False)
|
| 75 |
+
|
| 76 |
+
if stage == "targeted" and has_lab_data:
|
| 77 |
+
logger.info("Graph: Routing to Vision Specialist (targeted path)")
|
| 78 |
+
return "vision_specialist"
|
| 79 |
+
else:
|
| 80 |
+
logger.info("Graph: Routing to Clinical Pharmacologist (empirical path)")
|
| 81 |
+
return "clinical_pharmacologist"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
|
| 85 |
+
"""
|
| 86 |
+
Determine routing after Vision Specialist.
|
| 87 |
+
|
| 88 |
+
Routes to Trend Analyst if:
|
| 89 |
+
- route_to_trend_analyst is True (i.e., we have MIC data to analyze)
|
| 90 |
+
|
| 91 |
+
Otherwise skips to Clinical Pharmacologist.
|
| 92 |
+
"""
|
| 93 |
+
should_analyze_trends = state.get("route_to_trend_analyst", False)
|
| 94 |
+
|
| 95 |
+
if should_analyze_trends:
|
| 96 |
+
logger.info("Graph: Routing to Trend Analyst")
|
| 97 |
+
return "trend_analyst"
|
| 98 |
+
else:
|
| 99 |
+
logger.info("Graph: Skipping Trend Analyst, routing to Clinical Pharmacologist")
|
| 100 |
+
return "clinical_pharmacologist"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# =============================================================================
|
| 104 |
+
# GRAPH CONSTRUCTION
|
| 105 |
+
# =============================================================================
|
| 106 |
+
|
| 107 |
+
def build_infection_graph() -> StateGraph:
|
| 108 |
+
"""
|
| 109 |
+
Build the LangGraph StateGraph for the infection lifecycle workflow.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Compiled StateGraph ready for execution
|
| 113 |
+
"""
|
| 114 |
+
# Create the graph with InfectionState as the state schema
|
| 115 |
+
graph = StateGraph(InfectionState)
|
| 116 |
+
|
| 117 |
+
# Add nodes
|
| 118 |
+
graph.add_node("intake_historian", intake_historian_node)
|
| 119 |
+
graph.add_node("vision_specialist", vision_specialist_node)
|
| 120 |
+
graph.add_node("trend_analyst", trend_analyst_node)
|
| 121 |
+
graph.add_node("clinical_pharmacologist", clinical_pharmacologist_node)
|
| 122 |
+
|
| 123 |
+
# Set entry point
|
| 124 |
+
graph.set_entry_point("intake_historian")
|
| 125 |
+
|
| 126 |
+
# Add conditional edges from intake_historian
|
| 127 |
+
graph.add_conditional_edges(
|
| 128 |
+
"intake_historian",
|
| 129 |
+
route_after_intake,
|
| 130 |
+
{
|
| 131 |
+
"vision_specialist": "vision_specialist",
|
| 132 |
+
"clinical_pharmacologist": "clinical_pharmacologist",
|
| 133 |
+
}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Add conditional edges from vision_specialist
|
| 137 |
+
graph.add_conditional_edges(
|
| 138 |
+
"vision_specialist",
|
| 139 |
+
route_after_vision,
|
| 140 |
+
{
|
| 141 |
+
"trend_analyst": "trend_analyst",
|
| 142 |
+
"clinical_pharmacologist": "clinical_pharmacologist",
|
| 143 |
+
}
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Add edge from trend_analyst to clinical_pharmacologist
|
| 147 |
+
graph.add_edge("trend_analyst", "clinical_pharmacologist")
|
| 148 |
+
|
| 149 |
+
# Add edge from clinical_pharmacologist to END
|
| 150 |
+
graph.add_edge("clinical_pharmacologist", END)
|
| 151 |
+
|
| 152 |
+
return graph
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def compile_graph():
|
| 156 |
+
"""
|
| 157 |
+
Build and compile the graph for execution.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Compiled graph that can be invoked with .invoke(state)
|
| 161 |
+
"""
|
| 162 |
+
graph = build_infection_graph()
|
| 163 |
+
return graph.compile()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# =============================================================================
|
| 167 |
+
# EXECUTION HELPERS
|
| 168 |
+
# =============================================================================
|
| 169 |
+
|
| 170 |
+
def run_pipeline(
|
| 171 |
+
patient_data: dict,
|
| 172 |
+
labs_raw_text: str | None = None,
|
| 173 |
+
) -> InfectionState:
|
| 174 |
+
"""
|
| 175 |
+
Run the full infection lifecycle pipeline.
|
| 176 |
+
|
| 177 |
+
This is the main entry point for executing the multi-agent workflow.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
patient_data: Dict containing patient information:
|
| 181 |
+
- age_years: Patient age
|
| 182 |
+
- weight_kg: Patient weight
|
| 183 |
+
- sex: "male" or "female"
|
| 184 |
+
- serum_creatinine_mg_dl: Serum creatinine (optional)
|
| 185 |
+
- medications: List of current medications
|
| 186 |
+
- allergies: List of allergies
|
| 187 |
+
- comorbidities: List of comorbidities
|
| 188 |
+
- infection_site: Site of infection
|
| 189 |
+
- suspected_source: Suspected pathogen/source
|
| 190 |
+
|
| 191 |
+
labs_raw_text: Raw text from lab report (if available).
|
| 192 |
+
If provided, triggers targeted (Stage 2) pathway.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Final InfectionState with recommendation
|
| 196 |
+
|
| 197 |
+
Example:
|
| 198 |
+
>>> state = run_pipeline(
|
| 199 |
+
... patient_data={
|
| 200 |
+
... "age_years": 65,
|
| 201 |
+
... "weight_kg": 70,
|
| 202 |
+
... "sex": "male",
|
| 203 |
+
... "serum_creatinine_mg_dl": 1.2,
|
| 204 |
+
... "medications": ["metformin", "lisinopril"],
|
| 205 |
+
... "allergies": ["penicillin"],
|
| 206 |
+
... "infection_site": "urinary",
|
| 207 |
+
... "suspected_source": "community UTI",
|
| 208 |
+
... },
|
| 209 |
+
... labs_raw_text="E. coli isolated. Ciprofloxacin MIC: 0.5 mg/L (S)"
|
| 210 |
+
... )
|
| 211 |
+
>>> print(state["recommendation"]["primary_antibiotic"])
|
| 212 |
+
"""
|
| 213 |
+
# Build initial state from patient data
|
| 214 |
+
initial_state: InfectionState = {
|
| 215 |
+
"age_years": patient_data.get("age_years"),
|
| 216 |
+
"weight_kg": patient_data.get("weight_kg"),
|
| 217 |
+
"height_cm": patient_data.get("height_cm"),
|
| 218 |
+
"sex": patient_data.get("sex"),
|
| 219 |
+
"serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"),
|
| 220 |
+
"medications": patient_data.get("medications", []),
|
| 221 |
+
"allergies": patient_data.get("allergies", []),
|
| 222 |
+
"comorbidities": patient_data.get("comorbidities", []),
|
| 223 |
+
"infection_site": patient_data.get("infection_site"),
|
| 224 |
+
"suspected_source": patient_data.get("suspected_source"),
|
| 225 |
+
"country_or_region": patient_data.get("country_or_region"),
|
| 226 |
+
"vitals": patient_data.get("vitals", {}),
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
# Add lab data if provided
|
| 230 |
+
if labs_raw_text:
|
| 231 |
+
initial_state["labs_raw_text"] = labs_raw_text
|
| 232 |
+
initial_state["stage"] = "targeted"
|
| 233 |
+
else:
|
| 234 |
+
initial_state["stage"] = "empirical"
|
| 235 |
+
|
| 236 |
+
# Compile and run the graph
|
| 237 |
+
logger.info(f"Starting pipeline execution (stage: {initial_state['stage']})")
|
| 238 |
+
|
| 239 |
+
compiled_graph = compile_graph()
|
| 240 |
+
final_state = compiled_graph.invoke(initial_state)
|
| 241 |
+
|
| 242 |
+
logger.info("Pipeline execution complete")
|
| 243 |
+
|
| 244 |
+
return final_state
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def run_empirical_pipeline(patient_data: dict) -> InfectionState:
|
| 248 |
+
"""
|
| 249 |
+
Run Stage 1 (Empirical) pipeline only.
|
| 250 |
+
|
| 251 |
+
Shorthand for run_pipeline without lab data.
|
| 252 |
+
"""
|
| 253 |
+
return run_pipeline(patient_data, labs_raw_text=None)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
|
| 257 |
+
"""
|
| 258 |
+
Run Stage 2 (Targeted) pipeline with lab data.
|
| 259 |
+
|
| 260 |
+
Shorthand for run_pipeline with lab data.
|
| 261 |
+
"""
|
| 262 |
+
return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# =============================================================================
|
| 266 |
+
# VISUALIZATION (for debugging)
|
| 267 |
+
# =============================================================================
|
| 268 |
+
|
| 269 |
+
def get_graph_mermaid() -> str:
|
| 270 |
+
"""
|
| 271 |
+
Get Mermaid diagram representation of the graph.
|
| 272 |
+
|
| 273 |
+
Useful for documentation and debugging.
|
| 274 |
+
"""
|
| 275 |
+
graph = build_infection_graph()
|
| 276 |
+
try:
|
| 277 |
+
return graph.compile().get_graph().draw_mermaid()
|
| 278 |
+
except Exception:
|
| 279 |
+
# Fallback: return manual diagram
|
| 280 |
+
return """
|
| 281 |
+
graph TD
|
| 282 |
+
A[intake_historian] --> B{route_after_intake}
|
| 283 |
+
B -->|targeted + lab data| C[vision_specialist]
|
| 284 |
+
B -->|empirical| E[clinical_pharmacologist]
|
| 285 |
+
C --> D{route_after_vision}
|
| 286 |
+
D -->|has MIC data| F[trend_analyst]
|
| 287 |
+
D -->|no MIC data| E
|
| 288 |
+
F --> E
|
| 289 |
+
E --> G[END]
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
__all__ = [
|
| 294 |
+
"build_infection_graph",
|
| 295 |
+
"compile_graph",
|
| 296 |
+
"run_pipeline",
|
| 297 |
+
"run_empirical_pipeline",
|
| 298 |
+
"run_targeted_pipeline",
|
| 299 |
+
"get_graph_mermaid",
|
| 300 |
+
]
|
src/prompts.py
CHANGED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates for Med-I-C multi-agent system.
|
| 3 |
+
|
| 4 |
+
Each agent has a specific role in the infection lifecycle workflow:
|
| 5 |
+
- Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
|
| 6 |
+
- Agent 2: Vision Specialist - Extract structured data from lab reports (images/PDFs)
|
| 7 |
+
- Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
|
| 8 |
+
- Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# AGENT 1: INTAKE HISTORIAN
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
|
| 18 |
+
|
| 19 |
+
1. Parse and structure patient demographics and clinical history
|
| 20 |
+
2. Calculate Creatinine Clearance (CrCl) using the Cockcroft-Gault equation when data is available
|
| 21 |
+
3. Identify key risk factors for antimicrobial-resistant infections
|
| 22 |
+
4. Determine the appropriate treatment stage (empirical vs targeted)
|
| 23 |
+
|
| 24 |
+
RISK FACTORS TO IDENTIFY:
|
| 25 |
+
- Prior MRSA or MDR infection history
|
| 26 |
+
- Recent antibiotic use (within 90 days)
|
| 27 |
+
- Healthcare-associated vs community-acquired infection
|
| 28 |
+
- Immunocompromised status
|
| 29 |
+
- Recent hospitalization or ICU stay
|
| 30 |
+
- Presence of medical devices (catheters, lines)
|
| 31 |
+
- Travel history to high-resistance regions
|
| 32 |
+
- Renal or hepatic impairment
|
| 33 |
+
|
| 34 |
+
OUTPUT FORMAT:
|
| 35 |
+
Provide a structured JSON response with the following fields:
|
| 36 |
+
{
|
| 37 |
+
"patient_summary": "Brief clinical summary",
|
| 38 |
+
"creatinine_clearance_ml_min": <number or null>,
|
| 39 |
+
"renal_dose_adjustment_needed": <boolean>,
|
| 40 |
+
"identified_risk_factors": ["list", "of", "factors"],
|
| 41 |
+
"suspected_pathogens": ["list", "of", "likely", "organisms"],
|
| 42 |
+
"infection_severity": "mild|moderate|severe|critical",
|
| 43 |
+
"recommended_stage": "empirical|targeted",
|
| 44 |
+
"notes": "Any additional clinical observations"
|
| 45 |
+
}
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
INTAKE_HISTORIAN_PROMPT = """Analyze the following patient information and provide a structured clinical assessment.
|
| 49 |
+
|
| 50 |
+
PATIENT DATA:
|
| 51 |
+
{patient_data}
|
| 52 |
+
|
| 53 |
+
CURRENT MEDICATIONS:
|
| 54 |
+
{medications}
|
| 55 |
+
|
| 56 |
+
KNOWN ALLERGIES:
|
| 57 |
+
{allergies}
|
| 58 |
+
|
| 59 |
+
CLINICAL CONTEXT:
|
| 60 |
+
- Suspected infection site: {infection_site}
|
| 61 |
+
- Suspected source: {suspected_source}
|
| 62 |
+
|
| 63 |
+
RAG CONTEXT (Relevant Guidelines):
|
| 64 |
+
{rag_context}
|
| 65 |
+
|
| 66 |
+
Provide your structured assessment following the system instructions."""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# =============================================================================
|
| 70 |
+
# AGENT 2: VISION SPECIALIST
|
| 71 |
+
# =============================================================================
|
| 72 |
+
|
| 73 |
+
VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
|
| 74 |
+
|
| 75 |
+
1. Extract structured data from laboratory reports (culture & sensitivity, antibiograms)
|
| 76 |
+
2. Handle reports in ANY language - always output in English
|
| 77 |
+
3. Identify pathogens, antibiotics tested, MIC values, and S/I/R interpretations
|
| 78 |
+
4. Flag any critical or unusual findings
|
| 79 |
+
|
| 80 |
+
SUPPORTED REPORT TYPES:
|
| 81 |
+
- Culture & Sensitivity reports
|
| 82 |
+
- Antibiogram reports
|
| 83 |
+
- Blood culture reports
|
| 84 |
+
- Urine culture reports
|
| 85 |
+
- Wound culture reports
|
| 86 |
+
- Respiratory culture reports
|
| 87 |
+
|
| 88 |
+
OUTPUT FORMAT:
|
| 89 |
+
Provide a structured JSON response:
|
| 90 |
+
{
|
| 91 |
+
"specimen_type": "blood|urine|wound|respiratory|other",
|
| 92 |
+
"collection_date": "YYYY-MM-DD or null",
|
| 93 |
+
"identified_organisms": [
|
| 94 |
+
{
|
| 95 |
+
"organism_name": "Standardized English name",
|
| 96 |
+
"original_name": "Name as written in report",
|
| 97 |
+
"colony_count": "if available",
|
| 98 |
+
"significance": "pathogen|colonizer|contaminant"
|
| 99 |
+
}
|
| 100 |
+
],
|
| 101 |
+
"susceptibility_results": [
|
| 102 |
+
{
|
| 103 |
+
"organism": "Organism name",
|
| 104 |
+
"antibiotic": "Standardized antibiotic name",
|
| 105 |
+
"mic_value": <number or null>,
|
| 106 |
+
"mic_unit": "mg/L",
|
| 107 |
+
"interpretation": "S|I|R",
|
| 108 |
+
"method": "disk diffusion|MIC|E-test"
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"critical_findings": ["List of urgent findings requiring immediate attention"],
|
| 112 |
+
"report_quality": "complete|partial|poor",
|
| 113 |
+
"extraction_confidence": 0.0-1.0,
|
| 114 |
+
"notes": "Any relevant observations about the report"
|
| 115 |
+
}
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
VISION_SPECIALIST_PROMPT = """Extract structured laboratory data from the following report.
|
| 119 |
+
|
| 120 |
+
REPORT CONTENT:
|
| 121 |
+
{report_content}
|
| 122 |
+
|
| 123 |
+
REPORT METADATA:
|
| 124 |
+
- Source format: {source_format}
|
| 125 |
+
- Language detected: {language}
|
| 126 |
+
|
| 127 |
+
Extract all pathogen identifications, susceptibility results, and MIC values.
|
| 128 |
+
Always standardize to English medical terminology.
|
| 129 |
+
Flag any critical findings that require urgent attention.
|
| 130 |
+
|
| 131 |
+
Provide your structured extraction following the system instructions."""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# =============================================================================
|
| 135 |
+
# AGENT 3: TREND ANALYST
|
| 136 |
+
# =============================================================================
|
| 137 |
+
|
| 138 |
+
TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
|
| 139 |
+
|
| 140 |
+
1. Analyze MIC trends over time to detect "MIC Creep"
|
| 141 |
+
2. Calculate resistance velocity and predict treatment failure risk
|
| 142 |
+
3. Compare current MICs against EUCAST/CLSI breakpoints
|
| 143 |
+
4. Identify emerging resistance patterns
|
| 144 |
+
|
| 145 |
+
MIC CREEP DEFINITION:
|
| 146 |
+
MIC creep is a gradual increase in MIC values over time, even while remaining
|
| 147 |
+
technically "Susceptible". This can predict treatment failure before formal
|
| 148 |
+
resistance develops.
|
| 149 |
+
|
| 150 |
+
RISK STRATIFICATION:
|
| 151 |
+
- LOW: Stable MIC, well below breakpoint (>4x margin)
|
| 152 |
+
- MODERATE: Rising trend but still 2-4x below breakpoint
|
| 153 |
+
- HIGH: Approaching breakpoint (<2x margin) or rapid increase
|
| 154 |
+
- CRITICAL: At or above breakpoint, or >4-fold increase over baseline
|
| 155 |
+
|
| 156 |
+
OUTPUT FORMAT:
|
| 157 |
+
Provide a structured JSON response:
|
| 158 |
+
{
|
| 159 |
+
"organism": "Pathogen name",
|
| 160 |
+
"antibiotic": "Antibiotic name",
|
| 161 |
+
"mic_history": [
|
| 162 |
+
{"date": "YYYY-MM-DD", "mic_value": <number>, "interpretation": "S|I|R"}
|
| 163 |
+
],
|
| 164 |
+
"baseline_mic": <number>,
|
| 165 |
+
"current_mic": <number>,
|
| 166 |
+
"fold_change": <number>,
|
| 167 |
+
"trend": "stable|increasing|decreasing|fluctuating",
|
| 168 |
+
"resistance_velocity": <number per time unit>,
|
| 169 |
+
"breakpoint_susceptible": <number>,
|
| 170 |
+
"breakpoint_resistant": <number>,
|
| 171 |
+
"margin_to_breakpoint": <number>,
|
| 172 |
+
"risk_level": "LOW|MODERATE|HIGH|CRITICAL",
|
| 173 |
+
"predicted_time_to_resistance": "estimate or N/A",
|
| 174 |
+
"recommendation": "Continue current therapy|Consider alternatives|Urgent switch needed",
|
| 175 |
+
"alternative_antibiotics": ["list", "if", "applicable"],
|
| 176 |
+
"rationale": "Detailed explanation of risk assessment"
|
| 177 |
+
}
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
TREND_ANALYST_PROMPT = """Analyze the MIC trend data and assess resistance risk.
|
| 181 |
+
|
| 182 |
+
ORGANISM: {organism}
|
| 183 |
+
ANTIBIOTIC: {antibiotic}
|
| 184 |
+
|
| 185 |
+
HISTORICAL MIC DATA:
|
| 186 |
+
{mic_history}
|
| 187 |
+
|
| 188 |
+
CURRENT BREAKPOINTS (EUCAST v16.0):
|
| 189 |
+
{breakpoint_data}
|
| 190 |
+
|
| 191 |
+
REGIONAL RESISTANCE DATA:
|
| 192 |
+
{resistance_context}
|
| 193 |
+
|
| 194 |
+
Analyze the trend, calculate risk level, and provide recommendations.
|
| 195 |
+
Follow the system instructions for output format."""
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# =============================================================================
|
| 199 |
+
# AGENT 4: CLINICAL PHARMACOLOGIST
|
| 200 |
+
# =============================================================================
|
| 201 |
+
|
| 202 |
+
CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
|
| 203 |
+
|
| 204 |
+
1. Synthesize all available clinical data into a final antibiotic recommendation
|
| 205 |
+
2. Apply WHO AWaRe classification principles (ACCESS -> WATCH -> RESERVE)
|
| 206 |
+
3. Perform comprehensive drug safety checks
|
| 207 |
+
4. Adjust dosing for renal function
|
| 208 |
+
5. Consider local resistance patterns and guideline recommendations
|
| 209 |
+
|
| 210 |
+
PRESCRIBING PRINCIPLES:
|
| 211 |
+
1. Start narrow, escalate only when justified
|
| 212 |
+
2. De-escalate when culture results allow
|
| 213 |
+
3. Prefer ACCESS category antibiotics when appropriate
|
| 214 |
+
4. Consider pharmacokinetic/pharmacodynamic (PK/PD) optimization
|
| 215 |
+
5. Document rationale for WATCH/RESERVE antibiotic use
|
| 216 |
+
|
| 217 |
+
SAFETY CHECKS:
|
| 218 |
+
- Drug-drug interactions (especially warfarin, methotrexate, immunosuppressants)
|
| 219 |
+
- Drug-allergy cross-reactivity (especially beta-lactam allergies)
|
| 220 |
+
- Renal dose adjustments (use CrCl)
|
| 221 |
+
- QT prolongation risk (fluoroquinolones, azithromycin)
|
| 222 |
+
- Pregnancy/lactation considerations
|
| 223 |
+
- Age-related considerations (pediatric/geriatric)
|
| 224 |
+
|
| 225 |
+
OUTPUT FORMAT:
|
| 226 |
+
Provide a structured JSON response:
|
| 227 |
+
{
|
| 228 |
+
"primary_recommendation": {
|
| 229 |
+
"antibiotic": "Drug name",
|
| 230 |
+
"dose": "Amount and unit",
|
| 231 |
+
"route": "IV|PO|IM",
|
| 232 |
+
"frequency": "Dosing interval",
|
| 233 |
+
"duration": "Expected treatment duration",
|
| 234 |
+
"aware_category": "ACCESS|WATCH|RESERVE"
|
| 235 |
+
},
|
| 236 |
+
"alternative_recommendation": {
|
| 237 |
+
"antibiotic": "Alternative drug",
|
| 238 |
+
"dose": "Amount and unit",
|
| 239 |
+
"route": "IV|PO|IM",
|
| 240 |
+
"frequency": "Dosing interval",
|
| 241 |
+
"indication": "When to use alternative"
|
| 242 |
+
},
|
| 243 |
+
"dose_adjustments": {
|
| 244 |
+
"renal": "Adjustment details or 'None needed'",
|
| 245 |
+
"hepatic": "Adjustment details or 'None needed'"
|
| 246 |
+
},
|
| 247 |
+
"safety_alerts": [
|
| 248 |
+
{
|
| 249 |
+
"level": "INFO|WARNING|CRITICAL",
|
| 250 |
+
"type": "interaction|allergy|contraindication|monitoring",
|
| 251 |
+
"message": "Detailed alert message",
|
| 252 |
+
"action_required": "What to do"
|
| 253 |
+
}
|
| 254 |
+
],
|
| 255 |
+
"monitoring_parameters": ["List of labs/vitals to monitor"],
|
| 256 |
+
"de_escalation_plan": "When and how to de-escalate",
|
| 257 |
+
"rationale": "Clinical reasoning for recommendation",
|
| 258 |
+
"guideline_references": ["Supporting guideline citations"],
|
| 259 |
+
"confidence_level": "high|moderate|low",
|
| 260 |
+
"requires_id_consult": <boolean>
|
| 261 |
+
}
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
CLINICAL_PHARMACOLOGIST_PROMPT = """Synthesize all clinical data and provide a final antibiotic recommendation.
|
| 265 |
+
|
| 266 |
+
PATIENT SUMMARY (from Intake Historian):
|
| 267 |
+
{intake_summary}
|
| 268 |
+
|
| 269 |
+
LAB RESULTS (from Vision Specialist):
|
| 270 |
+
{lab_results}
|
| 271 |
+
|
| 272 |
+
MIC TREND ANALYSIS (from Trend Analyst):
|
| 273 |
+
{trend_analysis}
|
| 274 |
+
|
| 275 |
+
PATIENT PARAMETERS:
|
| 276 |
+
- Age: {age} years
|
| 277 |
+
- Weight: {weight} kg
|
| 278 |
+
- CrCl: {crcl} mL/min
|
| 279 |
+
- Allergies: {allergies}
|
| 280 |
+
- Current medications: {current_medications}
|
| 281 |
+
|
| 282 |
+
INFECTION CONTEXT:
|
| 283 |
+
- Site: {infection_site}
|
| 284 |
+
- Source: {suspected_source}
|
| 285 |
+
- Severity: {severity}
|
| 286 |
+
|
| 287 |
+
RAG CONTEXT (Guidelines & Safety Data):
|
| 288 |
+
{rag_context}
|
| 289 |
+
|
| 290 |
+
Provide your final recommendation following the system instructions.
|
| 291 |
+
Ensure all safety checks are performed and documented."""
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# =============================================================================
|
| 295 |
+
# TXGEMMA SAFETY CHECKER (Supplementary)
|
| 296 |
+
# =============================================================================
|
| 297 |
+
|
| 298 |
+
TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
|
| 299 |
+
|
| 300 |
+
PROPOSED ANTIBIOTIC: {antibiotic}
|
| 301 |
+
DOSE: {dose}
|
| 302 |
+
ROUTE: {route}
|
| 303 |
+
DURATION: {duration}
|
| 304 |
+
|
| 305 |
+
PATIENT CONTEXT:
|
| 306 |
+
- Age: {age}
|
| 307 |
+
- Renal function (CrCl): {crcl} mL/min
|
| 308 |
+
- Current medications: {medications}
|
| 309 |
+
|
| 310 |
+
Evaluate for:
|
| 311 |
+
1. Known toxicity concerns
|
| 312 |
+
2. Drug-drug interaction potential
|
| 313 |
+
3. Dose appropriateness for renal function
|
| 314 |
+
|
| 315 |
+
Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# =============================================================================
|
| 319 |
+
# HELPER TEMPLATES
|
| 320 |
+
# =============================================================================
|
| 321 |
+
|
| 322 |
+
ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
|
| 323 |
+
|
| 324 |
+
ERROR DETAILS:
|
| 325 |
+
{error_details}
|
| 326 |
+
|
| 327 |
+
ORIGINAL INPUT:
|
| 328 |
+
{original_input}
|
| 329 |
+
|
| 330 |
+
Please attempt to recover by providing a valid response or indicating what additional information is needed."""
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
FALLBACK_EMPIRICAL_PROMPT = """No culture data is available. Based on the clinical presentation, provide empirical antibiotic recommendations.
|
| 334 |
+
|
| 335 |
+
CLINICAL SCENARIO:
|
| 336 |
+
- Infection site: {infection_site}
|
| 337 |
+
- Patient risk factors: {risk_factors}
|
| 338 |
+
- Local resistance patterns: {local_resistance}
|
| 339 |
+
|
| 340 |
+
Recommend appropriate empirical therapy following WHO AWaRe principles."""
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
__all__ = [
|
| 344 |
+
"INTAKE_HISTORIAN_SYSTEM",
|
| 345 |
+
"INTAKE_HISTORIAN_PROMPT",
|
| 346 |
+
"VISION_SPECIALIST_SYSTEM",
|
| 347 |
+
"VISION_SPECIALIST_PROMPT",
|
| 348 |
+
"TREND_ANALYST_SYSTEM",
|
| 349 |
+
"TREND_ANALYST_PROMPT",
|
| 350 |
+
"CLINICAL_PHARMACOLOGIST_SYSTEM",
|
| 351 |
+
"CLINICAL_PHARMACOLOGIST_PROMPT",
|
| 352 |
+
"TXGEMMA_SAFETY_PROMPT",
|
| 353 |
+
"ERROR_RECOVERY_PROMPT",
|
| 354 |
+
"FALLBACK_EMPIRICAL_PROMPT",
|
| 355 |
+
]
|
src/rag.py
CHANGED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG (Retrieval Augmented Generation) module for Med-I-C.
|
| 3 |
+
|
| 4 |
+
Provides unified retrieval across multiple knowledge collections:
|
| 5 |
+
- antibiotic_guidelines: WHO/IDSA treatment guidelines
|
| 6 |
+
- mic_breakpoints: EUCAST/CLSI breakpoint tables
|
| 7 |
+
- drug_safety: Drug interactions, warnings, contraindications
|
| 8 |
+
- pathogen_resistance: Regional resistance patterns
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, List, Optional
|
| 16 |
+
|
| 17 |
+
from .config import get_settings
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# CHROMA CLIENT & EMBEDDING SETUP
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
_chroma_client = None
|
| 27 |
+
_embedding_function = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_chroma_client():
|
| 31 |
+
"""Get or create ChromaDB persistent client."""
|
| 32 |
+
global _chroma_client
|
| 33 |
+
if _chroma_client is None:
|
| 34 |
+
import chromadb
|
| 35 |
+
|
| 36 |
+
settings = get_settings()
|
| 37 |
+
chroma_path = settings.chroma_db_dir
|
| 38 |
+
chroma_path.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
_chroma_client = chromadb.PersistentClient(path=str(chroma_path))
|
| 40 |
+
return _chroma_client
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_embedding_function():
|
| 44 |
+
"""Get or create the embedding function."""
|
| 45 |
+
global _embedding_function
|
| 46 |
+
if _embedding_function is None:
|
| 47 |
+
from chromadb.utils import embedding_functions
|
| 48 |
+
|
| 49 |
+
settings = get_settings()
|
| 50 |
+
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 51 |
+
model_name=settings.embedding_model_name.split("/")[-1]
|
| 52 |
+
)
|
| 53 |
+
return _embedding_function
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_collection(name: str):
|
| 57 |
+
"""
|
| 58 |
+
Get a ChromaDB collection by name.
|
| 59 |
+
|
| 60 |
+
Returns None if collection doesn't exist.
|
| 61 |
+
"""
|
| 62 |
+
client = get_chroma_client()
|
| 63 |
+
ef = get_embedding_function()
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
return client.get_collection(name=name, embedding_function=ef)
|
| 67 |
+
except Exception:
|
| 68 |
+
logger.warning(f"Collection '{name}' not found")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# =============================================================================
|
| 73 |
+
# COLLECTION-SPECIFIC RETRIEVERS
|
| 74 |
+
# =============================================================================
|
| 75 |
+
|
| 76 |
+
def search_antibiotic_guidelines(
|
| 77 |
+
query: str,
|
| 78 |
+
n_results: int = 5,
|
| 79 |
+
pathogen_filter: Optional[str] = None,
|
| 80 |
+
) -> List[Dict[str, Any]]:
|
| 81 |
+
"""
|
| 82 |
+
Search antibiotic treatment guidelines.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
query: Search query
|
| 86 |
+
n_results: Number of results to return
|
| 87 |
+
pathogen_filter: Optional pathogen type filter (e.g., "ESBL-E", "CRE")
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of relevant guideline excerpts with metadata
|
| 91 |
+
"""
|
| 92 |
+
collection = get_collection("idsa_treatment_guidelines")
|
| 93 |
+
if collection is None:
|
| 94 |
+
logger.warning("idsa_treatment_guidelines collection not available")
|
| 95 |
+
return []
|
| 96 |
+
|
| 97 |
+
where_filter = None
|
| 98 |
+
if pathogen_filter:
|
| 99 |
+
where_filter = {"pathogen_type": pathogen_filter}
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
results = collection.query(
|
| 103 |
+
query_texts=[query],
|
| 104 |
+
n_results=n_results,
|
| 105 |
+
where=where_filter,
|
| 106 |
+
include=["documents", "metadatas", "distances"],
|
| 107 |
+
)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Error querying guidelines: {e}")
|
| 110 |
+
return []
|
| 111 |
+
|
| 112 |
+
return _format_results(results)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def search_mic_breakpoints(
|
| 116 |
+
query: str,
|
| 117 |
+
n_results: int = 5,
|
| 118 |
+
organism: Optional[str] = None,
|
| 119 |
+
antibiotic: Optional[str] = None,
|
| 120 |
+
) -> List[Dict[str, Any]]:
|
| 121 |
+
"""
|
| 122 |
+
Search MIC breakpoint reference documentation.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
query: Search query
|
| 126 |
+
n_results: Number of results
|
| 127 |
+
organism: Optional organism name filter
|
| 128 |
+
antibiotic: Optional antibiotic name filter
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
List of relevant breakpoint information
|
| 132 |
+
"""
|
| 133 |
+
collection = get_collection("mic_reference_docs")
|
| 134 |
+
if collection is None:
|
| 135 |
+
logger.warning("mic_reference_docs collection not available")
|
| 136 |
+
return []
|
| 137 |
+
|
| 138 |
+
# Build query with organism/antibiotic context if provided
|
| 139 |
+
enhanced_query = query
|
| 140 |
+
if organism:
|
| 141 |
+
enhanced_query = f"{organism} {enhanced_query}"
|
| 142 |
+
if antibiotic:
|
| 143 |
+
enhanced_query = f"{antibiotic} {enhanced_query}"
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
results = collection.query(
|
| 147 |
+
query_texts=[enhanced_query],
|
| 148 |
+
n_results=n_results,
|
| 149 |
+
include=["documents", "metadatas", "distances"],
|
| 150 |
+
)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Error querying breakpoints: {e}")
|
| 153 |
+
return []
|
| 154 |
+
|
| 155 |
+
return _format_results(results)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def search_drug_safety(
|
| 159 |
+
query: str,
|
| 160 |
+
n_results: int = 5,
|
| 161 |
+
drug_name: Optional[str] = None,
|
| 162 |
+
) -> List[Dict[str, Any]]:
|
| 163 |
+
"""
|
| 164 |
+
Search drug safety information (interactions, warnings, contraindications).
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
query: Search query
|
| 168 |
+
n_results: Number of results
|
| 169 |
+
drug_name: Optional drug name to focus search
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
List of relevant safety information
|
| 173 |
+
"""
|
| 174 |
+
collection = get_collection("drug_safety")
|
| 175 |
+
if collection is None:
|
| 176 |
+
# Fallback: try existing collections
|
| 177 |
+
logger.warning("drug_safety collection not available")
|
| 178 |
+
return []
|
| 179 |
+
|
| 180 |
+
enhanced_query = f"{drug_name} {query}" if drug_name else query
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
results = collection.query(
|
| 184 |
+
query_texts=[enhanced_query],
|
| 185 |
+
n_results=n_results,
|
| 186 |
+
include=["documents", "metadatas", "distances"],
|
| 187 |
+
)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Error querying drug safety: {e}")
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
return _format_results(results)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def search_resistance_patterns(
|
| 196 |
+
query: str,
|
| 197 |
+
n_results: int = 5,
|
| 198 |
+
organism: Optional[str] = None,
|
| 199 |
+
region: Optional[str] = None,
|
| 200 |
+
) -> List[Dict[str, Any]]:
|
| 201 |
+
"""
|
| 202 |
+
Search pathogen resistance pattern data.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
query: Search query
|
| 206 |
+
n_results: Number of results
|
| 207 |
+
organism: Optional organism filter
|
| 208 |
+
region: Optional geographic region filter
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
List of relevant resistance data
|
| 212 |
+
"""
|
| 213 |
+
collection = get_collection("pathogen_resistance")
|
| 214 |
+
if collection is None:
|
| 215 |
+
logger.warning("pathogen_resistance collection not available")
|
| 216 |
+
return []
|
| 217 |
+
|
| 218 |
+
enhanced_query = query
|
| 219 |
+
if organism:
|
| 220 |
+
enhanced_query = f"{organism} {enhanced_query}"
|
| 221 |
+
if region:
|
| 222 |
+
enhanced_query = f"{region} {enhanced_query}"
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
results = collection.query(
|
| 226 |
+
query_texts=[enhanced_query],
|
| 227 |
+
n_results=n_results,
|
| 228 |
+
include=["documents", "metadatas", "distances"],
|
| 229 |
+
)
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error querying resistance patterns: {e}")
|
| 232 |
+
return []
|
| 233 |
+
|
| 234 |
+
return _format_results(results)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# =============================================================================
|
| 238 |
+
# UNIFIED CONTEXT RETRIEVER
|
| 239 |
+
# =============================================================================
|
| 240 |
+
|
| 241 |
+
def get_context_for_agent(
|
| 242 |
+
agent_name: str,
|
| 243 |
+
query: str,
|
| 244 |
+
patient_context: Optional[Dict[str, Any]] = None,
|
| 245 |
+
n_results: int = 3,
|
| 246 |
+
) -> str:
|
| 247 |
+
"""
|
| 248 |
+
Get formatted RAG context string for a specific agent.
|
| 249 |
+
|
| 250 |
+
This is the main entry point for agents to retrieve context.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
agent_name: Name of the requesting agent
|
| 254 |
+
query: The primary search query
|
| 255 |
+
patient_context: Optional dict with patient-specific info
|
| 256 |
+
n_results: Number of results per collection
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Formatted context string for injection into prompts
|
| 260 |
+
"""
|
| 261 |
+
context_parts = []
|
| 262 |
+
patient_context = patient_context or {}
|
| 263 |
+
|
| 264 |
+
if agent_name == "intake_historian":
|
| 265 |
+
# Get empirical therapy guidelines
|
| 266 |
+
guidelines = search_antibiotic_guidelines(
|
| 267 |
+
query=query,
|
| 268 |
+
n_results=n_results,
|
| 269 |
+
pathogen_filter=patient_context.get("pathogen_type"),
|
| 270 |
+
)
|
| 271 |
+
if guidelines:
|
| 272 |
+
context_parts.append("RELEVANT TREATMENT GUIDELINES:")
|
| 273 |
+
for g in guidelines:
|
| 274 |
+
context_parts.append(f"- {g['content'][:500]}...")
|
| 275 |
+
context_parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]")
|
| 276 |
+
|
| 277 |
+
elif agent_name == "vision_specialist":
|
| 278 |
+
# Get MIC reference info for lab interpretation
|
| 279 |
+
breakpoints = search_mic_breakpoints(
|
| 280 |
+
query=query,
|
| 281 |
+
n_results=n_results,
|
| 282 |
+
organism=patient_context.get("organism"),
|
| 283 |
+
antibiotic=patient_context.get("antibiotic"),
|
| 284 |
+
)
|
| 285 |
+
if breakpoints:
|
| 286 |
+
context_parts.append("RELEVANT BREAKPOINT INFORMATION:")
|
| 287 |
+
for b in breakpoints:
|
| 288 |
+
context_parts.append(f"- {b['content'][:400]}...")
|
| 289 |
+
|
| 290 |
+
elif agent_name == "trend_analyst":
|
| 291 |
+
# Get breakpoints and resistance trends
|
| 292 |
+
breakpoints = search_mic_breakpoints(
|
| 293 |
+
query=f"breakpoint {patient_context.get('organism', '')} {patient_context.get('antibiotic', '')}",
|
| 294 |
+
n_results=n_results,
|
| 295 |
+
)
|
| 296 |
+
resistance = search_resistance_patterns(
|
| 297 |
+
query=query,
|
| 298 |
+
n_results=n_results,
|
| 299 |
+
organism=patient_context.get("organism"),
|
| 300 |
+
region=patient_context.get("region"),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if breakpoints:
|
| 304 |
+
context_parts.append("EUCAST BREAKPOINT DATA:")
|
| 305 |
+
for b in breakpoints:
|
| 306 |
+
context_parts.append(f"- {b['content'][:400]}...")
|
| 307 |
+
|
| 308 |
+
if resistance:
|
| 309 |
+
context_parts.append("\nRESISTANCE PATTERN DATA:")
|
| 310 |
+
for r in resistance:
|
| 311 |
+
context_parts.append(f"- {r['content'][:400]}...")
|
| 312 |
+
|
| 313 |
+
elif agent_name == "clinical_pharmacologist":
|
| 314 |
+
# Get comprehensive context for final recommendation
|
| 315 |
+
guidelines = search_antibiotic_guidelines(
|
| 316 |
+
query=query,
|
| 317 |
+
n_results=n_results,
|
| 318 |
+
)
|
| 319 |
+
safety = search_drug_safety(
|
| 320 |
+
query=query,
|
| 321 |
+
n_results=n_results,
|
| 322 |
+
drug_name=patient_context.get("proposed_antibiotic"),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if guidelines:
|
| 326 |
+
context_parts.append("TREATMENT GUIDELINES:")
|
| 327 |
+
for g in guidelines:
|
| 328 |
+
context_parts.append(f"- {g['content'][:400]}...")
|
| 329 |
+
|
| 330 |
+
if safety:
|
| 331 |
+
context_parts.append("\nDRUG SAFETY INFORMATION:")
|
| 332 |
+
for s in safety:
|
| 333 |
+
context_parts.append(f"- {s['content'][:400]}...")
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
# Generic retrieval
|
| 337 |
+
guidelines = search_antibiotic_guidelines(query, n_results=n_results)
|
| 338 |
+
if guidelines:
|
| 339 |
+
for g in guidelines:
|
| 340 |
+
context_parts.append(f"- {g['content'][:500]}...")
|
| 341 |
+
|
| 342 |
+
if not context_parts:
|
| 343 |
+
return "No relevant context found in knowledge base."
|
| 344 |
+
|
| 345 |
+
return "\n".join(context_parts)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_context_string(
|
| 349 |
+
query: str,
|
| 350 |
+
collections: Optional[List[str]] = None,
|
| 351 |
+
n_results_per_collection: int = 3,
|
| 352 |
+
**filters,
|
| 353 |
+
) -> str:
|
| 354 |
+
"""
|
| 355 |
+
Get a combined context string from multiple collections.
|
| 356 |
+
|
| 357 |
+
This is a simpler interface for general-purpose RAG retrieval.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
query: Search query
|
| 361 |
+
collections: List of collection names to search (defaults to all)
|
| 362 |
+
n_results_per_collection: Results per collection
|
| 363 |
+
**filters: Additional filters (organism, antibiotic, region, etc.)
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Combined context string
|
| 367 |
+
"""
|
| 368 |
+
default_collections = [
|
| 369 |
+
"idsa_treatment_guidelines",
|
| 370 |
+
"mic_reference_docs",
|
| 371 |
+
]
|
| 372 |
+
collections = collections or default_collections
|
| 373 |
+
|
| 374 |
+
context_parts = []
|
| 375 |
+
|
| 376 |
+
for collection_name in collections:
|
| 377 |
+
if collection_name == "idsa_treatment_guidelines":
|
| 378 |
+
results = search_antibiotic_guidelines(
|
| 379 |
+
query,
|
| 380 |
+
n_results=n_results_per_collection,
|
| 381 |
+
pathogen_filter=filters.get("pathogen_type"),
|
| 382 |
+
)
|
| 383 |
+
elif collection_name == "mic_reference_docs":
|
| 384 |
+
results = search_mic_breakpoints(
|
| 385 |
+
query,
|
| 386 |
+
n_results=n_results_per_collection,
|
| 387 |
+
organism=filters.get("organism"),
|
| 388 |
+
antibiotic=filters.get("antibiotic"),
|
| 389 |
+
)
|
| 390 |
+
elif collection_name == "drug_safety":
|
| 391 |
+
results = search_drug_safety(
|
| 392 |
+
query,
|
| 393 |
+
n_results=n_results_per_collection,
|
| 394 |
+
drug_name=filters.get("drug_name"),
|
| 395 |
+
)
|
| 396 |
+
elif collection_name == "pathogen_resistance":
|
| 397 |
+
results = search_resistance_patterns(
|
| 398 |
+
query,
|
| 399 |
+
n_results=n_results_per_collection,
|
| 400 |
+
organism=filters.get("organism"),
|
| 401 |
+
region=filters.get("region"),
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
if results:
|
| 407 |
+
context_parts.append(f"=== {collection_name.upper()} ===")
|
| 408 |
+
for r in results:
|
| 409 |
+
context_parts.append(r["content"])
|
| 410 |
+
context_parts.append(f"[Relevance: {1 - r.get('distance', 0):.2f}]")
|
| 411 |
+
context_parts.append("")
|
| 412 |
+
|
| 413 |
+
return "\n".join(context_parts) if context_parts else "No relevant context found."
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# =============================================================================
|
| 417 |
+
# HELPER FUNCTIONS
|
| 418 |
+
# =============================================================================
|
| 419 |
+
|
| 420 |
+
def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 421 |
+
"""Format ChromaDB query results into a standard format."""
|
| 422 |
+
if not results or not results.get("documents"):
|
| 423 |
+
return []
|
| 424 |
+
|
| 425 |
+
formatted = []
|
| 426 |
+
documents = results["documents"][0] if results["documents"] else []
|
| 427 |
+
metadatas = results.get("metadatas", [[]])[0]
|
| 428 |
+
distances = results.get("distances", [[]])[0]
|
| 429 |
+
|
| 430 |
+
for i, doc in enumerate(documents):
|
| 431 |
+
formatted.append({
|
| 432 |
+
"content": doc,
|
| 433 |
+
"metadata": metadatas[i] if i < len(metadatas) else {},
|
| 434 |
+
"distance": distances[i] if i < len(distances) else None,
|
| 435 |
+
"source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
|
| 436 |
+
"relevance_score": 1 - (distances[i] if i < len(distances) else 0),
|
| 437 |
+
})
|
| 438 |
+
|
| 439 |
+
return formatted
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def list_available_collections() -> List[str]:
|
| 443 |
+
"""List all available ChromaDB collections."""
|
| 444 |
+
client = get_chroma_client()
|
| 445 |
+
try:
|
| 446 |
+
collections = client.list_collections()
|
| 447 |
+
return [c.name for c in collections]
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logger.error(f"Error listing collections: {e}")
|
| 450 |
+
return []
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
|
| 454 |
+
"""Get information about a specific collection."""
|
| 455 |
+
collection = get_collection(name)
|
| 456 |
+
if collection is None:
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
return {
|
| 461 |
+
"name": collection.name,
|
| 462 |
+
"count": collection.count(),
|
| 463 |
+
"metadata": collection.metadata,
|
| 464 |
+
}
|
| 465 |
+
except Exception as e:
|
| 466 |
+
logger.error(f"Error getting collection info: {e}")
|
| 467 |
+
return None
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
__all__ = [
|
| 471 |
+
"get_chroma_client",
|
| 472 |
+
"get_embedding_function",
|
| 473 |
+
"get_collection",
|
| 474 |
+
"search_antibiotic_guidelines",
|
| 475 |
+
"search_mic_breakpoints",
|
| 476 |
+
"search_drug_safety",
|
| 477 |
+
"search_resistance_patterns",
|
| 478 |
+
"get_context_for_agent",
|
| 479 |
+
"get_context_string",
|
| 480 |
+
"list_available_collections",
|
| 481 |
+
"get_collection_info",
|
| 482 |
+
]
|
src/utils.py
CHANGED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for Med-I-C multi-agent system.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- Creatinine Clearance (CrCl) calculator
|
| 6 |
+
- MIC trend analysis and creep detection
|
| 7 |
+
- Prescription card formatter
|
| 8 |
+
- Data validation helpers
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import math
|
| 15 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# CREATININE CLEARANCE CALCULATOR
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
def calculate_crcl(
|
| 23 |
+
age_years: float,
|
| 24 |
+
weight_kg: float,
|
| 25 |
+
serum_creatinine_mg_dl: float,
|
| 26 |
+
sex: Literal["male", "female"],
|
| 27 |
+
use_ibw: bool = False,
|
| 28 |
+
height_cm: Optional[float] = None,
|
| 29 |
+
) -> float:
|
| 30 |
+
"""
|
| 31 |
+
Calculate Creatinine Clearance using the Cockcroft-Gault equation.
|
| 32 |
+
|
| 33 |
+
Formula:
|
| 34 |
+
CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
age_years: Patient age in years
|
| 38 |
+
weight_kg: Actual body weight in kg
|
| 39 |
+
serum_creatinine_mg_dl: Serum creatinine in mg/dL
|
| 40 |
+
sex: Patient sex ("male" or "female")
|
| 41 |
+
use_ibw: If True, use Ideal Body Weight instead of actual weight
|
| 42 |
+
height_cm: Height in cm (required if use_ibw=True)
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Estimated CrCl in mL/min
|
| 46 |
+
"""
|
| 47 |
+
if serum_creatinine_mg_dl <= 0:
|
| 48 |
+
raise ValueError("Serum creatinine must be positive")
|
| 49 |
+
|
| 50 |
+
if age_years <= 0 or weight_kg <= 0:
|
| 51 |
+
raise ValueError("Age and weight must be positive")
|
| 52 |
+
|
| 53 |
+
# Calculate weight to use
|
| 54 |
+
weight = weight_kg
|
| 55 |
+
if use_ibw and height_cm:
|
| 56 |
+
weight = calculate_ibw(height_cm, sex)
|
| 57 |
+
# Use adjusted body weight if actual weight > IBW
|
| 58 |
+
if weight_kg > weight * 1.3:
|
| 59 |
+
weight = calculate_adjusted_bw(weight, weight_kg)
|
| 60 |
+
|
| 61 |
+
# Cockcroft-Gault equation
|
| 62 |
+
crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
|
| 63 |
+
|
| 64 |
+
# Apply sex factor
|
| 65 |
+
if sex == "female":
|
| 66 |
+
crcl *= 0.85
|
| 67 |
+
|
| 68 |
+
return round(crcl, 1)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
|
| 72 |
+
"""
|
| 73 |
+
Calculate Ideal Body Weight using the Devine formula.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
height_cm: Height in centimeters
|
| 77 |
+
sex: Patient sex
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Ideal body weight in kg
|
| 81 |
+
"""
|
| 82 |
+
height_inches = height_cm / 2.54
|
| 83 |
+
height_over_60 = max(0, height_inches - 60)
|
| 84 |
+
|
| 85 |
+
if sex == "male":
|
| 86 |
+
ibw = 50 + 2.3 * height_over_60
|
| 87 |
+
else:
|
| 88 |
+
ibw = 45.5 + 2.3 * height_over_60
|
| 89 |
+
|
| 90 |
+
return round(ibw, 1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
|
| 94 |
+
"""
|
| 95 |
+
Calculate Adjusted Body Weight for obese patients.
|
| 96 |
+
|
| 97 |
+
Formula: AdjBW = IBW + 0.4 × (Actual - IBW)
|
| 98 |
+
"""
|
| 99 |
+
return round(ibw + 0.4 * (actual_weight - ibw), 1)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_renal_dose_category(crcl: float) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Categorize renal function for dosing purposes.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Renal function category
|
| 108 |
+
"""
|
| 109 |
+
if crcl >= 90:
|
| 110 |
+
return "normal"
|
| 111 |
+
elif crcl >= 60:
|
| 112 |
+
return "mild_impairment"
|
| 113 |
+
elif crcl >= 30:
|
| 114 |
+
return "moderate_impairment"
|
| 115 |
+
elif crcl >= 15:
|
| 116 |
+
return "severe_impairment"
|
| 117 |
+
else:
|
| 118 |
+
return "esrd"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# =============================================================================
|
| 122 |
+
# MIC TREND ANALYSIS
|
| 123 |
+
# =============================================================================
|
| 124 |
+
|
| 125 |
+
def calculate_mic_trend(
|
| 126 |
+
mic_values: List[Dict[str, Any]],
|
| 127 |
+
susceptible_breakpoint: Optional[float] = None,
|
| 128 |
+
resistant_breakpoint: Optional[float] = None,
|
| 129 |
+
) -> Dict[str, Any]:
|
| 130 |
+
"""
|
| 131 |
+
Analyze MIC trend over time and detect MIC creep.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
mic_values: List of dicts with 'date' and 'mic_value' keys
|
| 135 |
+
susceptible_breakpoint: S breakpoint (optional)
|
| 136 |
+
resistant_breakpoint: R breakpoint (optional)
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Dict with trend analysis results
|
| 140 |
+
"""
|
| 141 |
+
if len(mic_values) < 2:
|
| 142 |
+
return {
|
| 143 |
+
"trend": "insufficient_data",
|
| 144 |
+
"risk_level": "UNKNOWN",
|
| 145 |
+
"alert": "Need at least 2 MIC values for trend analysis",
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Extract MIC values
|
| 149 |
+
mics = [float(v["mic_value"]) for v in mic_values]
|
| 150 |
+
|
| 151 |
+
baseline_mic = mics[0]
|
| 152 |
+
current_mic = mics[-1]
|
| 153 |
+
|
| 154 |
+
# Calculate fold change
|
| 155 |
+
if baseline_mic > 0:
|
| 156 |
+
fold_change = current_mic / baseline_mic
|
| 157 |
+
else:
|
| 158 |
+
fold_change = float("inf")
|
| 159 |
+
|
| 160 |
+
# Calculate trend
|
| 161 |
+
if len(mics) >= 3:
|
| 162 |
+
# Linear regression slope
|
| 163 |
+
n = len(mics)
|
| 164 |
+
x_mean = (n - 1) / 2
|
| 165 |
+
y_mean = sum(mics) / n
|
| 166 |
+
numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
|
| 167 |
+
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
| 168 |
+
slope = numerator / denominator if denominator != 0 else 0
|
| 169 |
+
|
| 170 |
+
if slope > 0.5:
|
| 171 |
+
trend = "increasing"
|
| 172 |
+
elif slope < -0.5:
|
| 173 |
+
trend = "decreasing"
|
| 174 |
+
else:
|
| 175 |
+
trend = "stable"
|
| 176 |
+
else:
|
| 177 |
+
if current_mic > baseline_mic * 1.5:
|
| 178 |
+
trend = "increasing"
|
| 179 |
+
elif current_mic < baseline_mic * 0.67:
|
| 180 |
+
trend = "decreasing"
|
| 181 |
+
else:
|
| 182 |
+
trend = "stable"
|
| 183 |
+
|
| 184 |
+
# Calculate resistance velocity (fold change per time point)
|
| 185 |
+
velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
|
| 186 |
+
|
| 187 |
+
# Determine risk level
|
| 188 |
+
risk_level, alert = _assess_mic_risk(
|
| 189 |
+
current_mic, baseline_mic, fold_change, trend,
|
| 190 |
+
susceptible_breakpoint, resistant_breakpoint
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return {
|
| 194 |
+
"baseline_mic": baseline_mic,
|
| 195 |
+
"current_mic": current_mic,
|
| 196 |
+
"ratio": round(fold_change, 2),
|
| 197 |
+
"trend": trend,
|
| 198 |
+
"velocity": round(velocity, 3),
|
| 199 |
+
"risk_level": risk_level,
|
| 200 |
+
"alert": alert,
|
| 201 |
+
"n_readings": len(mics),
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _assess_mic_risk(
|
| 206 |
+
current_mic: float,
|
| 207 |
+
baseline_mic: float,
|
| 208 |
+
fold_change: float,
|
| 209 |
+
trend: str,
|
| 210 |
+
s_breakpoint: Optional[float],
|
| 211 |
+
r_breakpoint: Optional[float],
|
| 212 |
+
) -> Tuple[str, str]:
|
| 213 |
+
"""
|
| 214 |
+
Assess risk level based on MIC trends and breakpoints.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Tuple of (risk_level, alert_message)
|
| 218 |
+
"""
|
| 219 |
+
# If we have breakpoints, use them for risk assessment
|
| 220 |
+
if s_breakpoint is not None and r_breakpoint is not None:
|
| 221 |
+
margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
|
| 222 |
+
|
| 223 |
+
if current_mic > r_breakpoint:
|
| 224 |
+
return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
|
| 225 |
+
|
| 226 |
+
if current_mic > s_breakpoint:
|
| 227 |
+
return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
|
| 228 |
+
|
| 229 |
+
if margin < 2:
|
| 230 |
+
if trend == "increasing":
|
| 231 |
+
return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
|
| 232 |
+
else:
|
| 233 |
+
return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
|
| 234 |
+
|
| 235 |
+
if margin < 4:
|
| 236 |
+
if trend == "increasing":
|
| 237 |
+
return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
|
| 238 |
+
else:
|
| 239 |
+
return "LOW", "MIC stable with adequate margin to breakpoint."
|
| 240 |
+
|
| 241 |
+
return "LOW", "MIC well below breakpoint with good safety margin."
|
| 242 |
+
|
| 243 |
+
# Without breakpoints, use fold change and trend
|
| 244 |
+
if fold_change >= 8:
|
| 245 |
+
return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
|
| 246 |
+
|
| 247 |
+
if fold_change >= 4:
|
| 248 |
+
return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
|
| 249 |
+
|
| 250 |
+
if fold_change >= 2:
|
| 251 |
+
if trend == "increasing":
|
| 252 |
+
return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
|
| 253 |
+
else:
|
| 254 |
+
return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
|
| 255 |
+
|
| 256 |
+
if trend == "increasing":
|
| 257 |
+
return "MODERATE", "MIC showing upward trend. Continue monitoring."
|
| 258 |
+
|
| 259 |
+
return "LOW", "MIC stable or decreasing. Current therapy appropriate."
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def detect_mic_creep(
|
| 263 |
+
organism: str,
|
| 264 |
+
antibiotic: str,
|
| 265 |
+
mic_history: List[Dict[str, Any]],
|
| 266 |
+
breakpoints: Dict[str, float],
|
| 267 |
+
) -> Dict[str, Any]:
|
| 268 |
+
"""
|
| 269 |
+
Detect MIC creep for a specific organism-antibiotic pair.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
organism: Pathogen name
|
| 273 |
+
antibiotic: Antibiotic name
|
| 274 |
+
mic_history: Historical MIC values with dates
|
| 275 |
+
breakpoints: Dict with 'susceptible' and 'resistant' keys
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Comprehensive MIC creep analysis
|
| 279 |
+
"""
|
| 280 |
+
trend_analysis = calculate_mic_trend(
|
| 281 |
+
mic_history,
|
| 282 |
+
susceptible_breakpoint=breakpoints.get("susceptible"),
|
| 283 |
+
resistant_breakpoint=breakpoints.get("resistant"),
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Add organism/antibiotic context
|
| 287 |
+
trend_analysis["organism"] = organism
|
| 288 |
+
trend_analysis["antibiotic"] = antibiotic
|
| 289 |
+
trend_analysis["breakpoint_susceptible"] = breakpoints.get("susceptible")
|
| 290 |
+
trend_analysis["breakpoint_resistant"] = breakpoints.get("resistant")
|
| 291 |
+
|
| 292 |
+
# Calculate time to resistance estimate
|
| 293 |
+
if trend_analysis["trend"] == "increasing" and trend_analysis["velocity"] > 1.0:
|
| 294 |
+
current = trend_analysis["current_mic"]
|
| 295 |
+
s_bp = breakpoints.get("susceptible")
|
| 296 |
+
if s_bp and current < s_bp:
|
| 297 |
+
# Estimate doublings needed to reach breakpoint
|
| 298 |
+
doublings_needed = math.log2(s_bp / current) if current > 0 else 0
|
| 299 |
+
# Estimate time based on velocity
|
| 300 |
+
if trend_analysis["velocity"] > 1.0:
|
| 301 |
+
log_velocity = math.log(trend_analysis["velocity"]) / math.log(2)
|
| 302 |
+
if log_velocity > 0:
|
| 303 |
+
time_estimate = doublings_needed / log_velocity
|
| 304 |
+
trend_analysis["estimated_readings_to_resistance"] = round(time_estimate, 1)
|
| 305 |
+
|
| 306 |
+
return trend_analysis
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# =============================================================================
|
| 310 |
+
# PRESCRIPTION FORMATTER
|
| 311 |
+
# =============================================================================
|
| 312 |
+
|
| 313 |
+
def format_prescription_card(recommendation: Dict[str, Any]) -> str:
|
| 314 |
+
"""
|
| 315 |
+
Format a recommendation into a readable prescription card.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
recommendation: Dict with recommendation details
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Formatted prescription card as string
|
| 322 |
+
"""
|
| 323 |
+
lines = []
|
| 324 |
+
lines.append("=" * 50)
|
| 325 |
+
lines.append("ANTIBIOTIC PRESCRIPTION")
|
| 326 |
+
lines.append("=" * 50)
|
| 327 |
+
|
| 328 |
+
primary = recommendation.get("primary_recommendation", recommendation)
|
| 329 |
+
|
| 330 |
+
lines.append(f"\nDRUG: {primary.get('antibiotic', 'N/A')}")
|
| 331 |
+
lines.append(f"DOSE: {primary.get('dose', 'N/A')}")
|
| 332 |
+
lines.append(f"ROUTE: {primary.get('route', 'N/A')}")
|
| 333 |
+
lines.append(f"FREQUENCY: {primary.get('frequency', 'N/A')}")
|
| 334 |
+
lines.append(f"DURATION: {primary.get('duration', 'N/A')}")
|
| 335 |
+
|
| 336 |
+
if primary.get("aware_category"):
|
| 337 |
+
lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
|
| 338 |
+
|
| 339 |
+
# Dose adjustments
|
| 340 |
+
adjustments = recommendation.get("dose_adjustments", {})
|
| 341 |
+
if adjustments.get("renal") and adjustments["renal"] != "None needed":
|
| 342 |
+
lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
|
| 343 |
+
if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
|
| 344 |
+
lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
|
| 345 |
+
|
| 346 |
+
# Safety alerts
|
| 347 |
+
alerts = recommendation.get("safety_alerts", [])
|
| 348 |
+
if alerts:
|
| 349 |
+
lines.append("\n" + "-" * 50)
|
| 350 |
+
lines.append("SAFETY ALERTS:")
|
| 351 |
+
for alert in alerts:
|
| 352 |
+
level = alert.get("level", "INFO")
|
| 353 |
+
marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
|
| 354 |
+
lines.append(f" {marker} {alert.get('message', '')}")
|
| 355 |
+
|
| 356 |
+
# Monitoring
|
| 357 |
+
monitoring = recommendation.get("monitoring_parameters", [])
|
| 358 |
+
if monitoring:
|
| 359 |
+
lines.append("\n" + "-" * 50)
|
| 360 |
+
lines.append("MONITORING:")
|
| 361 |
+
for param in monitoring:
|
| 362 |
+
lines.append(f" - {param}")
|
| 363 |
+
|
| 364 |
+
# Rationale
|
| 365 |
+
if recommendation.get("rationale"):
|
| 366 |
+
lines.append("\n" + "-" * 50)
|
| 367 |
+
lines.append("RATIONALE:")
|
| 368 |
+
lines.append(f" {recommendation['rationale']}")
|
| 369 |
+
|
| 370 |
+
lines.append("\n" + "=" * 50)
|
| 371 |
+
|
| 372 |
+
return "\n".join(lines)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# =============================================================================
|
| 376 |
+
# JSON PARSING HELPERS
|
| 377 |
+
# =============================================================================
|
| 378 |
+
|
| 379 |
+
def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
|
| 380 |
+
"""
|
| 381 |
+
Safely parse JSON from agent output, handling common issues.
|
| 382 |
+
|
| 383 |
+
Attempts to extract JSON from text that may contain markdown code blocks
|
| 384 |
+
or other formatting.
|
| 385 |
+
"""
|
| 386 |
+
if not text:
|
| 387 |
+
return None
|
| 388 |
+
|
| 389 |
+
# Try direct parse first
|
| 390 |
+
try:
|
| 391 |
+
return json.loads(text)
|
| 392 |
+
except json.JSONDecodeError:
|
| 393 |
+
pass
|
| 394 |
+
|
| 395 |
+
# Try to extract JSON from markdown code block
|
| 396 |
+
import re
|
| 397 |
+
|
| 398 |
+
json_patterns = [
|
| 399 |
+
r"```json\s*\n?(.*?)\n?```", # ```json ... ```
|
| 400 |
+
r"```\s*\n?(.*?)\n?```", # ``` ... ```
|
| 401 |
+
r"\{[\s\S]*\}", # Raw JSON object
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
for pattern in json_patterns:
|
| 405 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 406 |
+
if match:
|
| 407 |
+
try:
|
| 408 |
+
json_str = match.group(1) if match.lastindex else match.group(0)
|
| 409 |
+
return json.loads(json_str)
|
| 410 |
+
except (json.JSONDecodeError, IndexError):
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
return None
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
|
| 417 |
+
"""
|
| 418 |
+
Validate that agent output contains required fields.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
output: Agent output dict
|
| 422 |
+
required_fields: List of required field names
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Tuple of (is_valid, list_of_missing_fields)
|
| 426 |
+
"""
|
| 427 |
+
missing = [field for field in required_fields if field not in output]
|
| 428 |
+
return len(missing) == 0, missing
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# =============================================================================
|
| 432 |
+
# DATA NORMALIZATION
|
| 433 |
+
# =============================================================================
|
| 434 |
+
|
| 435 |
+
def normalize_antibiotic_name(name: str) -> str:
|
| 436 |
+
"""
|
| 437 |
+
Normalize antibiotic name to standard format.
|
| 438 |
+
"""
|
| 439 |
+
# Common name mappings
|
| 440 |
+
mappings = {
|
| 441 |
+
"amox": "amoxicillin",
|
| 442 |
+
"amox/clav": "amoxicillin-clavulanate",
|
| 443 |
+
"augmentin": "amoxicillin-clavulanate",
|
| 444 |
+
"pip/tazo": "piperacillin-tazobactam",
|
| 445 |
+
"zosyn": "piperacillin-tazobactam",
|
| 446 |
+
"tmp/smx": "trimethoprim-sulfamethoxazole",
|
| 447 |
+
"bactrim": "trimethoprim-sulfamethoxazole",
|
| 448 |
+
"cipro": "ciprofloxacin",
|
| 449 |
+
"levo": "levofloxacin",
|
| 450 |
+
"moxi": "moxifloxacin",
|
| 451 |
+
"vanc": "vancomycin",
|
| 452 |
+
"vanco": "vancomycin",
|
| 453 |
+
"mero": "meropenem",
|
| 454 |
+
"imi": "imipenem",
|
| 455 |
+
"gent": "gentamicin",
|
| 456 |
+
"tobra": "tobramycin",
|
| 457 |
+
"ceftriax": "ceftriaxone",
|
| 458 |
+
"rocephin": "ceftriaxone",
|
| 459 |
+
"cefepime": "cefepime",
|
| 460 |
+
"maxipime": "cefepime",
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
normalized = name.lower().strip()
|
| 464 |
+
return mappings.get(normalized, normalized)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def normalize_organism_name(name: str) -> str:
|
| 468 |
+
"""
|
| 469 |
+
Normalize organism name to standard format.
|
| 470 |
+
"""
|
| 471 |
+
name = name.strip()
|
| 472 |
+
|
| 473 |
+
# Common abbreviations
|
| 474 |
+
abbreviations = {
|
| 475 |
+
"e. coli": "Escherichia coli",
|
| 476 |
+
"e.coli": "Escherichia coli",
|
| 477 |
+
"k. pneumoniae": "Klebsiella pneumoniae",
|
| 478 |
+
"k.pneumoniae": "Klebsiella pneumoniae",
|
| 479 |
+
"p. aeruginosa": "Pseudomonas aeruginosa",
|
| 480 |
+
"p.aeruginosa": "Pseudomonas aeruginosa",
|
| 481 |
+
"s. aureus": "Staphylococcus aureus",
|
| 482 |
+
"s.aureus": "Staphylococcus aureus",
|
| 483 |
+
"mrsa": "Staphylococcus aureus (MRSA)",
|
| 484 |
+
"mssa": "Staphylococcus aureus (MSSA)",
|
| 485 |
+
"enterococcus": "Enterococcus species",
|
| 486 |
+
"vre": "Enterococcus (VRE)",
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
lower_name = name.lower()
|
| 490 |
+
return abbreviations.get(lower_name, name)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
__all__ = [
|
| 494 |
+
"calculate_crcl",
|
| 495 |
+
"calculate_ibw",
|
| 496 |
+
"calculate_adjusted_bw",
|
| 497 |
+
"get_renal_dose_category",
|
| 498 |
+
"calculate_mic_trend",
|
| 499 |
+
"detect_mic_creep",
|
| 500 |
+
"format_prescription_card",
|
| 501 |
+
"safe_json_parse",
|
| 502 |
+
"validate_agent_output",
|
| 503 |
+
"normalize_antibiotic_name",
|
| 504 |
+
"normalize_organism_name",
|
| 505 |
+
]
|