ghitaben commited on
Commit
ba2715f
·
1 Parent(s): 18c0556

Enhance patient analysis form with dynamic site-specific fields and support for lab image uploads

Browse files
Files changed (7) hide show
  1. app.py +119 -6
  2. src/agents.py +40 -9
  3. src/form_config.py +327 -0
  4. src/graph.py +7 -2
  5. src/loader.py +49 -3
  6. src/prompts.py +6 -0
  7. src/state.py +2 -1
app.py CHANGED
@@ -12,6 +12,7 @@ import streamlit as st
12
  PROJECT_ROOT = Path(__file__).parent
13
  sys.path.insert(0, str(PROJECT_ROOT))
14
 
 
15
  from src.tools import (
16
  calculate_mic_trend,
17
  get_empirical_therapy_guidance,
@@ -380,13 +381,75 @@ def page_patient_analysis():
380
  height = st.number_input("Height (cm)", 50.0, 250.0, 170.0, step=0.5)
381
  with c2:
382
  sex = st.selectbox("Biological sex", ["male", "female"])
383
- creatinine = st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  with c3:
385
  infection_site = st.selectbox(
386
  "Primary infection site",
387
  ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"],
 
388
  )
389
- suspected_source = st.text_input("Suspected source", placeholder="e.g., community-acquired UTI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  with st.expander("Medical History"):
392
  c1, c2 = st.columns(2)
@@ -404,9 +467,56 @@ def page_patient_analysis():
404
  )
405
 
406
  with st.expander("Lab / Culture Results (optional — triggers targeted pathway)"):
407
- method = st.radio("Input method", ["None — empirical pathway only", "Paste lab text"], horizontal=True)
 
 
 
 
408
  labs_raw_text = None
409
- if method == "Paste lab text":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  labs_raw_text = st.text_area(
411
  "Lab report",
412
  placeholder=(
@@ -422,6 +532,7 @@ def page_patient_analysis():
422
  run_btn = st.button("Run Agent Pipeline", type="primary", use_container_width=False)
423
 
424
  if run_btn:
 
425
  patient_data = {
426
  "age_years": age,
427
  "weight_kg": weight,
@@ -433,11 +544,13 @@ def page_patient_analysis():
433
  "medications": [m.strip() for m in medications.split("\n") if m.strip()],
434
  "allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
435
  "comorbidities": list(comorbidities) + list(risk_factors),
 
 
436
  }
437
 
438
  stages = (
439
  ["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
440
- if labs_raw_text
441
  else ["Intake Historian", "Clinical Pharmacologist"]
442
  )
443
 
@@ -449,7 +562,7 @@ def page_patient_analysis():
449
  from src.graph import run_pipeline
450
  result = run_pipeline(patient_data, labs_raw_text)
451
  except Exception:
452
- result = _demo_result(patient_data, labs_raw_text)
453
 
454
  prog.progress(100, text="Complete")
455
  st.session_state.pipeline_result = result
 
12
  PROJECT_ROOT = Path(__file__).parent
13
  sys.path.insert(0, str(PROJECT_ROOT))
14
 
15
+ from src.form_config import CREATININE_PROMINENT_SITES, SITE_SPECIFIC_FIELDS, SUSPECTED_SOURCE_OPTIONS
16
  from src.tools import (
17
  calculate_mic_trend,
18
  get_empirical_therapy_guidance,
 
381
  height = st.number_input("Height (cm)", 50.0, 250.0, 170.0, step=0.5)
382
  with c2:
383
  sex = st.selectbox("Biological sex", ["male", "female"])
384
+ # Infection site is needed to decide creatinine visibility, so render it first
385
+ # (Streamlit reruns top-to-bottom, but c3 renders in the same pass, so we
386
+ # read infection_site from session state on the *next* rerun. We default
387
+ # to the current widget value via a placeholder key.)
388
+ infection_site = st.session_state.get("_infection_site_val", "urinary")
389
+ if infection_site in CREATININE_PROMINENT_SITES:
390
+ creatinine = st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1,
391
+ help="Required for CrCl-based dose adjustment")
392
+ else:
393
+ renal_flag = st.checkbox("Known renal impairment / CKD?",
394
+ help="Check to enter serum creatinine for dose adjustment")
395
+ creatinine = (
396
+ st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1)
397
+ if renal_flag else None
398
+ )
399
  with c3:
400
  infection_site = st.selectbox(
401
  "Primary infection site",
402
  ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"],
403
+ key="_infection_site_val",
404
  )
405
+ source_options = SUSPECTED_SOURCE_OPTIONS.get(infection_site, [])
406
+ if source_options:
407
+ suspected_source = st.selectbox("Suspected source", source_options)
408
+ if suspected_source == "Other":
409
+ suspected_source = st.text_input(
410
+ "Specify source", placeholder="Describe the suspected source"
411
+ )
412
+ else:
413
+ suspected_source = st.text_input(
414
+ "Suspected source", placeholder="e.g., community-acquired infection"
415
+ )
416
+
417
+ # ── Site-specific assessment (dynamic per infection site) ──
418
+ site_vitals: dict[str, str] = {}
419
+ site_fields = SITE_SPECIFIC_FIELDS.get(infection_site, [])
420
+ if site_fields:
421
+ with st.expander(f"Site-Specific Assessment — {infection_site.title()}", expanded=True):
422
+ cols = st.columns(2)
423
+ for i, field in enumerate(site_fields):
424
+ col = cols[i % 2]
425
+ with col:
426
+ fkey = f"site_{field['key']}"
427
+ ftype = field["type"]
428
+ if ftype == "selectbox":
429
+ val = st.selectbox(field["label"], field["options"], key=fkey)
430
+ elif ftype == "multiselect":
431
+ val = st.multiselect(field["label"], field["options"], key=fkey)
432
+ val = ", ".join(val) if val else ""
433
+ elif ftype == "number_input":
434
+ val = st.number_input(
435
+ field["label"],
436
+ min_value=field.get("min", 0.0),
437
+ max_value=field.get("max", 999.0),
438
+ value=field.get("default", 0.0),
439
+ step=field.get("step", 1.0),
440
+ key=fkey,
441
+ )
442
+ val = str(val)
443
+ elif ftype == "checkbox":
444
+ val = st.checkbox(
445
+ field["label"], value=field.get("default", False), key=fkey
446
+ )
447
+ val = "Yes" if val else "No"
448
+ elif ftype == "text_input":
449
+ val = st.text_input(field["label"], key=fkey)
450
+ else:
451
+ continue
452
+ site_vitals[field["key"]] = str(val)
453
 
454
  with st.expander("Medical History"):
455
  c1, c2 = st.columns(2)
 
467
  )
468
 
469
  with st.expander("Lab / Culture Results (optional — triggers targeted pathway)"):
470
+ method = st.radio(
471
+ "Input method",
472
+ ["None — empirical pathway only", "Upload file (PDF / image)", "Paste lab text"],
473
+ horizontal=True,
474
+ )
475
  labs_raw_text = None
476
+ labs_image_bytes = None
477
+
478
+ if method == "Upload file (PDF / image)":
479
+ uploaded = st.file_uploader(
480
+ "Lab report file",
481
+ type=["pdf", "png", "jpg", "jpeg", "tiff", "tif", "bmp"],
482
+ help="Upload a culture & sensitivity report, antibiogram, or any lab document.",
483
+ )
484
+ if uploaded is not None:
485
+ file_bytes = uploaded.read()
486
+ ext = uploaded.name.rsplit(".", 1)[-1].lower()
487
+ if ext == "pdf":
488
+ # Extract text from PDF using pypdf
489
+ import pypdf
490
+ from io import BytesIO
491
+ try:
492
+ reader = pypdf.PdfReader(BytesIO(file_bytes))
493
+ extracted = "\n".join(
494
+ page.extract_text() or "" for page in reader.pages
495
+ ).strip()
496
+ if extracted:
497
+ labs_raw_text = extracted
498
+ st.success(f"PDF parsed — {len(reader.pages)} page(s), {len(extracted)} characters extracted.")
499
+ else:
500
+ st.warning(
501
+ "PDF text extraction returned empty content (scanned PDF?). "
502
+ "The file will be processed as an image by the vision model."
503
+ )
504
+ # Convert first page to image fallback via pillow (requires pypdf extras)
505
+ labs_image_bytes = file_bytes
506
+ except Exception as e:
507
+ st.error(f"PDF parsing failed: {e}")
508
+ else:
509
+ # Image file — pass directly to the multimodal model
510
+ labs_image_bytes = file_bytes
511
+ from PIL import Image as _PILImage
512
+ from io import BytesIO as _BytesIO
513
+ try:
514
+ thumb = _PILImage.open(_BytesIO(file_bytes))
515
+ st.image(thumb, caption=f"Uploaded: {uploaded.name}", width=320)
516
+ except Exception:
517
+ st.info(f"Image uploaded: {uploaded.name}")
518
+
519
+ elif method == "Paste lab text":
520
  labs_raw_text = st.text_area(
521
  "Lab report",
522
  placeholder=(
 
532
  run_btn = st.button("Run Agent Pipeline", type="primary", use_container_width=False)
533
 
534
  if run_btn:
535
+ has_lab_input = bool(labs_raw_text or labs_image_bytes)
536
  patient_data = {
537
  "age_years": age,
538
  "weight_kg": weight,
 
544
  "medications": [m.strip() for m in medications.split("\n") if m.strip()],
545
  "allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
546
  "comorbidities": list(comorbidities) + list(risk_factors),
547
+ "vitals": site_vitals,
548
+ "labs_image_bytes": labs_image_bytes,
549
  }
550
 
551
  stages = (
552
  ["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
553
+ if has_lab_input
554
  else ["Intake Historian", "Clinical Pharmacologist"]
555
  )
556
 
 
562
  from src.graph import run_pipeline
563
  result = run_pipeline(patient_data, labs_raw_text)
564
  except Exception:
565
+ result = _demo_result(patient_data, labs_raw_text or bool(labs_image_bytes))
566
 
567
  prog.progress(100, text="Complete")
568
  st.session_state.pipeline_result = result
src/agents.py CHANGED
@@ -12,7 +12,7 @@ import logging
12
  from typing import Optional
13
 
14
  from .config import get_settings
15
- from .loader import run_inference, TextModelName
16
  from .prompts import (
17
  INTAKE_HISTORIAN_SYSTEM,
18
  INTAKE_HISTORIAN_PROMPT,
@@ -66,12 +66,17 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
66
  patient_context={"pathogen_type": state.get("suspected_source")},
67
  )
68
 
 
 
 
 
69
  prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
70
  patient_data=patient_data,
71
  medications=', '.join(state.get('medications', [])) or 'None reported',
72
  allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
73
  infection_site=state.get('infection_site', 'Unknown'),
74
  suspected_source=state.get('suspected_source', 'Unknown'),
 
75
  rag_context=rag_context,
76
  )}"
77
 
@@ -105,14 +110,24 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
105
  logger.info("Running Vision Specialist agent...")
106
 
107
  labs_raw = state.get("labs_raw_text", "")
108
- if not labs_raw:
 
 
109
  logger.info("No lab data to process, skipping Vision Specialist")
110
  state["vision_notes"] = "No lab data provided"
111
  state["route_to_trend_analyst"] = False
112
  return state
113
 
114
- # Language detection is not implemented; we assume English or instruct the model to translate
115
- language = "English (assumed)"
 
 
 
 
 
 
 
 
116
  rag_context = get_context_for_agent(
117
  agent_name="vision_specialist",
118
  query="culture sensitivity susceptibility interpretation",
@@ -120,13 +135,22 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
120
  )
121
 
122
  prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
123
- report_content=labs_raw,
124
- source_format='text',
125
  language=language,
126
  )}"
127
 
128
  try:
129
- response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1)
 
 
 
 
 
 
 
 
 
