ghitaben commited on
Commit
91b591f
·
1 Parent(s): 793d027

Add prompt templates and utility functions for Med-I-C multi-agent system

Browse files
Files changed (7) hide show
  1. app.py +449 -219
  2. src/agents.py +572 -14
  3. src/agents/__init__.py +0 -0
  4. src/graph.py +300 -0
  5. src/prompts.py +355 -0
  6. src/rag.py +482 -0
  7. src/utils.py +505 -0
app.py CHANGED
@@ -1,10 +1,13 @@
1
  """
2
  Med-I-C: AMR-Guard Demo Application
3
  Infection Lifecycle Orchestrator - Streamlit Interface
 
 
4
  """
5
 
6
  import streamlit as st
7
  import sys
 
8
  from pathlib import Path
9
 
10
  # Add project root to path
@@ -12,19 +15,14 @@ PROJECT_ROOT = Path(__file__).parent
12
  sys.path.insert(0, str(PROJECT_ROOT))
13
 
14
  from src.tools import (
15
- query_antibiotic_info,
16
- get_antibiotics_by_category,
17
  interpret_mic_value,
18
- get_breakpoints_for_pathogen,
19
- query_resistance_pattern,
20
  get_most_effective_antibiotics,
21
  calculate_mic_trend,
22
- check_drug_interactions,
23
  screen_antibiotic_safety,
24
  search_clinical_guidelines,
25
- get_treatment_recommendation,
26
  get_empirical_therapy_guidance,
27
  )
 
28
 
29
  # Page configuration
