import streamlit as st
import torch
import torch.nn.functional as F
from PIL import Image
from demo.utils.load_model import load_fusion_model
from demo.utils.grad_cam import GradCAM, overlay_cam
from demo.utils.saliency import (
compute_text_saliency,
merge_wordpieces,
filter_tokens,
highlight_text,
)
# --------------------------------------------------
# Page configuration
# --------------------------------------------------
st.set_page_config(
page_title="Multimodal Clinical AI",
layout="wide",
initial_sidebar_state="collapsed"
)
# --------------------------------------------------
# Header
# --------------------------------------------------
st.markdown(
"""
Multimodal Clinical Decision Support
Chest X-ray + Radiology Text → Ranked Diagnoses with Explainability
""",
unsafe_allow_html=True
)
st.divider()
# --------------------------------------------------
# Load model (cached)
# --------------------------------------------------
@st.cache_resource
def load_all():
return load_fusion_model(
"checkpoints/fusion_model/fusion_layer4_tuned.pt"
)
model, tokenizer, image_transform, LABELS, device = load_all()
# --------------------------------------------------
# Input Section
# --------------------------------------------------
col1, col2 = st.columns(2)
with col1:
st.subheader("Chest X-ray")
uploaded_image = st.file_uploader(
"Upload Chest X-ray",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed"
)
with col2:
st.subheader("Radiology Findings")
findings = st.text_area(
"Enter findings",
height=180,
placeholder="e.g. Enlarged cardiac silhouette with pulmonary congestion...",
label_visibility="collapsed"
)
st.markdown("
", unsafe_allow_html=True)
analyze = st.button("Analyze Case", use_container_width=True)
st.markdown("
", unsafe_allow_html=True)
# --------------------------------------------------
# Inference + Explainability
# --------------------------------------------------
if analyze and uploaded_image and findings:
# ---- Preprocess inputs ----
image = Image.open(uploaded_image).convert("RGB")
image_tensor = image_transform(image).unsqueeze(0).to(device)
enc = tokenizer(
findings,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt"
)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
# ---- Forward pass ----
with torch.no_grad():
logits = model(image_tensor, input_ids, attention_mask)
probs = F.softmax(logits, dim=1)
top2_prob, top2_idx = torch.topk(probs, k=2, dim=1)
primary_idx = top2_idx[0, 0].item()
secondary_idx = top2_idx[0, 1].item()
# --------------------------------------------------
# Diagnosis Output
# --------------------------------------------------
col1, col2 = st.columns(2)
with col1:
st.markdown("### 🩺 Primary Diagnosis")
st.success(
f"{LABELS[primary_idx]} \nConfidence: {top2_prob[0,0]:.2f}"
)
with col2:
st.markdown("### 🔍 Secondary Diagnosis")
st.info(
f"{LABELS[secondary_idx]} \nConfidence: {top2_prob[0,1]:.2f}"
)
# --------------------------------------------------
# Explainability
# --------------------------------------------------
st.divider()
st.markdown("## Explainability")
col1, col2 = st.columns(2)
# ---- Grad-CAM ----
with col1:
st.markdown("#### Image Evidence (Grad-CAM)")
gradcam = GradCAM(model, model.image_encoder.layer4)
cam = gradcam.generate(
image_tensor,
input_ids,
attention_mask,
class_idx=primary_idx
)
overlay = overlay_cam(image_tensor, cam)
st.image(
overlay,
use_column_width=True,
caption="Regions influencing the primary diagnosis"
)
# ---- Text Saliency ----
with col2:
st.markdown("#### Text Evidence (Important Terms)")
saliency, attn_mask = compute_text_saliency(
model,
input_ids,
attention_mask,
target_class=primary_idx
)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Clean tokens
tokens, scores = filter_tokens(tokens, saliency, attn_mask)
# Merge wordpieces
tokens, scores = merge_wordpieces(tokens, scores)
# Highlight text
html_text = highlight_text(tokens, scores)
st.markdown(html_text, unsafe_allow_html=True)
# --------------------------------------------------
# Footer / Disclaimer
# --------------------------------------------------
st.divider()
st.caption(
"⚠️ For educational and research purposes only. "
"Not intended for clinical use."
)