130
  parsed = safe_json_parse(response)
131
  if parsed:
132
  state["vision_notes"] = json.dumps(parsed, indent=2)
@@ -256,6 +280,10 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
256
  patient_context={"proposed_antibiotic": None},
257
  )
258
 
 
 
 
 
259
  prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
260
  intake_summary=intake_summary,
261
  lab_results=lab_results,
@@ -268,6 +296,7 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
268
  infection_site=state.get('infection_site', 'Unknown'),
269
  suspected_source=state.get('suspected_source', 'Unknown'),
270
  severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
 
271
  rag_context=rag_context,
272
  )}"
273
 
@@ -355,8 +384,10 @@ def _format_patient_data(state: InfectionState) -> str:
355
  lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
356
 
357
  if state.get("vitals"):
358
- vitals_str = ", ".join(f"{k}: {v}" for k, v in state["vitals"].items())
359
- lines.append(f"Vitals: {vitals_str}")
 
 
360
 
361
  return "\n".join(lines) if lines else "No patient data available"
362
 
 
12
  from typing import Optional
13
 
14
  from .config import get_settings
15
+ from .loader import run_inference, run_inference_with_image, TextModelName
16
  from .prompts import (
17
  INTAKE_HISTORIAN_SYSTEM,
18
  INTAKE_HISTORIAN_PROMPT,
 
66
  patient_context={"pathogen_type": state.get("suspected_source")},
67
  )