30
  st.set_page_config(
@@ -48,6 +46,21 @@ st.markdown("""
48
  color: #666;
49
  margin-top: 0;
50
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  .risk-high {
52
  background-color: #FFCDD2;
53
  padding: 10px;
@@ -66,6 +79,13 @@ st.markdown("""
66
  border-radius: 5px;
67
  border-left: 4px solid #388E3C;
68
  }
 
 
 
 
 
 
 
69
  .info-box {
70
  background-color: #E3F2FD;
71
  padding: 15px;
@@ -79,7 +99,7 @@ st.markdown("""
79
  def main():
80
  # Header
81
  st.markdown('<p class="main-header">🦠 Med-I-C: AMR-Guard</p>', unsafe_allow_html=True)
82
- st.markdown('<p class="sub-header">Infection Lifecycle Orchestrator Demo</p>', unsafe_allow_html=True)
83
 
84
  # Sidebar navigation
85
  st.sidebar.title("Navigation")
@@ -87,85 +107,416 @@ def main():
87
  "Select Module",
88
  [
89
  "🏠 Overview",
90
- "💊 Stage 1: Empirical Advisor",
91
- "🔬 Stage 2: Lab Interpretation",
 
92
  "📊 MIC Trend Analysis",
93
  "⚠️ Drug Safety Check",
94
- "📚 Clinical Guidelines Search"
95
  ]
96
  )
97
 
98
  if page == "🏠 Overview":
99
  show_overview()
100
- elif page == "💊 Stage 1: Empirical Advisor":
 
 
101
  show_empirical_advisor()
102
- elif page == "🔬 Stage 2: Lab Interpretation":
103
  show_lab_interpretation()
104
  elif page == "📊 MIC Trend Analysis":
105
  show_mic_trend_analysis()
106
  elif page == "⚠️ Drug Safety Check":
107
  show_drug_safety()
108
- elif page == "📚 Clinical Guidelines Search":
109
  show_guidelines_search()
110
 
111
 
112
  def show_overview():
113
  st.header("System Overview")
114
 
 
 
 
 
 
 
 
 
115
  col1, col2 = st.columns(2)
116
 
117
  with col1:
118
- st.subheader("Stage 1: Empirical Phase")
119
  st.markdown("""
120
- **The "First 24 Hours"**
121
-
122
- Before lab results are available, the system:
123
- - Analyzes patient history and risk factors
124
- - Suggests empirical antibiotics based on:
125
- - Suspected pathogen
126
- - Local resistance patterns
127
- - WHO stewardship guidelines (ACCESS → WATCH → RESERVE)
128
- - Checks drug interactions with current medications
 
 
 
 
 
129
  """)
130
 
131
  with col2:
132
- st.subheader("Stage 2: Targeted Phase")
133
  st.markdown("""
134
- **The "Lab Interpretation"**
135
-
136
- Once antibiogram is available, the system:
137
- - Interprets MIC values against EUCAST breakpoints
138
- - Detects "MIC Creep" from historical data
139
- - Refines antibiotic selection
140
- - Provides evidence-based treatment recommendations
 
 
 
 
 
 
141
  """)
142
 
143
  st.divider()
144
 
 
145
  st.subheader("Knowledge Sources")
146
 
147
  col1, col2, col3, col4 = st.columns(4)
148
 
149
  with col1:
150
- st.metric("WHO EML", "264", "antibiotics classified")
151
  with col2:
152
- st.metric("ATLAS Data", "10K+", "susceptibility records")
153
  with col3:
154
- st.metric("Breakpoints", "41", "pathogen groups")
155
  with col4:
156
- st.metric("Interactions", "191K+", "drug pairs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
  def show_empirical_advisor():
160
- st.header("💊 Stage 1: Empirical Advisor")
161
- st.markdown("*Recommend empirical therapy before lab results*")
162
 
163
  col1, col2 = st.columns([2, 1])
164
 
165
  with col1:
166
  infection_type = st.selectbox(
167
  "Infection Type",
168
- ["Urinary Tract Infection (UTI)", "Pneumonia", "Sepsis",
169
  "Skin/Soft Tissue", "Intra-abdominal", "Meningitis"]
170
  )
171
 
@@ -182,269 +533,148 @@ def show_empirical_advisor():
182
  )
183
 
184
  with col2:
185
- st.markdown("**WHO Stewardship Categories**")
186
  st.markdown("""
187
  - **ACCESS**: First-line, low resistance
188
  - **WATCH**: Higher resistance potential
189
  - **RESERVE**: Last resort antibiotics
190
  """)
191
 
192
- if st.button("Get Empirical Recommendation", type="primary"):
193
- with st.spinner("Searching guidelines and resistance data..."):
194
- # Get recommendations from guidelines
195
  guidance = get_empirical_therapy_guidance(
196
- infection_type.split("(")[0].strip(),
197
  risk_factors
198
  )
199
 
200
- st.subheader("Recommendations")
201
 
202
  if guidance.get("recommendations"):
203
  for i, rec in enumerate(guidance["recommendations"][:3], 1):
204
- with st.expander(f"Guideline Excerpt {i} (Relevance: {rec.get('relevance_score', 0):.2f})"):
205
  st.markdown(rec.get("content", ""))
206
  st.caption(f"Source: {rec.get('source', 'IDSA Guidelines')}")
207
 
208
- # If pathogen specified, show resistance patterns
209
  if suspected_pathogen:
210
- st.subheader(f"Resistance Patterns for {suspected_pathogen}")
211
-
212
  effective = get_most_effective_antibiotics(suspected_pathogen, min_susceptibility=70)
213
 
214
  if effective:
215
- st.markdown("**Most Effective Antibiotics (>70% susceptibility)**")
216
  for ab in effective[:5]:
217
  st.write(f"- **{ab.get('antibiotic')}**: {ab.get('avg_susceptibility', 0):.1f}% susceptible")
218
  else:
219
- st.info("No resistance data found for this pathogen.")
220
 
221
 
222
  def show_lab_interpretation():
223
- st.header("🔬 Stage 2: Lab Interpretation")
224
  st.markdown("*Interpret antibiogram MIC values*")
225
 
226
  col1, col2 = st.columns(2)
227
 
228
  with col1:
229
- pathogen = st.text_input(
230
- "Identified Pathogen",
231
- placeholder="e.g., Escherichia coli, Pseudomonas aeruginosa"
232
- )
233
-
234
- antibiotic = st.text_input(
235
- "Antibiotic",
236
- placeholder="e.g., Ciprofloxacin, Meropenem"
237
- )
238
-
239
- mic_value = st.number_input(
240
- "MIC Value (mg/L)",
241
- min_value=0.001,
242
- max_value=1024.0,
243
- value=1.0,
244
- step=0.5
245
- )
246
 
247
  with col2:
248
- st.markdown("**How to Read Results**")
249
  st.markdown("""
250
- - **S (Susceptible)**: MIC ≤ breakpoint - antibiotic likely effective
251
- - **I (Intermediate)**: May work with higher doses
252
- - **R (Resistant)**: MIC > breakpoint - do not use
253
  """)
254
 
255
- if st.button("Interpret MIC", type="primary"):
256
  if pathogen and antibiotic:
257
- with st.spinner("Checking breakpoints..."):
258
- result = interpret_mic_value(pathogen, antibiotic, mic_value)
259
 
260
- interpretation = result.get("interpretation", "UNKNOWN")
 
 
 
 
 
261
 
262
- if interpretation == "SUSCEPTIBLE":
263
- st.success(f"✅ **{interpretation}**")
264
- elif interpretation == "RESISTANT":
265
- st.error(f"❌ **{interpretation}**")
266
- elif interpretation == "INTERMEDIATE":
267
- st.warning(f"⚠️ **{interpretation}**")
268
- else:
269
- st.info(f"❓ **{interpretation}**")
270
-
271
- st.markdown(f"**Details:** {result.get('message', '')}")
272
-
273
- if result.get("breakpoints"):
274
- bp = result["breakpoints"]
275
- st.markdown(f"""
276
- **Breakpoints:**
277
- - S ≤ {bp.get('susceptible', 'N/A')} mg/L
278
- - R > {bp.get('resistant', 'N/A')} mg/L
279
- """)
280
-
281
- if result.get("notes"):
282
- st.info(f"**Note:** {result.get('notes')}")
283
- else:
284
- st.warning("Please enter both pathogen and antibiotic names.")
285
 
286
 
287
  def show_mic_trend_analysis():
288
  st.header("📊 MIC Trend Analysis")
289
  st.markdown("*Detect MIC creep over time*")
290
 
291
- st.markdown("""
292
- Enter historical MIC values to detect resistance velocity.
293
- **MIC Creep**: A gradual increase in MIC that may predict treatment failure
294
- even when the organism is still classified as "Susceptible".
295
- """)
296
-
297
- # Input for historical MICs
298
- num_readings = st.slider("Number of historical readings", 2, 6, 3)
299
 
300
  mic_values = []
301
  cols = st.columns(num_readings)
302
 
303
  for i, col in enumerate(cols):
304
- with col:
305
- mic = col.number_input(
306
- f"MIC {i+1}",
307
- min_value=0.001,
308
- max_value=256.0,
309
- value=float(2 ** i), # Default: 1, 2, 4, ...
310
- key=f"mic_{i}"
311
- )
312
- mic_values.append({"date": f"T{i}", "mic_value": mic})
313
 
314
- if st.button("Analyze Trend", type="primary"):
315
  result = calculate_mic_trend(mic_values)
316
-
317
  risk_level = result.get("risk_level", "UNKNOWN")
318
 
319
  if risk_level == "HIGH":
320
- st.markdown(f'<div class="risk-high"><strong>🚨 HIGH RISK</strong><br>{result.get("alert", "")}</div>',
321
- unsafe_allow_html=True)
322
  elif risk_level == "MODERATE":
323
- st.markdown(f'<div class="risk-moderate"><strong>⚠️ MODERATE RISK</strong><br>{result.get("alert", "")}</div>',
324
- unsafe_allow_html=True)
325
  else:
326
- st.markdown(f'<div class="risk-low"><strong>✅ LOW RISK</strong><br>{result.get("alert", "")}</div>',
327
- unsafe_allow_html=True)
328
-
329
- st.divider()
330
 
331
  col1, col2, col3 = st.columns(3)
332
-
333
- with col1:
334
- st.metric("Baseline MIC", f"{result.get('baseline_mic', 'N/A')} mg/L")
335
- with col2:
336
- st.metric("Current MIC", f"{result.get('current_mic', 'N/A')} mg/L")
337
- with col3:
338
- st.metric("Fold Change", f"{result.get('ratio', 'N/A')}x")
339
-
340
- st.markdown(f"**Trend:** {result.get('trend', 'N/A')}")
341
- st.markdown(f"**Resistance Velocity:** {result.get('velocity', 'N/A')}x per time point")
342
 
343
 
344
  def show_drug_safety():
345
  st.header("⚠️ Drug Safety Check")
346
- st.markdown("*Screen for drug interactions*")
347
 
348
  col1, col2 = st.columns(2)
349
 
350
  with col1:
351
- antibiotic = st.text_input(
352
- "Proposed Antibiotic",
353
- placeholder="e.g., Ciprofloxacin"
354
- )
355
-
356
- current_meds = st.text_area(
357
- "Current Medications (one per line)",
358
- placeholder="Warfarin\nMetformin\nAmlodipine",
359
- height=150
360
- )
361
 
362
  with col2:
363
- allergies = st.text_area(
364
- "Known Allergies (one per line)",
365
- placeholder="Penicillin\nSulfa",
366
- height=100
367
- )
368
 
369
  if st.button("Check Safety", type="primary"):
370
  if antibiotic:
371
  medications = [m.strip() for m in current_meds.split("\n") if m.strip()]
372
  allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
373
 
374
- with st.spinner("Checking interactions..."):
375
- result = screen_antibiotic_safety(antibiotic, medications, allergy_list)
376
 
377
- if result.get("safe_to_use"):
378
- st.success("✅ No critical safety concerns identified")
379
- else:
380
- st.error("❌ SAFETY CONCERNS IDENTIFIED")
381
-
382
- # Show alerts
383
- if result.get("alerts"):
384
- st.subheader("Alerts")
385
- for alert in result["alerts"]:
386
- level = alert.get("level", "WARNING")
387
- if level == "CRITICAL":
388
- st.error(f"🚨 {alert.get('message', '')}")
389
- else:
390
- st.warning(f"⚠️ {alert.get('message', '')}")
391
-
392
- # Show allergy warnings
393
- if result.get("allergy_warnings"):
394
- st.subheader("Allergy Warnings")
395
- for warn in result["allergy_warnings"]:
396
- st.error(f"🚫 {warn.get('message', '')}")
397
-
398
- # Show interactions
399
- if result.get("interactions"):
400
- st.subheader("Drug Interactions Found")
401
- for interaction in result["interactions"][:5]:
402
- severity = interaction.get("severity", "unknown")
403
- icon = "🔴" if severity == "major" else "🟡" if severity == "moderate" else "🟢"
404
- st.markdown(f"""
405
- {icon} **{interaction.get('drug_1')}** ↔ **{interaction.get('drug_2')}**
406
- - Severity: {severity.upper()}
407
- - {interaction.get('interaction_description', '')}
408
- """)
409
- else:
410
- st.warning("Please enter an antibiotic name.")
411
 
412
 
413
  def show_guidelines_search():
414
- st.header("📚 Clinical Guidelines Search")
415
- st.markdown("*Search IDSA treatment guidelines*")
416
 
417
- query = st.text_input(
418
- "Search Query",
419
- placeholder="e.g., treatment for ESBL E. coli UTI"
420
- )
421
 
422
- pathogen_filter = st.selectbox(
423
- "Filter by Pathogen Type (optional)",
424
- ["All", "ESBL-E", "CRE", "CRAB", "DTR-PA", "S.maltophilia", "AmpC-E"]
425
- )
426
-
427
- if st.button("Search Guidelines", type="primary"):
428
  if query:
429
- with st.spinner("Searching clinical guidelines..."):
430
- filter_value = None if pathogen_filter == "All" else pathogen_filter
431
-
432
- results = search_clinical_guidelines(query, pathogen_filter=filter_value, n_results=5)
433
-
434
- if results:
435
- st.subheader(f"Found {len(results)} relevant excerpts")
436
-
437
- for i, result in enumerate(results, 1):
438
- with st.expander(
439
- f"Result {i} - {result.get('pathogen_type', 'General')} "
440
- f"(Relevance: {result.get('relevance_score', 0):.2f})"
441
- ):
442
- st.markdown(result.get("content", ""))
443
- st.caption(f"Source: {result.get('source', 'IDSA Guidelines')}")
444
- else:
445
- st.info("No results found. Try a different query or remove the filter.")
446
- else:
447
- st.warning("Please enter a search query.")
448
 
449
 
450
  if __name__ == "__main__":
 
1
  """
2
  Med-I-C: AMR-Guard Demo Application
3
  Infection Lifecycle Orchestrator - Streamlit Interface
4
+
5
+ Multi-Agent Architecture powered by MedGemma via LangGraph
6
  """
7
 
8
  import streamlit as st
9
  import sys
10
+ import json
11
  from pathlib import Path
12
 
13
  # Add project root to path
 
15
  sys.path.insert(0, str(PROJECT_ROOT))
16
 
17
  from src.tools import (
 
 
18
  interpret_mic_value,
 
 
19
  get_most_effective_antibiotics,
20
  calculate_mic_trend,
 
21
  screen_antibiotic_safety,
22
  search_clinical_guidelines,
 
23
  get_empirical_therapy_guidance,
24
  )
25
+ from src.utils import format_prescription_card
26
 
27
  # Page configuration
28
  st.set_page_config(
 
46
  color: #666;
47
  margin-top: 0;
48
  }
49
+ .agent-card {
50
+ background-color: #F5F5F5;
51
+ padding: 15px;
52
+ border-radius: 8px;
53
+ margin: 10px 0;
54
+ border-left: 4px solid #1E88E5;
55
+ }
56
+ .agent-active {
57
+ border-left-color: #4CAF50;
58
+ background-color: #E8F5E9;
59
+ }
60
+ .agent-complete {
61
+ border-left-color: #9E9E9E;
62
+ background-color: #FAFAFA;
63
+ }
64
  .risk-high {
65
  background-color: #FFCDD2;
66
  padding: 10px;
 
79
  border-radius: 5px;
80
  border-left: 4px solid #388E3C;
81
  }
82
+ .prescription-card {
83
+ background-color: #E3F2FD;
84
+ padding: 20px;
85
+ border-radius: 10px;
86
+ font-family: monospace;
87
+ white-space: pre-wrap;
88
+ }
89
  .info-box {
90
  background-color: #E3F2FD;
91
  padding: 15px;
 
99
  def main():
100
  # Header
101
  st.markdown('<p class="main-header">🦠 Med-I-C: AMR-Guard</p>', unsafe_allow_html=True)
102
+ st.markdown('<p class="sub-header">Infection Lifecycle Orchestrator - Multi-Agent System</p>', unsafe_allow_html=True)
103
 
104
  # Sidebar navigation
105
  st.sidebar.title("Navigation")
 
107
  "Select Module",
108
  [
109
  "🏠 Overview",
110
+ "🤖 Agent Pipeline",
111
+ "💊 Empirical Advisor",
112
+ "🔬 Lab Interpretation",
113
  "📊 MIC Trend Analysis",
114
  "⚠️ Drug Safety Check",
115
+ "📚 Clinical Guidelines"
116
  ]
117
  )
118
 
119
  if page == "🏠 Overview":
120
  show_overview()
121
+ elif page == "🤖 Agent Pipeline":
122
+ show_agent_pipeline()
123
+ elif page == "💊 Empirical Advisor":
124
  show_empirical_advisor()
125
+ elif page == "🔬 Lab Interpretation":
126
  show_lab_interpretation()
127
  elif page == "📊 MIC Trend Analysis":
128
  show_mic_trend_analysis()
129
  elif page == "⚠️ Drug Safety Check":
130
  show_drug_safety()
131
+ elif page == "📚 Clinical Guidelines":
132
  show_guidelines_search()
133
 
134
 
135
  def show_overview():
136
  st.header("System Overview")
137
 
138
+ st.markdown("""
139
+ **AMR-Guard** is a multi-agent AI system that orchestrates the complete infection treatment lifecycle,
140
+ from initial empirical therapy to targeted treatment based on lab results.
141
+ """)
142
+
143
+ # Architecture diagram
144
+ st.subheader("Multi-Agent Architecture")
145
+
146
  col1, col2 = st.columns(2)
147
 
148
  with col1:
 
149
  st.markdown("""
150
+ ### Stage 1: Empirical Phase
151
+ **Path:** Agent 1 → Agent 4
152
+
153
+ *Before lab results are available*
154
+
155
+ 1. **Intake Historian** (Agent 1)
156
+ - Parses patient demographics & history
157
+ - Calculates CrCl for renal dosing
158
+ - Identifies risk factors for MDR
159
+
160
+ 2. **Clinical Pharmacologist** (Agent 4)
161
+ - Recommends empirical antibiotics
162
+ - Applies WHO AWaRe principles
163
+ - Performs safety checks
164
  """)
165
 
166
  with col2:
 
167
  st.markdown("""
168
+ ### Stage 2: Targeted Phase
169
+ **Path:** Agent 1 → Agent 2 → Agent 3 → Agent 4
170
+
171
+ *When lab/culture results are available*
172
+
173
+ 1. **Intake Historian** (Agent 1)
174
+ 2. **Vision Specialist** (Agent 2)
175
+ - Extracts data from lab reports
176
+ - Supports any language/format
177
+ 3. **Trend Analyst** (Agent 3)
178
+ - Detects MIC creep patterns
179
+ - Calculates resistance velocity
180
+ 4. **Clinical Pharmacologist** (Agent 4)
181
  """)
182
 
183
  st.divider()
184
 
185
+ # Knowledge sources
186
  st.subheader("Knowledge Sources")
187
 
188
  col1, col2, col3, col4 = st.columns(4)
189
 
190
  with col1:
191
+ st.metric("WHO AWaRe", "264", "antibiotics classified")
192
  with col2:
193
+ st.metric("EUCAST", "v16.0", "breakpoint tables")
194
  with col3:
195
+ st.metric("IDSA", "2024", "treatment guidelines")
196
  with col4:
197
+ st.metric("DDInter", "191K+", "drug interactions")
198
+
199
+ # Model info
200
+ st.subheader("AI Models")
201
+ st.markdown("""
202
+ | Agent | Primary Model | Fallback |
203
+ |-------|---------------|----------|
204
+ | Intake Historian | MedGemma 4B IT | Vertex AI API |
205
+ | Vision Specialist | MedGemma 4B IT (multimodal) | Vertex AI API |
206
+ | Trend Analyst | MedGemma 4B IT | Vertex AI API |
207
+ | Clinical Pharmacologist | MedGemma 4B + TxGemma 2B (safety) | Vertex AI API |
208
+ """)
209
+
210
+
211
+ def show_agent_pipeline():
212
+ st.header("🤖 Multi-Agent Pipeline")
213
+ st.markdown("*Run the complete infection lifecycle workflow*")
214
+
215
+ # Initialize session state
216
+ if "pipeline_result" not in st.session_state:
217
+ st.session_state.pipeline_result = None
218
+
219
+ # Patient Information Form
220
+ with st.expander("Patient Information", expanded=True):
221
+ col1, col2, col3 = st.columns(3)
222
+
223
+ with col1:
224
+ age = st.number_input("Age (years)", min_value=0, max_value=120, value=65)
225
+ weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0)
226
+ height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
227
+
228
+ with col2:
229
+ sex = st.selectbox("Sex", ["male", "female"])
230
+ creatinine = st.number_input("Serum Creatinine (mg/dL)", min_value=0.1, max_value=20.0, value=1.2)
231
+
232
+ with col3:
233
+ infection_site = st.selectbox(
234
+ "Infection Site",
235
+ ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"]
236
+ )
237
+ suspected_source = st.text_input(
238
+ "Suspected Source",
239
+ placeholder="e.g., community UTI, hospital-acquired pneumonia"
240
+ )
241
+
242
+ with st.expander("Medical History"):
243
+ col1, col2 = st.columns(2)
244
+
245
+ with col1:
246
+ medications = st.text_area(
247
+ "Current Medications (one per line)",
248
+ placeholder="Metformin\nLisinopril\nAspirin",
249
+ height=100
250
+ )
251
+ allergies = st.text_area(
252
+ "Allergies (one per line)",
253
+ placeholder="Penicillin\nSulfa",
254
+ height=100
255
+ )
256
+
257
+ with col2:
258
+ comorbidities = st.multiselect(
259
+ "Comorbidities",
260
+ ["Diabetes", "CKD", "Heart Failure", "COPD", "Immunocompromised",
261
+ "Recent Surgery", "Malignancy", "Liver Disease"]
262
+ )
263
+ risk_factors = st.multiselect(
264
+ "MDR Risk Factors",
265
+ ["Prior MRSA infection", "Recent antibiotic use (<90 days)",
266
+ "Healthcare-associated", "Recent hospitalization",
267
+ "Nursing home resident", "Prior MDR infection"]
268
+ )
269
+
270
+ # Lab Data (Optional - triggers Stage 2)
271
+ with st.expander("Lab Results (Optional - triggers targeted pathway)"):
272
+ lab_input_method = st.radio(
273
+ "Input Method",
274
+ ["None (Empirical only)", "Paste Lab Text", "Upload File"],
275
+ horizontal=True
276
+ )
277
+
278
+ labs_raw_text = None
279
+
280
+ if lab_input_method == "Paste Lab Text":
281
+ labs_raw_text = st.text_area(
282
+ "Lab Report Text",
283
+ placeholder="""Example:
284
+ Culture: Urine
285
+ Organism: Escherichia coli
286
+ Colony Count: >100,000 CFU/mL
287
+
288
+ Susceptibility:
289
+ Ampicillin: R (MIC >32)
290
+ Ciprofloxacin: S (MIC 0.25)
291
+ Nitrofurantoin: S (MIC 16)
292
+ Trimethoprim-Sulfamethoxazole: R (MIC >4)""",
293
+ height=200
294
+ )
295
+
296
+ elif lab_input_method == "Upload File":
297
+ uploaded_file = st.file_uploader(
298
+ "Upload Lab Report (PDF or Image)",
299
+ type=["pdf", "png", "jpg", "jpeg"]
300
+ )
301
+ if uploaded_file:
302
+ st.info("File uploaded. Text extraction will be performed by the Vision Specialist agent.")
303
+ # In production, would extract text here
304
+ labs_raw_text = f"[Uploaded file: {uploaded_file.name}]"
305
+
306
+ # Run Pipeline Button
307
+ st.divider()
308
+
309
+ col1, col2, col3 = st.columns([1, 2, 1])
310
+ with col2:
311
+ run_pipeline_btn = st.button(
312
+ "🚀 Run Agent Pipeline",
313
+ type="primary",
314
+ use_container_width=True
315
+ )
316
+
317
+ if run_pipeline_btn:
318
+ # Build patient data
319
+ patient_data = {
320
+ "age_years": age,
321
+ "weight_kg": weight,
322
+ "height_cm": height,
323
+ "sex": sex,
324
+ "serum_creatinine_mg_dl": creatinine,
325
+ "infection_site": infection_site,
326
+ "suspected_source": suspected_source or f"{infection_site} infection",
327
+ "medications": [m.strip() for m in medications.split("\n") if m.strip()],
328
+ "allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
329
+ "comorbidities": list(comorbidities) + list(risk_factors),
330
+ }
331
+
332
+ # Show pipeline progress
333
+ st.subheader("Pipeline Execution")
334
+
335
+ # Agent progress indicators
336
+ agents = [
337
+ ("Intake Historian", "Analyzing patient data..."),
338
+ ("Vision Specialist", "Processing lab results...") if labs_raw_text else None,
339
+ ("Trend Analyst", "Analyzing MIC trends...") if labs_raw_text else None,
340
+ ("Clinical Pharmacologist", "Generating recommendations..."),
341
+ ]
342
+ agents = [a for a in agents if a is not None]
343
+
344
+ progress_bar = st.progress(0)
345
+ status_text = st.empty()
346
+
347
+ # Simulate pipeline execution (in production, would call actual pipeline)
348
+ try:
349
+ # Try to import and run the actual pipeline
350
+ from src.graph import run_pipeline
351
+
352
+ for i, (agent_name, status_msg) in enumerate(agents):
353
+ status_text.text(f"Agent {i+1}/{len(agents)}: {agent_name} - {status_msg}")
354
+ progress_bar.progress((i + 1) / len(agents))
355
+
356
+ # Run the actual pipeline
357
+ result = run_pipeline(patient_data, labs_raw_text)
358
+ st.session_state.pipeline_result = result
359
+
360
+ except Exception as e:
361
+ st.error(f"Pipeline execution error: {e}")
362
+ st.info("Running in demo mode with simulated output...")
363
+
364
+ # Demo mode - simulate results
365
+ st.session_state.pipeline_result = _generate_demo_result(patient_data, labs_raw_text)
366
+
367
+ progress_bar.progress(100)
368
+ status_text.text("Pipeline complete!")
369
+
370
+ # Display Results
371
+ if st.session_state.pipeline_result:
372
+ result = st.session_state.pipeline_result
373
+
374
+ st.divider()
375
+ st.subheader("Pipeline Results")
376
+
377
+ # Tabs for different result sections
378
+ tab1, tab2, tab3, tab4 = st.tabs([
379
+ "📋 Recommendation",
380
+ "👤 Patient Summary",
381
+ "🔬 Lab Analysis",
382
+ "⚠️ Safety Alerts"
383
+ ])
384
+
385
+ with tab1:
386
+ rec = result.get("recommendation", {})
387
+ if rec:
388
+ st.markdown("### Antibiotic Recommendation")
389
+
390
+ col1, col2 = st.columns(2)
391
+
392
+ with col1:
393
+ st.markdown(f"**Primary:** {rec.get('primary_antibiotic', 'N/A')}")
394
+ st.markdown(f"**Dose:** {rec.get('dose', 'N/A')}")
395
+ st.markdown(f"**Route:** {rec.get('route', 'N/A')}")
396
+ st.markdown(f"**Frequency:** {rec.get('frequency', 'N/A')}")
397
+ st.markdown(f"**Duration:** {rec.get('duration', 'N/A')}")
398
+
399
+ with col2:
400
+ if rec.get("backup_antibiotic"):
401
+ st.markdown(f"**Alternative:** {rec.get('backup_antibiotic')}")
402
+
403
+ st.markdown("---")
404
+ st.markdown("**Rationale:**")
405
+ st.markdown(rec.get("rationale", "No rationale provided"))
406
+
407
+ if rec.get("references"):
408
+ st.markdown("**References:**")
409
+ for ref in rec["references"]:
410
+ st.markdown(f"- {ref}")
411
+
412
+ with tab2:
413
+ st.markdown("### Patient Assessment")
414
+ intake_notes = result.get("intake_notes", "")
415
+ if intake_notes:
416
+ try:
417
+ intake_data = json.loads(intake_notes) if isinstance(intake_notes, str) else intake_notes
418
+ st.json(intake_data)
419
+ except:
420
+ st.text(intake_notes)
421
+
422
+ if result.get("creatinine_clearance_ml_min"):
423
+ st.metric("Calculated CrCl", f"{result['creatinine_clearance_ml_min']} mL/min")
424
+
425
+ with tab3:
426
+ st.markdown("### Laboratory Analysis")
427
+
428
+ vision_notes = result.get("vision_notes", "No lab data processed")
429
+ if vision_notes and vision_notes != "No lab data provided":
430
+ try:
431
+ vision_data = json.loads(vision_notes) if isinstance(vision_notes, str) else vision_notes
432
+ st.json(vision_data)
433
+ except:
434
+ st.text(vision_notes)
435
+
436
+ trend_notes = result.get("trend_notes", "")
437
+ if trend_notes and trend_notes != "No MIC data available for trend analysis":
438
+ st.markdown("#### MIC Trend Analysis")
439
+ try:
440
+ trend_data = json.loads(trend_notes) if isinstance(trend_notes, str) else trend_notes
441
+ st.json(trend_data)
442
+ except:
443
+ st.text(trend_notes)
444
+
445
+ with tab4:
446
+ st.markdown("### Safety Alerts")
447
+
448
+ warnings = result.get("safety_warnings", [])
449
+ if warnings:
450
+ for warning in warnings:
451
+ st.warning(f"⚠️ {warning}")
452
+ else:
453
+ st.success("No safety concerns identified")
454
+
455
+ errors = result.get("errors", [])
456
+ if errors:
457
+ st.markdown("#### Errors")
458
+ for error in errors:
459
+ st.error(error)
460
+
461
+
462
+ def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict:
463
+ """Generate demo result when actual pipeline is not available."""
464
+ result = {
465
+ "stage": "targeted" if labs_raw_text else "empirical",
466
+ "creatinine_clearance_ml_min": 58.3,
467
+ "intake_notes": json.dumps({
468
+ "patient_summary": f"65-year-old male with {patient_data.get('suspected_source', 'infection')}",
469
+ "creatinine_clearance_ml_min": 58.3,
470
+ "renal_dose_adjustment_needed": True,
471
+ "identified_risk_factors": patient_data.get("comorbidities", []),
472
+ "infection_severity": "moderate",
473
+ "recommended_stage": "targeted" if labs_raw_text else "empirical",
474
+ }),
475
+ "recommendation": {
476
+ "primary_antibiotic": "Ciprofloxacin",
477
+ "dose": "500mg",
478
+ "route": "PO",
479
+ "frequency": "Every 12 hours",
480
+ "duration": "7 days",
481
+ "backup_antibiotic": "Nitrofurantoin",
482
+ "rationale": "Based on suspected community-acquired UTI with moderate renal impairment. Ciprofloxacin provides good coverage for common uropathogens. Dose adjusted for CrCl 58 mL/min.",
483
+ "references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
484
+ "safety_alerts": [],
485
+ },
486
+ "safety_warnings": [],
487
+ "errors": [],
488
+ }
489
+
490
+ if labs_raw_text:
491
+ result["vision_notes"] = json.dumps({
492
+ "specimen_type": "urine",
493
+ "identified_organisms": [{"organism_name": "Escherichia coli", "significance": "pathogen"}],
494
+ "susceptibility_results": [
495
+ {"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
496
+ {"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
497
+ ],
498
+ "extraction_confidence": 0.95,
499
+ })
500
+ result["trend_notes"] = json.dumps([{
501
+ "organism": "E. coli",
502
+ "antibiotic": "Ciprofloxacin",
503
+ "risk_level": "LOW",
504
+ "recommendation": "Continue current therapy",
505
+ }])
506
+
507
+ return result
508
 
509
 
510
  def show_empirical_advisor():
511
+ st.header("💊 Empirical Advisor")
512
+ st.markdown("*Get empirical therapy recommendations before lab results*")
513
 
514
  col1, col2 = st.columns([2, 1])
515
 
516
  with col1:
517
  infection_type = st.selectbox(
518
  "Infection Type",
519
+ ["Urinary Tract Infection", "Pneumonia", "Sepsis",
520
  "Skin/Soft Tissue", "Intra-abdominal", "Meningitis"]
521
  )
522
 
 
533
  )
534
 
535
  with col2:
536
+ st.markdown("**WHO AWaRe Categories**")
537
  st.markdown("""
538
  - **ACCESS**: First-line, low resistance
539
  - **WATCH**: Higher resistance potential
540
  - **RESERVE**: Last resort antibiotics
541
  """)
542
 
543
+ if st.button("Get Recommendation", type="primary"):
544
+ with st.spinner("Searching guidelines..."):
 
545
  guidance = get_empirical_therapy_guidance(
546
+ infection_type,
547
  risk_factors
548
  )
549
 
550
+ st.subheader("Guideline Recommendations")
551
 
552
  if guidance.get("recommendations"):
553
  for i, rec in enumerate(guidance["recommendations"][:3], 1):
554
+ with st.expander(f"Excerpt {i} (Relevance: {rec.get('relevance_score', 0):.2f})"):
555
  st.markdown(rec.get("content", ""))
556
  st.caption(f"Source: {rec.get('source', 'IDSA Guidelines')}")
557
 
 
558
  if suspected_pathogen:
559
+ st.subheader(f"Resistance Data: {suspected_pathogen}")
 
560
  effective = get_most_effective_antibiotics(suspected_pathogen, min_susceptibility=70)
561
 
562
  if effective:
 
563
  for ab in effective[:5]:
564
  st.write(f"- **{ab.get('antibiotic')}**: {ab.get('avg_susceptibility', 0):.1f}% susceptible")
565
  else:
566
+ st.info("No resistance data found.")
567
 
568
 
569
  def show_lab_interpretation():
570
+ st.header("🔬 Lab Interpretation")
571
  st.markdown("*Interpret antibiogram MIC values*")
572
 
573
  col1, col2 = st.columns(2)
574
 
575
  with col1:
576
+ pathogen = st.text_input("Pathogen", placeholder="e.g., Escherichia coli")
577
+ antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
578
+ mic_value = st.number_input("MIC (mg/L)", min_value=0.001, max_value=1024.0, value=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
  with col2:
581
+ st.markdown("**Interpretation Guide**")
582
  st.markdown("""
583
+ - **S**: Susceptible - antibiotic effective
584
+ - **I**: Intermediate - may work at higher doses
585
+ - **R**: Resistant - do not use
586
  """)
587
 
588
+ if st.button("Interpret", type="primary"):
589
  if pathogen and antibiotic:
590
+ result = interpret_mic_value(pathogen, antibiotic, mic_value)
591
+ interpretation = result.get("interpretation", "UNKNOWN")
592
 
593
+ if interpretation == "SUSCEPTIBLE":
594
+ st.success(f"✅ {interpretation}")
595
+ elif interpretation == "RESISTANT":
596
+ st.error(f"❌ {interpretation}")
597
+ else:
598
+ st.warning(f"⚠️ {interpretation}")
599
 
600
+ st.markdown(f"**Details:** {result.get('message', '')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
 
603
  def show_mic_trend_analysis():
604
  st.header("📊 MIC Trend Analysis")
605
  st.markdown("*Detect MIC creep over time*")
606
 
607
+ num_readings = st.slider("Historical readings", 2, 6, 3)
 
 
 
 
 
 
 
608
 
609
  mic_values = []
610
  cols = st.columns(num_readings)
611
 
612
  for i, col in enumerate(cols):
613
+ mic = col.number_input(f"MIC {i+1}", min_value=0.001, max_value=256.0, value=float(2 ** i), key=f"mic_{i}")
614
+ mic_values.append({"date": f"T{i}", "mic_value": mic})
 
 
 
 
 
 
 
615
 
616
+ if st.button("Analyze", type="primary"):
617
  result = calculate_mic_trend(mic_values)
 
618
  risk_level = result.get("risk_level", "UNKNOWN")
619
 
620
  if risk_level == "HIGH":
621
+ st.markdown(f'<div class="risk-high">🚨 HIGH RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
 
622
  elif risk_level == "MODERATE":
623
+ st.markdown(f'<div class="risk-moderate">⚠️ MODERATE: {result.get("alert", "")}</div>', unsafe_allow_html=True)
 
624
  else:
625
+ st.markdown(f'<div class="risk-low">✅ LOW RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
 
 
 
626
 
627
  col1, col2, col3 = st.columns(3)
628
+ col1.metric("Baseline", f"{result.get('baseline_mic', 'N/A')} mg/L")
629
+ col2.metric("Current", f"{result.get('current_mic', 'N/A')} mg/L")
630
+ col3.metric("Fold Change", f"{result.get('ratio', 'N/A')}x")
 
 
 
 
 
 
 
631
 
632
 
633
  def show_drug_safety():
634
  st.header("⚠️ Drug Safety Check")
 
635
 
636
  col1, col2 = st.columns(2)
637
 
638
  with col1:
639
+ antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
640
+ current_meds = st.text_area("Current Medications", placeholder="Warfarin\nMetformin", height=150)
 
 
 
 
 
 
 
 
641
 
642
  with col2:
643
+ allergies = st.text_area("Allergies", placeholder="Penicillin\nSulfa", height=100)
 
 
 
 
644
 
645
  if st.button("Check Safety", type="primary"):
646
  if antibiotic:
647
  medications = [m.strip() for m in current_meds.split("\n") if m.strip()]
648
  allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
649
 
650
+ result = screen_antibiotic_safety(antibiotic, medications, allergy_list)
 
651
 
652
+ if result.get("safe_to_use"):
653
+ st.success("✅ No critical safety concerns")
654
+ else:
655
+ st.error("❌ Safety concerns identified")
656
+
657
+ for alert in result.get("alerts", []):
658
+ st.warning(f"⚠️ {alert.get('message', '')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
 
661
  def show_guidelines_search():
662
+ st.header("📚 Clinical Guidelines")
 
663
 
664
+ query = st.text_input("Search", placeholder="e.g., ESBL E. coli UTI treatment")
665
+ pathogen_filter = st.selectbox("Pathogen Filter", ["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"])
 
 
666
 
667
+ if st.button("Search", type="primary"):
 
 
 
 
 
668
  if query:
669
+ filter_val = None if pathogen_filter == "All" else pathogen_filter
670
+ results = search_clinical_guidelines(query, pathogen_filter=filter_val, n_results=5)
671
+
672
+ if results:
673
+ for i, r in enumerate(results, 1):
674
+ with st.expander(f"Result {i} (Relevance: {r.get('relevance_score', 0):.2f})"):
675
+ st.markdown(r.get("content", ""))
676
+ else:
677
+ st.info("No results found.")
 
 
 
 
 
 
 
 
 
 
678
 
679
 
680
  if __name__ == "__main__":
src/agents.py CHANGED
@@ -1,16 +1,574 @@
1
- import os
2
- from langchain.agents import create_agent
3
- from langchain.chat_models import init_chat_model
4
- from dotenv import load_dotenv
5
-
6
- os.environ["GOOGLE_API_KEY"] = load_dotenv().get("GOOGLE_API_KEY")
7
-
8
- model = init_chat_model(
9
- "google_genai:gemini-2.5-flash-lite",
10
- # Kwargs passed to the model:
11
- temperature=0.7,
12
- timeout=30,
13
- max_tokens=1000,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- Intake_Historian = create_agent(model=model, tools=["google_search"], verbose=True)
 
 
 
 
 
 
 
 
1
+ """
2
+ Med-I-C Multi-Agent System.
3
+
4
+ Implements the 4 specialized agents for the infection lifecycle workflow:
5
+ - Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
6
+ - Agent 2: Vision Specialist - Extract structured data from lab reports
7
+ - Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
8
+ - Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ from typing import Any, Dict, Optional
16
+
17
+ from .config import get_settings
18
+ from .loader import run_inference, TextModelName
19
+ from .prompts import (
20
+ INTAKE_HISTORIAN_SYSTEM,
21
+ INTAKE_HISTORIAN_PROMPT,
22
+ VISION_SPECIALIST_SYSTEM,
23
+ VISION_SPECIALIST_PROMPT,
24
+ TREND_ANALYST_SYSTEM,
25
+ TREND_ANALYST_PROMPT,
26
+ CLINICAL_PHARMACOLOGIST_SYSTEM,
27
+ CLINICAL_PHARMACOLOGIST_PROMPT,
28
+ TXGEMMA_SAFETY_PROMPT,
29
  )
30
+ from .rag import get_context_for_agent
31
+ from .state import InfectionState
32
+ from .utils import (
33
+ calculate_crcl,
34
+ get_renal_dose_category,
35
+ safe_json_parse,
36
+ normalize_organism_name,
37
+ normalize_antibiotic_name,
38
+ )
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ # =============================================================================
44
+ # AGENT 1: INTAKE HISTORIAN
45
+ # =============================================================================
46
+
47
+ def run_intake_historian(state: InfectionState) -> InfectionState:
48
+ """
49
+ Agent 1: Parse patient data, calculate CrCl, identify risk factors.
50
+
51
+ Input state fields used:
52
+ - age_years, weight_kg, height_cm, sex
53
+ - serum_creatinine_mg_dl
54
+ - medications, allergies, comorbidities
55
+ - suspected_source, infection_site
56
+
57
+ Output state fields updated:
58
+ - creatinine_clearance_ml_min
59
+ - intake_notes
60
+ - stage (empirical/targeted)
61
+ - route_to_vision
62
+ """
63
+ logger.info("Running Intake Historian agent...")
64
+
65
+ # Calculate CrCl if we have the required data
66
+ crcl = None
67
+ if all([
68
+ state.get("age_years"),
69
+ state.get("weight_kg"),
70
+ state.get("serum_creatinine_mg_dl"),
71
+ state.get("sex"),
72
+ ]):
73
+ try:
74
+ crcl = calculate_crcl(
75
+ age_years=state["age_years"],
76
+ weight_kg=state["weight_kg"],
77
+ serum_creatinine_mg_dl=state["serum_creatinine_mg_dl"],
78
+ sex=state["sex"],
79
+ use_ibw=True,
80
+ height_cm=state.get("height_cm"),
81
+ )
82
+ state["creatinine_clearance_ml_min"] = crcl
83
+ logger.info(f"Calculated CrCl: {crcl} mL/min")
84
+ except Exception as e:
85
+ logger.warning(f"Could not calculate CrCl: {e}")
86
+ state.setdefault("errors", []).append(f"CrCl calculation error: {e}")
87
+
88
+ # Build patient data string for prompt
89
+ patient_data = _format_patient_data(state)
90
+
91
+ # Get RAG context
92
+ query = f"treatment {state.get('suspected_source', '')} {state.get('infection_site', '')}"
93
+ rag_context = get_context_for_agent(
94
+ agent_name="intake_historian",
95
+ query=query,
96
+ patient_context={
97
+ "pathogen_type": state.get("suspected_source"),
98
+ },
99
+ )
100
+
101
+ # Format the prompt
102
+ prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
103
+ patient_data=patient_data,
104
+ medications=', '.join(state.get('medications', [])) or 'None reported',
105
+ allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
106
+ infection_site=state.get('infection_site', 'Unknown'),
107
+ suspected_source=state.get('suspected_source', 'Unknown'),
108
+ rag_context=rag_context,
109
+ )}"
110
+
111
+ # Run inference
112
+ try:
113
+ response = run_inference(
114
+ prompt=prompt,
115
+ model_name="medgemma_4b",
116
+ max_new_tokens=1024,
117
+ temperature=0.2,
118
+ )
119
+
120
+ # Parse response
121
+ parsed = safe_json_parse(response)
122
+ if parsed:
123
+ state["intake_notes"] = json.dumps(parsed, indent=2)
124
+
125
+ # Update state from parsed response
126
+ if parsed.get("creatinine_clearance_ml_min") and crcl is None:
127
+ state["creatinine_clearance_ml_min"] = parsed["creatinine_clearance_ml_min"]
128
+
129
+ # Determine stage
130
+ recommended_stage = parsed.get("recommended_stage", "empirical")
131
+ state["stage"] = recommended_stage
132
+
133
+ # Route to vision if we have lab data to process
134
+ state["route_to_vision"] = bool(state.get("labs_raw_text"))
135
+ else:
136
+ state["intake_notes"] = response
137
+ state["stage"] = "empirical"
138
+ state["route_to_vision"] = bool(state.get("labs_raw_text"))
139
+
140
+ except Exception as e:
141
+ logger.error(f"Intake Historian error: {e}")
142
+ state.setdefault("errors", []).append(f"Intake Historian error: {e}")
143
+ state["intake_notes"] = f"Error: {e}"
144
+ state["stage"] = "empirical"
145
+
146
+ logger.info(f"Intake Historian complete. Stage: {state.get('stage')}")
147
+ return state
148
+
149
+
150
+ # =============================================================================
151
+ # AGENT 2: VISION SPECIALIST
152
+ # =============================================================================
153
+
154
+ def run_vision_specialist(state: InfectionState) -> InfectionState:
155
+ """
156
+ Agent 2: Extract structured data from lab reports (text, images, PDFs).
157
+
158
+ Input state fields used:
159
+ - labs_raw_text (extracted text from lab report)
160
+
161
+ Output state fields updated:
162
+ - labs_parsed
163
+ - mic_data
164
+ - vision_notes
165
+ - route_to_trend_analyst
166
+ """
167
+ logger.info("Running Vision Specialist agent...")
168
+
169
+ labs_raw = state.get("labs_raw_text", "")
170
+ if not labs_raw:
171
+ logger.info("No lab data to process, skipping Vision Specialist")
172
+ state["vision_notes"] = "No lab data provided"
173
+ state["route_to_trend_analyst"] = False
174
+ return state
175
+
176
+ # Detect language (simplified - in production would use langdetect)
177
+ language = "English (assumed)"
178
+
179
+ # Get RAG context for lab interpretation
180
+ rag_context = get_context_for_agent(
181
+ agent_name="vision_specialist",
182
+ query="culture sensitivity susceptibility interpretation",
183
+ patient_context={},
184
+ )
185
+
186
+ # Format the prompt
187
+ prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
188
+ report_content=labs_raw,
189
+ source_format='text',
190
+ language=language,
191
+ )}"
192
+
193
+ # Run inference
194
+ try:
195
+ response = run_inference(
196
+ prompt=prompt,
197
+ model_name="medgemma_4b",
198
+ max_new_tokens=2048,
199
+ temperature=0.1,
200
+ )
201
+
202
+ # Parse response
203
+ parsed = safe_json_parse(response)
204
+ if parsed:
205
+ state["vision_notes"] = json.dumps(parsed, indent=2)
206
+
207
+ # Extract organisms and susceptibility data
208
+ organisms = parsed.get("identified_organisms", [])
209
+ susceptibility = parsed.get("susceptibility_results", [])
210
+
211
+ # Convert to MICDatum format
212
+ mic_data = []
213
+ for result in susceptibility:
214
+ mic_datum = {
215
+ "organism": normalize_organism_name(result.get("organism", "")),
216
+ "antibiotic": normalize_antibiotic_name(result.get("antibiotic", "")),
217
+ "mic_value": str(result.get("mic_value", "")),
218
+ "mic_unit": result.get("mic_unit", "mg/L"),
219
+ "interpretation": result.get("interpretation"),
220
+ }
221
+ mic_data.append(mic_datum)
222
+
223
+ state["mic_data"] = mic_data
224
+ state["labs_parsed"] = [{
225
+ "name": org.get("organism_name", "Unknown"),
226
+ "value": org.get("colony_count", ""),
227
+ "flag": "pathogen" if org.get("significance") == "pathogen" else None,
228
+ } for org in organisms]
229
+
230
+ # Route to trend analyst if we have MIC data
231
+ state["route_to_trend_analyst"] = len(mic_data) > 0
232
+
233
+ # Check for critical findings
234
+ critical = parsed.get("critical_findings", [])
235
+ if critical:
236
+ state.setdefault("safety_warnings", []).extend(critical)
237
+
238
+ else:
239
+ state["vision_notes"] = response
240
+ state["route_to_trend_analyst"] = False
241
+
242
+ except Exception as e:
243
+ logger.error(f"Vision Specialist error: {e}")
244
+ state.setdefault("errors", []).append(f"Vision Specialist error: {e}")
245
+ state["vision_notes"] = f"Error: {e}"
246
+ state["route_to_trend_analyst"] = False
247
+
248
+ logger.info(f"Vision Specialist complete. MIC data points: {len(state.get('mic_data', []))}")
249
+ return state
250
+
251
+
252
+ # =============================================================================
253
+ # AGENT 3: TREND ANALYST
254
+ # =============================================================================
255
+
256
+ def run_trend_analyst(state: InfectionState) -> InfectionState:
257
+ """
258
+ Agent 3: Analyze MIC trends and detect resistance velocity.
259
+
260
+ Input state fields used:
261
+ - mic_data (current MIC readings)
262
+ - Historical MIC data (if available)
263
+
264
+ Output state fields updated:
265
+ - mic_trend_summary
266
+ - trend_notes
267
+ - safety_warnings (if high risk detected)
268
+ """
269
+ logger.info("Running Trend Analyst agent...")
270
+
271
+ mic_data = state.get("mic_data", [])
272
+ if not mic_data:
273
+ logger.info("No MIC data to analyze, skipping Trend Analyst")
274
+ state["trend_notes"] = "No MIC data available for trend analysis"
275
+ return state
276
+
277
+ # For each organism-antibiotic pair, analyze trends
278
+ trend_results = []
279
+
280
+ for mic in mic_data:
281
+ organism = mic.get("organism", "Unknown")
282
+ antibiotic = mic.get("antibiotic", "Unknown")
283
+
284
+ # Get RAG context for breakpoints
285
+ rag_context = get_context_for_agent(
286
+ agent_name="trend_analyst",
287
+ query=f"breakpoint {organism} {antibiotic}",
288
+ patient_context={
289
+ "organism": organism,
290
+ "antibiotic": antibiotic,
291
+ "region": state.get("country_or_region"),
292
+ },
293
+ )
294
+
295
+ # Format MIC history (in production, would pull from database)
296
+ mic_history = [{"date": "current", "mic_value": mic.get("mic_value", "0")}]
297
+
298
+ # Format prompt
299
+ prompt = f"{TREND_ANALYST_SYSTEM}\n\n{TREND_ANALYST_PROMPT.format(
300
+ organism=organism,
301
+ antibiotic=antibiotic,
302
+ mic_history=json.dumps(mic_history, indent=2),
303
+ breakpoint_data=rag_context,
304
+ resistance_context='Regional data not available',
305
+ )}"
306
+
307
+ try:
308
+ response = run_inference(
309
+ prompt=prompt,
310
+ model_name="medgemma_4b",
311
+ max_new_tokens=1024,
312
+ temperature=0.2,
313
+ )
314
+
315
+ parsed = safe_json_parse(response)
316
+ if parsed:
317
+ trend_results.append(parsed)
318
+
319
+ # Add safety warning if high/critical risk
320
+ risk_level = parsed.get("risk_level", "LOW")
321
+ if risk_level in ["HIGH", "CRITICAL"]:
322
+ warning = f"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}"
323
+ state.setdefault("safety_warnings", []).append(warning)
324
+ else:
325
+ trend_results.append({"raw_response": response})
326
+
327
+ except Exception as e:
328
+ logger.error(f"Trend analysis error for {organism}/{antibiotic}: {e}")
329
+ trend_results.append({"error": str(e)})
330
+
331
+ # Summarize trends
332
+ state["trend_notes"] = json.dumps(trend_results, indent=2)
333
+
334
+ # Create summary
335
+ high_risk_count = sum(1 for t in trend_results if t.get("risk_level") in ["HIGH", "CRITICAL"])
336
+ state["mic_trend_summary"] = f"Analyzed {len(trend_results)} organism-antibiotic pairs. High-risk findings: {high_risk_count}"
337
+
338
+ logger.info(f"Trend Analyst complete. {state['mic_trend_summary']}")
339
+ return state
340
+
341
+
342
+ # =============================================================================
343
+ # AGENT 4: CLINICAL PHARMACOLOGIST
344
+ # =============================================================================
345
+
346
+ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
347
+ """
348
+ Agent 4: Generate final antibiotic recommendation with safety checks.
349
+
350
+ Input state fields used:
351
+ - intake_notes, vision_notes, trend_notes
352
+ - age_years, weight_kg, creatinine_clearance_ml_min
353
+ - allergies, medications
354
+ - infection_site, suspected_source
355
+
356
+ Output state fields updated:
357
+ - recommendation
358
+ - pharmacology_notes
359
+ - safety_warnings (additional alerts)
360
+ """
361
+ logger.info("Running Clinical Pharmacologist agent...")
362
+
363
+ # Gather all previous agent outputs
364
+ intake_summary = state.get("intake_notes", "No intake data")
365
+ lab_results = state.get("vision_notes", "No lab data")
366
+ trend_analysis = state.get("trend_notes", "No trend data")
367
+
368
+ # Get RAG context
369
+ query = f"treatment {state.get('suspected_source', '')} antibiotic recommendation"
370
+ rag_context = get_context_for_agent(
371
+ agent_name="clinical_pharmacologist",
372
+ query=query,
373
+ patient_context={
374
+ "proposed_antibiotic": None, # Will be determined by agent
375
+ },
376
+ )
377
+
378
+ # Format prompt
379
+ prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
380
+ intake_summary=intake_summary,
381
+ lab_results=lab_results,
382
+ trend_analysis=trend_analysis,
383
+ age=state.get('age_years', 'Unknown'),
384
+ weight=state.get('weight_kg', 'Unknown'),
385
+ crcl=state.get('creatinine_clearance_ml_min', 'Unknown'),
386
+ allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
387
+ current_medications=', '.join(state.get('medications', [])) or 'None reported',
388
+ infection_site=state.get('infection_site', 'Unknown'),
389
+ suspected_source=state.get('suspected_source', 'Unknown'),
390
+ severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
391
+ rag_context=rag_context,
392
+ )}"
393
+
394
+ try:
395
+ response = run_inference(
396
+ prompt=prompt,
397
+ model_name="medgemma_4b",
398
+ max_new_tokens=2048,
399
+ temperature=0.2,
400
+ )
401
+
402
+ parsed = safe_json_parse(response)
403
+ if parsed:
404
+ state["pharmacology_notes"] = json.dumps(parsed, indent=2)
405
+
406
+ # Build recommendation
407
+ primary = parsed.get("primary_recommendation", {})
408
+ recommendation = {
409
+ "primary_antibiotic": primary.get("antibiotic"),
410
+ "dose": primary.get("dose"),
411
+ "route": primary.get("route"),
412
+ "frequency": primary.get("frequency"),
413
+ "duration": primary.get("duration"),
414
+ "rationale": parsed.get("rationale"),
415
+ "references": parsed.get("guideline_references", []),
416
+ "safety_alerts": [a.get("message") for a in parsed.get("safety_alerts", [])],
417
+ }
418
+
419
+ # Add alternative if provided
420
+ alt = parsed.get("alternative_recommendation", {})
421
+ if alt.get("antibiotic"):
422
+ recommendation["backup_antibiotic"] = alt.get("antibiotic")
423
+
424
+ state["recommendation"] = recommendation
425
+
426
+ # Add safety alerts to state
427
+ for alert in parsed.get("safety_alerts", []):
428
+ if alert.get("level") in ["WARNING", "CRITICAL"]:
429
+ state.setdefault("safety_warnings", []).append(alert.get("message"))
430
+
431
+ # Run TxGemma safety check (optional)
432
+ if primary.get("antibiotic"):
433
+ safety_result = _run_txgemma_safety_check(
434
+ antibiotic=primary.get("antibiotic"),
435
+ dose=primary.get("dose"),
436
+ route=primary.get("route"),
437
+ duration=primary.get("duration"),
438
+ age=state.get("age_years"),
439
+ crcl=state.get("creatinine_clearance_ml_min"),
440
+ medications=state.get("medications", []),
441
+ )
442
+ if safety_result:
443
+ state.setdefault("debug_log", []).append(f"TxGemma safety: {safety_result}")
444
+
445
+ else:
446
+ state["pharmacology_notes"] = response
447
+ state["recommendation"] = {"rationale": response}
448
+
449
+ except Exception as e:
450
+ logger.error(f"Clinical Pharmacologist error: {e}")
451
+ state.setdefault("errors", []).append(f"Clinical Pharmacologist error: {e}")
452
+ state["pharmacology_notes"] = f"Error: {e}"
453
+
454
+ logger.info("Clinical Pharmacologist complete.")
455
+ return state
456
+
457
+
458
+ # =============================================================================
459
+ # HELPER FUNCTIONS
460
+ # =============================================================================
461
+
462
+ def _format_patient_data(state: InfectionState) -> str:
463
+ """Format patient data for prompt injection."""
464
+ lines = []
465
+
466
+ if state.get("patient_id"):
467
+ lines.append(f"Patient ID: {state['patient_id']}")
468
+
469
+ demographics = []
470
+ if state.get("age_years"):
471
+ demographics.append(f"{state['age_years']} years old")
472
+ if state.get("sex"):
473
+ demographics.append(state["sex"])
474
+ if demographics:
475
+ lines.append(f"Demographics: {', '.join(demographics)}")
476
+
477
+ if state.get("weight_kg"):
478
+ lines.append(f"Weight: {state['weight_kg']} kg")
479
+ if state.get("height_cm"):
480
+ lines.append(f"Height: {state['height_cm']} cm")
481
+
482
+ if state.get("serum_creatinine_mg_dl"):
483
+ lines.append(f"Serum Creatinine: {state['serum_creatinine_mg_dl']} mg/dL")
484
+ if state.get("creatinine_clearance_ml_min"):
485
+ crcl = state["creatinine_clearance_ml_min"]
486
+ category = get_renal_dose_category(crcl)
487
+ lines.append(f"CrCl: {crcl} mL/min ({category})")
488
+
489
+ if state.get("comorbidities"):
490
+ lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
491
+
492
+ if state.get("vitals"):
493
+ vitals_str = ", ".join(f"{k}: {v}" for k, v in state["vitals"].items())
494
+ lines.append(f"Vitals: {vitals_str}")
495
+
496
+ return "\n".join(lines) if lines else "No patient data available"
497
+
498
+
499
+ def _run_txgemma_safety_check(
500
+ antibiotic: str,
501
+ dose: Optional[str],
502
+ route: Optional[str],
503
+ duration: Optional[str],
504
+ age: Optional[float],
505
+ crcl: Optional[float],
506
+ medications: list,
507
+ ) -> Optional[str]:
508
+ """
509
+ Run TxGemma safety check (supplementary).
510
+
511
+ TxGemma is used only for safety validation, not primary recommendations.
512
+ """
513
+ try:
514
+ prompt = TXGEMMA_SAFETY_PROMPT.format(
515
+ antibiotic=antibiotic,
516
+ dose=dose or "Not specified",
517
+ route=route or "Not specified",
518
+ duration=duration or "Not specified",
519
+ age=age or "Unknown",
520
+ crcl=crcl or "Unknown",
521
+ medications=", ".join(medications) if medications else "None",
522
+ )
523
+
524
+ response = run_inference(
525
+ prompt=prompt,
526
+ model_name="txgemma_2b", # Use smaller TxGemma for safety check
527
+ max_new_tokens=256,
528
+ temperature=0.1,
529
+ )
530
+
531
+ return response
532
+
533
+ except Exception as e:
534
+ logger.warning(f"TxGemma safety check failed: {e}")
535
+ return None
536
+
537
+
538
+ # =============================================================================
539
+ # AGENT REGISTRY
540
+ # =============================================================================
541
+
542
+ AGENTS = {
543
+ "intake_historian": run_intake_historian,
544
+ "vision_specialist": run_vision_specialist,
545
+ "trend_analyst": run_trend_analyst,
546
+ "clinical_pharmacologist": run_clinical_pharmacologist,
547
+ }
548
+
549
+
550
+ def run_agent(agent_name: str, state: InfectionState) -> InfectionState:
551
+ """
552
+ Run a specific agent by name.
553
+
554
+ Args:
555
+ agent_name: Name of the agent to run
556
+ state: Current infection state
557
+
558
+ Returns:
559
+ Updated infection state
560
+ """
561
+ if agent_name not in AGENTS:
562
+ raise ValueError(f"Unknown agent: {agent_name}")
563
+
564
+ return AGENTS[agent_name](state)
565
+
566
 
