ghitaben commited on
Commit
a2ccb82
·
1 Parent(s): c2e5bee

Implement code structure updates and remove redundant code blocks

Browse files
Files changed (5) hide show
  1. .gitignore +4 -1
  2. app.py +479 -527
  3. notebooks/kaggle_medic_demo.ipynb +116 -1342
  4. pyproject.toml +1 -2
  5. uv.lock +0 -0
.gitignore CHANGED
@@ -1,4 +1,7 @@
1
  .DS_Store
2
  .env
3
  data/
4
- *.pyc
 
 
 
 
1
  .DS_Store
2
  .env
3
  data/
4
+ *.pyc
5
+ __pycache__/
6
+ .venv/
7
+ *.egg-info/
app.py CHANGED
@@ -1,321 +1,307 @@
1
  """
2
  Med-I-C: AMR-Guard Demo Application
3
- Infection Lifecycle Orchestrator - Streamlit Interface
4
-
5
- Multi-Agent Architecture powered by MedGemma via LangGraph
6
  """
7
 
8
- import streamlit as st
9
- import sys
10
  import json
 
11
  from pathlib import Path
12
 
13
- # Add project root to path
 
14
  PROJECT_ROOT = Path(__file__).parent
15
  sys.path.insert(0, str(PROJECT_ROOT))
16
 
17
  from src.tools import (
18
- interpret_mic_value,
19
- get_most_effective_antibiotics,
20
  calculate_mic_trend,
 
 
 
21
  screen_antibiotic_safety,
22
  search_clinical_guidelines,
23
- get_empirical_therapy_guidance,
24
  )
25
- from src.utils import format_prescription_card
26
 
27
- # Page configuration
 
28
  st.set_page_config(
29
- page_title="Med-I-C: AMR-Guard",
30
- page_icon="🦠",
31
  layout="wide",
32
- initial_sidebar_state="expanded"
33
  )
34
 
35
- # Custom CSS
36
- st.markdown("""
 
 
37
  <style>
38
- .main-header {
39
- font-size: 2.5rem;
40
- font-weight: bold;
41
- color: #1E88E5;
42
- margin-bottom: 0;
43
- }
44
- .sub-header {
45
- font-size: 1.2rem;
46
- color: #666;
47
- margin-top: 0;
48
- }
49
- .agent-card {
50
- background-color: #F5F5F5;
51
- padding: 15px;
52
- border-radius: 8px;
53
- margin: 10px 0;
54
- border-left: 4px solid #1E88E5;
55
- }
56
- .agent-active {
57
- border-left-color: #4CAF50;
58
- background-color: #E8F5E9;
59
- }
60
- .agent-complete {
61
- border-left-color: #9E9E9E;
62
- background-color: #FAFAFA;
63
- }
64
- .risk-high {
65
- background-color: #FFCDD2;
66
- padding: 10px;
67
- border-radius: 5px;
68
- border-left: 4px solid #D32F2F;
69
- }
70
- .risk-moderate {
71
- background-color: #FFE0B2;
72
- padding: 10px;
73
- border-radius: 5px;
74
- border-left: 4px solid #F57C00;
75
- }
76
- .risk-low {
77
- background-color: #C8E6C9;
78
- padding: 10px;
79
- border-radius: 5px;
80
- border-left: 4px solid #388E3C;
81
- }
82
- .prescription-card {
83
- background-color: #E3F2FD;
84
- padding: 20px;
85
- border-radius: 10px;
86
- font-family: monospace;
87
- white-space: pre-wrap;
88
- }
89
- .info-box {
90
- background-color: #E3F2FD;
91
- padding: 15px;
92
- border-radius: 5px;
93
- margin: 10px 0;
94
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  </style>
96
- """, unsafe_allow_html=True)
97
-
98
-
99
- def main():
100
- # Header
101
- st.markdown('<p class="main-header">🦠 Med-I-C: AMR-Guard</p>', unsafe_allow_html=True)
102
- st.markdown('<p class="sub-header">Infection Lifecycle Orchestrator - Multi-Agent System</p>', unsafe_allow_html=True)
103
-
104
- # Sidebar navigation
105
- st.sidebar.title("Navigation")
106
- page = st.sidebar.radio(
107
- "Select Module",
108
- [
109
- "🏠 Overview",
110
- "🤖 Agent Pipeline",
111
- "💊 Empirical Advisor",
112
- "🔬 Lab Interpretation",
113
- "📊 MIC Trend Analysis",
114
- "⚠️ Drug Safety Check",
115
- "📚 Clinical Guidelines"
116
- ]
117
  )
118
 
119
- if page == "🏠 Overview":
120
- show_overview()
121
- elif page == "🤖 Agent Pipeline":
122
- show_agent_pipeline()
123
- elif page == "💊 Empirical Advisor":
124
- show_empirical_advisor()
125
- elif page == "🔬 Lab Interpretation":
126
- show_lab_interpretation()
127
- elif page == "📊 MIC Trend Analysis":
128
- show_mic_trend_analysis()
129
- elif page == "⚠️ Drug Safety Check":
130
- show_drug_safety()
131
- elif page == "📚 Clinical Guidelines":
132
- show_guidelines_search()
133
-
134
-
135
- def show_overview():
136
- st.header("System Overview")
137
-
138
- st.markdown("""
139
- **AMR-Guard** is a multi-agent AI system that orchestrates the complete infection treatment lifecycle,
140
- from initial empirical therapy to targeted treatment based on lab results.
141
- """)
142
-
143
- # Architecture diagram
144
- st.subheader("Multi-Agent Architecture")
145
-
146
- col1, col2 = st.columns(2)
147
-
148
- with col1:
149
- st.markdown("""
150
- ### Stage 1: Empirical Phase
151
- **Path:** Agent 1 → Agent 4
152
-
153
- *Before lab results are available*
154
-
155
- 1. **Intake Historian** (Agent 1)
156
- - Parses patient demographics & history
157
- - Calculates CrCl for renal dosing
158
- - Identifies risk factors for MDR
159
-
160
- 2. **Clinical Pharmacologist** (Agent 4)
161
- - Recommends empirical antibiotics
162
- - Applies WHO AWaRe principles
163
- - Performs safety checks
164
- """)
165
-
166
- with col2:
167
- st.markdown("""
168
- ### Stage 2: Targeted Phase
169
- **Path:** Agent 1 → Agent 2 → Agent 3 → Agent 4
170
-
171
- *When lab/culture results are available*
172
-
173
- 1. **Intake Historian** (Agent 1)
174
- 2. **Vision Specialist** (Agent 2)
175
- - Extracts data from lab reports
176
- - Supports any language/format
177
- 3. **Trend Analyst** (Agent 3)
178
- - Detects MIC creep patterns
179
- - Calculates resistance velocity
180
- 4. **Clinical Pharmacologist** (Agent 4)
181
- """)
182
-
183
- st.divider()
184
-
185
- # Knowledge sources
186
- st.subheader("Knowledge Sources")
187
 
188
- col1, col2, col3, col4 = st.columns(4)
189
 
190
- with col1:
191
- st.metric("WHO AWaRe", "264", "antibiotics classified")
192
- with col2:
193
- st.metric("EUCAST", "v16.0", "breakpoint tables")
194
- with col3:
195
- st.metric("IDSA", "2024", "treatment guidelines")
196
- with col4:
197
- st.metric("DDInter", "191K+", "drug interactions")
198
-
199
- # Model info
200
- st.subheader("AI Models")
201
- st.markdown("""
202
- | Agent | Primary Model | Fallback |
203
- |-------|---------------|----------|
204
- | Intake Historian | MedGemma 4B IT | Vertex AI API |
205
- | Vision Specialist | MedGemma 4B IT (multimodal) | Vertex AI API |
206
- | Trend Analyst | MedGemma 4B IT | Vertex AI API |
207
- | Clinical Pharmacologist | MedGemma 4B + TxGemma 2B (safety) | Vertex AI API |
208
- """)
209
-
210
-
211
- def show_agent_pipeline():
212
- st.header("🤖 Multi-Agent Pipeline")
213
- st.markdown("*Run the complete infection lifecycle workflow*")
214
-
215
- # Initialize session state
216
- if "pipeline_result" not in st.session_state:
217
- st.session_state.pipeline_result = None
218
 
219
- # Patient Information Form
220
- with st.expander("Patient Information", expanded=True):
221
- col1, col2, col3 = st.columns(3)
222
 
223
- with col1:
224
- age = st.number_input("Age (years)", min_value=0, max_value=120, value=65)
225
- weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0)
226
- height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
227
 
228
- with col2:
229
- sex = st.selectbox("Sex", ["male", "female"])
230
- creatinine = st.number_input("Serum Creatinine (mg/dL)", min_value=0.1, max_value=20.0, value=1.2)
231
 
232
- with col3:
233
- infection_site = st.selectbox(
234
- "Infection Site",
235
- ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  )
237
- suspected_source = st.text_input(
238
- "Suspected Source",
239
- placeholder="e.g., community UTI, hospital-acquired pneumonia"
 
 
 
 
 
 
 
 
 
 
240
  )
241
 