68
 
69
+ site_vitals_str = "\n".join(
70
+ f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items()
71
+ ) or "None provided"
72
+
73
  prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
74
  patient_data=patient_data,
75
  medications=', '.join(state.get('medications', [])) or 'None reported',
76
  allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
77
  infection_site=state.get('infection_site', 'Unknown'),
78
  suspected_source=state.get('suspected_source', 'Unknown'),
79
+ site_vitals=site_vitals_str,
80
  rag_context=rag_context,
81
  )}"
82
 
 
110
  logger.info("Running Vision Specialist agent...")
111
 
112
  labs_raw = state.get("labs_raw_text", "")
113
+ labs_image_bytes = state.get("labs_image_bytes")
114
+
115
+ if not labs_raw and not labs_image_bytes:
116
  logger.info("No lab data to process, skipping Vision Specialist")
117
  state["vision_notes"] = "No lab data provided"
118
  state["route_to_trend_analyst"] = False
119
  return state
120
 
121
+ # Determine input modality and prepare prompt content description
122
+ if labs_image_bytes:
123
+ source_format = "image"
124
+ language = "Auto-detected"
125
+ report_content = "See attached image — extract all lab data visible in the image."
126
+ else:
127
+ source_format = "text"
128
+ language = "English (assumed)"
129
+ report_content = labs_raw
130
+
131
  rag_context = get_context_for_agent(
132
  agent_name="vision_specialist",
133
  query="culture sensitivity susceptibility interpretation",
 
135
  )
