Enhance patient analysis form with dynamic site-specific fields and support for lab image uploads
Browse files- app.py +119 -6
- src/agents.py +40 -9
- src/form_config.py +327 -0
- src/graph.py +7 -2
- src/loader.py +49 -3
- src/prompts.py +6 -0
- src/state.py +2 -1
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import streamlit as st
|
|
| 12 |
PROJECT_ROOT = Path(__file__).parent
|
| 13 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 14 |
|
|
|
|
| 15 |
from src.tools import (
|
| 16 |
calculate_mic_trend,
|
| 17 |
get_empirical_therapy_guidance,
|
|
@@ -380,13 +381,75 @@ def page_patient_analysis():
|
|
| 380 |
height = st.number_input("Height (cm)", 50.0, 250.0, 170.0, step=0.5)
|
| 381 |
with c2:
|
| 382 |
sex = st.selectbox("Biological sex", ["male", "female"])
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
with c3:
|
| 385 |
infection_site = st.selectbox(
|
| 386 |
"Primary infection site",
|
| 387 |
["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"],
|
|
|
|
| 388 |
)
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
with st.expander("Medical History"):
|
| 392 |
c1, c2 = st.columns(2)
|
|
@@ -404,9 +467,56 @@ def page_patient_analysis():
|
|
| 404 |
)
|
| 405 |
|
| 406 |
with st.expander("Lab / Culture Results (optional — triggers targeted pathway)"):
|
| 407 |
-
method = st.radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
labs_raw_text = None
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
labs_raw_text = st.text_area(
|
| 411 |
"Lab report",
|
| 412 |
placeholder=(
|
|
@@ -422,6 +532,7 @@ def page_patient_analysis():
|
|
| 422 |
run_btn = st.button("Run Agent Pipeline", type="primary", use_container_width=False)
|
| 423 |
|
| 424 |
if run_btn:
|
|
|
|
| 425 |
patient_data = {
|
| 426 |
"age_years": age,
|
| 427 |
"weight_kg": weight,
|
|
@@ -433,11 +544,13 @@ def page_patient_analysis():
|
|
| 433 |
"medications": [m.strip() for m in medications.split("\n") if m.strip()],
|
| 434 |
"allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
|
| 435 |
"comorbidities": list(comorbidities) + list(risk_factors),
|
|
|
|
|
|
|
| 436 |
}
|
| 437 |
|
| 438 |
stages = (
|
| 439 |
["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
|
| 440 |
-
if
|
| 441 |
else ["Intake Historian", "Clinical Pharmacologist"]
|
| 442 |
)
|
| 443 |
|
|
@@ -449,7 +562,7 @@ def page_patient_analysis():
|
|
| 449 |
from src.graph import run_pipeline
|
| 450 |
result = run_pipeline(patient_data, labs_raw_text)
|
| 451 |
except Exception:
|
| 452 |
-
result = _demo_result(patient_data, labs_raw_text)
|
| 453 |
|
| 454 |
prog.progress(100, text="Complete")
|
| 455 |
st.session_state.pipeline_result = result
|
|
|
|
| 12 |
PROJECT_ROOT = Path(__file__).parent
|
| 13 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 14 |
|
| 15 |
+
from src.form_config import CREATININE_PROMINENT_SITES, SITE_SPECIFIC_FIELDS, SUSPECTED_SOURCE_OPTIONS
|
| 16 |
from src.tools import (
|
| 17 |
calculate_mic_trend,
|
| 18 |
get_empirical_therapy_guidance,
|
|
|
|
| 381 |
height = st.number_input("Height (cm)", 50.0, 250.0, 170.0, step=0.5)
|
| 382 |
with c2:
|
| 383 |
sex = st.selectbox("Biological sex", ["male", "female"])
|
| 384 |
+
# Infection site is needed to decide creatinine visibility, so render it first
|
| 385 |
+
# (Streamlit reruns top-to-bottom, but c3 renders in the same pass, so we
|
| 386 |
+
# read infection_site from session state on the *next* rerun. We default
|
| 387 |
+
# to the current widget value via a placeholder key.)
|
| 388 |
+
infection_site = st.session_state.get("_infection_site_val", "urinary")
|
| 389 |
+
if infection_site in CREATININE_PROMINENT_SITES:
|
| 390 |
+
creatinine = st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1,
|
| 391 |
+
help="Required for CrCl-based dose adjustment")
|
| 392 |
+
else:
|
| 393 |
+
renal_flag = st.checkbox("Known renal impairment / CKD?",
|
| 394 |
+
help="Check to enter serum creatinine for dose adjustment")
|
| 395 |
+
creatinine = (
|
| 396 |
+
st.number_input("Serum Creatinine (mg/dL)", 0.1, 20.0, 1.2, step=0.1)
|
| 397 |
+
if renal_flag else None
|
| 398 |
+
)
|
| 399 |
with c3:
|
| 400 |
infection_site = st.selectbox(
|
| 401 |
"Primary infection site",
|
| 402 |
["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"],
|
| 403 |
+
key="_infection_site_val",
|
| 404 |
)
|
| 405 |
+
source_options = SUSPECTED_SOURCE_OPTIONS.get(infection_site, [])
|
| 406 |
+
if source_options:
|
| 407 |
+
suspected_source = st.selectbox("Suspected source", source_options)
|
| 408 |
+
if suspected_source == "Other":
|
| 409 |
+
suspected_source = st.text_input(
|
| 410 |
+
"Specify source", placeholder="Describe the suspected source"
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
suspected_source = st.text_input(
|
| 414 |
+
"Suspected source", placeholder="e.g., community-acquired infection"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# ── Site-specific assessment (dynamic per infection site) ──
|
| 418 |
+
site_vitals: dict[str, str] = {}
|
| 419 |
+
site_fields = SITE_SPECIFIC_FIELDS.get(infection_site, [])
|
| 420 |
+
if site_fields:
|
| 421 |
+
with st.expander(f"Site-Specific Assessment — {infection_site.title()}", expanded=True):
|
| 422 |
+
cols = st.columns(2)
|
| 423 |
+
for i, field in enumerate(site_fields):
|
| 424 |
+
col = cols[i % 2]
|
| 425 |
+
with col:
|
| 426 |
+
fkey = f"site_{field['key']}"
|
| 427 |
+
ftype = field["type"]
|
| 428 |
+
if ftype == "selectbox":
|
| 429 |
+
val = st.selectbox(field["label"], field["options"], key=fkey)
|
| 430 |
+
elif ftype == "multiselect":
|
| 431 |
+
val = st.multiselect(field["label"], field["options"], key=fkey)
|
| 432 |
+
val = ", ".join(val) if val else ""
|
| 433 |
+
elif ftype == "number_input":
|
| 434 |
+
val = st.number_input(
|
| 435 |
+
field["label"],
|
| 436 |
+
min_value=field.get("min", 0.0),
|
| 437 |
+
max_value=field.get("max", 999.0),
|
| 438 |
+
value=field.get("default", 0.0),
|
| 439 |
+
step=field.get("step", 1.0),
|
| 440 |
+
key=fkey,
|
| 441 |
+
)
|
| 442 |
+
val = str(val)
|
| 443 |
+
elif ftype == "checkbox":
|
| 444 |
+
val = st.checkbox(
|
| 445 |
+
field["label"], value=field.get("default", False), key=fkey
|
| 446 |
+
)
|
| 447 |
+
val = "Yes" if val else "No"
|
| 448 |
+
elif ftype == "text_input":
|
| 449 |
+
val = st.text_input(field["label"], key=fkey)
|
| 450 |
+
else:
|
| 451 |
+
continue
|
| 452 |
+
site_vitals[field["key"]] = str(val)
|
| 453 |
|
| 454 |
with st.expander("Medical History"):
|
| 455 |
c1, c2 = st.columns(2)
|
|
|
|
| 467 |
)
|
| 468 |
|
| 469 |
with st.expander("Lab / Culture Results (optional — triggers targeted pathway)"):
|
| 470 |
+
method = st.radio(
|
| 471 |
+
"Input method",
|
| 472 |
+
["None — empirical pathway only", "Upload file (PDF / image)", "Paste lab text"],
|
| 473 |
+
horizontal=True,
|
| 474 |
+
)
|
| 475 |
labs_raw_text = None
|
| 476 |
+
labs_image_bytes = None
|
| 477 |
+
|
| 478 |
+
if method == "Upload file (PDF / image)":
|
| 479 |
+
uploaded = st.file_uploader(
|
| 480 |
+
"Lab report file",
|
| 481 |
+
type=["pdf", "png", "jpg", "jpeg", "tiff", "tif", "bmp"],
|
| 482 |
+
help="Upload a culture & sensitivity report, antibiogram, or any lab document.",
|
| 483 |
+
)
|
| 484 |
+
if uploaded is not None:
|
| 485 |
+
file_bytes = uploaded.read()
|
| 486 |
+
ext = uploaded.name.rsplit(".", 1)[-1].lower()
|
| 487 |
+
if ext == "pdf":
|
| 488 |
+
# Extract text from PDF using pypdf
|
| 489 |
+
import pypdf
|
| 490 |
+
from io import BytesIO
|
| 491 |
+
try:
|
| 492 |
+
reader = pypdf.PdfReader(BytesIO(file_bytes))
|
| 493 |
+
extracted = "\n".join(
|
| 494 |
+
page.extract_text() or "" for page in reader.pages
|
| 495 |
+
).strip()
|
| 496 |
+
if extracted:
|
| 497 |
+
labs_raw_text = extracted
|
| 498 |
+
st.success(f"PDF parsed — {len(reader.pages)} page(s), {len(extracted)} characters extracted.")
|
| 499 |
+
else:
|
| 500 |
+
st.warning(
|
| 501 |
+
"PDF text extraction returned empty content (scanned PDF?). "
|
| 502 |
+
"The file will be processed as an image by the vision model."
|
| 503 |
+
)
|
| 504 |
+
# Convert first page to image fallback via pillow (requires pypdf extras)
|
| 505 |
+
labs_image_bytes = file_bytes
|
| 506 |
+
except Exception as e:
|
| 507 |
+
st.error(f"PDF parsing failed: {e}")
|
| 508 |
+
else:
|
| 509 |
+
# Image file — pass directly to the multimodal model
|
| 510 |
+
labs_image_bytes = file_bytes
|
| 511 |
+
from PIL import Image as _PILImage
|
| 512 |
+
from io import BytesIO as _BytesIO
|
| 513 |
+
try:
|
| 514 |
+
thumb = _PILImage.open(_BytesIO(file_bytes))
|
| 515 |
+
st.image(thumb, caption=f"Uploaded: {uploaded.name}", width=320)
|
| 516 |
+
except Exception:
|
| 517 |
+
st.info(f"Image uploaded: {uploaded.name}")
|
| 518 |
+
|
| 519 |
+
elif method == "Paste lab text":
|
| 520 |
labs_raw_text = st.text_area(
|
| 521 |
"Lab report",
|
| 522 |
placeholder=(
|
|
|
|
| 532 |
run_btn = st.button("Run Agent Pipeline", type="primary", use_container_width=False)
|
| 533 |
|
| 534 |
if run_btn:
|
| 535 |
+
has_lab_input = bool(labs_raw_text or labs_image_bytes)
|
| 536 |
patient_data = {
|
| 537 |
"age_years": age,
|
| 538 |
"weight_kg": weight,
|
|
|
|
| 544 |
"medications": [m.strip() for m in medications.split("\n") if m.strip()],
|
| 545 |
"allergies": [a.strip() for a in allergies.split("\n") if a.strip()],
|
| 546 |
"comorbidities": list(comorbidities) + list(risk_factors),
|
| 547 |
+
"vitals": site_vitals,
|
| 548 |
+
"labs_image_bytes": labs_image_bytes,
|
| 549 |
}
|
| 550 |
|
| 551 |
stages = (
|
| 552 |
["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
|
| 553 |
+
if has_lab_input
|
| 554 |
else ["Intake Historian", "Clinical Pharmacologist"]
|
| 555 |
)
|
| 556 |
|
|
|
|
| 562 |
from src.graph import run_pipeline
|
| 563 |
result = run_pipeline(patient_data, labs_raw_text)
|
| 564 |
except Exception:
|
| 565 |
+
result = _demo_result(patient_data, labs_raw_text or bool(labs_image_bytes))
|
| 566 |
|
| 567 |
prog.progress(100, text="Complete")
|
| 568 |
st.session_state.pipeline_result = result
|
src/agents.py
CHANGED
|
@@ -12,7 +12,7 @@ import logging
|
|
| 12 |
from typing import Optional
|
| 13 |
|
| 14 |
from .config import get_settings
|
| 15 |
-
from .loader import run_inference, TextModelName
|
| 16 |
from .prompts import (
|
| 17 |
INTAKE_HISTORIAN_SYSTEM,
|
| 18 |
INTAKE_HISTORIAN_PROMPT,
|
|
@@ -66,12 +66,17 @@ def run_intake_historian(state: InfectionState) -> InfectionState:
|
|
| 66 |
patient_context={"pathogen_type": state.get("suspected_source")},
|
| 67 |
)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
|
| 70 |
patient_data=patient_data,
|
| 71 |
medications=', '.join(state.get('medications', [])) or 'None reported',
|
| 72 |
allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
|
| 73 |
infection_site=state.get('infection_site', 'Unknown'),
|
| 74 |
suspected_source=state.get('suspected_source', 'Unknown'),
|
|
|
|
| 75 |
rag_context=rag_context,
|
| 76 |
)}"
|
| 77 |
|
|
@@ -105,14 +110,24 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
|
|
| 105 |
logger.info("Running Vision Specialist agent...")
|
| 106 |
|
| 107 |
labs_raw = state.get("labs_raw_text", "")
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
logger.info("No lab data to process, skipping Vision Specialist")
|
| 110 |
state["vision_notes"] = "No lab data provided"
|
| 111 |
state["route_to_trend_analyst"] = False
|
| 112 |
return state
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
rag_context = get_context_for_agent(
|
| 117 |
agent_name="vision_specialist",
|
| 118 |
query="culture sensitivity susceptibility interpretation",
|
|
@@ -120,13 +135,22 @@ def run_vision_specialist(state: InfectionState) -> InfectionState:
|
|
| 120 |
)
|
| 121 |
|
| 122 |
prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
|
| 123 |
-
report_content=
|
| 124 |
-
source_format=
|
| 125 |
language=language,
|
| 126 |
)}"
|
| 127 |
|
| 128 |
try:
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
parsed = safe_json_parse(response)
|
| 131 |
if parsed:
|
| 132 |
state["vision_notes"] = json.dumps(parsed, indent=2)
|
|
@@ -256,6 +280,10 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 256 |
patient_context={"proposed_antibiotic": None},
|
| 257 |
)
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
|
| 260 |
intake_summary=intake_summary,
|
| 261 |
lab_results=lab_results,
|
|
@@ -268,6 +296,7 @@ def run_clinical_pharmacologist(state: InfectionState) -> InfectionState:
|
|
| 268 |
infection_site=state.get('infection_site', 'Unknown'),
|
| 269 |
suspected_source=state.get('suspected_source', 'Unknown'),
|
| 270 |
severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
|
|
|
|
| 271 |
rag_context=rag_context,
|
| 272 |
)}"
|
| 273 |
|
|
@@ -355,8 +384,10 @@ def _format_patient_data(state: InfectionState) -> str:
|
|
| 355 |
lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
|
| 356 |
|
| 357 |
if state.get("vitals"):
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
| 360 |
|
| 361 |
return "\n".join(lines) if lines else "No patient data available"
|
| 362 |
|
|
|
|
| 12 |
from typing import Optional
|
| 13 |
|
| 14 |
from .config import get_settings
|
| 15 |
+
from .loader import run_inference, run_inference_with_image, TextModelName
|
| 16 |
from .prompts import (
|
| 17 |
INTAKE_HISTORIAN_SYSTEM,
|
| 18 |
INTAKE_HISTORIAN_PROMPT,
|
|
|
|
| 66 |
patient_context={"pathogen_type": state.get("suspected_source")},
|
| 67 |
)
|
| 68 |
|
| 69 |
+
site_vitals_str = "\n".join(
|
| 70 |
+
f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items()
|
| 71 |
+
) or "None provided"
|
| 72 |
+
|
| 73 |
prompt = f"{INTAKE_HISTORIAN_SYSTEM}\n\n{INTAKE_HISTORIAN_PROMPT.format(
|
| 74 |
patient_data=patient_data,
|
| 75 |
medications=', '.join(state.get('medications', [])) or 'None reported',
|
| 76 |
allergies=', '.join(state.get('allergies', [])) or 'No known allergies',
|
| 77 |
infection_site=state.get('infection_site', 'Unknown'),
|
| 78 |
suspected_source=state.get('suspected_source', 'Unknown'),
|
| 79 |
+
site_vitals=site_vitals_str,
|
| 80 |
rag_context=rag_context,
|
| 81 |
)}"
|
| 82 |
|
|
|
|
| 110 |
logger.info("Running Vision Specialist agent...")
|
| 111 |
|
| 112 |
labs_raw = state.get("labs_raw_text", "")
|
| 113 |
+
labs_image_bytes = state.get("labs_image_bytes")
|
| 114 |
+
|
| 115 |
+
if not labs_raw and not labs_image_bytes:
|
| 116 |
logger.info("No lab data to process, skipping Vision Specialist")
|
| 117 |
state["vision_notes"] = "No lab data provided"
|
| 118 |
state["route_to_trend_analyst"] = False
|
| 119 |
return state
|
| 120 |
|
| 121 |
+
# Determine input modality and prepare prompt content description
|
| 122 |
+
if labs_image_bytes:
|
| 123 |
+
source_format = "image"
|
| 124 |
+
language = "Auto-detected"
|
| 125 |
+
report_content = "See attached image — extract all lab data visible in the image."
|
| 126 |
+
else:
|
| 127 |
+
source_format = "text"
|
| 128 |
+
language = "English (assumed)"
|
| 129 |
+
report_content = labs_raw
|
| 130 |
+
|
| 131 |
rag_context = get_context_for_agent(
|
| 132 |
agent_name="vision_specialist",
|
| 133 |
query="culture sensitivity susceptibility interpretation",
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
prompt = f"{VISION_SPECIALIST_SYSTEM}\n\n{VISION_SPECIALIST_PROMPT.format(
|
| 138 |
+
report_content=report_content,
|
| 139 |
+
source_format=source_format,
|
| 140 |
language=language,
|
| 141 |
)}"
|
| 142 |
|
| 143 |
try:
|
| 144 |
+
if labs_image_bytes:
|
| 145 |
+
from io import BytesIO
|
| 146 |
+
from PIL import Image as PILImage
|
| 147 |
+
image = PILImage.open(BytesIO(labs_image_bytes)).convert("RGB")
|
| 148 |
+
logger.info(f"Running vision inference on uploaded image ({image.size})")
|
| 149 |
+
response = run_inference_with_image(
|
| 150 |
+
prompt=prompt, image=image, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
response = run_inference(prompt=prompt, model_name="medgemma_4b", max_new_tokens=2048, temperature=0.1)
|
| 154 |
parsed = safe_json_parse(response)
|
| 155 |
if parsed:
|
| 156 |
state["vision_notes"] = json.dumps(parsed, indent=2)
|
|
|
|
| 280 |
patient_context={"proposed_antibiotic": None},
|
| 281 |
)
|
| 282 |
|
| 283 |
+
site_vitals_str = "\n".join(
|
| 284 |
+
f"- {k.replace('_', ' ').title()}: {v}" for k, v in state.get("vitals", {}).items()
|
| 285 |
+
) or "None provided"
|
| 286 |
+
|
| 287 |
prompt = f"{CLINICAL_PHARMACOLOGIST_SYSTEM}\n\n{CLINICAL_PHARMACOLOGIST_PROMPT.format(
|
| 288 |
intake_summary=intake_summary,
|
| 289 |
lab_results=lab_results,
|
|
|
|
| 296 |
infection_site=state.get('infection_site', 'Unknown'),
|
| 297 |
suspected_source=state.get('suspected_source', 'Unknown'),
|
| 298 |
severity=state.get('intake_notes', {}).get('infection_severity', 'Unknown') if isinstance(state.get('intake_notes'), dict) else 'Unknown',
|
| 299 |
+
site_vitals=site_vitals_str,
|
| 300 |
rag_context=rag_context,
|
| 301 |
)}"
|
| 302 |
|
|
|
|
| 384 |
lines.append(f"Comorbidities: {', '.join(state['comorbidities'])}")
|
| 385 |
|
| 386 |
if state.get("vitals"):
|
| 387 |
+
lines.append("Site-Specific Assessment:")
|
| 388 |
+
for k, v in state["vitals"].items():
|
| 389 |
+
label = k.replace("_", " ").title()
|
| 390 |
+
lines.append(f" - {label}: {v}")
|
| 391 |
|
| 392 |
return "\n".join(lines) if lines else "No patient data available"
|
| 393 |
|
src/form_config.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Declarative field definitions for the dynamic Patient Analysis form.
|
| 3 |
+
|
| 4 |
+
Each infection site maps to a list of site-specific fields and contextual
|
| 5 |
+
suspected-source options. Universal fields (age, sex, weight, height,
|
| 6 |
+
creatinine, medications, allergies, comorbidities, risk factors) are always
|
| 7 |
+
shown and are NOT listed here.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
SITE_SPECIFIC_FIELDS: dict[str, list[dict]] = {
|
| 11 |
+
"urinary": [
|
| 12 |
+
{
|
| 13 |
+
"key": "catheter_status",
|
| 14 |
+
"label": "Catheter status",
|
| 15 |
+
"type": "selectbox",
|
| 16 |
+
"options": [
|
| 17 |
+
"No catheter",
|
| 18 |
+
"Indwelling (Foley)",
|
| 19 |
+
"Intermittent",
|
| 20 |
+
"Suprapubic",
|
| 21 |
+
"Recently removed (<48 h)",
|
| 22 |
+
],
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"key": "urinary_symptoms",
|
| 26 |
+
"label": "Urinary symptoms",
|
| 27 |
+
"type": "multiselect",
|
| 28 |
+
"options": [
|
| 29 |
+
"Dysuria",
|
| 30 |
+
"Frequency",
|
| 31 |
+
"Urgency",
|
| 32 |
+
"Hematuria",
|
| 33 |
+
"Suprapubic pain",
|
| 34 |
+
"Flank pain",
|
| 35 |
+
"Fever / chills",
|
| 36 |
+
],
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"key": "urine_appearance",
|
| 40 |
+
"label": "Urine appearance",
|
| 41 |
+
"type": "selectbox",
|
| 42 |
+
"options": ["Clear", "Cloudy", "Turbid", "Malodorous", "Hematuria"],
|
| 43 |
+
},
|
| 44 |
+
],
|
| 45 |
+
"respiratory": [
|
| 46 |
+
{
|
| 47 |
+
"key": "o2_saturation",
|
| 48 |
+
"label": "O\u2082 Saturation (%)",
|
| 49 |
+
"type": "number_input",
|
| 50 |
+
"min": 50.0,
|
| 51 |
+
"max": 100.0,
|
| 52 |
+
"default": 97.0,
|
| 53 |
+
"step": 0.5,
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"key": "ventilation_status",
|
| 57 |
+
"label": "Ventilation status",
|
| 58 |
+
"type": "selectbox",
|
| 59 |
+
"options": [
|
| 60 |
+
"Room air",
|
| 61 |
+
"Supplemental O\u2082 (nasal cannula)",
|
| 62 |
+
"Supplemental O\u2082 (mask)",
|
| 63 |
+
"Non-invasive (BiPAP / CPAP)",
|
| 64 |
+
"Mechanical ventilation",
|
| 65 |
+
],
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"key": "cough_type",
|
| 69 |
+
"label": "Cough type",
|
| 70 |
+
"type": "selectbox",
|
| 71 |
+
"options": ["None", "Dry", "Productive", "Hemoptysis"],
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"key": "sputum_character",
|
| 75 |
+
"label": "Sputum character",
|
| 76 |
+
"type": "selectbox",
|
| 77 |
+
"options": [
|
| 78 |
+
"None",
|
| 79 |
+
"Clear / white",
|
| 80 |
+
"Yellow",
|
| 81 |
+
"Green / purulent",
|
| 82 |
+
"Rust-colored",
|
| 83 |
+
"Blood-tinged",
|
| 84 |
+
],
|
| 85 |
+
},
|
| 86 |
+
],
|
| 87 |
+
"bloodstream": [
|
| 88 |
+
{
|
| 89 |
+
"key": "central_line_present",
|
| 90 |
+
"label": "Central line present",
|
| 91 |
+
"type": "checkbox",
|
| 92 |
+
"default": False,
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"key": "temperature_c",
|
| 96 |
+
"label": "Temperature (\u00b0C)",
|
| 97 |
+
"type": "number_input",
|
| 98 |
+
"min": 34.0,
|
| 99 |
+
"max": 43.0,
|
| 100 |
+
"default": 38.5,
|
| 101 |
+
"step": 0.1,
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"key": "heart_rate_bpm",
|
| 105 |
+
"label": "Heart rate (bpm)",
|
| 106 |
+
"type": "number_input",
|
| 107 |
+
"min": 30,
|
| 108 |
+
"max": 250,
|
| 109 |
+
"default": 90,
|
| 110 |
+
"step": 1,
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"key": "respiratory_rate",
|
| 114 |
+
"label": "Respiratory rate (/min)",
|
| 115 |
+
"type": "number_input",
|
| 116 |
+
"min": 5,
|
| 117 |
+
"max": 60,
|
| 118 |
+
"default": 18,
|
| 119 |
+
"step": 1,
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"key": "wbc_count",
|
| 123 |
+
"label": "WBC count (\u00d710\u2079/L)",
|
| 124 |
+
"type": "number_input",
|
| 125 |
+
"min": 0.0,
|
| 126 |
+
"max": 100.0,
|
| 127 |
+
"default": 12.0,
|
| 128 |
+
"step": 0.1,
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"key": "lactate_mmol",
|
| 132 |
+
"label": "Lactate (mmol/L)",
|
| 133 |
+
"type": "number_input",
|
| 134 |
+
"min": 0.0,
|
| 135 |
+
"max": 30.0,
|
| 136 |
+
"default": 1.0,
|
| 137 |
+
"step": 0.1,
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"key": "shock_status",
|
| 141 |
+
"label": "Shock status",
|
| 142 |
+
"type": "selectbox",
|
| 143 |
+
"options": [
|
| 144 |
+
"No shock",
|
| 145 |
+
"Compensated (SBP > 90, tachycardia)",
|
| 146 |
+
"Septic shock (vasopressors required)",
|
| 147 |
+
],
|
| 148 |
+
},
|
| 149 |
+
],
|
| 150 |
+
"skin": [
|
| 151 |
+
{
|
| 152 |
+
"key": "wound_type",
|
| 153 |
+
"label": "Wound type",
|
| 154 |
+
"type": "selectbox",
|
| 155 |
+
"options": [
|
| 156 |
+
"Laceration",
|
| 157 |
+
"Ulcer (diabetic / pressure)",
|
| 158 |
+
"Bite (animal / human)",
|
| 159 |
+
"Surgical site",
|
| 160 |
+
"Burn",
|
| 161 |
+
"Abscess",
|
| 162 |
+
"Cellulitis (no wound)",
|
| 163 |
+
],
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"key": "cellulitis_extent",
|
| 167 |
+
"label": "Cellulitis extent",
|
| 168 |
+
"type": "selectbox",
|
| 169 |
+
"options": [
|
| 170 |
+
"None",
|
| 171 |
+
"Localized (< 5 cm)",
|
| 172 |
+
"Moderate (5\u201310 cm)",
|
| 173 |
+
"Extensive (> 10 cm)",
|
| 174 |
+
"Rapidly spreading",
|
| 175 |
+
],
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"key": "abscess_present",
|
| 179 |
+
"label": "Abscess present",
|
| 180 |
+
"type": "checkbox",
|
| 181 |
+
"default": False,
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"key": "foreign_body",
|
| 185 |
+
"label": "Foreign body / implant",
|
| 186 |
+
"type": "checkbox",
|
| 187 |
+
"default": False,
|
| 188 |
+
},
|
| 189 |
+
],
|
| 190 |
+
"intra-abdominal": [
|
| 191 |
+
{
|
| 192 |
+
"key": "abdominal_pain_location",
|
| 193 |
+
"label": "Pain location",
|
| 194 |
+
"type": "selectbox",
|
| 195 |
+
"options": [
|
| 196 |
+
"Diffuse",
|
| 197 |
+
"RUQ",
|
| 198 |
+
"LUQ",
|
| 199 |
+
"RLQ",
|
| 200 |
+
"LLQ",
|
| 201 |
+
"Epigastric",
|
| 202 |
+
"Periumbilical",
|
| 203 |
+
],
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"key": "peritonitis_signs",
|
| 207 |
+
"label": "Peritonitis signs",
|
| 208 |
+
"type": "multiselect",
|
| 209 |
+
"options": [
|
| 210 |
+
"Guarding",
|
| 211 |
+
"Rebound tenderness",
|
| 212 |
+
"Rigidity",
|
| 213 |
+
"Absent bowel sounds",
|
| 214 |
+
],
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"key": "perforation_suspected",
|
| 218 |
+
"label": "Perforation suspected",
|
| 219 |
+
"type": "checkbox",
|
| 220 |
+
"default": False,
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"key": "ascites",
|
| 224 |
+
"label": "Ascites present",
|
| 225 |
+
"type": "checkbox",
|
| 226 |
+
"default": False,
|
| 227 |
+
},
|
| 228 |
+
],
|
| 229 |
+
"CNS": [
|
| 230 |
+
{
|
| 231 |
+
"key": "csf_obtained",
|
| 232 |
+
"label": "CSF obtained",
|
| 233 |
+
"type": "checkbox",
|
| 234 |
+
"default": False,
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"key": "neuro_symptoms",
|
| 238 |
+
"label": "Neurological symptoms",
|
| 239 |
+
"type": "multiselect",
|
| 240 |
+
"options": [
|
| 241 |
+
"Headache",
|
| 242 |
+
"Neck stiffness",
|
| 243 |
+
"Photophobia",
|
| 244 |
+
"Altered mental status",
|
| 245 |
+
"Seizures",
|
| 246 |
+
"Focal deficits",
|
| 247 |
+
],
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"key": "recent_neurosurgery",
|
| 251 |
+
"label": "Recent neurosurgery",
|
| 252 |
+
"type": "checkbox",
|
| 253 |
+
"default": False,
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"key": "gcs_score",
|
| 257 |
+
"label": "GCS score",
|
| 258 |
+
"type": "number_input",
|
| 259 |
+
"min": 3,
|
| 260 |
+
"max": 15,
|
| 261 |
+
"default": 15,
|
| 262 |
+
"step": 1,
|
| 263 |
+
},
|
| 264 |
+
],
|
| 265 |
+
"other": [],
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Sites where serum creatinine is shown prominently in demographics.
|
| 270 |
+
# For all other sites a "renal impairment?" toggle is shown instead.
|
| 271 |
+
CREATININE_PROMINENT_SITES: frozenset[str] = frozenset(
|
| 272 |
+
{"urinary", "bloodstream", "CNS", "respiratory"}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
SUSPECTED_SOURCE_OPTIONS: dict[str, list[str]] = {
|
| 276 |
+
"urinary": [
|
| 277 |
+
"Community-acquired UTI",
|
| 278 |
+
"Catheter-associated UTI (CAUTI)",
|
| 279 |
+
"Complicated UTI",
|
| 280 |
+
"Pyelonephritis",
|
| 281 |
+
"Urosepsis",
|
| 282 |
+
"Other",
|
| 283 |
+
],
|
| 284 |
+
"respiratory": [
|
| 285 |
+
"Community-acquired pneumonia (CAP)",
|
| 286 |
+
"Hospital-acquired pneumonia (HAP)",
|
| 287 |
+
"Ventilator-associated pneumonia (VAP)",
|
| 288 |
+
"Aspiration pneumonia",
|
| 289 |
+
"Lung abscess",
|
| 290 |
+
"Empyema",
|
| 291 |
+
"Other",
|
| 292 |
+
],
|
| 293 |
+
"bloodstream": [
|
| 294 |
+
"Primary bacteremia",
|
| 295 |
+
"Catheter-related BSI (CRBSI)",
|
| 296 |
+
"Secondary bacteremia (from known source)",
|
| 297 |
+
"Endocarditis",
|
| 298 |
+
"Unknown source",
|
| 299 |
+
"Other",
|
| 300 |
+
],
|
| 301 |
+
"skin": [
|
| 302 |
+
"Cellulitis",
|
| 303 |
+
"Surgical site infection",
|
| 304 |
+
"Diabetic foot infection",
|
| 305 |
+
"Bite wound infection",
|
| 306 |
+
"Necrotizing fasciitis",
|
| 307 |
+
"Abscess",
|
| 308 |
+
"Other",
|
| 309 |
+
],
|
| 310 |
+
"intra-abdominal": [
|
| 311 |
+
"Appendicitis",
|
| 312 |
+
"Cholecystitis / cholangitis",
|
| 313 |
+
"Diverticulitis",
|
| 314 |
+
"Peritonitis (SBP)",
|
| 315 |
+
"Post-surgical",
|
| 316 |
+
"Liver abscess",
|
| 317 |
+
"Other",
|
| 318 |
+
],
|
| 319 |
+
"CNS": [
|
| 320 |
+
"Community-acquired meningitis",
|
| 321 |
+
"Post-neurosurgical meningitis",
|
| 322 |
+
"Healthcare-associated ventriculitis",
|
| 323 |
+
"Brain abscess",
|
| 324 |
+
"Other",
|
| 325 |
+
],
|
| 326 |
+
"other": [],
|
| 327 |
+
}
|
src/graph.py
CHANGED
|
@@ -77,6 +77,9 @@ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> Infect
|
|
| 77 |
Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
|
| 78 |
Without it, only the empirical (Stage 1) pathway runs.
|
| 79 |
"""
|
|
|
|
|
|
|
|
|
|
| 80 |
initial_state: InfectionState = {
|
| 81 |
"age_years": patient_data.get("age_years"),
|
| 82 |
"weight_kg": patient_data.get("weight_kg"),
|
|
@@ -90,15 +93,17 @@ def run_pipeline(patient_data: dict, labs_raw_text: str | None = None) -> Infect
|
|
| 90 |
"suspected_source": patient_data.get("suspected_source"),
|
| 91 |
"country_or_region": patient_data.get("country_or_region"),
|
| 92 |
"vitals": patient_data.get("vitals", {}),
|
| 93 |
-
"stage": "targeted" if
|
| 94 |
"errors": [],
|
| 95 |
"safety_warnings": [],
|
| 96 |
}
|
| 97 |
|
| 98 |
if labs_raw_text:
|
| 99 |
initial_state["labs_raw_text"] = labs_raw_text
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
logger.info(f"Starting pipeline (stage: {initial_state['stage']})")
|
| 102 |
logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
|
| 103 |
|
| 104 |
try:
|
|
|
|
| 77 |
Pass labs_raw_text to trigger the targeted (Stage 2) pathway.
|
| 78 |
Without it, only the empirical (Stage 1) pathway runs.
|
| 79 |
"""
|
| 80 |
+
labs_image_bytes: bytes | None = patient_data.get("labs_image_bytes")
|
| 81 |
+
has_lab_input = bool(labs_raw_text or labs_image_bytes)
|
| 82 |
+
|
| 83 |
initial_state: InfectionState = {
|
| 84 |
"age_years": patient_data.get("age_years"),
|
| 85 |
"weight_kg": patient_data.get("weight_kg"),
|
|
|
|
| 93 |
"suspected_source": patient_data.get("suspected_source"),
|
| 94 |
"country_or_region": patient_data.get("country_or_region"),
|
| 95 |
"vitals": patient_data.get("vitals", {}),
|
| 96 |
+
"stage": "targeted" if has_lab_input else "empirical",
|
| 97 |
"errors": [],
|
| 98 |
"safety_warnings": [],
|
| 99 |
}
|
| 100 |
|
| 101 |
if labs_raw_text:
|
| 102 |
initial_state["labs_raw_text"] = labs_raw_text
|
| 103 |
+
if labs_image_bytes:
|
| 104 |
+
initial_state["labs_image_bytes"] = labs_image_bytes
|
| 105 |
|
| 106 |
+
logger.info(f"Starting pipeline (stage: {initial_state['stage']}, lab_text={bool(labs_raw_text)}, lab_image={bool(labs_image_bytes)})")
|
| 107 |
logger.info(f"Patient: {patient_data.get('age_years')}y, {patient_data.get('sex')}, infection: {patient_data.get('infection_site')}")
|
| 108 |
|
| 109 |
try:
|
src/loader.py
CHANGED
|
@@ -58,9 +58,20 @@ def _get_local_multimodal(model_name: TextModelName):
|
|
| 58 |
model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
|
| 59 |
logger.info(f"Model loaded successfully: {model_path}")
|
| 60 |
|
| 61 |
-
def _call(
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
inputs = processor.apply_chat_template(
|
| 65 |
messages, add_generation_prompt=True, tokenize=True,
|
| 66 |
return_dict=True, return_tensors="pt",
|
|
@@ -155,3 +166,38 @@ def run_inference(
|
|
| 155 |
except Exception as e:
|
| 156 |
logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
|
| 157 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
model = AutoModelForImageTextToText.from_pretrained(model_path, **load_kwargs)
|
| 59 |
logger.info(f"Model loaded successfully: {model_path}")
|
| 60 |
|
| 61 |
+
def _call(
|
| 62 |
+
prompt: str,
|
| 63 |
+
max_new_tokens: int = 512,
|
| 64 |
+
temperature: float = 0.2,
|
| 65 |
+
image=None, # optional PIL.Image.Image for vision-language inference
|
| 66 |
+
**generate_kwargs: Any,
|
| 67 |
+
) -> str:
|
| 68 |
+
# Build chat content; prepend image token when an image is provided
|
| 69 |
+
content = []
|
| 70 |
+
if image is not None:
|
| 71 |
+
content.append({"type": "image", "image": image})
|
| 72 |
+
content.append({"type": "text", "text": prompt})
|
| 73 |
+
messages = [{"role": "user", "content": content}]
|
| 74 |
+
|
| 75 |
inputs = processor.apply_chat_template(
|
| 76 |
messages, add_generation_prompt=True, tokenize=True,
|
| 77 |
return_dict=True, return_tensors="pt",
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
|
| 168 |
raise
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def run_inference_with_image(
|
| 172 |
+
prompt: str,
|
| 173 |
+
image: Any, # PIL.Image.Image
|
| 174 |
+
model_name: TextModelName = "medgemma_4b",
|
| 175 |
+
max_new_tokens: int = 1024,
|
| 176 |
+
temperature: float = 0.1,
|
| 177 |
+
**kwargs: Any,
|
| 178 |
+
) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Run vision-language inference passing a PIL image alongside the text prompt.
|
| 181 |
+
|
| 182 |
+
Falls back to text-only inference if the resolved model is not multimodal
|
| 183 |
+
(e.g. when medgemma_4b is remapped to a text-only model in the env config).
|
| 184 |
+
"""
|
| 185 |
+
logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
|
| 186 |
+
try:
|
| 187 |
+
model_path = _get_model_path(model_name)
|
| 188 |
+
if not _is_multimodal(model_path):
|
| 189 |
+
logger.warning(
|
| 190 |
+
f"{model_name} ({model_path}) is not a multimodal model; "
|
| 191 |
+
"falling back to text-only inference."
|
| 192 |
+
)
|
| 193 |
+
return run_inference(prompt, model_name, max_new_tokens, temperature, **kwargs)
|
| 194 |
+
|
| 195 |
+
model_fn = _get_local_multimodal(model_name)
|
| 196 |
+
result = model_fn(
|
| 197 |
+
prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
|
| 198 |
+
)
|
| 199 |
+
logger.info(f"Vision inference complete, response length: {len(result)} chars")
|
| 200 |
+
return result
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Vision inference failed for {model_name}: {e}", exc_info=True)
|
| 203 |
+
raise
|
src/prompts.py
CHANGED
|
@@ -49,6 +49,9 @@ CLINICAL CONTEXT:
|
|
| 49 |
- Suspected infection site: {infection_site}
|
| 50 |
- Suspected source: {suspected_source}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
RAG CONTEXT (Relevant Guidelines):
|
| 53 |
{rag_context}
|
| 54 |
|
|
@@ -267,6 +270,9 @@ INFECTION CONTEXT:
|
|
| 267 |
- Source: {suspected_source}
|
| 268 |
- Severity: {severity}
|
| 269 |
|
|
|
|
|
|
|
|
|
|
| 270 |
RAG CONTEXT (Guidelines & Safety Data):
|
| 271 |
{rag_context}
|
| 272 |
|
|
|
|
| 49 |
- Suspected infection site: {infection_site}
|
| 50 |
- Suspected source: {suspected_source}
|
| 51 |
|
| 52 |
+
SITE-SPECIFIC ASSESSMENT:
|
| 53 |
+
{site_vitals}
|
| 54 |
+
|
| 55 |
RAG CONTEXT (Relevant Guidelines):
|
| 56 |
{rag_context}
|
| 57 |
|
|
|
|
| 270 |
- Source: {suspected_source}
|
| 271 |
- Severity: {severity}
|
| 272 |
|
| 273 |
+
SITE-SPECIFIC ASSESSMENT:
|
| 274 |
+
{site_vitals}
|
| 275 |
+
|
| 276 |
RAG CONTEXT (Guidelines & Safety Data):
|
| 277 |
{rag_context}
|
| 278 |
|
src/state.py
CHANGED
|
@@ -64,7 +64,8 @@ class InfectionState(TypedDict, total=False):
|
|
| 64 |
vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
|
| 65 |
|
| 66 |
# Lab data & MICs
|
| 67 |
-
labs_raw_text: NotRequired[Optional[str]]
|
|
|
|
| 68 |
labs_parsed: NotRequired[List[LabResult]]
|
| 69 |
mic_data: NotRequired[List[MICDatum]]
|
| 70 |
mic_trend_summary: NotRequired[Optional[str]]
|
|
|
|
| 64 |
vitals: NotRequired[Dict[str, str]] # flexible key/value, e.g. {"BP": "120/80"}
|
| 65 |
|
| 66 |
# Lab data & MICs
|
| 67 |
+
labs_raw_text: NotRequired[Optional[str]] # raw OCR or pasted text
|
| 68 |
+
labs_image_bytes: NotRequired[Optional[bytes]] # uploaded image (PNG/JPG/TIFF) for vision model
|
| 69 |
labs_parsed: NotRequired[List[LabResult]]
|
| 70 |
mic_data: NotRequired[List[MICDatum]]
|
| 71 |
mic_trend_summary: NotRequired[Optional[str]]
|