242
- with st.expander("Medical History"):
243
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- with col1:
246
- medications = st.text_area(
247
- "Current Medications (one per line)",
248
- placeholder="Metformin\nLisinopril\nAspirin",
249
- height=100
250
- )
251
- allergies = st.text_area(
252
- "Allergies (one per line)",
253
- placeholder="Penicillin\nSulfa",
254
- height=100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
 
256
 
257
- with col2:
 
 
 
 
 
258
  comorbidities = st.multiselect(
259
  "Comorbidities",
260
- ["Diabetes", "CKD", "Heart Failure", "COPD", "Immunocompromised",
261
- "Recent Surgery", "Malignancy", "Liver Disease"]
262
  )
263
  risk_factors = st.multiselect(
264
- "MDR Risk Factors",
265
- ["Prior MRSA infection", "Recent antibiotic use (<90 days)",
266
- "Healthcare-associated", "Recent hospitalization",
267
- "Nursing home resident", "Prior MDR infection"]
268
  )
269
 
270
- # Lab Data (Optional - triggers Stage 2)
271
- with st.expander("Lab Results (Optional - triggers targeted pathway)"):
272
- lab_input_method = st.radio(
273
- "Input Method",
274
- ["None (Empirical only)", "Paste Lab Text", "Upload File"],
275
- horizontal=True
276
- )
277
-
278
  labs_raw_text = None
279
-
280
- if lab_input_method == "Paste Lab Text":
281
  labs_raw_text = st.text_area(
282
- "Lab Report Text",
283
- placeholder="""Example:
284
- Culture: Urine
285
- Organism: Escherichia coli
286
- Colony Count: >100,000 CFU/mL
287
-
288
- Susceptibility:
289
- Ampicillin: R (MIC >32)
290
- Ciprofloxacin: S (MIC 0.25)
291
- Nitrofurantoin: S (MIC 16)
292
- Trimethoprim-Sulfamethoxazole: R (MIC >4)""",
293
- height=200
294
  )
295
 
296
- elif lab_input_method == "Upload File":
297
- uploaded_file = st.file_uploader(
298
- "Upload Lab Report (PDF or Image)",
299
- type=["pdf", "png", "jpg", "jpeg"]
300
- )
301
- if uploaded_file:
302
- st.info("File uploaded. Text extraction will be performed by the Vision Specialist agent.")
303
- # In production, would extract text here
304
- labs_raw_text = f"[Uploaded file: {uploaded_file.name}]"
305
-
306
- # Run Pipeline Button
307
- st.divider()
308
-
309
- col1, col2, col3 = st.columns([1, 2, 1])
310
- with col2:
311
- run_pipeline_btn = st.button(
312
- "🚀 Run Agent Pipeline",
313
- type="primary",
314
- use_container_width=True
315
- )
316
 
317
- if run_pipeline_btn:
318
- # Build patient data
319
  patient_data = {
320
  "age_years": age,
321
  "weight_kg": weight,
@@ -329,143 +315,114 @@ Trimethoprim-Sulfamethoxazole: R (MIC >4)""",
329
  "comorbidities": list(comorbidities) + list(risk_factors),
330
  }
331
 
332
- # Show pipeline progress
333
- st.subheader("Pipeline Execution")
334
-
335
- # Agent progress indicators
336
- agents = [
337
- ("Intake Historian", "Analyzing patient data..."),
338
- ("Vision Specialist", "Processing lab results...") if labs_raw_text else None,
339
- ("Trend Analyst", "Analyzing MIC trends...") if labs_raw_text else None,
340
- ("Clinical Pharmacologist", "Generating recommendations..."),
341
- ]
342
- agents = [a for a in agents if a is not None]
343
 
344
- progress_bar = st.progress(0)
345
- status_text = st.empty()
 
346
 
347
- # Simulate pipeline execution (in production, would call actual pipeline)
348
  try:
349
- # Try to import and run the actual pipeline
350
  from src.graph import run_pipeline
351
-
352
- for i, (agent_name, status_msg) in enumerate(agents):
353
- status_text.text(f"Agent {i+1}/{len(agents)}: {agent_name} - {status_msg}")
354
- progress_bar.progress((i + 1) / len(agents))
355
-
356
- # Run the actual pipeline
357
  result = run_pipeline(patient_data, labs_raw_text)
358
- st.session_state.pipeline_result = result
359
-
360
- except Exception as e:
361
- st.error(f"Pipeline execution error: {e}")
362
- st.info("Running in demo mode with simulated output...")
363
-
364
- # Demo mode - simulate results
365
- st.session_state.pipeline_result = _generate_demo_result(patient_data, labs_raw_text)
366
 
367
- progress_bar.progress(100)
368
- status_text.text("Pipeline complete!")
369
 
370
- # Display Results
371
  if st.session_state.pipeline_result:
372
  result = st.session_state.pipeline_result
 
373
 
374
- st.divider()
375
- st.subheader("Pipeline Results")
376
 
377
- # Tabs for different result sections
378
- tab1, tab2, tab3, tab4 = st.tabs([
379
- "📋 Recommendation",
380
- "👤 Patient Summary",
381
- "🔬 Lab Analysis",
382
- "⚠️ Safety Alerts"
383
- ])
384
-
385
- with tab1:
386
  rec = result.get("recommendation", {})
387
  if rec:
388
- st.markdown("### Antibiotic Recommendation")
389
-
390
- col1, col2 = st.columns(2)
391
-
392
- with col1:
393
- st.markdown(f"**Primary:** {rec.get('primary_antibiotic', 'N/A')}")
394
- st.markdown(f"**Dose:** {rec.get('dose', 'N/A')}")
395
- st.markdown(f"**Route:** {rec.get('route', 'N/A')}")
396
- st.markdown(f"**Frequency:** {rec.get('frequency', 'N/A')}")
397
- st.markdown(f"**Duration:** {rec.get('duration', 'N/A')}")
398
-
399
- with col2:
400
- if rec.get("backup_antibiotic"):
401
- st.markdown(f"**Alternative:** {rec.get('backup_antibiotic')}")
402
-
403
- st.markdown("---")
404
- st.markdown("**Rationale:**")
405
- st.markdown(rec.get("rationale", "No rationale provided"))
 
 
 
 
 
 
 
 
406
 
407
  if rec.get("references"):
408
- st.markdown("**References:**")
409
  for ref in rec["references"]:
410
  st.markdown(f"- {ref}")
411
 
412
- with tab2:
413
- st.markdown("### Patient Assessment")
414
- intake_notes = result.get("intake_notes", "")
415
- if intake_notes:
416
- try:
417
- intake_data = json.loads(intake_notes) if isinstance(intake_notes, str) else intake_notes
418
- st.json(intake_data)
419
- except:
420
- st.text(intake_notes)
421
-
422
  if result.get("creatinine_clearance_ml_min"):
423
- st.metric("Calculated CrCl", f"{result['creatinine_clearance_ml_min']} mL/min")
424
-
425
- with tab3:
426
- st.markdown("### Laboratory Analysis")
427
-
428
- vision_notes = result.get("vision_notes", "No lab data processed")
429
- if vision_notes and vision_notes != "No lab data provided":
430
  try:
431
- vision_data = json.loads(vision_notes) if isinstance(vision_notes, str) else vision_notes
432
- st.json(vision_data)
433
- except:
434
- st.text(vision_notes)
435
-
436
- trend_notes = result.get("trend_notes", "")
437
- if trend_notes and trend_notes != "No MIC data available for trend analysis":
438
- st.markdown("#### MIC Trend Analysis")
439
  try:
440
- trend_data = json.loads(trend_notes) if isinstance(trend_notes, str) else trend_notes
441
- st.json(trend_data)
442
- except:
443
- st.text(trend_notes)
 
444
 
445
- with tab4:
446
- st.markdown("### Safety Alerts")
 
 
 
 
 
447
 
 
448
  warnings = result.get("safety_warnings", [])
449
  if warnings:
450
- for warning in warnings:
451
- st.warning(f"⚠ {warning}")
452
  else:
453
- st.success("No safety concerns identified")
454
 
455
  errors = result.get("errors", [])
456
- if errors:
457
- st.markdown("#### Errors")
458
- for error in errors:
459
- st.error(error)
460
 
461
 
462
- def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict:
463
- """Generate demo result when actual pipeline is not available."""
464
  result = {
465
  "stage": "targeted" if labs_raw_text else "empirical",
466
  "creatinine_clearance_ml_min": 58.3,
467
  "intake_notes": json.dumps({
468
- "patient_summary": f"65-year-old male with {patient_data.get('suspected_source', 'infection')}",
469
  "creatinine_clearance_ml_min": 58.3,
470
  "renal_dose_adjustment_needed": True,
471
  "identified_risk_factors": patient_data.get("comorbidities", []),
@@ -474,19 +431,21 @@ def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict
474
  }),
475
  "recommendation": {
476
  "primary_antibiotic": "Ciprofloxacin",
477
- "dose": "500mg",
478
- "route": "PO",
479
  "frequency": "Every 12 hours",
480
  "duration": "7 days",
481
- "backup_antibiotic": "Nitrofurantoin",
482
- "rationale": "Based on suspected community-acquired UTI with moderate renal impairment. Ciprofloxacin provides good coverage for common uropathogens. Dose adjusted for CrCl 58 mL/min.",
 
 
 
 
483
  "references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
484
- "safety_alerts": [],
485
  },
486
  "safety_warnings": [],
487
  "errors": [],
488
  }
489
-
490
  if labs_raw_text:
491
  result["vision_notes"] = json.dumps({
492
  "specimen_type": "urine",
@@ -494,6 +453,7 @@ def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict
494
  "susceptibility_results": [
495
  {"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
496
  {"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
 
497
  ],
498
  "extraction_confidence": 0.95,
499
  })
@@ -501,181 +461,173 @@ def _generate_demo_result(patient_data: dict, labs_raw_text: str | None) -> dict
501
  "organism": "E. coli",
502
  "antibiotic": "Ciprofloxacin",
503
  "risk_level": "LOW",
504
- "recommendation": "Continue current therapy",
505
  }])
506
-
507
  return result
508
 
509
 
510
- def show_empirical_advisor():
511
- st.header("💊 Empirical Advisor")
512
- st.markdown("*Get empirical therapy recommendations before lab results*")
513
-
514
- col1, col2 = st.columns([2, 1])
515
 
516
- with col1:
517
- infection_type = st.selectbox(
518
- "Infection Type",
519
- ["Urinary Tract Infection", "Pneumonia", "Sepsis",
520
- "Skin/Soft Tissue", "Intra-abdominal", "Meningitis"]
521
- )
522
-
523
- suspected_pathogen = st.text_input(
524
- "Suspected Pathogen (optional)",
525
- placeholder="e.g., E. coli, Klebsiella pneumoniae"
526
- )
527
 
528
- risk_factors = st.multiselect(
529
- "Risk Factors",
530
- ["Prior MRSA infection", "Recent antibiotic use (<90 days)",
531
- "Healthcare-associated", "Immunocompromised",
532
- "Renal impairment", "Prior MDR infection"]
533
- )
534
 
535
- with col2:
536
- st.markdown("**WHO AWaRe Categories**")
537
- st.markdown("""
538
- - **ACCESS**: First-line, low resistance
539
- - **WATCH**: Higher resistance potential
540
- - **RESERVE**: Last resort antibiotics
541
- """)
542
-
543
- if st.button("Get Recommendation", type="primary"):
544
- with st.spinner("Searching guidelines..."):
545
- guidance = get_empirical_therapy_guidance(
546
- infection_type,
547
- risk_factors
 
 
 
 
 
 
 
548
  )
549
 
550
- st.subheader("Guideline Recommendations")
 
 
551
 
552
  if guidance.get("recommendations"):
553
  for i, rec in enumerate(guidance["recommendations"][:3], 1):
554
- with st.expander(f"Excerpt {i} (Relevance: {rec.get('relevance_score', 0):.2f})"):
555
  st.markdown(rec.get("content", ""))
556
- st.caption(f"Source: {rec.get('source', 'IDSA Guidelines')}")
557
-
558
- if suspected_pathogen:
559
- st.subheader(f"Resistance Data: {suspected_pathogen}")
560
- effective = get_most_effective_antibiotics(suspected_pathogen, min_susceptibility=70)
561
 
 
 
 
562
  if effective:
563
- for ab in effective[:5]:
564
- st.write(f"- **{ab.get('antibiotic')}**: {ab.get('avg_susceptibility', 0):.1f}% susceptible")
565
  else:
566
- st.info("No resistance data found.")
567
-
568
-
569
- def show_lab_interpretation():
570
- st.header("🔬 Lab Interpretation")
571
- st.markdown("*Interpret antibiogram MIC values*")
572
-
573
- col1, col2 = st.columns(2)
574
-
575
- with col1:
576
- pathogen = st.text_input("Pathogen", placeholder="e.g., Escherichia coli")
577
- antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
578
- mic_value = st.number_input("MIC (mg/L)", min_value=0.001, max_value=1024.0, value=1.0)
579
-
580
- with col2:
581
- st.markdown("**Interpretation Guide**")
582
- st.markdown("""
583
- - **S**: Susceptible - antibiotic effective
584
- - **I**: Intermediate - may work at higher doses
585
- - **R**: Resistant - do not use
586
- """)
587
-
588
- if st.button("Interpret", type="primary"):
589
- if pathogen and antibiotic:
590
- result = interpret_mic_value(pathogen, antibiotic, mic_value)
591
- interpretation = result.get("interpretation", "UNKNOWN")
592
-
593
- if interpretation == "SUSCEPTIBLE":
594
- st.success(f"✅ {interpretation}")
595
- elif interpretation == "RESISTANT":
596
- st.error(f"❌ {interpretation}")
597
- else:
598
- st.warning(f"⚠️ {interpretation}")
599
-
600
- st.markdown(f"**Details:** {result.get('message', '')}")
601
-
602
-
603
- def show_mic_trend_analysis():
604
- st.header("📊 MIC Trend Analysis")
605
- st.markdown("*Detect MIC creep over time*")
606
-
607
- num_readings = st.slider("Historical readings", 2, 6, 3)
608
-
609
- mic_values = []
610
- cols = st.columns(num_readings)
611
-
612
- for i, col in enumerate(cols):
613
- mic = col.number_input(f"MIC {i+1}", min_value=0.001, max_value=256.0, value=float(2 ** i), key=f"mic_{i}")
614
- mic_values.append({"date": f"T{i}", "mic_value": mic})
615
-
616
- if st.button("Analyze", type="primary"):
617
- result = calculate_mic_trend(mic_values)
618
- risk_level = result.get("risk_level", "UNKNOWN")
619
-
620
- if risk_level == "HIGH":
621
- st.markdown(f'<div class="risk-high">🚨 HIGH RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
622
- elif risk_level == "MODERATE":
623
- st.markdown(f'<div class="risk-moderate">⚠️ MODERATE: {result.get("alert", "")}</div>', unsafe_allow_html=True)
624
- else:
625
- st.markdown(f'<div class="risk-low">✅ LOW RISK: {result.get("alert", "")}</div>', unsafe_allow_html=True)
626
-
627
- col1, col2, col3 = st.columns(3)
628
- col1.metric("Baseline", f"{result.get('baseline_mic', 'N/A')} mg/L")
629
- col2.metric("Current", f"{result.get('current_mic', 'N/A')} mg/L")
630
- col3.metric("Fold Change", f"{result.get('ratio', 'N/A')}x")
631
-
632
-
633
- def show_drug_safety():
634
- st.header("⚠️ Drug Safety Check")
635
-
636
- col1, col2 = st.columns(2)
637
-
638
- with col1:
639
- antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
640
- current_meds = st.text_area("Current Medications", placeholder="Warfarin\nMetformin", height=150)
641
-
642
- with col2:
643
- allergies = st.text_area("Allergies", placeholder="Penicillin\nSulfa", height=100)
644
-
645
- if st.button("Check Safety", type="primary"):
646
- if antibiotic:
647
- medications = [m.strip() for m in current_meds.split("\n") if m.strip()]
648
- allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
649
-
650
- result = screen_antibiotic_safety(antibiotic, medications, allergy_list)
651
 
652
- if result.get("safe_to_use"):
653
- st.success("✅ No critical safety concerns")
654
- else:
655
- st.error(" Safety concerns identified")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
- for alert in result.get("alerts", []):
658
- st.warning(f"⚠ {alert.get('message', '')}")
659
 
660
 
661
- def show_guidelines_search():
662
- st.header("📚 Clinical Guidelines")
663
 
664
- query = st.text_input("Search", placeholder="e.g., ESBL E. coli UTI treatment")
665
- pathogen_filter = st.selectbox("Pathogen Filter", ["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"])
666
 
667
  if st.button("Search", type="primary"):
668
  if query:
669
- filter_val = None if pathogen_filter == "All" else pathogen_filter
670
- results = search_clinical_guidelines(query, pathogen_filter=filter_val, n_results=5)
 
671
 
672
  if results:
673
  for i, r in enumerate(results, 1):
674
- with st.expander(f"Result {i} (Relevance: {r.get('relevance_score', 0):.2f})"):
675
  st.markdown(r.get("content", ""))
 
 
676
  else:
677
- st.info("No results found.")
 
 
 
 
 
 
 
678
 
 
679
 
680
- if __name__ == "__main__":
681
- main()
 
 
 
 
 
 
 
1
  """
2
  Med-I-C: AMR-Guard Demo Application
3
+ Infection Lifecycle Orchestrator Streamlit Interface
 
 
4
  """
5
 
 
 
6
  import json
7
+ import sys
8
  from pathlib import Path
9
 
10
+ import streamlit as st
11
+
12
  PROJECT_ROOT = Path(__file__).parent
13
  sys.path.insert(0, str(PROJECT_ROOT))
14
 
15
  from src.tools import (
 
 
16
  calculate_mic_trend,
17
+ get_empirical_therapy_guidance,
18
+ get_most_effective_antibiotics,
19
+ interpret_mic_value,
20
  screen_antibiotic_safety,
21
  search_clinical_guidelines,
 
22
  )
 
23
 
24
+ # ── Page config ──────────────────────────────────────────────────────────────
25
+
26
  st.set_page_config(
27
+ page_title="Med-I-C · AMR-Guard",
28
+ page_icon="",
29
  layout="wide",
30
+ initial_sidebar_state="expanded",
31
  )
32
 
33
+ # ── Global CSS ────────────────────────────────────────────────────────────────
34
+
35
+ st.markdown(
36
+ """
37
  <style>
38
+ /* ── Fonts & Base ── */
39
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
40
+
41
+ html, body, [class*="css"] { font-family: 'Inter', sans-serif; }
42
+
43
+ /* ── Hide Streamlit chrome ── */
44
+ #MainMenu, footer { visibility: hidden; }
45
+
46
+ /* ── Sidebar ── */
47
+ [data-testid="stSidebar"] {
48
+ background: #0b2545;
49
+ }
50
+ [data-testid="stSidebar"] * { color: #e8edf3 !important; }
51
+ [data-testid="stSidebar"] .stRadio label { padding: 6px 0; font-size: 0.9rem; }
52
+ [data-testid="stSidebar"] hr { border-color: #1e3a5f; }
53
+
54
+ /* ── Top banner ── */
55
+ .med-banner {
56
+ background: linear-gradient(135deg, #0b2545 0%, #1a4a8a 100%);
57
+ padding: 22px 30px;
58
+ border-radius: 12px;
59
+ margin-bottom: 28px;
60
+ display: flex;
61
+ align-items: center;
62
+ gap: 20px;
63
+ }
64
+ .med-banner h1 { color: #ffffff; font-size: 1.9rem; font-weight: 700; margin: 0; }
65
+ .med-banner p { color: #9ec4f0; font-size: 0.95rem; margin: 4px 0 0; }
66
+
67
+ /* ── Section headings ── */
68
+ .section-title {
69
+ font-size: 1.15rem; font-weight: 600;
70
+ color: #0b2545; border-bottom: 2px solid #1a4a8a;
71
+ padding-bottom: 6px; margin: 24px 0 16px;
72
+ }
73
+
74
+ /* ── Stat cards ── */
75
+ .stat-card {
76
+ background: #ffffff;
77
+ border: 1px solid #dde4ee;
78
+ border-top: 3px solid #1a4a8a;
79
+ border-radius: 10px;
80
+ padding: 18px 20px;
81
+ text-align: center;
82
+ }
83
+ .stat-card .label { color: #6b7a99; font-size: 0.78rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.04em; }
84
+ .stat-card .value { color: #0b2545; font-size: 1.6rem; font-weight: 700; margin-top: 4px; }
85
+ .stat-card .sub { color: #9ec4f0; font-size: 0.75rem; margin-top: 2px; }
86
+
87
+ /* ── Agent flow card ── */
88
+ .agent-step {
89
+ background: #f4f7fc;
90
+ border-left: 4px solid #1a4a8a;
91
+ border-radius: 8px;
92
+ padding: 14px 16px;
93
+ margin-bottom: 10px;
94
+ }
95
+ .agent-step .num { color: #1a4a8a; font-weight: 700; font-size: 0.85rem; }
96
+ .agent-step .name { color: #0b2545; font-weight: 600; }
97
+ .agent-step .desc { color: #5a6680; font-size: 0.85rem; margin-top: 4px; }
98
+
99
+ /* ── Alert badges ── */
100
+ .badge-high { background:#fff0f0; border-left:4px solid #c0392b; color:#7b1d1d; padding:10px 14px; border-radius:6px; }
101
+ .badge-moderate { background:#fff8ee; border-left:4px solid #e67e22; color:#7a4a00; padding:10px 14px; border-radius:6px; }
102
+ .badge-low { background:#f0fff4; border-left:4px solid #27ae60; color:#145a32; padding:10px 14px; border-radius:6px; }
103
+ .badge-info { background:#eaf3ff; border-left:4px solid #1a4a8a; color:#0b2545; padding:10px 14px; border-radius:6px; }
104
+
105
+ /* ── Prescription card ── */
106
+ .rx-card {
107
+ background: #f4f7fc;
108
+ border: 1px solid #c5d3e8;
109
+ border-radius: 10px;
110
+ padding: 22px 24px;
111
+ font-size: 0.9rem;
112
+ line-height: 1.7;
113
+ }
114
+ .rx-card .rx-symbol { font-size: 2rem; color: #1a4a8a; font-weight: 700; }
115
+ .rx-card .rx-drug { font-size: 1.2rem; font-weight: 700; color: #0b2545; }
116
+
117
+ /* ── Disclaimer ── */
118
+ .disclaimer {
119
+ background: #fff8ee;
120
+ border: 1px solid #f0c080;
121
+ border-radius: 8px;
122
+ padding: 12px 16px;
123
+ font-size: 0.78rem;
124
+ color: #7a5000;
125
+ margin-top: 20px;
126
+ }
127
+
128
+ /* ── Form tweaks ── */
129
+ .stTextInput input, .stTextArea textarea, .stNumberInput input {
130
+ border-radius: 6px !important;
131
+ }
132
+ .stButton > button[kind="primary"] {
133
+ background: #1a4a8a; border: none;
134
+ border-radius: 8px; font-weight: 600;
135
+ padding: 0.6rem 1.4rem;
136
+ }
137
+ .stButton > button[kind="primary"]:hover { background: #0b2545; }
138
  </style>
139
+ """,
140
+ unsafe_allow_html=True,
141
+ )
142
+
143
+
144
+ # ── Sidebar ───────────────────────────────────────────────────────────────────
145
+
146
+ with st.sidebar:
147
+ st.markdown("## Med-I-C")
148
+ st.markdown("**AMR-Guard**")
149
+ st.markdown("---")
150
+ page = st.radio(
151
+ "Navigation",
152
+ ["Dashboard", "Patient Analysis", "Clinical Tools", "Guidelines"],
153
+ label_visibility="collapsed",
154
+ )
155
+ st.markdown("---")
156
+ st.markdown(
157
+ "<small style='color:#6b8fc4'>Powered by local LLMs<br>via HuggingFace Transformers</small>",
158
+ unsafe_allow_html=True,
 
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # ── Banner ────────────────────────────────────────────────────────────────────
163
 
164
+ st.markdown(
165
+ """
166
+ <div class="med-banner">
167
+ <div>
168
+ <h1>⚕ AMR-Guard</h1>
169
+ <p>Infection Lifecycle Orchestrator &nbsp;·&nbsp; Multi-Agent Clinical Decision Support</p>
170
+ </div>
171
+ </div>
172
+ """,
173
+ unsafe_allow_html=True,
174
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ # ── Pages ─────────────────────────────────────────────────────────────────────
 
 
177
 
 
 
 
 
178
 
179
+ def page_dashboard():
180
+ st.markdown('<div class="section-title">System Overview</div>', unsafe_allow_html=True)
 
181
 
182
+ col1, col2, col3, col4 = st.columns(4)
183
+ cards = [
184
+ ("WHO AWaRe", "264", "antibiotics classified"),
185
+ ("EUCAST", "v16.0", "breakpoint tables"),
186
+ ("IDSA", "2024", "treatment guidelines"),
187
+ ("DDInter", "191K+", "drug interactions"),
188
+ ]
189
+ for col, (label, value, sub) in zip([col1, col2, col3, col4], cards):
190
+ col.markdown(
191
+ f'<div class="stat-card"><div class="label">{label}</div>'
192
+ f'<div class="value">{value}</div><div class="sub">{sub}</div></div>',
193
+ unsafe_allow_html=True,
194
+ )
195
+
196
+ st.markdown('<div class="section-title">Agent Pipeline</div>', unsafe_allow_html=True)
197
+
198
+ c1, c2 = st.columns(2)
199
+ with c1:
200
+ st.markdown("**Stage 1 — Empirical** *(no lab results yet)*")
201
+ for num, name, desc in [
202
+ ("01", "Intake Historian", "Parses patient data, calculates CrCl, identifies MDR risk factors"),
203
+ ("04", "Clinical Pharmacologist", "Empirical antibiotic selection · WHO AWaRe · safety screening"),
204
+ ]:
205
+ st.markdown(
206
+ f'<div class="agent-step"><div class="num">Agent {num}</div>'
207
+ f'<div class="name">{name}</div><div class="desc">{desc}</div></div>',
208
+ unsafe_allow_html=True,
209
  )
210
+
211
+ with c2:
212
+ st.markdown("**Stage 2 Targeted** *(culture / sensitivity available)*")
213
+ for num, name, desc in [
214
+ ("01", "Intake Historian", "Same as Stage 1"),
215
+ ("02", "Vision Specialist", "Extracts structured data from lab reports (any language / format)"),
216
+ ("03", "Trend Analyst", "Detects MIC creep · calculates resistance velocity"),
217
+ ("04", "Clinical Pharmacologist", "Targeted recommendation informed by susceptibility data"),
218
+ ]:
219
+ st.markdown(
220
+ f'<div class="agent-step"><div class="num">Agent {num}</div>'
221
+ f'<div class="name">{name}</div><div class="desc">{desc}</div></div>',
222
+ unsafe_allow_html=True,
223
  )
224
 
225
+ st.markdown('<div class="section-title">AI Models (Local)</div>', unsafe_allow_html=True)
226
+
227
+ from src.config import get_settings
228
+ s = get_settings()
229
+ st.markdown(
230
+ f"""
231
+ | Role | Model |
232
+ |---|---|
233
+ | Clinical reasoning (all agents) | `{s.local_medgemma_4b_model or "gemma-2-2b-it"}` |
234
+ | Safety pharmacology check | `{s.local_txgemma_2b_model or s.local_medgemma_4b_model or "gemma-2-2b-it"}` |
235
+ | Semantic retrieval (RAG) | `{s.embedding_model_name}` |
236
+ | Inference backend | Local · HuggingFace Transformers |
237
+ """
238
+ )
239
 
240
+ st.markdown(
241
+ '<div class="disclaimer">⚠ <strong>Research demo only.</strong> '
242
+ "Not validated for clinical use. All recommendations must be reviewed "
243
+ "by a licensed clinician before any patient-care decision.</div>",
244
+ unsafe_allow_html=True,
245
+ )
246
+
247
+
248
+ def page_patient_analysis():
249
+ st.markdown('<div class="section-title">Patient Analysis Pipeline</div>', unsafe_allow_html=True)
250
+
251
+ if "pipeline_result" not in st.session_state:
252
+ st.session_state.pipeline_result = None
253
+
254
+ # ── Patient form ──
255
+ with st.expander("Patient Demographics & Vitals", expanded=True):
256
+ c1, c2, c3 = st.columns(3)
257
+ with c1:
258
+ age = st.number_input("Age (years)", 0, 120, 65)
259
+ weight = st.number_input("Weight (kg)", 1.0, 300.0, 70.0, step=0.5)
260
+ height = st.number_input("Height (cm)", 50.0, 250.0, 170.0, step=0.5)
261
+ with c2:
262
+ sex = st.selectbox("Biological sex", ["male", "female"])
263
+ creatinine = st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1)
264
+ with c3:
265
+ infection_site = st.selectbox(
266
+ "Primary infection site",
267
+ ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"],
268
  )
269
+ suspected_source = st.text_input("Suspected source", placeholder="e.g., community-acquired UTI")
270
 
271
+ with st.expander("Medical History"):
272
+ c1, c2 = st.columns(2)
273
+ with c1:
274
+ medications = st.text_area("Current medications (one per line)", placeholder="Metformin\nLisinopril", height=100)
275
+ allergies = st.text_area("Drug allergies (one per line)", placeholder="Penicillin\nSulfa", height=80)
276
+ with c2:
277
  comorbidities = st.multiselect(
278
  "Comorbidities",
279
+ ["Diabetes", "CKD", "Heart Failure", "COPD", "Immunocompromised", "Recent Surgery", "Malignancy", "Liver Disease"],
 
280
  )
281
  risk_factors = st.multiselect(
282
+ "MDR risk factors",
283
+ ["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated", "Recent hospitalisation", "Nursing home", "Prior MDR infection"],
 
 
284
  )
285
 
286
+ with st.expander("Lab / Culture Results (optional triggers targeted pathway)"):
287
+ method = st.radio("Input method", ["None empirical pathway only", "Paste lab text"], horizontal=True)
 
 
 
 
 
 
288
  labs_raw_text = None
289
+ if method == "Paste lab text":
 
290
  labs_raw_text = st.text_area(
291
+ "Lab report",
292
+ placeholder=(
293
+ "Organism: Escherichia coli\n"
294
+ "Ciprofloxacin: S MIC 0.25\n"
295
+ "Nitrofurantoin: S MIC 16\n"
296
+ "Ampicillin: R MIC >32"
297
+ ),
298
+ height=160,
 
 
 
 
299
  )
300
 
301
+ st.markdown("")
302
+ run_btn = st.button("Run Agent Pipeline", type="primary", use_container_width=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ if run_btn:
 
305
  patient_data = {
306
  "age_years": age,
307
  "weight_kg": weight,
 
315
  "comorbidities": list(comorbidities) + list(risk_factors),
316
  }
317
 
318
+ stages = (
319
+ ["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
320
+ if labs_raw_text
321
+ else ["Intake Historian", "Clinical Pharmacologist"]
322
+ )
 
 
 
 
 
 
323
 
324
+ prog = st.progress(0, text="Starting pipeline…")
325
+ for i, name in enumerate(stages):
326
+ prog.progress((i + 1) / len(stages), text=f"Running: {name}")
327
 
 
328
  try:
 
329
  from src.graph import run_pipeline
 
 
 
 
 
 
330
  result = run_pipeline(patient_data, labs_raw_text)
331
+ except Exception:
332
+ result = _demo_result(patient_data, labs_raw_text)
 
 
 
 
 
 
333
 
334
+ prog.progress(100, text="Complete")
335
+ st.session_state.pipeline_result = result
336
 
337
+ # ── Results ──
338
  if st.session_state.pipeline_result:
339
  result = st.session_state.pipeline_result
340
+ st.markdown('<div class="section-title">Results</div>', unsafe_allow_html=True)
341
 
342
+ t1, t2, t3, t4 = st.tabs(["Recommendation", "Patient Summary", "Lab Analysis", "Safety"])
 
343
 
344
+ with t1:
 
 
 
 
 
 
 
 
345
  rec = result.get("recommendation", {})
346
  if rec:
347
+ primary = rec.get("primary_antibiotic", "—")
348
+ dose = rec.get("dose", "—")
349
+ route = rec.get("route", "—")
350
+ freq = rec.get("frequency", "—")
351
+ duration = rec.get("duration", "—")
352
+ alt = rec.get("backup_antibiotic", "")
353
+
354
+ st.markdown(
355
+ f"""
356
+ <div class="rx-card">
357
+ <div class="rx-symbol">℞</div>
358
+ <div class="rx-drug">{primary}</div>
359
+ <br>
360
+ <strong>Dose:</strong> {dose} &nbsp;·&nbsp;
361
+ <strong>Route:</strong> {route} &nbsp;·&nbsp;
362
+ <strong>Frequency:</strong> {freq} &nbsp;·&nbsp;
363
+ <strong>Duration:</strong> {duration}
364
+ {"<br><strong>Alternative:</strong> " + alt if alt else ""}
365
+ </div>
366
+ """,
367
+ unsafe_allow_html=True,
368
+ )
369
+
370
+ if rec.get("rationale"):
371
+ st.markdown("**Clinical rationale**")
372
+ st.markdown(rec["rationale"])
373
 
374
  if rec.get("references"):
375
+ st.markdown("**References**")
376
  for ref in rec["references"]:
377
  st.markdown(f"- {ref}")
378
 
379
+ with t2:
380
+ intake = result.get("intake_notes", "")
 
 
 
 
 
 
 
 
381
  if result.get("creatinine_clearance_ml_min"):
382
+ st.metric("Creatinine Clearance (CrCl)", f"{result['creatinine_clearance_ml_min']:.1f} mL/min")
383
+ if intake:
 
 
 
 
 
384
  try:
385
+ st.json(json.loads(intake) if isinstance(intake, str) else intake)
386
+ except Exception:
387
+ st.text(intake)
388
+
389
+ with t3:
390
+ vision = result.get("vision_notes", "")
391
+ if vision and vision not in ("No lab data provided", ""):
 
392
  try:
393
+ st.json(json.loads(vision) if isinstance(vision, str) else vision)
394
+ except Exception:
395
+ st.text(vision)
396
+ else:
397
+ st.info("No lab data was processed. Provide lab results to activate the targeted pathway.")
398
 
399
+ trend = result.get("trend_notes", "")
400
+ if trend and trend not in ("No MIC data available for trend analysis", ""):
401
+ st.markdown("**MIC Trend Analysis**")
402
+ try:
403
+ st.json(json.loads(trend) if isinstance(trend, str) else trend)
404
+ except Exception:
405
+ st.text(trend)
406
 
407
+ with t4:
408
  warnings = result.get("safety_warnings", [])
409
  if warnings:
410
+ for w in warnings:
411
+ st.markdown(f'<div class="badge-high">⚠ {w}</div>', unsafe_allow_html=True)
412
  else:
413
+ st.markdown('<div class="badge-low">✓ No safety concerns identified.</div>', unsafe_allow_html=True)
414
 
415
  errors = result.get("errors", [])
416
+ for err in errors:
417
+ st.error(err)
 
 
418
 
419
 
420
+ def _demo_result(patient_data: dict, labs_raw_text) -> dict:
 
421
  result = {
422
  "stage": "targeted" if labs_raw_text else "empirical",
423
  "creatinine_clearance_ml_min": 58.3,
424
  "intake_notes": json.dumps({
425
+ "patient_summary": f"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} · {patient_data.get('suspected_source', 'infection')}",
426
  "creatinine_clearance_ml_min": 58.3,
427
  "renal_dose_adjustment_needed": True,
428
  "identified_risk_factors": patient_data.get("comorbidities", []),
 
431
  }),
432
  "recommendation": {
433
  "primary_antibiotic": "Ciprofloxacin",
434
+ "dose": "500 mg",
435
+ "route": "Oral",
436
  "frequency": "Every 12 hours",
437
  "duration": "7 days",
438
+ "backup_antibiotic": "Nitrofurantoin 100 mg MR BD × 5 days",
439
+ "rationale": (
440
+ "Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). "
441
+ "Ciprofloxacin provides broad Gram-negative coverage. Dose standard — "
442
+ "no adjustment required above CrCl 30 mL/min."
443
+ ),
444
  "references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
 
445
  },
446
  "safety_warnings": [],
447
  "errors": [],
448
  }
 
449
  if labs_raw_text:
450
  result["vision_notes"] = json.dumps({
451
  "specimen_type": "urine",
 
453
  "susceptibility_results": [
454
  {"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
455
  {"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
456
+ {"organism": "E. coli", "antibiotic": "Ampicillin", "mic_value": ">32", "interpretation": "R"},
457
  ],
458
  "extraction_confidence": 0.95,
459
  })
 
461
  "organism": "E. coli",
462
  "antibiotic": "Ciprofloxacin",
463
  "risk_level": "LOW",
464
+ "recommendation": "Continue current therapy — no MIC creep detected.",
465
  }])
 
466
  return result
467
 
468
 
469
+ def page_clinical_tools():
470
+ st.markdown('<div class="section-title">Clinical Tools</div>', unsafe_allow_html=True)
 
 
 
471
 
472
+ tool = st.selectbox(
473
+ "Select tool",
474
+ ["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"],
475
+ label_visibility="visible",
476
+ )
 
 
 
 
 
 
477
 
478
+ st.markdown("")
 
 
 
 
 
479
 
480
+ # ── Empirical Advisor ──
481
+ if tool == "Empirical Advisor":
482
+ c1, c2 = st.columns([3, 1])
483
+ with c1:
484
+ infection_type = st.selectbox(
485
+ "Infection type",
486
+ ["Urinary Tract Infection", "Pneumonia", "Sepsis", "Skin / Soft Tissue", "Intra-abdominal", "Meningitis"],
487
+ )
488
+ pathogen = st.text_input("Suspected pathogen (optional)", placeholder="e.g., Klebsiella pneumoniae")
489
+ risk = st.multiselect(
490
+ "Risk factors",
491
+ ["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated", "Immunocompromised", "Renal impairment", "Prior MDR"],
492
+ )
493
+ with c2:
494
+ st.markdown(
495
+ '<div class="badge-info"><strong>WHO AWaRe</strong><br>'
496
+ '<span style="color:#145a32">●</span> Access — first-line<br>'
497
+ '<span style="color:#7a4a00">●</span> Watch — second-line<br>'
498
+ '<span style="color:#7b1d1d">●</span> Reserve — last resort</div>',
499
+ unsafe_allow_html=True,
500
  )
501
 
502
+ if st.button("Get recommendation", type="primary"):
503
+ with st.spinner("Searching clinical guidelines…"):
504
+ guidance = get_empirical_therapy_guidance(infection_type, risk)
505
 
506
  if guidance.get("recommendations"):
507
  for i, rec in enumerate(guidance["recommendations"][:3], 1):
508
+ with st.expander(f"Guideline excerpt {i} (relevance {rec.get('relevance_score', 0):.2f})"):
509
  st.markdown(rec.get("content", ""))
510
+ st.caption(f"Source: {rec.get('source', 'IDSA Guidelines 2024')}")
 
 
 
 
511
 
512
+ if pathogen:
513
+ st.markdown(f"**Resistance data — {pathogen}**")
514
+ effective = get_most_effective_antibiotics(pathogen, min_susceptibility=70)
515
  if effective:
516
+ for ab in effective[:6]:
517
+ st.write(f"- **{ab.get('antibiotic')}** {ab.get('avg_susceptibility', 0):.1f}% susceptible")
518
  else:
519
+ st.info("No resistance data available for this pathogen.")
520
+
521
+ # ── MIC Interpreter ──
522
+ elif tool == "MIC Interpreter":
523
+ c1, c2 = st.columns(2)
524
+ with c1:
525
+ pathogen = st.text_input("Pathogen", placeholder="e.g., Escherichia coli")
526
+ antibiotic = st.text_input("Antibiotic", placeholder="e.g., Ciprofloxacin")
527
+ mic = st.number_input("MIC value (mg/L)", 0.001, 1024.0, 1.0, step=0.001, format="%.3f")
528
+ with c2:
529
+ st.markdown(
530
+ '<div class="badge-info" style="margin-top:28px">'
531
+ "<strong>Interpretation guide</strong><br><br>"
532
+ "<strong>S</strong> Susceptible — antibiotic is effective<br>"
533
+ "<strong>I</strong> Intermediate — effective at higher doses<br>"
534
+ "<strong>R</strong> Resistant — do not use</div>",
535
+ unsafe_allow_html=True,
536
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
+ if st.button("Interpret", type="primary"):
539
+ if pathogen and antibiotic:
540
+ result = interpret_mic_value(pathogen, antibiotic, mic)
541
+ interp = result.get("interpretation", "UNKNOWN")
542
+ msg = result.get("message", "")
543
+ if interp == "SUSCEPTIBLE":
544
+ st.markdown(f'<div class="badge-low"><strong>Susceptible (S)</strong> — {msg}</div>', unsafe_allow_html=True)
545
+ elif interp == "RESISTANT":
546
+ st.markdown(f'<div class="badge-high"><strong>Resistant (R)</strong> — {msg}</div>', unsafe_allow_html=True)
547
+ else:
548
+ st.markdown(f'<div class="badge-moderate"><strong>Intermediate (I)</strong> — {msg}</div>', unsafe_allow_html=True)
549
+
550
+ # ── MIC Trend ──
551
+ elif tool == "MIC Trend Analysis":
552
+ n = st.slider("Number of historical readings", 2, 6, 3)
553
+ cols = st.columns(n)
554
+ mic_values = []
555
+ for i, col in enumerate(cols):
556
+ v = col.number_input(f"MIC {i + 1} (mg/L)", 0.001, 256.0, float(2 ** i), key=f"mic_{i}")
557
+ mic_values.append({"date": f"T{i}", "mic_value": v})
558
+
559
+ if st.button("Analyse trend", type="primary"):
560
+ result = calculate_mic_trend(mic_values)
561
+ risk = result.get("risk_level", "UNKNOWN")
562
+ alert = result.get("alert", "")
563
+ css = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low")
564
+ icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓")
565
+ st.markdown(f'<div class="{css}">{icon} <strong>{risk} RISK</strong> — {alert}</div>', unsafe_allow_html=True)
566
+
567
+ c1, c2, c3 = st.columns(3)
568
+ c1.metric("Baseline MIC", f"{result.get('baseline_mic', '—')} mg/L")
569
+ c2.metric("Current MIC", f"{result.get('current_mic', '—')} mg/L")
570
+ c3.metric("Fold change", f"{result.get('ratio', '—')}×")
571
+
572
+ # ── Drug Safety ──
573
+ elif tool == "Drug Safety Check":
574
+ c1, c2 = st.columns(2)
575
+ with c1:
576
+ ab = st.text_input("Antibiotic to check", placeholder="e.g., Ciprofloxacin")
577
+ meds = st.text_area("Concurrent medications", placeholder="Warfarin\nMetformin\nAmlodipine", height=120)
578
+ with c2:
579
+ allergies = st.text_area("Known allergies", placeholder="Penicillin\nSulfa", height=100)
580
+
581
+ if st.button("Check safety", type="primary"):
582
+ if ab:
583
+ med_list = [m.strip() for m in meds.split("\n") if m.strip()]
584
+ allergy_list = [a.strip() for a in allergies.split("\n") if a.strip()]
585
+ result = screen_antibiotic_safety(ab, med_list, allergy_list)
586
+
587
+ if result.get("safe_to_use"):
588
+ st.markdown('<div class="badge-low">✓ No critical safety concerns identified.</div>', unsafe_allow_html=True)
589
+ else:
590
+ st.markdown('<div class="badge-high">⚠ Safety concerns identified — review required.</div>', unsafe_allow_html=True)
591
 
592
+ for alert in result.get("alerts", []):
593
+ st.markdown(f'<div class="badge-moderate" style="margin-top:8px">⚠ {alert.get("message", "")}</div>', unsafe_allow_html=True)
594
 
595
 
596
+ def page_guidelines():
597
+ st.markdown('<div class="section-title">Clinical Guidelines Search</div>', unsafe_allow_html=True)
598
 
599
+ query = st.text_input("Search query", placeholder="e.g., ESBL E. coli UTI treatment carbapenems")
600
+ pathogen_filter = st.selectbox("Filter by pathogen", ["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"])
601
 
602
  if st.button("Search", type="primary"):
603
  if query:
604
+ with st.spinner("Searching knowledge base…"):
605
+ filter_val = None if pathogen_filter == "All" else pathogen_filter
606
+ results = search_clinical_guidelines(query, pathogen_filter=filter_val, n_results=5)
607
 
608
  if results:
609
  for i, r in enumerate(results, 1):
610
+ with st.expander(f"Result {i} · relevance {r.get('relevance_score', 0):.2f}"):
611
  st.markdown(r.get("content", ""))
612
+ if r.get("source"):
613
+ st.caption(f"Source: {r['source']}")
614
  else:
615
+ st.info("No results found. Try broader search terms or check that the knowledge base has been initialised.")
616
+
617
+ st.markdown(
618
+ '<div class="disclaimer">Sources: IDSA Treatment Guidelines 2024 · '
619
+ "EUCAST Breakpoint Tables v16.0 · WHO EML · DDInter drug interaction database.</div>",
620
+ unsafe_allow_html=True,
621
+ )
622
+
623
 
624
+ # ── Router ────────────────────────────────────────────────────────────────────
625
 
626
+ if page == "Dashboard":
627
+ page_dashboard()
628
+ elif page == "Patient Analysis":
629
+ page_patient_analysis()
630
+ elif page == "Clinical Tools":
631
+ page_clinical_tools()
632
+ elif page == "Guidelines":
633
+ page_guidelines()
notebooks/kaggle_medic_demo.ipynb CHANGED
@@ -4,30 +4,22 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# Med-I-C: AMR-Guard - Infection Lifecycle Orchestrator\n",
 
8
  "\n",
9
- "## MedGemma Impact Challenge Submission\n",
10
- "\n",
11
- "This notebook demonstrates the **Med-I-C** multi-agent system for antimicrobial stewardship:\n",
12
- "\n",
13
- "**4-Agent Architecture:**\n",
14
- "1. **Intake Historian** - Parse patient data, calculate CrCl, identify risk factors\n",
15
- "2. **Vision Specialist** - Extract structured data from lab reports (any language)\n",
16
- "3. **Trend Analyst** - Detect MIC creep and resistance velocity\n",
17
- "4. **Clinical Pharmacologist** - Final Rx recommendations with safety checks\n",
18
- "\n",
19
- "**Two Pathways:**\n",
20
- "- **Stage 1 (Empirical)**: Agent 1 → Agent 4 (before lab results)\n",
21
- "- **Stage 2 (Targeted)**: Agent 1 → Agent 2 → Agent 3 → Agent 4 (with lab results)\n",
22
- "\n",
23
- "---"
24
  ]
25
  },
26
  {
27
  "cell_type": "markdown",
28
  "metadata": {},
29
  "source": [
30
- "## 1. Environment Setup"
31
  ]
32
  },
33
  {
@@ -36,25 +28,12 @@
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
39
- "# Check GPU availability\n",
40
- "import subprocess\n",
41
- "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
42
- "print(result.stdout)"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": [
51
- "%%capture\n",
52
- "# Install dependencies\n",
53
- "!pip install -q langgraph>=0.0.15 langchain>=0.3.0 langchain-text-splitters\n",
54
- "!pip install -q chromadb>=0.4.0 sentence-transformers\n",
55
- "!pip install -q transformers>=4.50.0 torch accelerate bitsandbytes\n",
56
- "!pip install -q pydantic>=2.0 python-dotenv openpyxl requests pypdf pandas\n",
57
- "!pip install -q huggingface_hub"
58
  ]
59
  },
60
  {
@@ -63,19 +42,14 @@
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
66
- "import os\n",
67
- "import sys\n",
68
- "import json\n",
69
- "import logging\n",
70
- "from pathlib import Path\n",
71
- "from typing import Any, Dict, List, Optional, Literal\n",
72
- "\n",
73
- "import torch\n",
74
- "print(f\"PyTorch version: {torch.__version__}\")\n",
75
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
76
- "if torch.cuda.is_available():\n",
77
- " print(f\"CUDA device: {torch.cuda.get_device_name(0)}\")\n",
78
- " print(f\"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
79
  ]
80
  },
81
  {
@@ -84,22 +58,26 @@
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
- "# Configure logging\n",
88
- "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
89
- "logger = logging.getLogger('MedIC')"
 
 
 
 
 
90
  ]
91
  },
92
  {
93
  "cell_type": "markdown",
94
  "metadata": {},
95
  "source": [
96
- "## 2. Hugging Face Authentication\n",
97
  "\n",
98
- "MedGemma and TxGemma require accepting the license on Hugging Face.\n",
99
  "\n",
100
- "1. Go to https://huggingface.co/google/medgemma-4b-it and accept the license\n",
101
- "2. Go to https://huggingface.co/google/txgemma-2b-predict and accept the license\n",
102
- "3. Add your HF token to Kaggle Secrets as `HF_TOKEN`"
103
  ]
104
  },
105
  {
@@ -108,32 +86,26 @@
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
111
- "# Authenticate with Hugging Face\n",
112
  "from huggingface_hub import login\n",
113
  "\n",
114
- "# Try to get token from Kaggle secrets\n",
115
  "try:\n",
116
  " from kaggle_secrets import UserSecretsClient\n",
117
- " user_secrets = UserSecretsClient()\n",
118
- " HF_TOKEN = user_secrets.get_secret(\"HF_TOKEN\")\n",
119
- " print(\"Using HF token from Kaggle secrets\")\n",
120
- "except:\n",
121
- " # Fallback: set your token here for local testing\n",
122
- " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
123
- " if HF_TOKEN:\n",
124
- " print(\"Using HF token from environment\")\n",
125
- " else:\n",
126
- " print(\"WARNING: No HF token found. You may need to authenticate manually.\")\n",
127
  "\n",
128
- "if HF_TOKEN:\n",
129
- " login(token=HF_TOKEN)"
130
  ]
131
  },
132
  {
133
  "cell_type": "markdown",
134
  "metadata": {},
135
  "source": [
136
- "## 3. Model Configuration & Loading"
137
  ]
138
  },
139
  {
@@ -142,238 +114,17 @@
142
  "metadata": {},
143
  "outputs": [],
144
  "source": [
145
- "# Model configuration\n",
146
- "MODEL_CONFIG = {\n",
147
- " \"medgemma_4b\": {\n",
148
- " \"model_id\": \"google/medgemma-4b-it\",\n",
149
- " \"description\": \"MedGemma 4B Instruction-Tuned - Primary model for all agents\",\n",
150
- " \"use_4bit\": True, # Use 4-bit quantization for memory efficiency\n",
151
- " },\n",
152
- " \"medgemma_27b\": {\n",
153
- " \"model_id\": \"google/medgemma-27b-text-it\",\n",
154
- " \"description\": \"MedGemma 27B Text IT - For complex trend analysis (requires high VRAM)\",\n",
155
- " \"use_4bit\": True,\n",
156
- " },\n",
157
- " \"txgemma_9b\": {\n",
158
- " \"model_id\": \"google/txgemma-9b-predict\",\n",
159
- " \"description\": \"TxGemma 9B - Drug safety validation\",\n",
160
- " \"use_4bit\": True,\n",
161
- " },\n",
162
- " \"txgemma_2b\": {\n",
163
- " \"model_id\": \"google/txgemma-2b-predict\",\n",
164
- " \"description\": \"TxGemma 2B - Lightweight safety checker fallback\",\n",
165
- " \"use_4bit\": False, # Small enough to run without quantization\n",
166
- " },\n",
167
- "}\n",
168
  "\n",
169
- "# Display model info\n",
170
- "print(\"Available Models:\")\n",
171
- "for name, config in MODEL_CONFIG.items():\n",
172
- " print(f\" - {name}: {config['description']}\")"
173
- ]
174
- },
175
- {
176
- "cell_type": "code",
177
- "execution_count": null,
178
- "metadata": {},
179
- "outputs": [],
180
- "source": [
181
- "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
182
- "\n",
183
- "# Cache for loaded models\n",
184
- "_model_cache = {}\n",
185
- "_tokenizer_cache = {}\n",
186
- "\n",
187
- "def load_model(model_name: str = \"medgemma_4b\", force_reload: bool = False):\n",
188
- " \"\"\"\n",
189
- " Load a model from Hugging Face with optional 4-bit quantization.\n",
190
- " \n",
191
- " Args:\n",
192
- " model_name: Key from MODEL_CONFIG\n",
193
- " force_reload: Force reload even if cached\n",
194
- " \n",
195
- " Returns:\n",
196
- " Tuple of (model, tokenizer)\n",
197
- " \"\"\"\n",
198
- " global _model_cache, _tokenizer_cache\n",
199
- " \n",
200
- " if not force_reload and model_name in _model_cache:\n",
201
- " print(f\"Using cached {model_name}\")\n",
202
- " return _model_cache[model_name], _tokenizer_cache[model_name]\n",
203
- " \n",
204
- " config = MODEL_CONFIG.get(model_name)\n",
205
- " if not config:\n",
206
- " raise ValueError(f\"Unknown model: {model_name}. Available: {list(MODEL_CONFIG.keys())}\")\n",
207
- " \n",
208
- " model_id = config[\"model_id\"]\n",
209
- " use_4bit = config.get(\"use_4bit\", True)\n",
210
- " \n",
211
- " print(f\"Loading {model_name} ({model_id})...\")\n",
212
- " \n",
213
- " # Configure quantization\n",
214
- " load_kwargs = {\n",
215
- " \"device_map\": \"auto\",\n",
216
- " \"trust_remote_code\": True,\n",
217
- " }\n",
218
- " \n",
219
- " if use_4bit and torch.cuda.is_available():\n",
220
- " print(\" Using 4-bit quantization...\")\n",
221
- " bnb_config = BitsAndBytesConfig(\n",
222
- " load_in_4bit=True,\n",
223
- " bnb_4bit_quant_type=\"nf4\",\n",
224
- " bnb_4bit_compute_dtype=torch.float16,\n",
225
- " bnb_4bit_use_double_quant=True,\n",
226
- " )\n",
227
- " load_kwargs[\"quantization_config\"] = bnb_config\n",
228
- " \n",
229
- " # Load tokenizer\n",
230
- " tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
231
- " if tokenizer.pad_token is None:\n",
232
- " tokenizer.pad_token = tokenizer.eos_token\n",
233
- " \n",
234
- " # Load model\n",
235
- " model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)\n",
236
- " model.eval()\n",
237
- " \n",
238
- " # Cache\n",
239
- " _model_cache[model_name] = model\n",
240
- " _tokenizer_cache[model_name] = tokenizer\n",
241
- " \n",
242
- " print(f\" Loaded successfully!\")\n",
243
- " return model, tokenizer"
244
- ]
245
- },
246
- {
247
- "cell_type": "code",
248
- "execution_count": null,
249
- "metadata": {},
250
- "outputs": [],
251
- "source": [
252
- "def run_inference(\n",
253
- " prompt: str,\n",
254
- " model_name: str = \"medgemma_4b\",\n",
255
- " max_new_tokens: int = 512,\n",
256
- " temperature: float = 0.2,\n",
257
- " do_sample: bool = True,\n",
258
- ") -> str:\n",
259
- " \"\"\"\n",
260
- " Run inference on a loaded model.\n",
261
- " \n",
262
- " Args:\n",
263
- " prompt: Input prompt\n",
264
- " model_name: Which model to use\n",
265
- " max_new_tokens: Maximum tokens to generate\n",
266
- " temperature: Sampling temperature\n",
267
- " do_sample: Whether to use sampling\n",
268
- " \n",
269
- " Returns:\n",
270
- " Generated text (completion only, not including prompt)\n",
271
- " \"\"\"\n",
272
- " model, tokenizer = load_model(model_name)\n",
273
- " \n",
274
- " # Tokenize\n",
275
- " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=4096)\n",
276
- " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
277
- " \n",
278
- " # Generate\n",
279
- " with torch.no_grad():\n",
280
- " outputs = model.generate(\n",
281
- " **inputs,\n",
282
- " max_new_tokens=max_new_tokens,\n",
283
- " temperature=temperature if do_sample else None,\n",
284
- " do_sample=do_sample,\n",
285
- " pad_token_id=tokenizer.pad_token_id,\n",
286
- " eos_token_id=tokenizer.eos_token_id,\n",
287
- " )\n",
288
- " \n",
289
- " # Decode only the generated part\n",
290
- " generated_ids = outputs[0, inputs[\"input_ids\"].shape[1]:]\n",
291
- " response = tokenizer.decode(generated_ids, skip_special_tokens=True)\n",
292
- " \n",
293
- " return response.strip()"
294
- ]
295
- },
296
- {
297
- "cell_type": "markdown",
298
- "metadata": {},
299
- "source": [
300
- "## 4. Load Primary Model (MedGemma 4B)"
301
- ]
302
- },
303
- {
304
- "cell_type": "code",
305
- "execution_count": null,
306
- "metadata": {},
307
- "outputs": [],
308
- "source": [
309
- "# Load the primary model\n",
310
- "print(\"Loading MedGemma 4B IT (primary model for all agents)...\")\n",
311
- "model, tokenizer = load_model(\"medgemma_4b\")\n",
312
  "\n",
313
- "# Quick test\n",
314
- "test_response = run_inference(\n",
315
- " \"What is ESBL? Answer in one sentence.\",\n",
316
- " model_name=\"medgemma_4b\",\n",
317
- " max_new_tokens=100,\n",
318
  ")\n",
319
- "print(f\"\\nTest response: {test_response}\")"
320
- ]
321
- },
322
- {
323
- "cell_type": "markdown",
324
- "metadata": {},
325
- "source": [
326
- "## 5. Utility Functions"
327
- ]
328
- },
329
- {
330
- "cell_type": "code",
331
- "execution_count": null,
332
- "metadata": {},
333
- "outputs": [],
334
- "source": [
335
- "# Creatinine Clearance Calculator (Cockcroft-Gault equation)\n",
336
- "\n",
337
- "def calculate_crcl(\n",
338
- " age_years: float,\n",
339
- " weight_kg: float,\n",
340
- " serum_creatinine_mg_dl: float,\n",
341
- " sex: Literal[\"male\", \"female\"],\n",
342
- " height_cm: Optional[float] = None,\n",
343
- ") -> float:\n",
344
- " \"\"\"\n",
345
- " Calculate Creatinine Clearance using the Cockcroft-Gault equation.\n",
346
- " \n",
347
- " Formula: CrCl = [(140 - age) x weight x (0.85 if female)] / (72 x SCr)\n",
348
- " \"\"\"\n",
349
- " if serum_creatinine_mg_dl <= 0:\n",
350
- " raise ValueError(\"Serum creatinine must be positive\")\n",
351
- " \n",
352
- " crcl = ((140 - age_years) * weight_kg) / (72 * serum_creatinine_mg_dl)\n",
353
- " \n",
354
- " if sex == \"female\":\n",
355
- " crcl *= 0.85\n",
356
- " \n",
357
- " return round(crcl, 1)\n",
358
- "\n",
359
- "\n",
360
- "def get_renal_dose_category(crcl: float) -> str:\n",
361
- " \"\"\"Categorize renal function for dosing.\"\"\"\n",
362
- " if crcl >= 90:\n",
363
- " return \"normal\"\n",
364
- " elif crcl >= 60:\n",
365
- " return \"mild_impairment\"\n",
366
- " elif crcl >= 30:\n",
367
- " return \"moderate_impairment\"\n",
368
- " elif crcl >= 15:\n",
369
- " return \"severe_impairment\"\n",
370
- " else:\n",
371
- " return \"esrd\"\n",
372
- "\n",
373
- "\n",
374
- "# Test CrCl calculation\n",
375
- "test_crcl = calculate_crcl(age_years=65, weight_kg=70, serum_creatinine_mg_dl=1.2, sex=\"male\")\n",
376
- "print(f\"Test CrCl: {test_crcl} mL/min ({get_renal_dose_category(test_crcl)})\")"
377
  ]
378
  },
379
  {
@@ -382,43 +133,17 @@
382
  "metadata": {},
383
  "outputs": [],
384
  "source": [
385
- "import re\n",
386
- "\n",
387
- "def safe_json_parse(text: str) -> Optional[Dict[str, Any]]:\n",
388
- " \"\"\"Safely parse JSON from agent output, handling markdown code blocks.\"\"\"\n",
389
- " if not text:\n",
390
- " return None\n",
391
- " \n",
392
- " # Try direct parse\n",
393
- " try:\n",
394
- " return json.loads(text)\n",
395
- " except json.JSONDecodeError:\n",
396
- " pass\n",
397
- " \n",
398
- " # Try extracting from markdown code blocks\n",
399
- " patterns = [\n",
400
- " r\"```json\\s*\\n?(.*?)\\n?```\",\n",
401
- " r\"```\\s*\\n?(.*?)\\n?```\",\n",
402
- " r\"\\{[\\s\\S]*\\}\",\n",
403
- " ]\n",
404
- " \n",
405
- " for pattern in patterns:\n",
406
- " match = re.search(pattern, text, re.DOTALL)\n",
407
- " if match:\n",
408
- " try:\n",
409
- " json_str = match.group(1) if match.lastindex else match.group(0)\n",
410
- " return json.loads(json_str)\n",
411
- " except (json.JSONDecodeError, IndexError):\n",
412
- " continue\n",
413
- " \n",
414
- " return None"
415
  ]
416
  },
417
  {
418
  "cell_type": "markdown",
419
  "metadata": {},
420
  "source": [
421
- "## 6. Agent Prompt Templates"
422
  ]
423
  },
424
  {
@@ -427,53 +152,27 @@
427
  "metadata": {},
428
  "outputs": [],
429
  "source": [
430
- "# Agent 1: Intake Historian\n",
431
- "INTAKE_HISTORIAN_SYSTEM = \"\"\"You are an expert clinical intake specialist. Your role is to:\n",
432
- "\n",
433
- "1. Parse and structure patient demographics and clinical history\n",
434
- "2. Calculate Creatinine Clearance (CrCl) using the Cockcroft-Gault equation when data is available\n",
435
- "3. Identify key risk factors for antimicrobial-resistant infections\n",
436
- "4. Determine the appropriate treatment stage (empirical vs targeted)\n",
437
- "\n",
438
- "RISK FACTORS TO IDENTIFY:\n",
439
- "- Prior MRSA or MDR infection history\n",
440
- "- Recent antibiotic use (within 90 days)\n",
441
- "- Healthcare-associated vs community-acquired infection\n",
442
- "- Immunocompromised status\n",
443
- "- Recent hospitalization or ICU stay\n",
444
- "- Presence of medical devices (catheters, lines)\n",
445
- "- Renal or hepatic impairment\n",
446
  "\n",
447
- "OUTPUT FORMAT:\n",
448
- "Provide a structured JSON response with the following fields:\n",
449
- "{\n",
450
- " \"patient_summary\": \"Brief clinical summary\",\n",
451
- " \"creatinine_clearance_ml_min\": <number or null>,\n",
452
- " \"renal_dose_adjustment_needed\": <boolean>,\n",
453
- " \"identified_risk_factors\": [\"list\", \"of\", \"factors\"],\n",
454
- " \"suspected_pathogens\": [\"list\", \"of\", \"likely\", \"organisms\"],\n",
455
- " \"infection_severity\": \"mild|moderate|severe|critical\",\n",
456
- " \"recommended_stage\": \"empirical|targeted\",\n",
457
- " \"notes\": \"Any additional clinical observations\"\n",
458
- "}\n",
459
- "\"\"\"\n",
460
  "\n",
461
- "INTAKE_HISTORIAN_PROMPT = \"\"\"Analyze the following patient information and provide a structured clinical assessment.\n",
 
 
 
462
  "\n",
463
- "PATIENT DATA:\n",
464
- "{patient_data}\n",
465
  "\n",
466
- "CURRENT MEDICATIONS:\n",
467
- "{medications}\n",
468
- "\n",
469
- "KNOWN ALLERGIES:\n",
470
- "{allergies}\n",
471
- "\n",
472
- "CLINICAL CONTEXT:\n",
473
- "- Suspected infection site: {infection_site}\n",
474
- "- Suspected source: {suspected_source}\n",
475
- "\n",
476
- "Provide your structured assessment following the system instructions.\"\"\""
477
  ]
478
  },
479
  {
@@ -482,797 +181,18 @@
482
  "metadata": {},
483
  "outputs": [],
484
  "source": [
485
- "# Agent 2: Vision Specialist\n",
486
- "VISION_SPECIALIST_SYSTEM = \"\"\"You are an expert medical laboratory data extraction specialist. Your role is to:\n",
487
- "\n",
488
- "1. Extract structured data from laboratory reports (culture & sensitivity, antibiograms)\n",
489
- "2. Handle reports in ANY language - always output in English\n",
490
- "3. Identify pathogens, antibiotics tested, MIC values, and S/I/R interpretations\n",
491
- "4. Flag any critical or unusual findings\n",
492
- "\n",
493
- "OUTPUT FORMAT:\n",
494
- "Provide a structured JSON response:\n",
495
- "{\n",
496
- " \"specimen_type\": \"blood|urine|wound|respiratory|other\",\n",
497
- " \"collection_date\": \"YYYY-MM-DD or null\",\n",
498
- " \"identified_organisms\": [\n",
499
- " {\n",
500
- " \"organism_name\": \"Standardized English name\",\n",
501
- " \"colony_count\": \"if available\",\n",
502
- " \"significance\": \"pathogen|colonizer|contaminant\"\n",
503
- " }\n",
504
- " ],\n",
505
- " \"susceptibility_results\": [\n",
506
- " {\n",
507
- " \"organism\": \"Organism name\",\n",
508
- " \"antibiotic\": \"Standardized antibiotic name\",\n",
509
- " \"mic_value\": <number or null>,\n",
510
- " \"mic_unit\": \"mg/L\",\n",
511
- " \"interpretation\": \"S|I|R\"\n",
512
- " }\n",
513
- " ],\n",
514
- " \"critical_findings\": [\"List of urgent findings\"],\n",
515
- " \"extraction_confidence\": 0.0-1.0\n",
516
- "}\n",
517
- "\"\"\"\n",
518
- "\n",
519
- "VISION_SPECIALIST_PROMPT = \"\"\"Extract structured laboratory data from the following report.\n",
520
- "\n",
521
- "REPORT CONTENT:\n",
522
- "{report_content}\n",
523
- "\n",
524
- "Extract all pathogen identifications, susceptibility results, and MIC values.\n",
525
- "Always standardize to English medical terminology.\n",
526
- "Flag any critical findings that require urgent attention.\n",
527
- "\n",
528
- "Provide your structured extraction following the system instructions.\"\"\""
529
- ]
530
- },
531
- {
532
- "cell_type": "code",
533
- "execution_count": null,
534
- "metadata": {},
535
- "outputs": [],
536
- "source": [
537
- "# Agent 3: Trend Analyst\n",
538
- "TREND_ANALYST_SYSTEM = \"\"\"You are an expert antimicrobial resistance trend analyst. Your role is to:\n",
539
- "\n",
540
- "1. Analyze MIC trends over time to detect \"MIC Creep\"\n",
541
- "2. Calculate resistance velocity and predict treatment failure risk\n",
542
- "3. Compare current MICs against EUCAST/CLSI breakpoints\n",
543
- "4. Identify emerging resistance patterns\n",
544
- "\n",
545
- "RISK STRATIFICATION:\n",
546
- "- LOW: Stable MIC, well below breakpoint (>4x margin)\n",
547
- "- MODERATE: Rising trend but still 2-4x below breakpoint\n",
548
- "- HIGH: Approaching breakpoint (<2x margin) or rapid increase\n",
549
- "- CRITICAL: At or above breakpoint, or >4-fold increase over baseline\n",
550
- "\n",
551
- "OUTPUT FORMAT:\n",
552
- "{\n",
553
- " \"organism\": \"Pathogen name\",\n",
554
- " \"antibiotic\": \"Antibiotic name\",\n",
555
- " \"baseline_mic\": <number>,\n",
556
- " \"current_mic\": <number>,\n",
557
- " \"fold_change\": <number>,\n",
558
- " \"trend\": \"stable|increasing|decreasing\",\n",
559
- " \"breakpoint_susceptible\": <number>,\n",
560
- " \"margin_to_breakpoint\": <number>,\n",
561
- " \"risk_level\": \"LOW|MODERATE|HIGH|CRITICAL\",\n",
562
- " \"recommendation\": \"Continue current therapy|Consider alternatives|Urgent switch needed\",\n",
563
- " \"rationale\": \"Detailed explanation\"\n",
564
- "}\n",
565
- "\"\"\"\n",
566
- "\n",
567
- "TREND_ANALYST_PROMPT = \"\"\"Analyze the MIC trend data and assess resistance risk.\n",
568
- "\n",
569
- "ORGANISM: {organism}\n",
570
- "ANTIBIOTIC: {antibiotic}\n",
571
- "\n",
572
- "HISTORICAL MIC DATA:\n",
573
- "{mic_history}\n",
574
- "\n",
575
- "EUCAST BREAKPOINT (S <=): {breakpoint} mg/L\n",
576
- "\n",
577
- "Analyze the trend, calculate risk level, and provide recommendations.\n",
578
- "Follow the system instructions for output format.\"\"\""
579
- ]
580
- },
581
- {
582
- "cell_type": "code",
583
- "execution_count": null,
584
- "metadata": {},
585
- "outputs": [],
586
- "source": [
587
- "# Agent 4: Clinical Pharmacologist\n",
588
- "CLINICAL_PHARMACOLOGIST_SYSTEM = \"\"\"You are an expert clinical pharmacologist specializing in infectious diseases and antimicrobial stewardship. Your role is to:\n",
589
- "\n",
590
- "1. Synthesize all available clinical data into a final antibiotic recommendation\n",
591
- "2. Apply WHO AWaRe classification principles (ACCESS -> WATCH -> RESERVE)\n",
592
- "3. Perform comprehensive drug safety checks\n",
593
- "4. Adjust dosing for renal function\n",
594
- "\n",
595
- "PRESCRIBING PRINCIPLES:\n",
596
- "1. Start narrow, escalate only when justified\n",
597
- "2. De-escalate when culture results allow\n",
598
- "3. Prefer ACCESS category antibiotics when appropriate\n",
599
- "4. Consider pharmacokinetic/pharmacodynamic (PK/PD) optimization\n",
600
- "\n",
601
- "OUTPUT FORMAT:\n",
602
- "{\n",
603
- " \"primary_recommendation\": {\n",
604
- " \"antibiotic\": \"Drug name\",\n",
605
- " \"dose\": \"Amount and unit\",\n",
606
- " \"route\": \"IV|PO|IM\",\n",
607
- " \"frequency\": \"Dosing interval\",\n",
608
- " \"duration\": \"Treatment duration\",\n",
609
- " \"aware_category\": \"ACCESS|WATCH|RESERVE\"\n",
610
- " },\n",
611
- " \"alternative_recommendation\": {\n",
612
- " \"antibiotic\": \"Alternative drug\",\n",
613
- " \"indication\": \"When to use alternative\"\n",
614
- " },\n",
615
- " \"dose_adjustments\": {\n",
616
- " \"renal\": \"Adjustment details or None needed\"\n",
617
- " },\n",
618
- " \"safety_alerts\": [\n",
619
- " {\n",
620
- " \"level\": \"INFO|WARNING|CRITICAL\",\n",
621
- " \"message\": \"Alert message\"\n",
622
- " }\n",
623
- " ],\n",
624
- " \"monitoring_parameters\": [\"Labs/vitals to monitor\"],\n",
625
- " \"rationale\": \"Clinical reasoning\",\n",
626
- " \"guideline_references\": [\"Supporting guidelines\"]\n",
627
- "}\n",
628
- "\"\"\"\n",
629
- "\n",
630
- "CLINICAL_PHARMACOLOGIST_PROMPT = \"\"\"Synthesize all clinical data and provide a final antibiotic recommendation.\n",
631
- "\n",
632
- "PATIENT SUMMARY:\n",
633
- "{intake_summary}\n",
634
- "\n",
635
- "LAB RESULTS:\n",
636
- "{lab_results}\n",
637
- "\n",
638
- "MIC TREND ANALYSIS:\n",
639
- "{trend_analysis}\n",
640
- "\n",
641
- "PATIENT PARAMETERS:\n",
642
- "- Age: {age} years\n",
643
- "- Weight: {weight} kg\n",
644
- "- CrCl: {crcl} mL/min\n",
645
- "- Allergies: {allergies}\n",
646
- "- Current medications: {current_medications}\n",
647
- "\n",
648
- "INFECTION CONTEXT:\n",
649
- "- Site: {infection_site}\n",
650
- "- Severity: {severity}\n",
651
- "\n",
652
- "Provide your final recommendation following the system instructions.\"\"\""
653
- ]
654
- },
655
- {
656
- "cell_type": "markdown",
657
- "metadata": {},
658
- "source": [
659
- "## 7. Agent Implementation"
660
- ]
661
- },
662
- {
663
- "cell_type": "code",
664
- "execution_count": null,
665
- "metadata": {},
666
- "outputs": [],
667
- "source": [
668
- "# State type definition\n",
669
- "from typing import TypedDict, NotRequired\n",
670
- "\n",
671
- "class InfectionState(TypedDict, total=False):\n",
672
- " \"\"\"Global state for the Med-I-C pipeline.\"\"\"\n",
673
- " # Patient data\n",
674
- " patient_id: Optional[str]\n",
675
- " age_years: Optional[float]\n",
676
- " sex: Optional[Literal[\"male\", \"female\"]]\n",
677
- " weight_kg: Optional[float]\n",
678
- " height_cm: Optional[float]\n",
679
- " \n",
680
- " # Clinical context\n",
681
- " suspected_source: Optional[str]\n",
682
- " comorbidities: List[str]\n",
683
- " medications: List[str]\n",
684
- " allergies: List[str]\n",
685
- " infection_site: Optional[str]\n",
686
- " \n",
687
- " # Renal function\n",
688
- " serum_creatinine_mg_dl: Optional[float]\n",
689
- " creatinine_clearance_ml_min: Optional[float]\n",
690
- " \n",
691
- " # Lab data\n",
692
- " labs_raw_text: Optional[str]\n",
693
- " mic_data: List[Dict[str, Any]]\n",
694
- " \n",
695
- " # Stage routing\n",
696
- " stage: Literal[\"empirical\", \"targeted\"]\n",
697
- " route_to_vision: bool\n",
698
- " route_to_trend_analyst: bool\n",
699
- " \n",
700
- " # Agent outputs\n",
701
- " intake_notes: Optional[str]\n",
702
- " vision_notes: Optional[str]\n",
703
- " trend_notes: Optional[str]\n",
704
- " pharmacology_notes: Optional[str]\n",
705
- " recommendation: Optional[Dict[str, Any]]\n",
706
- " \n",
707
- " # Safety\n",
708
- " safety_warnings: List[str]\n",
709
- " errors: List[str]"
710
- ]
711
- },
712
- {
713
- "cell_type": "code",
714
- "execution_count": null,
715
- "metadata": {},
716
- "outputs": [],
717
- "source": [
718
- "def run_intake_historian(state: InfectionState) -> InfectionState:\n",
719
- " \"\"\"\n",
720
- " Agent 1: Parse patient data, calculate CrCl, identify risk factors.\n",
721
- " \"\"\"\n",
722
- " print(\"\\n\" + \"=\"*60)\n",
723
- " print(\"AGENT 1: INTAKE HISTORIAN\")\n",
724
- " print(\"=\"*60)\n",
725
- " \n",
726
- " # Calculate CrCl if we have required data\n",
727
- " crcl = None\n",
728
- " if all([state.get(\"age_years\"), state.get(\"weight_kg\"), \n",
729
- " state.get(\"serum_creatinine_mg_dl\"), state.get(\"sex\")]):\n",
730
- " crcl = calculate_crcl(\n",
731
- " age_years=state[\"age_years\"],\n",
732
- " weight_kg=state[\"weight_kg\"],\n",
733
- " serum_creatinine_mg_dl=state[\"serum_creatinine_mg_dl\"],\n",
734
- " sex=state[\"sex\"],\n",
735
- " )\n",
736
- " state[\"creatinine_clearance_ml_min\"] = crcl\n",
737
- " print(f\"Calculated CrCl: {crcl} mL/min ({get_renal_dose_category(crcl)})\")\n",
738
- " \n",
739
- " # Build patient data string\n",
740
- " patient_data = f\"\"\"\n",
741
- "Age: {state.get('age_years', 'Unknown')} years\n",
742
- "Sex: {state.get('sex', 'Unknown')}\n",
743
- "Weight: {state.get('weight_kg', 'Unknown')} kg\n",
744
- "Serum Creatinine: {state.get('serum_creatinine_mg_dl', 'Unknown')} mg/dL\n",
745
- "CrCl: {crcl or 'Not calculated'} mL/min\n",
746
- "Comorbidities: {', '.join(state.get('comorbidities', [])) or 'None reported'}\n",
747
- "\"\"\"\n",
748
- " \n",
749
- " # Build prompt\n",
750
- " prompt = f\"{INTAKE_HISTORIAN_SYSTEM}\\n\\n{INTAKE_HISTORIAN_PROMPT.format(\n",
751
- " patient_data=patient_data,\n",
752
- " medications=', '.join(state.get('medications', [])) or 'None reported',\n",
753
- " allergies=', '.join(state.get('allergies', [])) or 'No known allergies',\n",
754
- " infection_site=state.get('infection_site', 'Unknown'),\n",
755
- " suspected_source=state.get('suspected_source', 'Unknown'),\n",
756
- " )}\"\"\"\n",
757
- " \n",
758
- " # Run inference\n",
759
- " print(\"Running MedGemma inference...\")\n",
760
- " response = run_inference(prompt, model_name=\"medgemma_4b\", max_new_tokens=1024)\n",
761
- " \n",
762
- " # Parse response\n",
763
- " parsed = safe_json_parse(response)\n",
764
- " if parsed:\n",
765
- " state[\"intake_notes\"] = json.dumps(parsed, indent=2)\n",
766
- " state[\"stage\"] = parsed.get(\"recommended_stage\", \"empirical\")\n",
767
- " print(f\"\\nIntake Assessment:\")\n",
768
- " print(json.dumps(parsed, indent=2))\n",
769
- " else:\n",
770
- " state[\"intake_notes\"] = response\n",
771
- " state[\"stage\"] = \"empirical\"\n",
772
- " print(f\"\\nRaw response: {response[:500]}...\")\n",
773
- " \n",
774
- " # Determine routing\n",
775
- " state[\"route_to_vision\"] = bool(state.get(\"labs_raw_text\"))\n",
776
- " print(f\"\\nStage: {state['stage']}\")\n",
777
- " print(f\"Route to Vision Specialist: {state['route_to_vision']}\")\n",
778
- " \n",
779
- " return state"
780
- ]
781
- },
782
- {
783
- "cell_type": "code",
784
- "execution_count": null,
785
- "metadata": {},
786
- "outputs": [],
787
- "source": [
788
- "def run_vision_specialist(state: InfectionState) -> InfectionState:\n",
789
- " \"\"\"\n",
790
- " Agent 2: Extract structured data from lab reports.\n",
791
- " \"\"\"\n",
792
- " print(\"\\n\" + \"=\"*60)\n",
793
- " print(\"AGENT 2: VISION SPECIALIST\")\n",
794
- " print(\"=\"*60)\n",
795
- " \n",
796
- " labs_raw = state.get(\"labs_raw_text\", \"\")\n",
797
- " if not labs_raw:\n",
798
- " print(\"No lab data to process, skipping.\")\n",
799
- " state[\"vision_notes\"] = \"No lab data provided\"\n",
800
- " state[\"route_to_trend_analyst\"] = False\n",
801
- " return state\n",
802
- " \n",
803
- " # Build prompt\n",
804
- " prompt = f\"{VISION_SPECIALIST_SYSTEM}\\n\\n{VISION_SPECIALIST_PROMPT.format(\n",
805
- " report_content=labs_raw,\n",
806
- " )}\"\n",
807
- " \n",
808
- " # Run inference\n",
809
- " print(\"Running MedGemma inference on lab report...\")\n",
810
- " response = run_inference(prompt, model_name=\"medgemma_4b\", max_new_tokens=2048)\n",
811
- " \n",
812
- " # Parse response\n",
813
- " parsed = safe_json_parse(response)\n",
814
- " if parsed:\n",
815
- " state[\"vision_notes\"] = json.dumps(parsed, indent=2)\n",
816
- " \n",
817
- " # Extract MIC data\n",
818
- " susceptibility = parsed.get(\"susceptibility_results\", [])\n",
819
- " state[\"mic_data\"] = susceptibility\n",
820
- " state[\"route_to_trend_analyst\"] = len(susceptibility) > 0\n",
821
- " \n",
822
- " print(f\"\\nExtracted Lab Data:\")\n",
823
- " print(json.dumps(parsed, indent=2))\n",
824
- " \n",
825
- " # Check for critical findings\n",
826
- " critical = parsed.get(\"critical_findings\", [])\n",
827
- " if critical:\n",
828
- " print(f\"\\n⚠️ CRITICAL FINDINGS: {critical}\")\n",
829
- " state.setdefault(\"safety_warnings\", []).extend(critical)\n",
830
- " else:\n",
831
- " state[\"vision_notes\"] = response\n",
832
- " state[\"route_to_trend_analyst\"] = False\n",
833
- " print(f\"\\nRaw response: {response[:500]}...\")\n",
834
- " \n",
835
- " print(f\"\\nMIC data points: {len(state.get('mic_data', []))}\")\n",
836
- " print(f\"Route to Trend Analyst: {state.get('route_to_trend_analyst', False)}\")\n",
837
- " \n",
838
- " return state"
839
- ]
840
- },
841
- {
842
- "cell_type": "code",
843
- "execution_count": null,
844
- "metadata": {},
845
- "outputs": [],
846
- "source": [
847
- "def run_trend_analyst(state: InfectionState) -> InfectionState:\n",
848
- " \"\"\"\n",
849
- " Agent 3: Analyze MIC trends and detect resistance velocity.\n",
850
- " \"\"\"\n",
851
- " print(\"\\n\" + \"=\"*60)\n",
852
- " print(\"AGENT 3: TREND ANALYST\")\n",
853
- " print(\"=\"*60)\n",
854
- " \n",
855
- " mic_data = state.get(\"mic_data\", [])\n",
856
- " if not mic_data:\n",
857
- " print(\"No MIC data to analyze, skipping.\")\n",
858
- " state[\"trend_notes\"] = \"No MIC data available\"\n",
859
- " return state\n",
860
- " \n",
861
- " trend_results = []\n",
862
- " \n",
863
- " for mic in mic_data:\n",
864
- " organism = mic.get(\"organism\", \"Unknown\")\n",
865
- " antibiotic = mic.get(\"antibiotic\", \"Unknown\")\n",
866
- " mic_value = mic.get(\"mic_value\")\n",
867
- " \n",
868
- " if mic_value is None:\n",
869
- " continue\n",
870
- " \n",
871
- " # For demo, create synthetic history showing increasing trend\n",
872
- " mic_history = json.dumps([\n",
873
- " {\"date\": \"2024-01-01\", \"mic_value\": float(mic_value) / 2},\n",
874
- " {\"date\": \"2024-06-01\", \"mic_value\": float(mic_value) / 1.5},\n",
875
- " {\"date\": \"2025-01-01\", \"mic_value\": float(mic_value)},\n",
876
- " ], indent=2)\n",
877
- " \n",
878
- " # Use standard EUCAST breakpoint (simplified)\n",
879
- " breakpoint = 2.0 # Default S <= 2 mg/L\n",
880
- " \n",
881
- " prompt = f\"{TREND_ANALYST_SYSTEM}\\n\\n{TREND_ANALYST_PROMPT.format(\n",
882
- " organism=organism,\n",
883
- " antibiotic=antibiotic,\n",
884
- " mic_history=mic_history,\n",
885
- " breakpoint=breakpoint,\n",
886
- " )}\"\n",
887
- " \n",
888
- " print(f\"\\nAnalyzing: {organism} / {antibiotic}...\")\n",
889
- " response = run_inference(prompt, model_name=\"medgemma_4b\", max_new_tokens=1024)\n",
890
- " \n",
891
- " parsed = safe_json_parse(response)\n",
892
- " if parsed:\n",
893
- " trend_results.append(parsed)\n",
894
- " risk_level = parsed.get(\"risk_level\", \"UNKNOWN\")\n",
895
- " print(f\" Risk Level: {risk_level}\")\n",
896
- " \n",
897
- " if risk_level in [\"HIGH\", \"CRITICAL\"]:\n",
898
- " warning = f\"MIC trend alert for {organism}/{antibiotic}: {parsed.get('recommendation', 'Review needed')}\"\n",
899
- " state.setdefault(\"safety_warnings\", []).append(warning)\n",
900
- " else:\n",
901
- " trend_results.append({\"organism\": organism, \"antibiotic\": antibiotic, \"raw\": response[:200]})\n",
902
- " \n",
903
- " state[\"trend_notes\"] = json.dumps(trend_results, indent=2)\n",
904
- " print(f\"\\nTrend Analysis Complete:\")\n",
905
- " print(json.dumps(trend_results, indent=2))\n",
906
- " \n",
907
- " return state"
908
- ]
909
- },
910
- {
911
- "cell_type": "code",
912
- "execution_count": null,
913
- "metadata": {},
914
- "outputs": [],
915
- "source": [
916
- "def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:\n",
917
- " \"\"\"\n",
918
- " Agent 4: Generate final antibiotic recommendation with safety checks.\n",
919
- " \"\"\"\n",
920
- " print(\"\\n\" + \"=\"*60)\n",
921
- " print(\"AGENT 4: CLINICAL PHARMACOLOGIST\")\n",
922
- " print(\"=\"*60)\n",
923
- " \n",
924
- " # Gather previous agent outputs\n",
925
- " intake_summary = state.get(\"intake_notes\", \"No intake data\")\n",
926
- " lab_results = state.get(\"vision_notes\", \"No lab data\")\n",
927
- " trend_analysis = state.get(\"trend_notes\", \"No trend data\")\n",
928
- " \n",
929
- " prompt = f\"{CLINICAL_PHARMACOLOGIST_SYSTEM}\\n\\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(\n",
930
- " intake_summary=intake_summary,\n",
931
- " lab_results=lab_results,\n",
932
- " trend_analysis=trend_analysis,\n",
933
- " age=state.get('age_years', 'Unknown'),\n",
934
- " weight=state.get('weight_kg', 'Unknown'),\n",
935
- " crcl=state.get('creatinine_clearance_ml_min', 'Unknown'),\n",
936
- " allergies=', '.join(state.get('allergies', [])) or 'No known allergies',\n",
937
- " current_medications=', '.join(state.get('medications', [])) or 'None',\n",
938
- " infection_site=state.get('infection_site', 'Unknown'),\n",
939
- " severity='moderate',\n",
940
- " )}\"\n",
941
- " \n",
942
- " print(\"Running MedGemma inference for final recommendation...\")\n",
943
- " response = run_inference(prompt, model_name=\"medgemma_4b\", max_new_tokens=2048)\n",
944
- " \n",
945
- " parsed = safe_json_parse(response)\n",
946
- " if parsed:\n",
947
- " state[\"pharmacology_notes\"] = json.dumps(parsed, indent=2)\n",
948
- " \n",
949
- " # Build recommendation\n",
950
- " primary = parsed.get(\"primary_recommendation\", {})\n",
951
- " recommendation = {\n",
952
- " \"primary_antibiotic\": primary.get(\"antibiotic\"),\n",
953
- " \"dose\": primary.get(\"dose\"),\n",
954
- " \"route\": primary.get(\"route\"),\n",
955
- " \"frequency\": primary.get(\"frequency\"),\n",
956
- " \"duration\": primary.get(\"duration\"),\n",
957
- " \"rationale\": parsed.get(\"rationale\"),\n",
958
- " \"references\": parsed.get(\"guideline_references\", []),\n",
959
- " \"safety_alerts\": [a.get(\"message\") for a in parsed.get(\"safety_alerts\", [])],\n",
960
- " }\n",
961
- " \n",
962
- " alt = parsed.get(\"alternative_recommendation\", {})\n",
963
- " if alt.get(\"antibiotic\"):\n",
964
- " recommendation[\"backup_antibiotic\"] = alt.get(\"antibiotic\")\n",
965
- " \n",
966
- " state[\"recommendation\"] = recommendation\n",
967
- " \n",
968
- " print(f\"\\n\" + \"=\"*60)\n",
969
- " print(\"FINAL RECOMMENDATION\")\n",
970
- " print(\"=\"*60)\n",
971
- " print(json.dumps(recommendation, indent=2))\n",
972
- " \n",
973
- " # Add safety alerts\n",
974
- " for alert in parsed.get(\"safety_alerts\", []):\n",
975
- " if alert.get(\"level\") in [\"WARNING\", \"CRITICAL\"]:\n",
976
- " state.setdefault(\"safety_warnings\", []).append(alert.get(\"message\"))\n",
977
- " else:\n",
978
- " state[\"pharmacology_notes\"] = response\n",
979
- " state[\"recommendation\"] = {\"rationale\": response}\n",
980
- " print(f\"\\nRaw response: {response[:500]}...\")\n",
981
- " \n",
982
- " return state"
983
- ]
984
- },
985
- {
986
- "cell_type": "markdown",
987
- "metadata": {},
988
- "source": [
989
- "## 8. Pipeline Orchestration with LangGraph"
990
- ]
991
- },
992
- {
993
- "cell_type": "code",
994
- "execution_count": null,
995
- "metadata": {},
996
- "outputs": [],
997
- "source": [
998
- "from langgraph.graph import StateGraph, END\n",
999
- "\n",
1000
- "def build_infection_graph():\n",
1001
- " \"\"\"\n",
1002
- " Build the LangGraph StateGraph for the infection lifecycle workflow.\n",
1003
- " \n",
1004
- " Stage 1 (Empirical): Intake Historian -> Clinical Pharmacologist\n",
1005
- " Stage 2 (Targeted): Intake Historian -> Vision Specialist -> Trend Analyst -> Clinical Pharmacologist\n",
1006
- " \"\"\"\n",
1007
- " graph = StateGraph(InfectionState)\n",
1008
- " \n",
1009
- " # Add nodes\n",
1010
- " graph.add_node(\"intake_historian\", run_intake_historian)\n",
1011
- " graph.add_node(\"vision_specialist\", run_vision_specialist)\n",
1012
- " graph.add_node(\"trend_analyst\", run_trend_analyst)\n",
1013
- " graph.add_node(\"clinical_pharmacologist\", run_clinical_pharmacologist)\n",
1014
- " \n",
1015
- " # Set entry point\n",
1016
- " graph.set_entry_point(\"intake_historian\")\n",
1017
- " \n",
1018
- " # Conditional routing after intake\n",
1019
- " def route_after_intake(state: InfectionState):\n",
1020
- " if state.get(\"stage\") == \"targeted\" and state.get(\"route_to_vision\"):\n",
1021
- " return \"vision_specialist\"\n",
1022
- " return \"clinical_pharmacologist\"\n",
1023
- " \n",
1024
- " graph.add_conditional_edges(\n",
1025
- " \"intake_historian\",\n",
1026
- " route_after_intake,\n",
1027
- " {\n",
1028
- " \"vision_specialist\": \"vision_specialist\",\n",
1029
- " \"clinical_pharmacologist\": \"clinical_pharmacologist\",\n",
1030
- " }\n",
1031
- " )\n",
1032
- " \n",
1033
- " # Conditional routing after vision\n",
1034
- " def route_after_vision(state: InfectionState):\n",
1035
- " if state.get(\"route_to_trend_analyst\"):\n",
1036
- " return \"trend_analyst\"\n",
1037
- " return \"clinical_pharmacologist\"\n",
1038
- " \n",
1039
- " graph.add_conditional_edges(\n",
1040
- " \"vision_specialist\",\n",
1041
- " route_after_vision,\n",
1042
- " {\n",
1043
- " \"trend_analyst\": \"trend_analyst\",\n",
1044
- " \"clinical_pharmacologist\": \"clinical_pharmacologist\",\n",
1045
- " }\n",
1046
- " )\n",
1047
- " \n",
1048
- " # Edges to final node\n",
1049
- " graph.add_edge(\"trend_analyst\", \"clinical_pharmacologist\")\n",
1050
- " graph.add_edge(\"clinical_pharmacologist\", END)\n",
1051
- " \n",
1052
- " return graph\n",
1053
- "\n",
1054
- "\n",
1055
- "def run_pipeline(patient_data: dict, labs_raw_text: Optional[str] = None) -> InfectionState:\n",
1056
- " \"\"\"\n",
1057
- " Run the full infection lifecycle pipeline.\n",
1058
- " \n",
1059
- " Args:\n",
1060
- " patient_data: Patient information dict\n",
1061
- " labs_raw_text: Optional lab report text (triggers Stage 2)\n",
1062
- " \n",
1063
- " Returns:\n",
1064
- " Final InfectionState with recommendation\n",
1065
- " \"\"\"\n",
1066
- " # Build initial state\n",
1067
- " initial_state: InfectionState = {\n",
1068
- " \"age_years\": patient_data.get(\"age_years\"),\n",
1069
- " \"weight_kg\": patient_data.get(\"weight_kg\"),\n",
1070
- " \"height_cm\": patient_data.get(\"height_cm\"),\n",
1071
- " \"sex\": patient_data.get(\"sex\"),\n",
1072
- " \"serum_creatinine_mg_dl\": patient_data.get(\"serum_creatinine_mg_dl\"),\n",
1073
- " \"medications\": patient_data.get(\"medications\", []),\n",
1074
- " \"allergies\": patient_data.get(\"allergies\", []),\n",
1075
- " \"comorbidities\": patient_data.get(\"comorbidities\", []),\n",
1076
- " \"infection_site\": patient_data.get(\"infection_site\"),\n",
1077
- " \"suspected_source\": patient_data.get(\"suspected_source\"),\n",
1078
- " \"safety_warnings\": [],\n",
1079
- " \"errors\": [],\n",
1080
- " }\n",
1081
- " \n",
1082
- " # Add lab data if provided\n",
1083
- " if labs_raw_text:\n",
1084
- " initial_state[\"labs_raw_text\"] = labs_raw_text\n",
1085
- " initial_state[\"stage\"] = \"targeted\"\n",
1086
- " else:\n",
1087
- " initial_state[\"stage\"] = \"empirical\"\n",
1088
- " \n",
1089
- " # Build and run graph\n",
1090
- " print(\"\\n\" + \"#\"*70)\n",
1091
- " print(f\"# STARTING MED-I-C PIPELINE (Stage: {initial_state['stage'].upper()})\")\n",
1092
- " print(\"#\"*70)\n",
1093
- " \n",
1094
- " graph = build_infection_graph()\n",
1095
- " compiled = graph.compile()\n",
1096
- " final_state = compiled.invoke(initial_state)\n",
1097
- " \n",
1098
- " print(\"\\n\" + \"#\"*70)\n",
1099
- " print(\"# PIPELINE COMPLETE\")\n",
1100
- " print(\"#\"*70)\n",
1101
- " \n",
1102
- " return final_state"
1103
- ]
1104
- },
1105
- {
1106
- "cell_type": "markdown",
1107
- "metadata": {},
1108
- "source": [
1109
- "## 9. Test Cases"
1110
- ]
1111
- },
1112
- {
1113
- "cell_type": "markdown",
1114
- "metadata": {},
1115
- "source": [
1116
- "### Test Case 1: Stage 1 (Empirical) - Community UTI"
1117
- ]
1118
- },
1119
- {
1120
- "cell_type": "code",
1121
- "execution_count": null,
1122
- "metadata": {},
1123
- "outputs": [],
1124
- "source": [
1125
- "# Test Case 1: Stage 1 Empirical - Community UTI\n",
1126
- "patient_data_uti = {\n",
1127
- " \"age_years\": 65,\n",
1128
- " \"weight_kg\": 70,\n",
1129
- " \"height_cm\": 170,\n",
1130
- " \"sex\": \"male\",\n",
1131
- " \"serum_creatinine_mg_dl\": 1.2,\n",
1132
- " \"medications\": [\"metformin\", \"lisinopril\", \"aspirin\"],\n",
1133
- " \"allergies\": [\"penicillin\"],\n",
1134
- " \"comorbidities\": [\"diabetes\", \"hypertension\"],\n",
1135
- " \"infection_site\": \"urinary\",\n",
1136
- " \"suspected_source\": \"community-acquired UTI\",\n",
1137
- "}\n",
1138
- "\n",
1139
- "result_uti = run_pipeline(patient_data_uti)"
1140
- ]
1141
- },
1142
- {
1143
- "cell_type": "code",
1144
- "execution_count": null,
1145
- "metadata": {},
1146
- "outputs": [],
1147
- "source": [
1148
- "# Display results\n",
1149
- "print(\"\\n\" + \"=\"*70)\n",
1150
- "print(\"TEST CASE 1: COMMUNITY UTI (EMPIRICAL)\")\n",
1151
- "print(\"=\"*70)\n",
1152
- "\n",
1153
- "print(f\"\\nCrCl: {result_uti.get('creatinine_clearance_ml_min')} mL/min\")\n",
1154
- "print(f\"Stage: {result_uti.get('stage')}\")\n",
1155
- "\n",
1156
- "rec = result_uti.get('recommendation', {})\n",
1157
- "if rec:\n",
1158
- " print(f\"\\nRecommendation:\")\n",
1159
- " print(f\" Drug: {rec.get('primary_antibiotic')}\")\n",
1160
- " print(f\" Dose: {rec.get('dose')}\")\n",
1161
- " print(f\" Route: {rec.get('route')}\")\n",
1162
- " print(f\" Frequency: {rec.get('frequency')}\")\n",
1163
- " print(f\" Duration: {rec.get('duration')}\")\n",
1164
- " print(f\" Rationale: {rec.get('rationale')}\")\n",
1165
- "\n",
1166
- "warnings = result_uti.get('safety_warnings', [])\n",
1167
- "if warnings:\n",
1168
- " print(f\"\\nSafety Warnings:\")\n",
1169
- " for w in warnings:\n",
1170
- " print(f\" ⚠️ {w}\")"
1171
- ]
1172
- },
1173
- {
1174
- "cell_type": "markdown",
1175
- "metadata": {},
1176
- "source": [
1177
- "### Test Case 2: Stage 2 (Targeted) - With Lab Results"
1178
- ]
1179
- },
1180
- {
1181
- "cell_type": "code",
1182
- "execution_count": null,
1183
- "metadata": {},
1184
- "outputs": [],
1185
- "source": [
1186
- "# Test Case 2: Stage 2 Targeted - UTI with Lab Results\n",
1187
- "patient_data_targeted = {\n",
1188
- " \"age_years\": 72,\n",
1189
- " \"weight_kg\": 65,\n",
1190
- " \"height_cm\": 165,\n",
1191
- " \"sex\": \"female\",\n",
1192
- " \"serum_creatinine_mg_dl\": 1.5,\n",
1193
- " \"medications\": [\"warfarin\", \"amlodipine\"],\n",
1194
- " \"allergies\": [],\n",
1195
- " \"comorbidities\": [\"atrial fibrillation\", \"hypertension\", \"CKD stage 3\"],\n",
1196
- " \"infection_site\": \"urinary\",\n",
1197
- " \"suspected_source\": \"complicated UTI with pyelonephritis\",\n",
1198
- "}\n",
1199
- "\n",
1200
- "lab_report = \"\"\"\n",
1201
- "URINE CULTURE REPORT\n",
1202
- "Patient ID: 12345\n",
1203
- "Collection Date: 2025-02-15\n",
1204
- "\n",
1205
- "Specimen: Midstream urine\n",
1206
- "Colony Count: >100,000 CFU/mL\n",
1207
- "\n",
1208
- "ORGANISM ISOLATED:\n",
1209
- "Escherichia coli\n",
1210
- "\n",
1211
- "ANTIMICROBIAL SUSCEPTIBILITY:\n",
1212
- "-----------------------------------\n",
1213
- "Antibiotic MIC (mg/L) Interpretation\n",
1214
- "-----------------------------------\n",
1215
- "Ampicillin >32 R\n",
1216
- "Amoxicillin-Clav 16 I\n",
1217
- "Ceftriaxone 0.25 S\n",
1218
- "Cefepime 0.5 S\n",
1219
- "Ciprofloxacin 0.5 S\n",
1220
- "Levofloxacin 1 S\n",
1221
- "Nitrofurantoin 32 S\n",
1222
- "TMP-SMX >4 R\n",
1223
- "Gentamicin 2 S\n",
1224
- "Meropenem 0.06 S\n",
1225
- "\n",
1226
- "NOTES:\n",
1227
- "- ESBL screening negative\n",
1228
- "- No carbapenemase detected\n",
1229
- "\"\"\"\n",
1230
- "\n",
1231
- "result_targeted = run_pipeline(patient_data_targeted, labs_raw_text=lab_report)"
1232
- ]
1233
- },
1234
- {
1235
- "cell_type": "code",
1236
- "execution_count": null,
1237
- "metadata": {},
1238
- "outputs": [],
1239
- "source": [
1240
- "# Display results\n",
1241
- "print(\"\\n\" + \"=\"*70)\n",
1242
- "print(\"TEST CASE 2: COMPLICATED UTI WITH LAB RESULTS (TARGETED)\")\n",
1243
- "print(\"=\"*70)\n",
1244
- "\n",
1245
- "print(f\"\\nCrCl: {result_targeted.get('creatinine_clearance_ml_min')} mL/min\")\n",
1246
- "print(f\"Stage: {result_targeted.get('stage')}\")\n",
1247
- "\n",
1248
- "print(f\"\\nExtracted MIC Data:\")\n",
1249
- "for mic in result_targeted.get('mic_data', []):\n",
1250
- " print(f\" - {mic.get('organism')} / {mic.get('antibiotic')}: MIC {mic.get('mic_value')} ({mic.get('interpretation')})\")\n",
1251
- "\n",
1252
- "rec = result_targeted.get('recommendation', {})\n",
1253
- "if rec:\n",
1254
- " print(f\"\\nFinal Recommendation:\")\n",
1255
- " print(f\" Primary: {rec.get('primary_antibiotic')}\")\n",
1256
- " print(f\" Dose: {rec.get('dose')}\")\n",
1257
- " print(f\" Route: {rec.get('route')}\")\n",
1258
- " print(f\" Frequency: {rec.get('frequency')}\")\n",
1259
- " print(f\" Duration: {rec.get('duration')}\")\n",
1260
- " if rec.get('backup_antibiotic'):\n",
1261
- " print(f\" Alternative: {rec.get('backup_antibiotic')}\")\n",
1262
- " print(f\" Rationale: {rec.get('rationale')}\")\n",
1263
  "\n",
1264
- "warnings = result_targeted.get('safety_warnings', [])\n",
1265
- "if warnings:\n",
1266
- " print(f\"\\n⚠️ Safety Warnings:\")\n",
1267
- " for w in warnings:\n",
1268
- " print(f\" - {w}\")"
1269
  ]
1270
  },
1271
  {
1272
  "cell_type": "markdown",
1273
  "metadata": {},
1274
  "source": [
1275
- "### Test Case 3: ESBL-producing Organism (High-Risk)"
1276
  ]
1277
  },
1278
  {
@@ -1281,55 +201,8 @@
1281
  "metadata": {},
1282
  "outputs": [],
1283
  "source": [
1284
- "# Test Case 3: ESBL E. coli\n",
1285
- "patient_esbl = {\n",
1286
- " \"age_years\": 58,\n",
1287
- " \"weight_kg\": 85,\n",
1288
- " \"height_cm\": 175,\n",
1289
- " \"sex\": \"male\",\n",
1290
- " \"serum_creatinine_mg_dl\": 1.0,\n",
1291
- " \"medications\": [\"metformin\", \"atorvastatin\"],\n",
1292
- " \"allergies\": [],\n",
1293
- " \"comorbidities\": [\"diabetes\", \"recent hospitalization\"],\n",
1294
- " \"infection_site\": \"bloodstream\",\n",
1295
- " \"suspected_source\": \"healthcare-associated bacteremia\",\n",
1296
- "}\n",
1297
- "\n",
1298
- "lab_esbl = \"\"\"\n",
1299
- "BLOOD CULTURE REPORT\n",
1300
- "Collection Date: 2025-02-18\n",
1301
- "\n",
1302
- "POSITIVE: Gram-negative bacilli\n",
1303
- "\n",
1304
- "FINAL IDENTIFICATION:\n",
1305
- "Escherichia coli (ESBL-producing)\n",
1306
- "\n",
1307
- "ANTIMICROBIAL SUSCEPTIBILITY:\n",
1308
- "-----------------------------------\n",
1309
- "Antibiotic MIC (mg/L) Interpretation\n",
1310
- "-----------------------------------\n",
1311
- "Ampicillin >32 R\n",
1312
- "Ampicillin-Sulbact >32 R\n",
1313
- "Ceftriaxone >32 R\n",
1314
- "Cefepime >32 R\n",
1315
- "Ceftazidime >32 R\n",
1316
- "Ciprofloxacin >4 R\n",
1317
- "Levofloxacin >8 R\n",
1318
- "TMP-SMX >4 R\n",
1319
- "Gentamicin 8 I\n",
1320
- "Amikacin 4 S\n",
1321
- "Ertapenem 0.25 S\n",
1322
- "Meropenem 0.06 S\n",
1323
- "Imipenem 0.5 S\n",
1324
- "Tigecycline 0.5 S\n",
1325
- "\n",
1326
- "ESBL CONFIRMATION: POSITIVE\n",
1327
- "Carbapenemase: NOT DETECTED\n",
1328
- "\n",
1329
- "CRITICAL ALERT: ESBL-producing organism in bloodstream\n",
1330
- "\"\"\"\n",
1331
- "\n",
1332
- "result_esbl = run_pipeline(patient_esbl, labs_raw_text=lab_esbl)"
1333
  ]
1334
  },
1335
  {
@@ -1338,37 +211,28 @@
1338
  "metadata": {},
1339
  "outputs": [],
1340
  "source": [
1341
- "# Display ESBL results\n",
1342
- "print(\"\\n\" + \"=\"*70)\n",
1343
- "print(\"TEST CASE 3: ESBL E. coli BACTEREMIA (HIGH-RISK)\")\n",
1344
- "print(\"=\"*70)\n",
1345
  "\n",
1346
- "rec = result_esbl.get('recommendation', {})\n",
1347
- "if rec:\n",
1348
- " print(f\"\\nRecommendation:\")\n",
1349
- " print(f\" Primary: {rec.get('primary_antibiotic')}\")\n",
1350
- " print(f\" Dose: {rec.get('dose')}\")\n",
1351
- " print(f\" Route: {rec.get('route')}\")\n",
1352
- " print(f\" Rationale: {rec.get('rationale')}\")\n",
1353
- "\n",
1354
- "warnings = result_esbl.get('safety_warnings', [])\n",
1355
- "if warnings:\n",
1356
- " print(f\"\\n🚨 SAFETY ALERTS:\")\n",
1357
- " for w in warnings:\n",
1358
- " print(f\" - {w}\")"
1359
- ]
1360
- },
1361
- {
1362
- "cell_type": "markdown",
1363
- "metadata": {},
1364
- "source": [
1365
- "## 10. Streamlit App (Optional - for local testing)\n",
1366
- "\n",
1367
- "Note: Streamlit doesn't run directly in Kaggle notebooks. To test the Streamlit app:\n",
1368
- "1. Download this notebook and the source files\n",
1369
- "2. Run locally with `streamlit run app.py`\n",
1370
  "\n",
1371
- "Alternatively, you can use `ngrok` or `localtunnel` to expose the Streamlit app."
 
 
 
 
 
 
 
 
 
1372
  ]
1373
  },
1374
  {
@@ -1377,103 +241,21 @@
1377
  "metadata": {},
1378
  "outputs": [],
1379
  "source": [
1380
- "# Optional: Install and run Streamlit with localtunnel\n",
1381
- "# Uncomment to use\n",
1382
- "\n",
1383
- "# !pip install -q streamlit localtunnel\n",
1384
- "\n",
1385
- "# # Write the app file\n",
1386
- "# app_code = '''\n",
1387
- "# import streamlit as st\n",
1388
- "# import json\n",
1389
- "\n",
1390
- "# st.set_page_config(page_title=\"Med-I-C: AMR-Guard\", page_icon=\"🦠\", layout=\"wide\")\n",
1391
- "\n",
1392
- "# st.title(\"🦠 Med-I-C: AMR-Guard\")\n",
1393
- "# st.subheader(\"Infection Lifecycle Orchestrator - Multi-Agent System\")\n",
1394
- "\n",
1395
- "# st.markdown(\"\"\"\n",
1396
- "# This demo showcases the Med-I-C multi-agent system powered by MedGemma.\n",
1397
- "\n",
1398
- "# **Note:** Running in demo mode. For full functionality, deploy with GPU support.\n",
1399
- "# \"\"\")\n",
1400
- "\n",
1401
- "# # Patient form\n",
1402
- "# with st.form(\"patient_form\"):\n",
1403
- "# col1, col2 = st.columns(2)\n",
1404
- "# with col1:\n",
1405
- "# age = st.number_input(\"Age\", 0, 120, 65)\n",
1406
- "# weight = st.number_input(\"Weight (kg)\", 1.0, 300.0, 70.0)\n",
1407
- "# sex = st.selectbox(\"Sex\", [\"male\", \"female\"])\n",
1408
- "# with col2:\n",
1409
- "# creatinine = st.number_input(\"Creatinine (mg/dL)\", 0.1, 20.0, 1.2)\n",
1410
- "# infection_site = st.selectbox(\"Infection Site\", [\"urinary\", \"respiratory\", \"bloodstream\"])\n",
1411
- "# \n",
1412
- "# submitted = st.form_submit_button(\"Get Recommendation\")\n",
1413
- "\n",
1414
- "# if submitted:\n",
1415
- "# st.success(\"Demo mode: Showing simulated recommendation\")\n",
1416
- "# st.json({\n",
1417
- "# \"primary_antibiotic\": \"Ciprofloxacin\",\n",
1418
- "# \"dose\": \"500mg\",\n",
1419
- "# \"route\": \"PO\",\n",
1420
- "# \"frequency\": \"Every 12 hours\",\n",
1421
- "# \"duration\": \"7 days\"\n",
1422
- "# })\n",
1423
- "# '''\n",
1424
- "\n",
1425
- "# with open('streamlit_app.py', 'w') as f:\n",
1426
- "# f.write(app_code)\n",
1427
- "\n",
1428
- "# # Run with localtunnel\n",
1429
- "# !streamlit run streamlit_app.py &>/dev/null &\n",
1430
- "# !npx localtunnel --port 8501"
1431
- ]
1432
- },
1433
- {
1434
- "cell_type": "markdown",
1435
- "metadata": {},
1436
- "source": [
1437
- "## 11. Summary & Conclusions\n",
1438
- "\n",
1439
- "This notebook demonstrates the **Med-I-C** multi-agent system for antimicrobial stewardship:\n",
1440
- "\n",
1441
- "### Key Features:\n",
1442
- "1. **4 Specialized Agents** powered by MedGemma 4B IT\n",
1443
- "2. **Conditional Routing** via LangGraph for Stage 1 (Empirical) vs Stage 2 (Targeted)\n",
1444
- "3. **CrCl Calculation** using Cockcroft-Gault equation\n",
1445
- "4. **MIC Trend Analysis** for resistance detection\n",
1446
- "5. **Safety Checks** including drug interactions and allergy alerts\n",
1447
- "\n",
1448
- "### Models Used:\n",
1449
- "- **MedGemma 4B IT** - Primary model for all agents (4-bit quantized)\n",
1450
- "- **TxGemma 2B** - Optional safety validation (not demonstrated in this notebook)\n",
1451
- "\n",
1452
- "### Future Enhancements:\n",
1453
- "- Integration with RAG (ChromaDB) for guideline retrieval\n",
1454
- "- MedGemma 27B for complex trend analysis\n",
1455
- "- Vision capabilities for image-based lab report extraction\n",
1456
- "- Regional resistance pattern analysis\n",
1457
- "\n",
1458
- "---\n",
1459
  "\n",
1460
- "**Competition:** MedGemma Impact Challenge \n",
1461
- "**Category:** Agentic Workflow \n",
1462
- "**Deadline:** February 24, 2026"
1463
- ]
1464
- },
1465
- {
1466
- "cell_type": "code",
1467
- "execution_count": null,
1468
- "metadata": {},
1469
- "outputs": [],
1470
- "source": [
1471
- "# Final memory cleanup\n",
1472
- "import gc\n",
1473
- "gc.collect()\n",
1474
- "if torch.cuda.is_available():\n",
1475
- " torch.cuda.empty_cache()\n",
1476
- " print(f\"GPU Memory after cleanup: {torch.cuda.memory_allocated() / 1e9:.2f} GB\")"
1477
  ]
1478
  }
1479
  ],
@@ -1484,18 +266,10 @@
1484
  "name": "python3"
1485
  },
1486
  "language_info": {
1487
- "codemirror_mode": {
1488
- "name": "ipython",
1489
- "version": 3
1490
- },
1491
- "file_extension": ".py",
1492
- "mimetype": "text/x-python",
1493
  "name": "python",
1494
- "nbconvert_exporter": "python",
1495
- "pygments_lexer": "ipython3",
1496
  "version": "3.10.0"
1497
  }
1498
  },
1499
  "nbformat": 4,
1500
- "nbformat_minor": 4
1501
  }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# Med-I-C · AMR-Guard\n",
8
+ "### Infection Lifecycle Orchestrator — Kaggle Demo\n",
9
  "\n",
10
+ "**Steps**\n",
11
+ "1. Clone repo & install packages\n",
12
+ "2. Authenticate with Hugging Face\n",
13
+ "3. Download models\n",
14
+ "4. Initialise the knowledge base\n",
15
+ "5. Launch the Streamlit app"
 
 
 
 
 
 
 
 
 
16
  ]
17
  },
18
  {
19
  "cell_type": "markdown",
20
  "metadata": {},
21
  "source": [
22
+ "## 1 · Environment"
23
  ]
24
  },
25
  {
 
28
  "metadata": {},
29
  "outputs": [],
30
  "source": [
31
+ "import subprocess, torch\n",
32
+ "\n",
33
+ "# GPU check\n",
34
+ "print(subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],\n",
35
+ " capture_output=True, text=True).stdout.strip())\n",
36
+ "print(f\"PyTorch {torch.__version__} · CUDA {torch.cuda.is_available()}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ]
38
  },
39
  {
 
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
45
+ "%%bash\n",
46
+ "# Clone the repo (skip if already present)\n",
47
+ "if [ ! -d /kaggle/working/Med-I-C ]; then\n",
48
+ " git clone https://github.com/benghita/Med-I-C.git /kaggle/working/Med-I-C\n",
49
+ "else\n",
50
+ " echo \"Repo already cloned pulling latest changes\"\n",
51
+ " git -C /kaggle/working/Med-I-C pull\n",
52
+ "fi"
 
 
 
 
 
53
  ]
54
  },
55
  {
 
58
  "metadata": {},
59
  "outputs": [],
60
  "source": [
61
+ "%%capture\n",
62
+ "# Install packages from pyproject.toml dependencies\n",
63
+ "!pip install -q \\\n",
64
+ " \"langgraph>=0.0.15\" \"langchain>=0.3.0\" langchain-text-splitters langchain-community \\\n",
65
+ " \"chromadb>=0.4.0\" sentence-transformers \\\n",
66
+ " \"transformers>=4.50.0\" accelerate bitsandbytes \\\n",
67
+ " streamlit huggingface_hub \\\n",
68
+ " \"pydantic>=2.0\" python-dotenv openpyxl pypdf \"pandas>=2.0\" jq"
69
  ]
70
  },
71
  {
72
  "cell_type": "markdown",
73
  "metadata": {},
74
  "source": [
75
+ "## 2 · Hugging Face Authentication\n",
76
  "\n",
77
+ "Add your token to **Kaggle Add-ons Secrets** as `HF_TOKEN`.\n",
78
  "\n",
79
+ "Accept model licences before running:\n",
80
+ "- https://huggingface.co/google/gemma-2-2b-it"
 
81
  ]
82
  },
83
  {
 
86
  "metadata": {},
87
  "outputs": [],
88
  "source": [
89
+ "import os\n",
90
  "from huggingface_hub import login\n",
91
  "\n",
 
92
  "try:\n",
93
  " from kaggle_secrets import UserSecretsClient\n",
94
+ " hf_token = UserSecretsClient().get_secret(\"HF_TOKEN\")\n",
95
+ " print(\"Token loaded from Kaggle secrets\")\n",
96
+ "except Exception:\n",
97
+ " hf_token = os.getenv(\"HF_TOKEN\", \"\")\n",
98
+ " print(\"Token loaded from environment\" if hf_token else \"WARNING: No HF_TOKEN found\")\n",
 
 
 
 
 
99
  "\n",
100
+ "if hf_token:\n",
101
+ " login(token=hf_token, add_to_git_credential=False)"
102
  ]
103
  },
104
  {
105
  "cell_type": "markdown",
106
  "metadata": {},
107
  "source": [
108
+ "## 3 · Download Models"
109
  ]
110
  },
111
  {
 
114
  "metadata": {},
115
  "outputs": [],
116
  "source": [
117
+ "from huggingface_hub import snapshot_download\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  "\n",
119
+ "# Single model used for all agents in the demo\n",
120
+ "MODEL_ID = \"google/gemma-2-2b-it\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  "\n",
122
+ "print(f\"Downloading {MODEL_ID} \")\n",
123
+ "snapshot_download(\n",
124
+ " repo_id=MODEL_ID,\n",
125
+ " ignore_patterns=[\"*.gguf\", \"*.ot\"], # skip quantised formats we don't need\n",
 
126
  ")\n",
127
+ "print(\"Download complete\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  ]
129
  },
130
  {
 
133
  "metadata": {},
134
  "outputs": [],
135
  "source": [
136
+ "# Embedding model for RAG (small, fast)\n",
137
+ "from sentence_transformers import SentenceTransformer\n",
138
+ "SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
139
+ "print(\"Embedding model ready\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  ]
141
  },
142
  {
143
  "cell_type": "markdown",
144
  "metadata": {},
145
  "source": [
146
+ "## 4 · Configure & Initialise"
147
  ]
148
  },
149
  {
 
152
  "metadata": {},
153
  "outputs": [],
154
  "source": [
155
+ "# Write .env for the Kaggle environment\n",
156
+ "env_content = f\"\"\"\n",
157
+ "MEDIC_ENV=kaggle\n",
158
+ "MEDIC_DEFAULT_BACKEND=local\n",
159
+ "MEDIC_USE_VERTEX=false\n",
160
+ "MEDIC_QUANTIZATION=4bit\n",
 
 
 
 
 
 
 
 
 
 
161
  "\n",
162
+ "MEDIC_LOCAL_MEDGEMMA_4B_MODEL={MODEL_ID}\n",
163
+ "MEDIC_LOCAL_MEDGEMMA_27B_MODEL={MODEL_ID}\n",
164
+ "MEDIC_LOCAL_TXGEMMA_9B_MODEL={MODEL_ID}\n",
165
+ "MEDIC_LOCAL_TXGEMMA_2B_MODEL={MODEL_ID}\n",
 
 
 
 
 
 
 
 
 
166
  "\n",
167
+ "MEDIC_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2\n",
168
+ "MEDIC_DATA_DIR=/kaggle/working/Med-I-C/data\n",
169
+ "MEDIC_CHROMA_DB_DIR=/kaggle/working/Med-I-C/data/chroma_db\n",
170
+ "\"\"\".strip()\n",
171
  "\n",
172
+ "with open(\"/kaggle/working/Med-I-C/.env\", \"w\") as f:\n",
173
+ " f.write(env_content)\n",
174
  "\n",
175
+ "print(\".env written\")"
 
 
 
 
 
 
 
 
 
 
176
  ]
177
  },
178
  {
 
181
  "metadata": {},
182
  "outputs": [],
183
  "source": [
184
+ "import sys\n",
185
+ "sys.path.insert(0, \"/kaggle/working/Med-I-C\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  "\n",
187
+ "# Initialise SQLite + ChromaDB knowledge base\n",
188
+ "!python /kaggle/working/Med-I-C/setup_demo.py"
 
 
 
189
  ]
190
  },
191
  {
192
  "cell_type": "markdown",
193
  "metadata": {},
194
  "source": [
195
+ "## 5 · Launch the App"
196
  ]
197
  },
198
  {
 
201
  "metadata": {},
202
  "outputs": [],
203
  "source": [
204
+ "%%capture\n",
205
+ "!pip install -q localtunnel"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  ]
207
  },
208
  {
 
211
  "metadata": {},
212
  "outputs": [],
213
  "source": [
214
+ "import subprocess, threading, time, requests\n",
 
 
 
215
  "\n",
216
+ "# Start Streamlit in the background\n",
217
+ "streamlit_proc = subprocess.Popen(\n",
218
+ " [\"streamlit\", \"run\", \"/kaggle/working/Med-I-C/app.py\",\n",
219
+ " \"--server.port\", \"8501\",\n",
220
+ " \"--server.headless\", \"true\",\n",
221
+ " \"--server.enableCORS\", \"false\"],\n",
222
+ " stdout=subprocess.DEVNULL,\n",
223
+ " stderr=subprocess.DEVNULL,\n",
224
+ ")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  "\n",
226
+ "# Wait for Streamlit to be ready\n",
227
+ "for _ in range(15):\n",
228
+ " try:\n",
229
+ " if requests.get(\"http://localhost:8501\", timeout=2).status_code == 200:\n",
230
+ " print(\"Streamlit is running on port 8501\")\n",
231
+ " break\n",
232
+ " except Exception:\n",
233
+ " time.sleep(2)\n",
234
+ "else:\n",
235
+ " print(\"Streamlit may still be starting…\")"
236
  ]
237
  },
238
  {
 
241
  "metadata": {},
242
  "outputs": [],
243
  "source": [
244
+ "# Expose via localtunnel the public URL will appear below\n",
245
+ "tunnel_proc = subprocess.Popen(\n",
246
+ " [\"npx\", \"localtunnel\", \"--port\", \"8501\"],\n",
247
+ " stdout=subprocess.PIPE,\n",
248
+ " stderr=subprocess.DEVNULL,\n",
249
+ " text=True,\n",
250
+ ")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  "\n",
252
+ "# Print the public URL\n",
253
+ "for line in tunnel_proc.stdout:\n",
254
+ " if \"https://\" in line:\n",
255
+ " print(\"\\n\" + \"=\"*50)\n",
256
+ " print(f\" App URL: {line.strip()}\")\n",
257
+ " print(\"=\"*50)\n",
258
+ " break"
 
 
 
 
 
 
 
 
 
 
259
  ]
260
  }
261
  ],
 
266
  "name": "python3"
267
  },
268
  "language_info": {
 
 
 
 
 
 
269
  "name": "python",
 
 
270
  "version": "3.10.0"
271
  }
272
  },
273
  "nbformat": 4,
274
+ "nbformat_minor": 5
275
  }
pyproject.toml CHANGED
@@ -8,8 +8,6 @@ dependencies = [
8
  "langgraph>=0.0.15",
9
  "langchain>=0.3.0",
10
  "langchain-text-splitters",
11
- "langchain-google-vertexai",
12
- "google-cloud-aiplatform",
13
  "chromadb>=0.4.0",
14
  "sentence-transformers",
15
  "transformers>=4.50.0",
@@ -26,4 +24,5 @@ dependencies = [
26
  "langchain-community>=0.4.1",
27
  "jq>=1.11.0",
28
  "pandas>=2.0.0",
 
29
  ]
 
8
  "langgraph>=0.0.15",
9
  "langchain>=0.3.0",
10
  "langchain-text-splitters",
 
 
11
  "chromadb>=0.4.0",
12
  "sentence-transformers",
13
  "transformers>=4.50.0",
 
24
  "langchain-community>=0.4.1",
25
  "jq>=1.11.0",
26
  "pandas>=2.0.0",
27
+ "huggingface-hub",
28
  ]
uv.lock CHANGED
The diff for this file is too large to render. See raw diff