136
 
137
  prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
138
+ report_content=report_content,
139
+ source_format=source_format,
140
  language=language,
141
  )}"
142
 
143
  try:
144
+ if labs_image_bytes:
145
+ from io import BytesIO
146
+ from PIL import Image as PILImage
147
+ image = PILImage.open(BytesIO(labs_image_bytes)).convert("RGB")
148
+ logger.info(f"Running vision inference on uploaded image ({image.size})")
149
+ response = run_inference_with_image(
150
+ prompt=prompt, image=image, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1
151
+ )
152
+ else:
153
+ response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1)
154
  parsed = safe_json_parse(response)
155
  if parsed:
156
  state["vision_notes"] = json.dumps(parsed, indent=2)
 
280
  patient_context={"proposed_antibiotic": None},
281
  )
282
 
283
+ site_vitals_str = "\n".join(
284
+ f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items()
285
+ ) or "None provided"
286
+
287
  prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
288
  intake_summary=intake_summary,
289
  lab_results=lab_results,
 
296
  infection_site=state.get('infection_site', 'Unknown'),
297
  suspected_source=state.get('suspected_source', 'Unknown'),
298
  severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
299
+ site_vitals=site_vitals_str,
300
  rag_context=rag_context,
301
  )}"
302
 
 
384
  lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
385
 
386
  if state.get("vitals"):
387
+ lines.append("Site-Specific Assessment:")
388
+ for k, v in state["vitals"].items():
389
+ label = k.replace("_", " ").title()
390
+ lines.append(f" - {label}: {v}")
391
 
392
  return "\n".join(lines) if lines else "No patient data available"
393
 