567
+ __all__ = [
568
+ "run_intake_historian",
569
+ "run_vision_specialist",
570
+ "run_trend_analyst",
571
+ "run_clinical_pharmacologist",
572
+ "run_agent",
573
+ "AGENTS",
574
+ ]
src/agents/__init__.py DELETED
File without changes
src/graph.py CHANGED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph Orchestrator for Med-I-C Multi-Agent System.
3
+
4
+ Implements the infection lifecycle workflow with conditional routing:
5
+
6
+ Stage 1 (Empirical - no lab results):
7
+ Intake Historian -> Clinical Pharmacologist
8
+
9
+ Stage 2 (Targeted - lab results available):
10
+ Intake Historian -> Vision Specialist -> Trend Analyst -> Clinical Pharmacologist
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import Literal
17
+
18
+ from langgraph.graph import StateGraph, END
19
+
20
+ from .agents import (
21
+ run_intake_historian,
22
+ run_vision_specialist,
23
+ run_trend_analyst,
24
+ run_clinical_pharmacologist,
25
+ )
26
+ from .state import InfectionState
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # =============================================================================
32
+ # NODE FUNCTIONS (Wrapper for agents)
33
+ # =============================================================================
34
+
35
+ def intake_historian_node(state: InfectionState) -> InfectionState:
36
+ """Node 1: Run Intake Historian agent."""
37
+ logger.info("Graph: Executing Intake Historian node")
38
+ return run_intake_historian(state)
39
+
40
+
41
+ def vision_specialist_node(state: InfectionState) -> InfectionState:
42
+ """Node 2: Run Vision Specialist agent."""
43
+ logger.info("Graph: Executing Vision Specialist node")
44
+ return run_vision_specialist(state)
45
+
46
+
47
+ def trend_analyst_node(state: InfectionState) -> InfectionState:
48
+ """Node 3: Run Trend Analyst agent."""
49
+ logger.info("Graph: Executing Trend Analyst node")
50
+ return run_trend_analyst(state)
51
+
52
+
53
+ def clinical_pharmacologist_node(state: InfectionState) -> InfectionState:
54
+ """Node 4: Run Clinical Pharmacologist agent."""
55
+ logger.info("Graph: Executing Clinical Pharmacologist node")
56
+ return run_clinical_pharmacologist(state)
57
+
58
+
59
+ # =============================================================================
60
+ # CONDITIONAL ROUTING FUNCTIONS
61
+ # =============================================================================
62
+
63
+ def route_after_intake(state: InfectionState) -> Literal["vision_specialist", "clinical_pharmacologist"]:
64
+ """
65
+ Determine routing after Intake Historian.
66
+
67
+ Routes to Vision Specialist if:
68
+ - stage is "targeted" AND
69
+ - route_to_vision is True (i.e., we have lab data to process)
70
+
71
+ Otherwise routes directly to Clinical Pharmacologist (empirical path).
72
+ """
73
+ stage = state.get("stage", "empirical")
74
+ has_lab_data = state.get("route_to_vision", False)
75
+
76
+ if stage == "targeted" and has_lab_data:
77
+ logger.info("Graph: Routing to Vision Specialist (targeted path)")
78
+ return "vision_specialist"
79
+ else:
80
+ logger.info("Graph: Routing to Clinical Pharmacologist (empirical path)")
81
+ return "clinical_pharmacologist"
82
+
83
+
84
+ def route_after_vision(state: InfectionState) -> Literal["trend_analyst", "clinical_pharmacologist"]:
85
+ """
86
+ Determine routing after Vision Specialist.
87
+
88
+ Routes to Trend Analyst if:
89
+ - route_to_trend_analyst is True (i.e., we have MIC data to analyze)
90
+
91
+ Otherwise skips to Clinical Pharmacologist.
92
+ """
93
+ should_analyze_trends = state.get("route_to_trend_analyst", False)
94
+
95
+ if should_analyze_trends:
96
+ logger.info("Graph: Routing to Trend Analyst")
97
+ return "trend_analyst"
98
+ else:
99
+ logger.info("Graph: Skipping Trend Analyst, routing to Clinical Pharmacologist")
100
+ return "clinical_pharmacologist"
101
+
102
+
103
+ # =============================================================================
104
+ # GRAPH CONSTRUCTION
105
+ # =============================================================================
106
+
107
+ def build_infection_graph() -> StateGraph:
108
+ """
109
+ Build the LangGraph StateGraph for the infection lifecycle workflow.
110
+
111
+ Returns:
112
+ Compiled StateGraph ready for execution
113
+ """
114
+ # Create the graph with InfectionState as the state schema
115
+ graph = StateGraph(InfectionState)
116
+
117
+ # Add nodes
118
+ graph.add_node("intake_historian", intake_historian_node)
119
+ graph.add_node("vision_specialist", vision_specialist_node)
120
+ graph.add_node("trend_analyst", trend_analyst_node)
121
+ graph.add_node("clinical_pharmacologist", clinical_pharmacologist_node)
122
+
123
+ # Set entry point
124
+ graph.set_entry_point("intake_historian")
125
+
126
+ # Add conditional edges from intake_historian
127
+ graph.add_conditional_edges(
128
+ "intake_historian",
129
+ route_after_intake,
130
+ {
131
+ "vision_specialist": "vision_specialist",
132
+ "clinical_pharmacologist": "clinical_pharmacologist",
133
+ }
134
+ )
135
+
136
+ # Add conditional edges from vision_specialist
137
+ graph.add_conditional_edges(
138
+ "vision_specialist",
139
+ route_after_vision,
140
+ {
141
+ "trend_analyst": "trend_analyst",
142
+ "clinical_pharmacologist": "clinical_pharmacologist",
143
+ }
144
+ )
145
+
146
+ # Add edge from trend_analyst to clinical_pharmacologist
147
+ graph.add_edge("trend_analyst", "clinical_pharmacologist")
148
+
149
+ # Add edge from clinical_pharmacologist to END
150
+ graph.add_edge("clinical_pharmacologist", END)
151
+
152
+ return graph
153
+
154
+
155
+ def compile_graph():
156
+ """
157
+ Build and compile the graph for execution.
158
+
159
+ Returns:
160
+ Compiled graph that can be invoked with .invoke(state)
161
+ """
162
+ graph = build_infection_graph()
163
+ return graph.compile()
164
+
165
+
166
+ # =============================================================================
167
+ # EXECUTION HELPERS
168
+ # =============================================================================
169
+
170
+ def run_pipeline(
171
+ patient_data: dict,
172
+ labs_raw_text: str | None = None,
173
+ ) -> InfectionState:
174
+ """
175
+ Run the full infection lifecycle pipeline.
176
+
177
+ This is the main entry point for executing the multi-agent workflow.
178
+
179
+ Args:
180
+ patient_data: Dict containing patient information:
181
+ - age_years: Patient age
182
+ - weight_kg: Patient weight
183
+ - sex: "male" or "female"
184
+ - serum_creatinine_mg_dl: Serum creatinine (optional)
185
+ - medications: List of current medications
186
+ - allergies: List of allergies
187
+ - comorbidities: List of comorbidities
188
+ - infection_site: Site of infection
189
+ - suspected_source: Suspected pathogen/source
190
+
191
+ labs_raw_text: Raw text from lab report (if available).
192
+ If provided, triggers targeted (Stage 2) pathway.
193
+
194
+ Returns:
195
+ Final InfectionState with recommendation
196
+
197
+ Example:
198
+ >>> state = run_pipeline(
199
+ ... patient_data={
200
+ ... "age_years": 65,
201
+ ... "weight_kg": 70,
202
+ ... "sex": "male",
203
+ ... "serum_creatinine_mg_dl": 1.2,
204
+ ... "medications": ["metformin", "lisinopril"],
205
+ ... "allergies": ["penicillin"],
206
+ ... "infection_site": "urinary",
207
+ ... "suspected_source": "community UTI",
208
+ ... },
209
+ ... labs_raw_text="E. coli isolated. Ciprofloxacin MIC: 0.5 mg/L (S)"
210
+ ... )
211
+ >>> print(state["recommendation"]["primary_antibiotic"])
212
+ """
213
+ # Build initial state from patient data
214
+ initial_state: InfectionState = {
215
+ "age_years": patient_data.get("age_years"),
216
+ "weight_kg": patient_data.get("weight_kg"),
217
+ "height_cm": patient_data.get("height_cm"),
218
+ "sex": patient_data.get("sex"),
219
+ "serum_creatinine_mg_dl": patient_data.get("serum_creatinine_mg_dl"),
220
+ "medications": patient_data.get("medications", []),
221
+ "allergies": patient_data.get("allergies", []),
222
+ "comorbidities": patient_data.get("comorbidities", []),
223
+ "infection_site": patient_data.get("infection_site"),
224
+ "suspected_source": patient_data.get("suspected_source"),
225
+ "country_or_region": patient_data.get("country_or_region"),
226
+ "vitals": patient_data.get("vitals", {}),
227
+ }
228
+
229
+ # Add lab data if provided
230
+ if labs_raw_text:
231
+ initial_state["labs_raw_text"] = labs_raw_text
232
+ initial_state["stage"] = "targeted"
233
+ else:
234
+ initial_state["stage"] = "empirical"
235
+
236
+ # Compile and run the graph
237
+ logger.info(f"Starting pipeline execution (stage: {initial_state['stage']})")
238
+
239
+ compiled_graph = compile_graph()
240
+ final_state = compiled_graph.invoke(initial_state)
241
+
242
+ logger.info("Pipeline execution complete")
243
+
244
+ return final_state
245
+
246
+
247
+ def run_empirical_pipeline(patient_data: dict) -> InfectionState:
248
+ """
249
+ Run Stage 1 (Empirical) pipeline only.
250
+
251
+ Shorthand for run_pipeline without lab data.
252
+ """
253
+ return run_pipeline(patient_data, labs_raw_text=None)
254
+
255
+
256
+ def run_targeted_pipeline(patient_data: dict, labs_raw_text: str) -> InfectionState:
257
+ """
258
+ Run Stage 2 (Targeted) pipeline with lab data.
259
+
260
+ Shorthand for run_pipeline with lab data.
261
+ """
262
+ return run_pipeline(patient_data, labs_raw_text=labs_raw_text)
263
+
264
+
265
+ # =============================================================================
266
+ # VISUALIZATION (for debugging)
267
+ # =============================================================================
268
+
269
+ def get_graph_mermaid() -> str:
270
+ """
271
+ Get Mermaid diagram representation of the graph.
272
+
273
+ Useful for documentation and debugging.
274
+ """
275
+ graph = build_infection_graph()
276
+ try:
277
+ return graph.compile().get_graph().draw_mermaid()
278
+ except Exception:
279
+ # Fallback: return manual diagram
280
+ return """
281
+ graph TD
282
+ A[intake_historian] --> B{route_after_intake}
283
+ B -->|targeted + lab data| C[vision_specialist]
284
+ B -->|empirical| E[clinical_pharmacologist]
285
+ C --> D{route_after_vision}
286
+ D -->|has MIC data| F[trend_analyst]
287
+ D -->|no MIC data| E
288
+ F --> E
289
+ E --> G[END]
290
+ """
291
+
292
+
293
+ __all__ = [
294
+ "build_infection_graph",
295
+ "compile_graph",
296
+ "run_pipeline",
297
+ "run_empirical_pipeline",
298
+ "run_targeted_pipeline",
299
+ "get_graph_mermaid",
300
+ ]
src/prompts.py CHANGED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for Med-I-C multi-agent system.
3
+
4
+ Each agent has a specific role in the infection lifecycle workflow:
5
+ - Agent 1: Intake Historian - Parse patient data, risk factors, calculate CrCl
6
+ - Agent 2: Vision Specialist - Extract structured data from lab reports (images/PDFs)
7
+ - Agent 3: Trend Analyst - Detect MIC creep and resistance velocity
8
+ - Agent 4: Clinical Pharmacologist - Final Rx recommendations + safety checks
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ # =============================================================================
14
+ # AGENT 1: INTAKE HISTORIAN
15
+ # =============================================================================
16
+
17
+ INTAKE_HISTORIAN_SYSTEM = """You are an expert clinical intake specialist. Your role is to:
18
+
19
+ 1. Parse and structure patient demographics and clinical history
20
+ 2. Calculate Creatinine Clearance (CrCl) using the Cockcroft-Gault equation when data is available
21
+ 3. Identify key risk factors for antimicrobial-resistant infections
22
+ 4. Determine the appropriate treatment stage (empirical vs targeted)
23
+
24
+ RISK FACTORS TO IDENTIFY:
25
+ - Prior MRSA or MDR infection history
26
+ - Recent antibiotic use (within 90 days)
27
+ - Healthcare-associated vs community-acquired infection
28
+ - Immunocompromised status
29
+ - Recent hospitalization or ICU stay
30
+ - Presence of medical devices (catheters, lines)
31
+ - Travel history to high-resistance regions
32
+ - Renal or hepatic impairment
33
+
34
+ OUTPUT FORMAT:
35
+ Provide a structured JSON response with the following fields:
36
+ {
37
+ "patient_summary": "Brief clinical summary",
38
+ "creatinine_clearance_ml_min": <number or null>,
39
+ "renal_dose_adjustment_needed": <boolean>,
40
+ "identified_risk_factors": ["list", "of", "factors"],
41
+ "suspected_pathogens": ["list", "of", "likely", "organisms"],
42
+ "infection_severity": "mild|moderate|severe|critical",
43
+ "recommended_stage": "empirical|targeted",
44
+ "notes": "Any additional clinical observations"
45
+ }
46
+ """
47
+
48
+ INTAKE_HISTORIAN_PROMPT = """Analyze the following patient information and provide a structured clinical assessment.
49
+
50
+ PATIENT DATA:
51
+ {patient_data}
52
+
53
+ CURRENT MEDICATIONS:
54
+ {medications}
55
+
56
+ KNOWN ALLERGIES:
57
+ {allergies}
58
+
59
+ CLINICAL CONTEXT:
60
+ - Suspected infection site: {infection_site}
61
+ - Suspected source: {suspected_source}
62
+
63
+ RAG CONTEXT (Relevant Guidelines):
64
+ {rag_context}
65
+
66
+ Provide your structured assessment following the system instructions."""
67
+
68
+
69
+ # =============================================================================
70
+ # AGENT 2: VISION SPECIALIST
71
+ # =============================================================================
72
+
73
+ VISION_SPECIALIST_SYSTEM = """You are an expert medical laboratory data extraction specialist. Your role is to:
74
+
75
+ 1. Extract structured data from laboratory reports (culture & sensitivity, antibiograms)
76
+ 2. Handle reports in ANY language - always output in English
77
+ 3. Identify pathogens, antibiotics tested, MIC values, and S/I/R interpretations
78
+ 4. Flag any critical or unusual findings
79
+
80
+ SUPPORTED REPORT TYPES:
81
+ - Culture & Sensitivity reports
82
+ - Antibiogram reports
83
+ - Blood culture reports
84
+ - Urine culture reports
85
+ - Wound culture reports
86
+ - Respiratory culture reports
87
+
88
+ OUTPUT FORMAT:
89
+ Provide a structured JSON response:
90
+ {
91
+ "specimen_type": "blood|urine|wound|respiratory|other",
92
+ "collection_date": "YYYY-MM-DD or null",
93
+ "identified_organisms": [
94
+ {
95
+ "organism_name": "Standardized English name",
96
+ "original_name": "Name as written in report",
97
+ "colony_count": "if available",
98
+ "significance": "pathogen|colonizer|contaminant"
99
+ }
100
+ ],
101
+ "susceptibility_results": [
102
+ {
103
+ "organism": "Organism name",
104
+ "antibiotic": "Standardized antibiotic name",
105
+ "mic_value": <number or null>,
106
+ "mic_unit": "mg/L",
107
+ "interpretation": "S|I|R",
108
+ "method": "disk diffusion|MIC|E-test"
109
+ }
110
+ ],
111
+ "critical_findings": ["List of urgent findings requiring immediate attention"],
112
+ "report_quality": "complete|partial|poor",
113
+ "extraction_confidence": 0.0-1.0,
114
+ "notes": "Any relevant observations about the report"
115
+ }
116
+ """
117
+
118
+ VISION_SPECIALIST_PROMPT = """Extract structured laboratory data from the following report.
119
+
120
+ REPORT CONTENT:
121
+ {report_content}
122
+
123
+ REPORT METADATA:
124
+ - Source format: {source_format}
125
+ - Language detected: {language}
126
+
127
+ Extract all pathogen identifications, susceptibility results, and MIC values.
128
+ Always standardize to English medical terminology.
129
+ Flag any critical findings that require urgent attention.
130
+
131
+ Provide your structured extraction following the system instructions."""
132
+
133
+
134
+ # =============================================================================
135
+ # AGENT 3: TREND ANALYST
136
+ # =============================================================================
137
+
138
+ TREND_ANALYST_SYSTEM = """You are an expert antimicrobial resistance trend analyst. Your role is to:
139
+
140
+ 1. Analyze MIC trends over time to detect "MIC Creep"
141
+ 2. Calculate resistance velocity and predict treatment failure risk
142
+ 3. Compare current MICs against EUCAST/CLSI breakpoints
143
+ 4. Identify emerging resistance patterns
144
+
145
+ MIC CREEP DEFINITION:
146
+ MIC creep is a gradual increase in MIC values over time, even while remaining
147
+ technically "Susceptible". This can predict treatment failure before formal
148
+ resistance develops.
149
+
150
+ RISK STRATIFICATION:
151
+ - LOW: Stable MIC, well below breakpoint (>4x margin)
152
+ - MODERATE: Rising trend but still 2-4x below breakpoint
153
+ - HIGH: Approaching breakpoint (<2x margin) or rapid increase
154
+ - CRITICAL: At or above breakpoint, or >4-fold increase over baseline
155
+
156
+ OUTPUT FORMAT:
157
+ Provide a structured JSON response:
158
+ {
159
+ "organism": "Pathogen name",
160
+ "antibiotic": "Antibiotic name",
161
+ "mic_history": [
162
+ {"date": "YYYY-MM-DD", "mic_value": <number>, "interpretation": "S|I|R"}
163
+ ],
164
+ "baseline_mic": <number>,
165
+ "current_mic": <number>,
166
+ "fold_change": <number>,
167
+ "trend": "stable|increasing|decreasing|fluctuating",
168
+ "resistance_velocity": <number per time unit>,
169
+ "breakpoint_susceptible": <number>,
170
+ "breakpoint_resistant": <number>,
171
+ "margin_to_breakpoint": <number>,
172
+ "risk_level": "LOW|MODERATE|HIGH|CRITICAL",
173
+ "predicted_time_to_resistance": "estimate or N/A",
174
+ "recommendation": "Continue current therapy|Consider alternatives|Urgent switch needed",
175
+ "alternative_antibiotics": ["list", "if", "applicable"],
176
+ "rationale": "Detailed explanation of risk assessment"
177
+ }
178
+ """
179
+
180
+ TREND_ANALYST_PROMPT = """Analyze the MIC trend data and assess resistance risk.
181
+
182
+ ORGANISM: {organism}
183
+ ANTIBIOTIC: {antibiotic}
184
+
185
+ HISTORICAL MIC DATA:
186
+ {mic_history}
187
+
188
+ CURRENT BREAKPOINTS (EUCAST v16.0):
189
+ {breakpoint_data}
190
+
191
+ REGIONAL RESISTANCE DATA:
192
+ {resistance_context}
193
+
194
+ Analyze the trend, calculate risk level, and provide recommendations.
195
+ Follow the system instructions for output format."""
196
+
197
+
198
+ # =============================================================================
199
+ # AGENT 4: CLINICAL PHARMACOLOGIST
200
+ # =============================================================================
201
+
202
+ CLINICAL_PHARMACOLOGIST_SYSTEM = """You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:
203
+
204
+ 1. Synthesize all available clinical data into a final antibiotic recommendation
205
+ 2. Apply WHO AWaRe classification principles (ACCESS -> WATCH -> RESERVE)
206
+ 3. Perform comprehensive drug safety checks
207
+ 4. Adjust dosing for renal function
208
+ 5. Consider local resistance patterns and guideline recommendations
209
+
210
+ PRESCRIBING PRINCIPLES:
211
+ 1. Start narrow, escalate only when justified
212
+ 2. De-escalate when culture results allow
213
+ 3. Prefer ACCESS category antibiotics when appropriate
214
+ 4. Consider pharmacokinetic/pharmacodynamic (PK/PD) optimization
215
+ 5. Document rationale for WATCH/RESERVE antibiotic use
216
+
217
+ SAFETY CHECKS:
218
+ - Drug-drug interactions (especially warfarin, methotrexate, immunosuppressants)
219
+ - Drug-allergy cross-reactivity (especially beta-lactam allergies)
220
+ - Renal dose adjustments (use CrCl)
221
+ - QT prolongation risk (fluoroquinolones, azithromycin)
222
+ - Pregnancy/lactation considerations
223
+ - Age-related considerations (pediatric/geriatric)
224
+
225
+ OUTPUT FORMAT:
226
+ Provide a structured JSON response:
227
+ {
228
+ "primary_recommendation": {
229
+ "antibiotic": "Drug name",
230
+ "dose": "Amount and unit",
231
+ "route": "IV|PO|IM",
232
+ "frequency": "Dosing interval",
233
+ "duration": "Expected treatment duration",
234
+ "aware_category": "ACCESS|WATCH|RESERVE"
235
+ },
236
+ "alternative_recommendation": {
237
+ "antibiotic": "Alternative drug",
238
+ "dose": "Amount and unit",
239
+ "route": "IV|PO|IM",
240
+ "frequency": "Dosing interval",
241
+ "indication": "When to use alternative"
242
+ },
243
+ "dose_adjustments": {
244
+ "renal": "Adjustment details or 'None needed'",
245
+ "hepatic": "Adjustment details or 'None needed'"
246
+ },
247
+ "safety_alerts": [
248
+ {
249
+ "level": "INFO|WARNING|CRITICAL",
250
+ "type": "interaction|allergy|contraindication|monitoring",
251
+ "message": "Detailed alert message",
252
+ "action_required": "What to do"
253
+ }
254
+ ],
255
+ "monitoring_parameters": ["List of labs/vitals to monitor"],
256
+ "de_escalation_plan": "When and how to de-escalate",
257
+ "rationale": "Clinical reasoning for recommendation",
258
+ "guideline_references": ["Supporting guideline citations"],
259
+ "confidence_level": "high|moderate|low",
260
+ "requires_id_consult": <boolean>
261
+ }
262
+ """
263
+
264
+ CLINICAL_PHARMACOLOGIST_PROMPT = """Synthesize all clinical data and provide a final antibiotic recommendation.
265
+
266
+ PATIENT SUMMARY (from Intake Historian):
267
+ {intake_summary}
268
+
269
+ LAB RESULTS (from Vision Specialist):
270
+ {lab_results}
271
+
272
+ MIC TREND ANALYSIS (from Trend Analyst):
273
+ {trend_analysis}
274
+
275
+ PATIENT PARAMETERS:
276
+ - Age: {age} years
277
+ - Weight: {weight} kg
278
+ - CrCl: {crcl} mL/min
279
+ - Allergies: {allergies}
280
+ - Current medications: {current_medications}
281
+
282
+ INFECTION CONTEXT:
283
+ - Site: {infection_site}
284
+ - Source: {suspected_source}
285
+ - Severity: {severity}
286
+
287
+ RAG CONTEXT (Guidelines & Safety Data):
288
+ {rag_context}
289
+
290
+ Provide your final recommendation following the system instructions.
291
+ Ensure all safety checks are performed and documented."""
292
+
293
+
294
+ # =============================================================================
295
+ # TXGEMMA SAFETY CHECKER (Supplementary)
296
+ # =============================================================================
297
+
298
+ TXGEMMA_SAFETY_PROMPT = """Evaluate the safety profile of the following antibiotic prescription:
299
+
300
+ PROPOSED ANTIBIOTIC: {antibiotic}
301
+ DOSE: {dose}
302
+ ROUTE: {route}
303
+ DURATION: {duration}
304
+
305
+ PATIENT CONTEXT:
306
+ - Age: {age}
307
+ - Renal function (CrCl): {crcl} mL/min
308
+ - Current medications: {medications}
309
+
310
+ Evaluate for:
311
+ 1. Known toxicity concerns
312
+ 2. Drug-drug interaction potential
313
+ 3. Dose appropriateness for renal function
314
+
315
+ Provide a brief safety assessment (2-3 sentences) and a risk rating (LOW/MODERATE/HIGH)."""
316
+
317
+
318
+ # =============================================================================
319
+ # HELPER TEMPLATES
320
+ # =============================================================================
321
+
322
+ ERROR_RECOVERY_PROMPT = """The previous agent encountered an error or produced invalid output.
323
+
324
+ ERROR DETAILS:
325
+ {error_details}
326
+
327
+ ORIGINAL INPUT:
328
+ {original_input}
329
+
330
+ Please attempt to recover by providing a valid response or indicating what additional information is needed."""
331
+
332
+
333
+ FALLBACK_EMPIRICAL_PROMPT = """No culture data is available. Based on the clinical presentation, provide empirical antibiotic recommendations.
334
+
335
+ CLINICAL SCENARIO:
336
+ - Infection site: {infection_site}
337
+ - Patient risk factors: {risk_factors}
338
+ - Local resistance patterns: {local_resistance}
339
+
340
+ Recommend appropriate empirical therapy following WHO AWaRe principles."""
341
+
342
+
343
+ __all__ = [
344
+ "INTAKE_HISTORIAN_SYSTEM",
345
+ "INTAKE_HISTORIAN_PROMPT",
346
+ "VISION_SPECIALIST_SYSTEM",
347
+ "VISION_SPECIALIST_PROMPT",
348
+ "TREND_ANALYST_SYSTEM",
349
+ "TREND_ANALYST_PROMPT",
350
+ "CLINICAL_PHARMACOLOGIST_SYSTEM",
351
+ "CLINICAL_PHARMACOLOGIST_PROMPT",
352
+ "TXGEMMA_SAFETY_PROMPT",
353
+ "ERROR_RECOVERY_PROMPT",
354
+ "FALLBACK_EMPIRICAL_PROMPT",
355
+ ]
src/rag.py CHANGED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG (Retrieval Augmented Generation) module for Med-I-C.
3
+
4
+ Provides unified retrieval across multiple knowledge collections:
5
+ - antibiotic_guidelines: WHO/IDSA treatment guidelines
6
+ - mic_breakpoints: EUCAST/CLSI breakpoint tables
7
+ - drug_safety: Drug interactions, warnings, contraindications
8
+ - pathogen_resistance: Regional resistance patterns
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Any, Dict, List, Optional
16
+
17
+ from .config import get_settings
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # =============================================================================
23
+ # CHROMA CLIENT & EMBEDDING SETUP
24
+ # =============================================================================
25
+
26
+ _chroma_client = None
27
+ _embedding_function = None
28
+
29
+
30
+ def get_chroma_client():
31
+ """Get or create ChromaDB persistent client."""
32
+ global _chroma_client
33
+ if _chroma_client is None:
34
+ import chromadb
35
+
36
+ settings = get_settings()
37
+ chroma_path = settings.chroma_db_dir
38
+ chroma_path.mkdir(parents=True, exist_ok=True)
39
+ _chroma_client = chromadb.PersistentClient(path=str(chroma_path))
40
+ return _chroma_client
41
+
42
+
43
+ def get_embedding_function():
44
+ """Get or create the embedding function."""
45
+ global _embedding_function
46
+ if _embedding_function is None:
47
+ from chromadb.utils import embedding_functions
48
+
49
+ settings = get_settings()
50
+ _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
51
+ model_name=settings.embedding_model_name.split("/")[-1]
52
+ )
53
+ return _embedding_function
54
+
55
+
56
+ def get_collection(name: str):
57
+ """
58
+ Get a ChromaDB collection by name.
59
+
60
+ Returns None if collection doesn't exist.
61
+ """
62
+ client = get_chroma_client()
63
+ ef = get_embedding_function()
64
+
65
+ try:
66
+ return client.get_collection(name=name, embedding_function=ef)
67
+ except Exception:
68
+ logger.warning(f"Collection '{name}' not found")
69
+ return None
70
+
71
+
72
+ # =============================================================================
73
+ # COLLECTION-SPECIFIC RETRIEVERS
74
+ # =============================================================================
75
+
76
+ def search_antibiotic_guidelines(
77
+ query: str,
78
+ n_results: int = 5,
79
+ pathogen_filter: Optional[str] = None,
80
+ ) -> List[Dict[str, Any]]:
81
+ """
82
+ Search antibiotic treatment guidelines.
83
+
84
+ Args:
85
+ query: Search query
86
+ n_results: Number of results to return
87
+ pathogen_filter: Optional pathogen type filter (e.g., "ESBL-E", "CRE")
88
+
89
+ Returns:
90
+ List of relevant guideline excerpts with metadata
91
+ """
92
+ collection = get_collection("idsa_treatment_guidelines")
93
+ if collection is None:
94
+ logger.warning("idsa_treatment_guidelines collection not available")
95
+ return []
96
+
97
+ where_filter = None
98
+ if pathogen_filter:
99
+ where_filter = {"pathogen_type": pathogen_filter}
100
+
101
+ try:
102
+ results = collection.query(
103
+ query_texts=[query],
104
+ n_results=n_results,
105
+ where=where_filter,
106
+ include=["documents", "metadatas", "distances"],
107
+ )
108
+ except Exception as e:
109
+ logger.error(f"Error querying guidelines: {e}")
110
+ return []
111
+
112
+ return _format_results(results)
113
+
114
+
115
+ def search_mic_breakpoints(
116
+ query: str,
117
+ n_results: int = 5,
118
+ organism: Optional[str] = None,
119
+ antibiotic: Optional[str] = None,
120
+ ) -> List[Dict[str, Any]]:
121
+ """
122
+ Search MIC breakpoint reference documentation.
123
+
124
+ Args:
125
+ query: Search query
126
+ n_results: Number of results
127
+ organism: Optional organism name filter
128
+ antibiotic: Optional antibiotic name filter
129
+
130
+ Returns:
131
+ List of relevant breakpoint information
132
+ """
133
+ collection = get_collection("mic_reference_docs")
134
+ if collection is None:
135
+ logger.warning("mic_reference_docs collection not available")
136
+ return []
137
+
138
+ # Build query with organism/antibiotic context if provided
139
+ enhanced_query = query
140
+ if organism:
141
+ enhanced_query = f"{organism} {enhanced_query}"
142
+ if antibiotic:
143
+ enhanced_query = f"{antibiotic} {enhanced_query}"
144
+
145
+ try:
146
+ results = collection.query(
147
+ query_texts=[enhanced_query],
148
+ n_results=n_results,
149
+ include=["documents", "metadatas", "distances"],
150
+ )
151
+ except Exception as e:
152
+ logger.error(f"Error querying breakpoints: {e}")
153
+ return []
154
+
155
+ return _format_results(results)
156
+
157
+
158
+ def search_drug_safety(
159
+ query: str,
160
+ n_results: int = 5,
161
+ drug_name: Optional[str] = None,
162
+ ) -> List[Dict[str, Any]]:
163
+ """
164
+ Search drug safety information (interactions, warnings, contraindications).
165
+
166
+ Args:
167
+ query: Search query
168
+ n_results: Number of results
169
+ drug_name: Optional drug name to focus search
170
+
171
+ Returns:
172
+ List of relevant safety information
173
+ """
174
+ collection = get_collection("drug_safety")
175
+ if collection is None:
176
+ # Fallback: try existing collections
177
+ logger.warning("drug_safety collection not available")
178
+ return []
179
+
180
+ enhanced_query = f"{drug_name} {query}" if drug_name else query
181
+
182
+ try:
183
+ results = collection.query(
184
+ query_texts=[enhanced_query],
185
+ n_results=n_results,
186
+ include=["documents", "metadatas", "distances"],
187
+ )
188
+ except Exception as e:
189
+ logger.error(f"Error querying drug safety: {e}")
190
+ return []
191
+
192
+ return _format_results(results)
193
+
194
+
195
+ def search_resistance_patterns(
196
+ query: str,
197
+ n_results: int = 5,
198
+ organism: Optional[str] = None,
199
+ region: Optional[str] = None,
200
+ ) -> List[Dict[str, Any]]:
201
+ """
202
+ Search pathogen resistance pattern data.
203
+
204
+ Args:
205
+ query: Search query
206
+ n_results: Number of results
207
+ organism: Optional organism filter
208
+ region: Optional geographic region filter
209
+
210
+ Returns:
211
+ List of relevant resistance data
212
+ """
213
+ collection = get_collection("pathogen_resistance")
214
+ if collection is None:
215
+ logger.warning("pathogen_resistance collection not available")
216
+ return []
217
+
218
+ enhanced_query = query
219
+ if organism:
220
+ enhanced_query = f"{organism} {enhanced_query}"
221
+ if region:
222
+ enhanced_query = f"{region} {enhanced_query}"
223
+
224
+ try:
225
+ results = collection.query(
226
+ query_texts=[enhanced_query],
227
+ n_results=n_results,
228
+ include=["documents", "metadatas", "distances"],
229
+ )
230
+ except Exception as e:
231
+ logger.error(f"Error querying resistance patterns: {e}")
232
+ return []
233
+
234
+ return _format_results(results)
235
+
236
+
237
+ # =============================================================================
238
+ # UNIFIED CONTEXT RETRIEVER
239
+ # =============================================================================
240
+
241
+ def get_context_for_agent(
242
+ agent_name: str,
243
+ query: str,
244
+ patient_context: Optional[Dict[str, Any]] = None,
245
+ n_results: int = 3,
246
+ ) -> str:
247
+ """
248
+ Get formatted RAG context string for a specific agent.
249
+
250
+ This is the main entry point for agents to retrieve context.
251
+
252
+ Args:
253
+ agent_name: Name of the requesting agent
254
+ query: The primary search query
255
+ patient_context: Optional dict with patient-specific info
256
+ n_results: Number of results per collection
257
+
258
+ Returns:
259
+ Formatted context string for injection into prompts
260
+ """
261
+ context_parts = []
262
+ patient_context = patient_context or {}
263
+
264
+ if agent_name == "intake_historian":
265
+ # Get empirical therapy guidelines
266
+ guidelines = search_antibiotic_guidelines(
267
+ query=query,
268
+ n_results=n_results,
269
+ pathogen_filter=patient_context.get("pathogen_type"),
270
+ )
271
+ if guidelines:
272
+ context_parts.append("RELEVANT TREATMENT GUIDELINES:")
273
+ for g in guidelines:
274
+ context_parts.append(f"- {g['content'][:500]}...")
275
+ context_parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]")
276
+
277
+ elif agent_name == "vision_specialist":
278
+ # Get MIC reference info for lab interpretation
279
+ breakpoints = search_mic_breakpoints(
280
+ query=query,
281
+ n_results=n_results,
282
+ organism=patient_context.get("organism"),
283
+ antibiotic=patient_context.get("antibiotic"),
284
+ )
285
+ if breakpoints:
286
+ context_parts.append("RELEVANT BREAKPOINT INFORMATION:")
287
+ for b in breakpoints:
288
+ context_parts.append(f"- {b['content'][:400]}...")
289
+
290
+ elif agent_name == "trend_analyst":
291
+ # Get breakpoints and resistance trends
292
+ breakpoints = search_mic_breakpoints(
293
+ query=f"breakpoint {patient_context.get('organism', '')} {patient_context.get('antibiotic', '')}",
294
+ n_results=n_results,
295
+ )
296
+ resistance = search_resistance_patterns(
297
+ query=query,
298
+ n_results=n_results,
299
+ organism=patient_context.get("organism"),
300
+ region=patient_context.get("region"),
301
+ )
302
+
303
+ if breakpoints:
304
+ context_parts.append("EUCAST BREAKPOINT DATA:")
305
+ for b in breakpoints:
306
+ context_parts.append(f"- {b['content'][:400]}...")
307
+
308
+ if resistance:
309
+ context_parts.append("\nRESISTANCE PATTERN DATA:")
310
+ for r in resistance:
311
+ context_parts.append(f"- {r['content'][:400]}...")
312
+
313
+ elif agent_name == "clinical_pharmacologist":
314
+ # Get comprehensive context for final recommendation
315
+ guidelines = search_antibiotic_guidelines(
316
+ query=query,
317
+ n_results=n_results,
318
+ )
319
+ safety = search_drug_safety(
320
+ query=query,
321
+ n_results=n_results,
322
+ drug_name=patient_context.get("proposed_antibiotic"),
323
+ )
324
+
325
+ if guidelines:
326
+ context_parts.append("TREATMENT GUIDELINES:")
327
+ for g in guidelines:
328
+ context_parts.append(f"- {g['content'][:400]}...")
329
+
330
+ if safety:
331
+ context_parts.append("\nDRUG SAFETY INFORMATION:")
332
+ for s in safety:
333
+ context_parts.append(f"- {s['content'][:400]}...")
334
+
335
+ else:
336
+ # Generic retrieval
337
+ guidelines = search_antibiotic_guidelines(query, n_results=n_results)
338
+ if guidelines:
339
+ for g in guidelines:
340
+ context_parts.append(f"- {g['content'][:500]}...")
341
+
342
+ if not context_parts:
343
+ return "No relevant context found in knowledge base."
344
+
345
+ return "\n".join(context_parts)
346
+
347
+
348
+ def get_context_string(
349
+ query: str,
350
+ collections: Optional[List[str]] = None,
351
+ n_results_per_collection: int = 3,
352
+ **filters,
353
+ ) -> str:
354
+ """
355
+ Get a combined context string from multiple collections.
356
+
357
+ This is a simpler interface for general-purpose RAG retrieval.
358
+
359
+ Args:
360
+ query: Search query
361
+ collections: List of collection names to search (defaults to all)
362
+ n_results_per_collection: Results per collection
363
+ **filters: Additional filters (organism, antibiotic, region, etc.)
364
+
365
+ Returns:
366
+ Combined context string
367
+ """
368
+ default_collections = [
369
+ "idsa_treatment_guidelines",
370
+ "mic_reference_docs",
371
+ ]
372
+ collections = collections or default_collections
373
+
374
+ context_parts = []
375
+
376
+ for collection_name in collections:
377
+ if collection_name == "idsa_treatment_guidelines":
378
+ results = search_antibiotic_guidelines(
379
+ query,
380
+ n_results=n_results_per_collection,
381
+ pathogen_filter=filters.get("pathogen_type"),
382
+ )
383
+ elif collection_name == "mic_reference_docs":
384
+ results = search_mic_breakpoints(
385
+ query,
386
+ n_results=n_results_per_collection,
387
+ organism=filters.get("organism"),
388
+ antibiotic=filters.get("antibiotic"),
389
+ )
390
+ elif collection_name == "drug_safety":
391
+ results = search_drug_safety(
392
+ query,
393
+ n_results=n_results_per_collection,
394
+ drug_name=filters.get("drug_name"),
395
+ )
396
+ elif collection_name == "pathogen_resistance":
397
+ results = search_resistance_patterns(
398
+ query,
399
+ n_results=n_results_per_collection,
400
+ organism=filters.get("organism"),
401
+ region=filters.get("region"),
402
+ )
403
+ else:
404
+ continue
405
+
406
+ if results:
407
+ context_parts.append(f"=== {collection_name.upper()} ===")
408
+ for r in results:
409
+ context_parts.append(r["content"])
410
+ context_parts.append(f"[Relevance: {1 - r.get('distance', 0):.2f}]")
411
+ context_parts.append("")
412
+
413
+ return "\n".join(context_parts) if context_parts else "No relevant context found."
414
+
415
+
416
+ # =============================================================================
417
+ # HELPER FUNCTIONS
418
+ # =============================================================================
419
+
420
+ def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
421
+ """Format ChromaDB query results into a standard format."""
422
+ if not results or not results.get("documents"):
423
+ return []
424
+
425
+ formatted = []
426
+ documents = results["documents"][0] if results["documents"] else []
427
+ metadatas = results.get("metadatas", [[]])[0]
428
+ distances = results.get("distances", [[]])[0]
429
+
430
+ for i, doc in enumerate(documents):
431
+ formatted.append({
432
+ "content": doc,
433
+ "metadata": metadatas[i] if i < len(metadatas) else {},
434
+ "distance": distances[i] if i < len(distances) else None,
435
+ "source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
436
+ "relevance_score": 1 - (distances[i] if i < len(distances) else 0),
437
+ })
438
+
439
+ return formatted
440
+
441
+
442
+ def list_available_collections() -> List[str]:
443
+ """List all available ChromaDB collections."""
444
+ client = get_chroma_client()
445
+ try:
446
+ collections = client.list_collections()
447
+ return [c.name for c in collections]
448
+ except Exception as e:
449
+ logger.error(f"Error listing collections: {e}")
450
+ return []
451
+
452
+
453
+ def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
454
+ """Get information about a specific collection."""
455
+ collection = get_collection(name)
456
+ if collection is None:
457
+ return None
458
+
459
+ try:
460
+ return {
461
+ "name": collection.name,
462
+ "count": collection.count(),
463
+ "metadata": collection.metadata,
464
+ }
465
+ except Exception as e:
466
+ logger.error(f"Error getting collection info: {e}")
467
+ return None
468
+
469
+
470
+ __all__ = [
471
+ "get_chroma_client",
472
+ "get_embedding_function",
473
+ "get_collection",
474
+ "search_antibiotic_guidelines",
475
+ "search_mic_breakpoints",
476
+ "search_drug_safety",
477
+ "search_resistance_patterns",
478
+ "get_context_for_agent",
479
+ "get_context_string",
480
+ "list_available_collections",
481
+ "get_collection_info",
482
+ ]
src/utils.py CHANGED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for Med-I-C multi-agent system.
3
+
4
+ Includes:
5
+ - Creatinine Clearance (CrCl) calculator
6
+ - MIC trend analysis and creep detection
7
+ - Prescription card formatter
8
+ - Data validation helpers
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import math
15
+ from typing import Any, Dict, List, Literal, Optional, Tuple
16
+
17
+
18
+ # =============================================================================
19
+ # CREATININE CLEARANCE CALCULATOR
20
+ # =============================================================================
21
+
22
+ def calculate_crcl(
23
+ age_years: float,
24
+ weight_kg: float,
25
+ serum_creatinine_mg_dl: float,
26
+ sex: Literal["male", "female"],
27
+ use_ibw: bool = False,
28
+ height_cm: Optional[float] = None,
29
+ ) -> float:
30
+ """
31
+ Calculate Creatinine Clearance using the Cockcroft-Gault equation.
32
+
33
+ Formula:
34
+ CrCl = [(140 - age) × weight × (0.85 if female)] / (72 × SCr)
35
+
36
+ Args:
37
+ age_years: Patient age in years
38
+ weight_kg: Actual body weight in kg
39
+ serum_creatinine_mg_dl: Serum creatinine in mg/dL
40
+ sex: Patient sex ("male" or "female")
41
+ use_ibw: If True, use Ideal Body Weight instead of actual weight
42
+ height_cm: Height in cm (required if use_ibw=True)
43
+
44
+ Returns:
45
+ Estimated CrCl in mL/min
46
+ """
47
+ if serum_creatinine_mg_dl <= 0:
48
+ raise ValueError("Serum creatinine must be positive")
49
+
50
+ if age_years <= 0 or weight_kg <= 0:
51
+ raise ValueError("Age and weight must be positive")
52
+
53
+ # Calculate weight to use
54
+ weight = weight_kg
55
+ if use_ibw and height_cm:
56
+ weight = calculate_ibw(height_cm, sex)
57
+ # Use adjusted body weight if actual weight > IBW
58
+ if weight_kg > weight * 1.3:
59
+ weight = calculate_adjusted_bw(weight, weight_kg)
60
+
61
+ # Cockcroft-Gault equation
62
+ crcl = ((140 - age_years) * weight) / (72 * serum_creatinine_mg_dl)
63
+
64
+ # Apply sex factor
65
+ if sex == "female":
66
+ crcl *= 0.85
67
+
68
+ return round(crcl, 1)
69
+
70
+
71
+ def calculate_ibw(height_cm: float, sex: Literal["male", "female"]) -> float:
72
+ """
73
+ Calculate Ideal Body Weight using the Devine formula.
74
+
75
+ Args:
76
+ height_cm: Height in centimeters
77
+ sex: Patient sex
78
+
79
+ Returns:
80
+ Ideal body weight in kg
81
+ """
82
+ height_inches = height_cm / 2.54
83
+ height_over_60 = max(0, height_inches - 60)
84
+
85
+ if sex == "male":
86
+ ibw = 50 + 2.3 * height_over_60
87
+ else:
88
+ ibw = 45.5 + 2.3 * height_over_60
89
+
90
+ return round(ibw, 1)
91
+
92
+
93
+ def calculate_adjusted_bw(ibw: float, actual_weight: float) -> float:
94
+ """
95
+ Calculate Adjusted Body Weight for obese patients.
96
+
97
+ Formula: AdjBW = IBW + 0.4 × (Actual - IBW)
98
+ """
99
+ return round(ibw + 0.4 * (actual_weight - ibw), 1)
100
+
101
+
102
+ def get_renal_dose_category(crcl: float) -> str:
103
+ """
104
+ Categorize renal function for dosing purposes.
105
+
106
+ Returns:
107
+ Renal function category
108
+ """
109
+ if crcl >= 90:
110
+ return "normal"
111
+ elif crcl >= 60:
112
+ return "mild_impairment"
113
+ elif crcl >= 30:
114
+ return "moderate_impairment"
115
+ elif crcl >= 15:
116
+ return "severe_impairment"
117
+ else:
118
+ return "esrd"
119
+
120
+
121
+ # =============================================================================
122
+ # MIC TREND ANALYSIS
123
+ # =============================================================================
124
+
125
+ def calculate_mic_trend(
126
+ mic_values: List[Dict[str, Any]],
127
+ susceptible_breakpoint: Optional[float] = None,
128
+ resistant_breakpoint: Optional[float] = None,
129
+ ) -> Dict[str, Any]:
130
+ """
131
+ Analyze MIC trend over time and detect MIC creep.
132
+
133
+ Args:
134
+ mic_values: List of dicts with 'date' and 'mic_value' keys
135
+ susceptible_breakpoint: S breakpoint (optional)
136
+ resistant_breakpoint: R breakpoint (optional)
137
+
138
+ Returns:
139
+ Dict with trend analysis results
140
+ """
141
+ if len(mic_values) < 2:
142
+ return {
143
+ "trend": "insufficient_data",
144
+ "risk_level": "UNKNOWN",
145
+ "alert": "Need at least 2 MIC values for trend analysis",
146
+ }
147
+
148
+ # Extract MIC values
149
+ mics = [float(v["mic_value"]) for v in mic_values]
150
+
151
+ baseline_mic = mics[0]
152
+ current_mic = mics[-1]
153
+
154
+ # Calculate fold change
155
+ if baseline_mic > 0:
156
+ fold_change = current_mic / baseline_mic
157
+ else:
158
+ fold_change = float("inf")
159
+
160
+ # Calculate trend
161
+ if len(mics) >= 3:
162
+ # Linear regression slope
163
+ n = len(mics)
164
+ x_mean = (n - 1) / 2
165
+ y_mean = sum(mics) / n
166
+ numerator = sum((i - x_mean) * (mics[i] - y_mean) for i in range(n))
167
+ denominator = sum((i - x_mean) ** 2 for i in range(n))
168
+ slope = numerator / denominator if denominator != 0 else 0
169
+
170
+ if slope > 0.5:
171
+ trend = "increasing"
172
+ elif slope < -0.5:
173
+ trend = "decreasing"
174
+ else:
175
+ trend = "stable"
176
+ else:
177
+ if current_mic > baseline_mic * 1.5:
178
+ trend = "increasing"
179
+ elif current_mic < baseline_mic * 0.67:
180
+ trend = "decreasing"
181
+ else:
182
+ trend = "stable"
183
+
184
+ # Calculate resistance velocity (fold change per time point)
185
+ velocity = fold_change ** (1 / (len(mics) - 1)) if len(mics) > 1 else 1.0
186
+
187
+ # Determine risk level
188
+ risk_level, alert = _assess_mic_risk(
189
+ current_mic, baseline_mic, fold_change, trend,
190
+ susceptible_breakpoint, resistant_breakpoint
191
+ )
192
+
193
+ return {
194
+ "baseline_mic": baseline_mic,
195
+ "current_mic": current_mic,
196
+ "ratio": round(fold_change, 2),
197
+ "trend": trend,
198
+ "velocity": round(velocity, 3),
199
+ "risk_level": risk_level,
200
+ "alert": alert,
201
+ "n_readings": len(mics),
202
+ }
203
+
204
+
205
+ def _assess_mic_risk(
206
+ current_mic: float,
207
+ baseline_mic: float,
208
+ fold_change: float,
209
+ trend: str,
210
+ s_breakpoint: Optional[float],
211
+ r_breakpoint: Optional[float],
212
+ ) -> Tuple[str, str]:
213
+ """
214
+ Assess risk level based on MIC trends and breakpoints.
215
+
216
+ Returns:
217
+ Tuple of (risk_level, alert_message)
218
+ """
219
+ # If we have breakpoints, use them for risk assessment
220
+ if s_breakpoint is not None and r_breakpoint is not None:
221
+ margin = s_breakpoint / current_mic if current_mic > 0 else float("inf")
222
+
223
+ if current_mic > r_breakpoint:
224
+ return "CRITICAL", f"MIC ({current_mic}) exceeds resistant breakpoint ({r_breakpoint}). Organism is RESISTANT."
225
+
226
+ if current_mic > s_breakpoint:
227
+ return "HIGH", f"MIC ({current_mic}) exceeds susceptible breakpoint ({s_breakpoint}). Consider alternative therapy."
228
+
229
+ if margin < 2:
230
+ if trend == "increasing":
231
+ return "HIGH", f"MIC approaching breakpoint (margin: {margin:.1f}x) with increasing trend. High risk of resistance emergence."
232
+ else:
233
+ return "MODERATE", f"MIC close to breakpoint (margin: {margin:.1f}x). Monitor closely."
234
+
235
+ if margin < 4:
236
+ if trend == "increasing":
237
+ return "MODERATE", f"MIC rising with {margin:.1f}x margin to breakpoint. Consider enhanced monitoring."
238
+ else:
239
+ return "LOW", "MIC stable with adequate margin to breakpoint."
240
+
241
+ return "LOW", "MIC well below breakpoint with good safety margin."
242
+
243
+ # Without breakpoints, use fold change and trend
244
+ if fold_change >= 8:
245
+ return "CRITICAL", f"MIC increased {fold_change:.1f}-fold from baseline. Urgent review needed."
246
+
247
+ if fold_change >= 4:
248
+ return "HIGH", f"MIC increased {fold_change:.1f}-fold from baseline. High risk of treatment failure."
249
+
250
+ if fold_change >= 2:
251
+ if trend == "increasing":
252
+ return "MODERATE", f"MIC increased {fold_change:.1f}-fold with rising trend. Enhanced monitoring recommended."
253
+ else:
254
+ return "LOW", f"MIC increased {fold_change:.1f}-fold but trend is {trend}."
255
+
256
+ if trend == "increasing":
257
+ return "MODERATE", "MIC showing upward trend. Continue monitoring."
258
+
259
+ return "LOW", "MIC stable or decreasing. Current therapy appropriate."
260
+
261
+
262
+ def detect_mic_creep(
263
+ organism: str,
264
+ antibiotic: str,
265
+ mic_history: List[Dict[str, Any]],
266
+ breakpoints: Dict[str, float],
267
+ ) -> Dict[str, Any]:
268
+ """
269
+ Detect MIC creep for a specific organism-antibiotic pair.
270
+
271
+ Args:
272
+ organism: Pathogen name
273
+ antibiotic: Antibiotic name
274
+ mic_history: Historical MIC values with dates
275
+ breakpoints: Dict with 'susceptible' and 'resistant' keys
276
+
277
+ Returns:
278
+ Comprehensive MIC creep analysis
279
+ """
280
+ trend_analysis = calculate_mic_trend(
281
+ mic_history,
282
+ susceptible_breakpoint=breakpoints.get("susceptible"),
283
+ resistant_breakpoint=breakpoints.get("resistant"),
284
+ )
285
+
286
+ # Add organism/antibiotic context
287
+ trend_analysis["organism"] = organism
288
+ trend_analysis["antibiotic"] = antibiotic
289
+ trend_analysis["breakpoint_susceptible"] = breakpoints.get("susceptible")
290
+ trend_analysis["breakpoint_resistant"] = breakpoints.get("resistant")
291
+
292
+ # Calculate time to resistance estimate
293
+ if trend_analysis["trend"] == "increasing" and trend_analysis["velocity"] > 1.0:
294
+ current = trend_analysis["current_mic"]
295
+ s_bp = breakpoints.get("susceptible")
296
+ if s_bp and current < s_bp:
297
+ # Estimate doublings needed to reach breakpoint
298
+ doublings_needed = math.log2(s_bp / current) if current > 0 else 0
299
+ # Estimate time based on velocity
300
+ if trend_analysis["velocity"] > 1.0:
301
+ log_velocity = math.log(trend_analysis["velocity"]) / math.log(2)
302
+ if log_velocity > 0:
303
+ time_estimate = doublings_needed / log_velocity
304
+ trend_analysis["estimated_readings_to_resistance"] = round(time_estimate, 1)
305
+
306
+ return trend_analysis
307
+
308
+
309
+ # =============================================================================
310
+ # PRESCRIPTION FORMATTER
311
+ # =============================================================================
312
+
313
+ def format_prescription_card(recommendation: Dict[str, Any]) -> str:
314
+ """
315
+ Format a recommendation into a readable prescription card.
316
+
317
+ Args:
318
+ recommendation: Dict with recommendation details
319
+
320
+ Returns:
321
+ Formatted prescription card as string
322
+ """
323
+ lines = []
324
+ lines.append("=" * 50)
325
+ lines.append("ANTIBIOTIC PRESCRIPTION")
326
+ lines.append("=" * 50)
327
+
328
+ primary = recommendation.get("primary_recommendation", recommendation)
329
+
330
+ lines.append(f"\nDRUG: {primary.get('antibiotic', 'N/A')}")
331
+ lines.append(f"DOSE: {primary.get('dose', 'N/A')}")
332
+ lines.append(f"ROUTE: {primary.get('route', 'N/A')}")
333
+ lines.append(f"FREQUENCY: {primary.get('frequency', 'N/A')}")
334
+ lines.append(f"DURATION: {primary.get('duration', 'N/A')}")
335
+
336
+ if primary.get("aware_category"):
337
+ lines.append(f"WHO AWaRe: {primary.get('aware_category')}")
338
+
339
+ # Dose adjustments
340
+ adjustments = recommendation.get("dose_adjustments", {})
341
+ if adjustments.get("renal") and adjustments["renal"] != "None needed":
342
+ lines.append(f"\nRENAL ADJUSTMENT: {adjustments['renal']}")
343
+ if adjustments.get("hepatic") and adjustments["hepatic"] != "None needed":
344
+ lines.append(f"HEPATIC ADJUSTMENT: {adjustments['hepatic']}")
345
+
346
+ # Safety alerts
347
+ alerts = recommendation.get("safety_alerts", [])
348
+ if alerts:
349
+ lines.append("\n" + "-" * 50)
350
+ lines.append("SAFETY ALERTS:")
351
+ for alert in alerts:
352
+ level = alert.get("level", "INFO")
353
+ marker = {"CRITICAL": "[!!!]", "WARNING": "[!!]", "INFO": "[i]"}.get(level, "[?]")
354
+ lines.append(f" {marker} {alert.get('message', '')}")
355
+
356
+ # Monitoring
357
+ monitoring = recommendation.get("monitoring_parameters", [])
358
+ if monitoring:
359
+ lines.append("\n" + "-" * 50)
360
+ lines.append("MONITORING:")
361
+ for param in monitoring:
362
+ lines.append(f" - {param}")
363
+
364
+ # Rationale
365
+ if recommendation.get("rationale"):
366
+ lines.append("\n" + "-" * 50)
367
+ lines.append("RATIONALE:")
368
+ lines.append(f" {recommendation['rationale']}")
369
+
370
+ lines.append("\n" + "=" * 50)
371
+
372
+ return "\n".join(lines)
373
+
374
+
375
+ # =============================================================================
376
+ # JSON PARSING HELPERS
377
+ # =============================================================================
378
+
379
+ def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
380
+ """
381
+ Safely parse JSON from agent output, handling common issues.
382
+
383
+ Attempts to extract JSON from text that may contain markdown code blocks
384
+ or other formatting.
385
+ """
386
+ if not text:
387
+ return None
388
+
389
+ # Try direct parse first
390
+ try:
391
+ return json.loads(text)
392
+ except json.JSONDecodeError:
393
+ pass
394
+
395
+ # Try to extract JSON from markdown code block
396
+ import re
397
+
398
+ json_patterns = [
399
+ r"```json\s*\n?(.*?)\n?```", # ```json ... ```
400
+ r"```\s*\n?(.*?)\n?```", # ``` ... ```
401
+ r"\{[\s\S]*\}", # Raw JSON object
402
+ ]
403
+
404
+ for pattern in json_patterns:
405
+ match = re.search(pattern, text, re.DOTALL)
406
+ if match:
407
+ try:
408
+ json_str = match.group(1) if match.lastindex else match.group(0)
409
+ return json.loads(json_str)
410
+ except (json.JSONDecodeError, IndexError):
411
+ continue
412
+
413
+ return None
414
+
415
+
416
+ def validate_agent_output(output: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
417
+ """
418
+ Validate that agent output contains required fields.
419
+
420
+ Args:
421
+ output: Agent output dict
422
+ required_fields: List of required field names
423
+
424
+ Returns:
425
+ Tuple of (is_valid, list_of_missing_fields)
426
+ """
427
+ missing = [field for field in required_fields if field not in output]
428
+ return len(missing) == 0, missing
429
+
430
+
431
+ # =============================================================================
432
+ # DATA NORMALIZATION
433
+ # =============================================================================
434
+
435
+ def normalize_antibiotic_name(name: str) -> str:
436
+ """
437
+ Normalize antibiotic name to standard format.
438
+ """
439
+ # Common name mappings
440
+ mappings = {
441
+ "amox": "amoxicillin",
442
+ "amox/clav": "amoxicillin-clavulanate",
443
+ "augmentin": "amoxicillin-clavulanate",
444
+ "pip/tazo": "piperacillin-tazobactam",
445
+ "zosyn": "piperacillin-tazobactam",
446
+ "tmp/smx": "trimethoprim-sulfamethoxazole",
447
+ "bactrim": "trimethoprim-sulfamethoxazole",
448
+ "cipro": "ciprofloxacin",
449
+ "levo": "levofloxacin",
450
+ "moxi": "moxifloxacin",
451
+ "vanc": "vancomycin",
452
+ "vanco": "vancomycin",
453
+ "mero": "meropenem",
454
+ "imi": "imipenem",
455
+ "gent": "gentamicin",
456
+ "tobra": "tobramycin",
457
+ "ceftriax": "ceftriaxone",
458
+ "rocephin": "ceftriaxone",
459
+ "cefepime": "cefepime",
460
+ "maxipime": "cefepime",
461
+ }
462
+
463
+ normalized = name.lower().strip()
464
+ return mappings.get(normalized, normalized)
465
+
466
+
467
+ def normalize_organism_name(name: str) -> str:
468
+ """
469
+ Normalize organism name to standard format.
470
+ """
471
+ name = name.strip()
472
+
473
+ # Common abbreviations
474
+ abbreviations = {
475
+ "e. coli": "Escherichia coli",
476
+ "e.coli": "Escherichia coli",
477
+ "k. pneumoniae": "Klebsiella pneumoniae",
478
+ "k.pneumoniae": "Klebsiella pneumoniae",
479
+ "p. aeruginosa": "Pseudomonas aeruginosa",
480
+ "p.aeruginosa": "Pseudomonas aeruginosa",
481
+ "s. aureus": "Staphylococcus aureus",
482
+ "s.aureus": "Staphylococcus aureus",
483
+ "mrsa": "Staphylococcus aureus (MRSA)",
484
+ "mssa": "Staphylococcus aureus (MSSA)",
485
+ "enterococcus": "Enterococcus species",
486
+ "vre": "Enterococcus (VRE)",
487
+ }
488
+
489
+ lower_name = name.lower()
490
+ return abbreviations.get(lower_name, name)
491
+
492
+
493
+ __all__ = [
494
+ "calculate_crcl",
495
+ "calculate_ibw",
496
+ "calculate_adjusted_bw",
497
+ "get_renal_dose_category",
498
+ "calculate_mic_trend",
499
+ "detect_mic_creep",
500
+ "format_prescription_card",
501
+ "safe_json_parse",
502
+ "validate_agent_output",
503
+ "normalize_antibiotic_name",
504
+ "normalize_organism_name",
505
+ ]