src/form_config.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Declarative field definitions for the dynamic Patient Analysis form.
3
+
4
+ Each infection site maps to a list of site-specific fields and contextual
5
+ suspected-source options. Universal fields (age, sex, weight, height,
6
+ creatinine, medications, allergies, comorbidities, risk factors) are always
7
+ shown and are NOT listed here.
8
+ """
9
+
10
+ SITE_SPECIFIC_FIELDS: dict[str, list[dict]] = {
11
+ "urinary": [
12
+ {
13
+ "key": "catheter_status",
14
+ "label": "Catheter status",
15
+ "type": "selectbox",
16
+ "options": [
17
+ "No catheter",
18
+ "Indwelling (Foley)",
19
+ "Intermittent",
20
+ "Suprapubic",
21
+ "Recently removed (<48 h)",
22
+ ],
23
+ },
24
+ {
25
+ "key": "urinary_symptoms",
26
+ "label": "Urinary symptoms",
27
+ "type": "multiselect",
28
+ "options": [
29
+ "Dysuria",
30
+ "Frequency",
31
+ "Urgency",
32
+ "Hematuria",
33
+ "Suprapubic pain",
34
+ "Flank pain",
35
+ "Fever / chills",
36
+ ],
37
+ },
38
+ {
39
+ "key": "urine_appearance",
40
+ "label": "Urine appearance",
41
+ "type": "selectbox",
42
+ "options": ["Clear", "Cloudy", "Turbid", "Malodorous", "Hematuria"],
43
+ },
44
+ ],
45
+ "respiratory": [
46
+ {
47
+ "key": "o2_saturation",
48
+ "label": "O\u2082 Saturation (%)",
49
+ "type": "number_input",
50
+ "min": 50.0,
51
+ "max": 100.0,
52
+ "default": 97.0,
53
+ "step": 0.5,
54
+ },
55
+ {
56
+ "key": "ventilation_status",
57
+ "label": "Ventilation status",
58
+ "type": "selectbox",
59
+ "options": [
60
+ "Room air",
61
+ "Supplemental O\u2082 (nasal cannula)",
62
+ "Supplemental O\u2082 (mask)",
63
+ "Non-invasive (BiPAP / CPAP)",
64
+ "Mechanical ventilation",
65
+ ],
66
+ },
67
+ {
68
+ "key": "cough_type",
69
+ "label": "Cough type",
70
+ "type": "selectbox",
71
+ "options": ["None", "Dry", "Productive", "Hemoptysis"],
72
+ },
73
+ {
74
+ "key": "sputum_character",
75
+ "label": "Sputum character",
76
+ "type": "selectbox",
77
+ "options": [
78
+ "None",
79
+ "Clear / white",
80
+ "Yellow",
81
+ "Green / purulent",
82
+ "Rust-colored",
83
+ "Blood-tinged",
84
+ ],
85
+ },
86
+ ],
87
+ "bloodstream": [
88
+ {
89
+ "key": "central_line_present",
90
+ "label": "Central line present",
91
+ "type": "checkbox",
92
+ "default": False,
93
+ },
94
+ {
95
+ "key": "temperature_c",
96
+ "label": "Temperature (\u00b0C)",
97
+ "type": "number_input",
98
+ "min": 34.0,
99
+ "max": 43.0,
100
+ "default": 38.5,
101
+ "step": 0.1,
102
+ },
103
+ {
104
+ "key": "heart_rate_bpm",
105
+ "label": "Heart rate (bpm)",
106
+ "type": "number_input",
107
+ "min": 30,
108
+ "max": 250,
109
+ "default": 90,
110
+ "step": 1,
111
+ },
112
+ {
113
+ "key": "respiratory_rate",
114
+ "label": "Respiratory rate (/min)",
115
+ "type": "number_input",
116
+ "min": 5,
117
+ "max": 60,
118
+ "default": 18,
119
+ "step": 1,
120
+ },
121
+ {
122
+ "key": "wbc_count",
123
+ "label": "WBC count (\u00d710\u2079/L)",
124
+ "type": "number_input",
125
+ "min": 0.0,
126
+ "max": 100.0,
127
+ "default": 12.0,
128
+ "step": 0.1,
129
+ },
130
+ {
131
+ "key": "lactate_mmol",
132
+ "label": "Lactate (mmol/L)",
133
+ "type": "number_input",
134
+ "min": 0.0,
135
+ "max": 30.0,
136
+ "default": 1.0,
137
+ "step": 0.1,
138
+ },
139
+ {
140
+ "key": "shock_status",
141
+ "label": "Shock status",
142
+ "type": "selectbox",
143
+ "options": [
144
+ "No shock",
145
+ "Compensated (SBP > 90, tachycardia)",
146
+ "Septic shock (vasopressors required)",
147
+ ],
148
+ },
149
+ ],
150
+ "skin": [
151
+ {
152
+ "key": "wound_type",
153
+ "label": "Wound type",
154
+ "type": "selectbox",
155
+ "options": [
156
+ "Laceration",
157
+ "Ulcer (diabetic / pressure)",
158
+ "Bite (animal / human)",
159
+ "Surgical site",
160
+ "Burn",
161
+ "Abscess",
162
+ "Cellulitis (no wound)",
163
+ ],
164
+ },
165
+ {
166
+ "key": "cellulitis_extent",
167
+ "label": "Cellulitis extent",
168
+ "type": "selectbox",
169
+ "options": [
170
+ "None",
171
+ "Localized (< 5 cm)",
172
+ "Moderate (5\u201310 cm)",
173
+ "Extensive (> 10 cm)",
174
+ "Rapidly spreading",
175
+ ],
176
+ },
177
+ {
178
+ "key": "abscess_present",
179
+ "label": "Abscess present",
180
+ "type": "checkbox",
181
+ "default": False,
182
+ },
183
+ {
184
+ "key": "foreign_body",
185
+ "label": "Foreign body / implant",
186
+ "type": "checkbox",
187
+ "default": False,
188
+ },
189
+ ],
190
+ "intra-abdominal": [
191
+ {
192
+ "key": "abdominal_pain_location",
193
+ "label": "Pain location",
194
+ "type": "selectbox",
195
+ "options": [
196
+ "Diffuse",
197
+ "RUQ",
198
+ "LUQ",
199
+ "RLQ",
200
+ "LLQ",
201
+ "Epigastric",
202
+ "Periumbilical",
203
+ ],
204
+ },
205
+ {
206
+ "key": "peritonitis_signs",
207
+ "label": "Peritonitis signs",
208
+ "type": "multiselect",
209
+ "options": [
210
+ "Guarding",
211
+ "Rebound tenderness",
212
+ "Rigidity",
213
+ "Absent bowel sounds",
214
+ ],
215
+ },
216
+ {
217
+ "key": "perforation_suspected",
218
+ "label": "Perforation suspected",
219
+ "type": "checkbox",
220
+ "default": False,
221
+ },
222
+ {
223
+ "key": "ascites",
224
+ "label": "Ascites present",
225
+ "type": "checkbox",
226
+ "default": False,
227
+ },
228
+ ],
229
+ "CNS": [
230
+ {
231
+ "key": "csf_obtained",
232
+ "label": "CSF obtained",
233
+ "type": "checkbox",
234
+ "default": False,
235
+ },
236
+ {
237
+ "key": "neuro_symptoms",
238
+ "label": "Neurological symptoms",
239
+ "type": "multiselect",
240
+ "options": [
241
+ "Headache",
242
+ "Neck stiffness",
243
+ "Photophobia",
244
+ "Altered mental status",
245
+ "Seizures",
246
+ "Focal deficits",
247
+ ],
248
+ },
249
+ {
250
+ "key": "recent_neurosurgery",
251
+ "label": "Recent neurosurgery",
252
+ "type": "checkbox",
253
+ "default": False,
254
+ },
255
+ {
256
+ "key": "gcs_score",
257
+ "label": "GCS score",
258
+ "type": "number_input",
259
+ "min": 3,
260
+ "max": 15,
261
+ "default": 15,
262
+ "step": 1,
263
+ },
264
+ ],
265
+ "other": [],
266
+ }
267
+
268
+
269
+ # Sites where serum creatinine is shown prominently in demographics.
270
+ # For all other sites a "renal impairment?" toggle is shown instead.
271
+ CREATININE_PROMINENT_SITES: frozenset[str] = frozenset(
272
+ {"urinary", "bloodstream", "CNS", "respiratory"}
273
+ )
274
+
275
+ SUSPECTED_SOURCE_OPTIONS: dict[str, list[str]] = {
276
+ "urinary": [
277
+ "Community-acquired UTI",
278
+ "Catheter-associated UTI (CAUTI)",
279
+ "Complicated UTI",
280
+ "Pyelonephritis",
281
+ "Urosepsis",
282
+ "Other",
283
+ ],
284
+ "respiratory": [
285
+ "Community-acquired pneumonia (CAP)",
286
+ "Hospital-acquired pneumonia (HAP)",
287
+ "Ventilator-associated pneumonia (VAP)",
288
+ "Aspiration pneumonia",
289
+ "Lung abscess",
290
+ "Empyema",
291
+ "Other",
292
+ ],
293
+ "bloodstream": [
294
+ "Primary bacteremia",
295
+ "Catheter-related BSI (CRBSI)",
296
+ "Secondary bacteremia (from known source)",
297
+ "Endocarditis",
298
+ "Unknown source",
299
+ "Other",
300
+ ],
301
+ "skin": [
302
+ "Cellulitis",
303
+ "Surgical site infection",
304
+ "Diabetic foot infection",
305
+ "Bite wound infection",
306
+ "Necrotizing fasciitis",
307
+ "Abscess",
308
+ "Other",
309
+ ],
310
+ "intra-abdominal": [
311
+ "Appendicitis",
312
+ "Cholecystitis / cholangitis",
313
+ "Diverticulitis",
314
+ "Peritonitis (SBP)",
315
+ "Post-surgical",
316
+ "Liver abscess",
317
+ "Other",
318
+ ],
319
+ "CNS": [
320
+ "Community-acquired meningitis",
321
+ "Post-neurosurgical meningitis",
322
+ "Healthcare-associated ventriculitis",
323
+ "Brain abscess",
324
+ "Other",
325
+ ],
326
+ "other": [],
327
+ }
src/graph.py CHANGED
@@ -77,6 +77,9 @@ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> Infect
77
  Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
78
  Without it, only the empirical (Stage 1) pathway runs.
79
  """
 
 
 
80
  initial_state: InfectionState = {
81
  "age_years": patient_data.get("age_years"),
82
  "weight_kg": patient_data.get("weight_kg"),
@@ -90,15 +93,17 @@ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> Infect
90
  "suspected_source": patient_data.get("suspected_source"),
91
  "country_or_region": patient_data.get("country_or_region"),
92
  "vitals": patient_data.get("vitals", {}),
93
- "stage": "targeted" if labs_raw_text else "empirical",
94
  "errors": [],
95
  "safety_warnings": [],
96
  }
97
 
98
  if labs_raw_text:
99
  initial_state["labs_raw_text"] = labs_raw_text
 
 
100
 
101
- logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
102
  logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
103
 
104
  try:
 
77
  Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
78
  Without it, only the empirical (Stage 1) pathway runs.
79
  """
80
+ labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes")
81
+ has_lab_input = bool(labs_raw_text or labs_image_bytes)
82
+
83
  initial_state: InfectionState = {
84
  "age_years": patient_data.get("age_years"),
85
  "weight_kg": patient_data.get("weight_kg"),
 
93
  "suspected_source": patient_data.get("suspected_source"),
94
  "country_or_region": patient_data.get("country_or_region"),
95
  "vitals": patient_data.get("vitals", {}),
96
+ "stage": "targeted" if has_lab_input else "empirical",
97
  "errors": [],
98
  "safety_warnings": [],
99
  }
100
 
101
  if labs_raw_text:
102
  initial_state["labs_raw_text"] = labs_raw_text
103
+ if labs_image_bytes:
104
+ initial_state["labs_image_bytes"] = labs_image_bytes
105
 
106
+ logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})")
107
  logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
108
 
109
  try:
src/loader.py CHANGED
@@ -58,9 +58,20 @@ def _get_local_multimodal(model_name: TextModelName):
58
  model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
59
  logger.info(f"Model loaded successfully: {model_path}")
60
 
61
- def _call(prompt: str, max_new_tokens: int = 512, temperature: float = 0.2, **generate_kwargs: Any) -> str:
62
- # Build a chat-style input for text-only queries
63
- messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
 
 
 
 
 
 
 
 
 
 
 
64
  inputs = processor.apply_chat_template(
65
  messages, add_generation_prompt=True, tokenize=True,
66
  return_dict=True, return_tensors="pt",
@@ -155,3 +166,38 @@ def run_inference(
155
  except Exception as e:
156
  logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
157
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
59
  logger.info(f"Model loaded successfully: {model_path}")
60
 
61
+ def _call(
62
+ prompt: str,
63
+ max_new_tokens: int = 512,
64
+ temperature: float = 0.2,
65
+ image=None, # optional PIL.Image.Image for vision-language inference
66
+ **generate_kwargs: Any,
67
+ ) -> str:
68
+ # Build chat content; prepend image token when an image is provided
69
+ content = []
70
+ if image is not None:
71
+ content.append({"type": "image", "image": image})
72
+ content.append({"type": "text", "text": prompt})
73
+ messages = [{"role": "user", "content": content}]
74
+
75
  inputs = processor.apply_chat_template(
76
  messages, add_generation_prompt=True, tokenize=True,
77
  return_dict=True, return_tensors="pt",
 
166
  except Exception as e:
167
  logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
168
  raise
169
+
170
+
171
+ def run_inference_with_image(
172
+ prompt: str,
173
+ image: Any, # PIL.Image.Image
174
+ model_name: TextModelName = "medgemma_4b",
175
+ max_new_tokens: int = 1024,
176
+ temperature: float = 0.1,
177
+ **kwargs: Any,
178
+ ) -> str:
179
+ """
180
+ Run vision-language inference passing a PIL image alongside the text prompt.
181
+
182
+ Falls back to text-only inference if the resolved model is not multimodal
183
+ (e.g. when medgemma_4b is remapped to a text-only model in the env config).
184
+ """
185
+ logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
186
+ try:
187
+ model_path = _get_model_path(model_name)
188
+ if not _is_multimodal(model_path):
189
+ logger.warning(
190
+ f"{model_name} ({model_path}) is not a multimodal model; "
191
+ "falling back to text-only inference."
192
+ )
193
+ return run_inference(prompt, model_name, max_new_tokens, temperature, **kwargs)
194
+
195
+ model_fn = _get_local_multimodal(model_name)
196
+ result = model_fn(
197
+ prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
198
+ )
199
+ logger.info(f"Vision inference complete, response length: {len(result)} chars")
200
+ return result
201
+ except Exception as e:
202
+ logger.error(f"Vision inference failed for {model_name}: {e}", exc_info=True)
203
+ raise
src/prompts.py CHANGED
@@ -49,6 +49,9 @@ CLINICAL CONTEXT:
49
  - Suspected infection site: {infection_site}
50
  - Suspected source: {suspected_source}
51
 
 
 
 
52
  RAG CONTEXT (Relevant Guidelines):
53
  {rag_context}
54
 
@@ -267,6 +270,9 @@ INFECTION CONTEXT:
267
  - Source: {suspected_source}
268
  - Severity: {severity}
269
 
 
 
 
270
  RAG CONTEXT (Guidelines & Safety Data):
271
  {rag_context}
272
 
 
49
  - Suspected infection site: {infection_site}
50
  - Suspected source: {suspected_source}
51
 
52
+ SITE-SPECIFIC ASSESSMENT:
53
+ {site_vitals}
54
+
55
  RAG CONTEXT (Relevant Guidelines):
56
  {rag_context}
57
 
 
270
  - Source: {suspected_source}
271
  - Severity: {severity}
272
 
273
+ SITE-SPECIFIC ASSESSMENT:
274
+ {site_vitals}
275
+
276
  RAG CONTEXT (Guidelines & Safety Data):
277
  {rag_context}
278
 
src/state.py CHANGED
@@ -64,7 +64,8 @@ class InfectionState(TypedDict, total=False):
64
  vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
65
 
66
  # Lab data & MICs
67
- labs_raw_text: NotRequired[Optional[str]] # raw OCR or PDF text
 
68
  labs_parsed: NotRequired[List[LabResult]]
69
  mic_data: NotRequired[List[MICDatum]]
70
  mic_trend_summary: NotRequired[Optional[str]]
 
64
  vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
65
 
66
  # Lab data & MICs
67
+ labs_raw_text: NotRequired[Optional[str]] # raw OCR or pasted text
68
+ labs_image_bytes: NotRequired[Optional[bytes]] # uploaded image (PNG/JPG/TIFF) for vision model
69
  labs_parsed: NotRequired[List[LabResult]]
70
  mic_data: NotRequired[List[MICDatum]]
71
  mic_trend_summary: NotRequired[Optional[str]]