Spaces:
Configuration error
Configuration error
Upload 6 files
Browse files- .env.example +0 -0
- .gitignore +48 -0
- README.md +0 -19
- app.py +178 -0
- create_patient_index.py +104 -0
- requirements.txt +0 -3
.env.example
ADDED
|
File without changes
|
.gitignore
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Raw image formats
|
| 2 |
+
*.raw
|
| 3 |
+
*.dng
|
| 4 |
+
*.cr2
|
| 5 |
+
*.cr3
|
| 6 |
+
*.nef
|
| 7 |
+
*.arw
|
| 8 |
+
*.rw2
|
| 9 |
+
*.orf
|
| 10 |
+
*.srw
|
| 11 |
+
*.x3f
|
| 12 |
+
*.raf
|
| 13 |
+
*.dcr
|
| 14 |
+
*.k25
|
| 15 |
+
*.kdc
|
| 16 |
+
*.mrw
|
| 17 |
+
|
| 18 |
+
# Medical imaging formats (common in clinical AI)
|
| 19 |
+
*.dcm
|
| 20 |
+
*.dicom
|
| 21 |
+
*.nii
|
| 22 |
+
*.nii.gz
|
| 23 |
+
*.mha
|
| 24 |
+
*.mhd
|
| 25 |
+
|
| 26 |
+
# Compressed raw formats
|
| 27 |
+
*.tiff
|
| 28 |
+
*.tif
|
| 29 |
+
|
| 30 |
+
# Processed image formats that might be large
|
| 31 |
+
*.png
|
| 32 |
+
*.jpg
|
| 33 |
+
*.jpeg
|
| 34 |
+
*.bmp
|
| 35 |
+
*.gif
|
| 36 |
+
*.webp
|
| 37 |
+
|
| 38 |
+
# Model checkpoints
|
| 39 |
+
*.pt
|
| 40 |
+
*.pth
|
| 41 |
+
*.ckpt
|
| 42 |
+
*.model
|
| 43 |
+
*.h5
|
| 44 |
+
*.pb
|
| 45 |
+
*.onnx
|
| 46 |
+
|
| 47 |
+
# Checkpoint directories
|
| 48 |
+
checkpoints/
|
README.md
CHANGED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ChestX-Ray Diagnosis
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 8501
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
-
pinned: false
|
| 11 |
-
short_description: This demo showcases a multimodal deep learning system that c
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
# Welcome to Streamlit!
|
| 15 |
-
|
| 16 |
-
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
| 17 |
-
|
| 18 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 19 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from demo.utils.load_model import load_fusion_model
|
| 7 |
+
from demo.utils.grad_cam import GradCAM, overlay_cam
|
| 8 |
+
from demo.utils.saliency import (
|
| 9 |
+
compute_text_saliency,
|
| 10 |
+
merge_wordpieces,
|
| 11 |
+
filter_tokens,
|
| 12 |
+
highlight_text,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# --------------------------------------------------
|
| 16 |
+
# Page configuration
|
| 17 |
+
# --------------------------------------------------
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="Multimodal Clinical AI",
|
| 20 |
+
layout="wide",
|
| 21 |
+
initial_sidebar_state="collapsed"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# --------------------------------------------------
|
| 25 |
+
# Header
|
| 26 |
+
# --------------------------------------------------
|
| 27 |
+
st.markdown(
|
| 28 |
+
"""
|
| 29 |
+
<h2 style="margin-bottom:0">Multimodal Clinical Decision Support</h2>
|
| 30 |
+
<p style="color:gray; margin-top:4px">
|
| 31 |
+
Chest X-ray + Radiology Text → Ranked Diagnoses with Explainability
|
| 32 |
+
</p>
|
| 33 |
+
""",
|
| 34 |
+
unsafe_allow_html=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
st.divider()
|
| 38 |
+
|
| 39 |
+
# --------------------------------------------------
|
| 40 |
+
# Load model (cached)
|
| 41 |
+
# --------------------------------------------------
|
| 42 |
+
@st.cache_resource
|
| 43 |
+
def load_all():
|
| 44 |
+
return load_fusion_model(
|
| 45 |
+
"checkpoints/fusion_model/fusion_layer4_tuned.pt"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
model, tokenizer, image_transform, LABELS, device = load_all()
|
| 49 |
+
|
| 50 |
+
# --------------------------------------------------
|
| 51 |
+
# Input Section
|
| 52 |
+
# --------------------------------------------------
|
| 53 |
+
col1, col2 = st.columns(2)
|
| 54 |
+
|
| 55 |
+
with col1:
|
| 56 |
+
st.subheader("Chest X-ray")
|
| 57 |
+
uploaded_image = st.file_uploader(
|
| 58 |
+
"Upload Chest X-ray",
|
| 59 |
+
type=["png", "jpg", "jpeg"],
|
| 60 |
+
label_visibility="collapsed"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
with col2:
|
| 64 |
+
st.subheader("Radiology Findings")
|
| 65 |
+
findings = st.text_area(
|
| 66 |
+
"Enter findings",
|
| 67 |
+
height=180,
|
| 68 |
+
placeholder="e.g. Enlarged cardiac silhouette with pulmonary congestion...",
|
| 69 |
+
label_visibility="collapsed"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 73 |
+
analyze = st.button("Analyze Case", use_container_width=True)
|
| 74 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 75 |
+
|
| 76 |
+
# --------------------------------------------------
|
| 77 |
+
# Inference + Explainability
|
| 78 |
+
# --------------------------------------------------
|
| 79 |
+
if analyze and uploaded_image and findings:
|
| 80 |
+
|
| 81 |
+
# ---- Preprocess inputs ----
|
| 82 |
+
image = Image.open(uploaded_image).convert("RGB")
|
| 83 |
+
image_tensor = image_transform(image).unsqueeze(0).to(device)
|
| 84 |
+
|
| 85 |
+
enc = tokenizer(
|
| 86 |
+
findings,
|
| 87 |
+
padding="max_length",
|
| 88 |
+
truncation=True,
|
| 89 |
+
max_length=256,
|
| 90 |
+
return_tensors="pt"
|
| 91 |
+
)
|
| 92 |
+
input_ids = enc["input_ids"].to(device)
|
| 93 |
+
attention_mask = enc["attention_mask"].to(device)
|
| 94 |
+
|
| 95 |
+
# ---- Forward pass ----
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
logits = model(image_tensor, input_ids, attention_mask)
|
| 98 |
+
probs = F.softmax(logits, dim=1)
|
| 99 |
+
|
| 100 |
+
top2_prob, top2_idx = torch.topk(probs, k=2, dim=1)
|
| 101 |
+
primary_idx = top2_idx[0, 0].item()
|
| 102 |
+
secondary_idx = top2_idx[0, 1].item()
|
| 103 |
+
|
| 104 |
+
# --------------------------------------------------
|
| 105 |
+
# Diagnosis Output
|
| 106 |
+
# --------------------------------------------------
|
| 107 |
+
col1, col2 = st.columns(2)
|
| 108 |
+
|
| 109 |
+
with col1:
|
| 110 |
+
st.markdown("### 🩺 Primary Diagnosis")
|
| 111 |
+
st.success(
|
| 112 |
+
f"{LABELS[primary_idx]} \nConfidence: {top2_prob[0,0]:.2f}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
with col2:
|
| 116 |
+
st.markdown("### 🔍 Secondary Diagnosis")
|
| 117 |
+
st.info(
|
| 118 |
+
f"{LABELS[secondary_idx]} \nConfidence: {top2_prob[0,1]:.2f}"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# --------------------------------------------------
|
| 122 |
+
# Explainability
|
| 123 |
+
# --------------------------------------------------
|
| 124 |
+
st.divider()
|
| 125 |
+
st.markdown("## Explainability")
|
| 126 |
+
|
| 127 |
+
col1, col2 = st.columns(2)
|
| 128 |
+
|
| 129 |
+
# ---- Grad-CAM ----
|
| 130 |
+
with col1:
|
| 131 |
+
st.markdown("#### Image Evidence (Grad-CAM)")
|
| 132 |
+
|
| 133 |
+
gradcam = GradCAM(model, model.image_encoder.layer4)
|
| 134 |
+
cam = gradcam.generate(
|
| 135 |
+
image_tensor,
|
| 136 |
+
input_ids,
|
| 137 |
+
attention_mask,
|
| 138 |
+
class_idx=primary_idx
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
overlay = overlay_cam(image_tensor, cam)
|
| 142 |
+
st.image(
|
| 143 |
+
overlay,
|
| 144 |
+
use_column_width=True,
|
| 145 |
+
caption="Regions influencing the primary diagnosis"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# ---- Text Saliency ----
|
| 149 |
+
with col2:
|
| 150 |
+
st.markdown("#### Text Evidence (Important Terms)")
|
| 151 |
+
|
| 152 |
+
saliency, attn_mask = compute_text_saliency(
|
| 153 |
+
model,
|
| 154 |
+
input_ids,
|
| 155 |
+
attention_mask,
|
| 156 |
+
target_class=primary_idx
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 160 |
+
|
| 161 |
+
# Clean tokens
|
| 162 |
+
tokens, scores = filter_tokens(tokens, saliency, attn_mask)
|
| 163 |
+
|
| 164 |
+
# Merge wordpieces
|
| 165 |
+
tokens, scores = merge_wordpieces(tokens, scores)
|
| 166 |
+
|
| 167 |
+
# Highlight text
|
| 168 |
+
html_text = highlight_text(tokens, scores)
|
| 169 |
+
st.markdown(html_text, unsafe_allow_html=True)
|
| 170 |
+
|
| 171 |
+
# --------------------------------------------------
|
| 172 |
+
# Footer / Disclaimer
|
| 173 |
+
# --------------------------------------------------
|
| 174 |
+
st.divider()
|
| 175 |
+
st.caption(
|
| 176 |
+
"⚠️ For educational and research purposes only. "
|
| 177 |
+
"Not intended for clinical use."
|
| 178 |
+
)
|
create_patient_index.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import re
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
# -----------------------------
|
| 8 |
+
# CONFIG
|
| 9 |
+
# -----------------------------
|
| 10 |
+
DATASET_NAME = "itsanmolgupta/mimic-cxr-dataset" # change this
|
| 11 |
+
SPLIT = "train"
|
| 12 |
+
|
| 13 |
+
IMAGE_DIR = Path("data/raw/images")
|
| 14 |
+
OUTPUT_CSV = Path("data/metadata/patient_index.csv")
|
| 15 |
+
|
| 16 |
+
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# -----------------------------
|
| 20 |
+
# LABEL DEFINITIONS
|
| 21 |
+
# -----------------------------
|
| 22 |
+
LABEL_KEYWORDS = {
|
| 23 |
+
"PNEUMOTHORAX": ["pneumothorax"],
|
| 24 |
+
"PNEUMONIA": ["pneumonia", "consolidation", "airspace disease"],
|
| 25 |
+
"EDEMA": ["pulmonary edema", "vascular congestion"],
|
| 26 |
+
"EFFUSION": ["pleural effusion"],
|
| 27 |
+
"CARDIOMEGALY": ["cardiomegaly", "enlarged heart"],
|
| 28 |
+
"NORMAL": [
|
| 29 |
+
"no acute cardiopulmonary",
|
| 30 |
+
"no acute abnormality",
|
| 31 |
+
"no acute disease",
|
| 32 |
+
"normal chest",
|
| 33 |
+
"unremarkable"
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
PRIORITY = [
|
| 38 |
+
"PNEUMOTHORAX",
|
| 39 |
+
"PNEUMONIA",
|
| 40 |
+
"EDEMA",
|
| 41 |
+
"EFFUSION",
|
| 42 |
+
"CARDIOMEGALY",
|
| 43 |
+
"NORMAL"
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def assign_label(impression: str) -> str:
|
| 48 |
+
if not isinstance(impression, str):
|
| 49 |
+
return "OTHER"
|
| 50 |
+
|
| 51 |
+
text = impression.lower()
|
| 52 |
+
text = re.sub(r"[^\w\s]", " ", text)
|
| 53 |
+
|
| 54 |
+
for label in PRIORITY:
|
| 55 |
+
for kw in LABEL_KEYWORDS[label]:
|
| 56 |
+
if kw in text:
|
| 57 |
+
return label
|
| 58 |
+
|
| 59 |
+
return "OTHER"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# -----------------------------
|
| 63 |
+
# MAIN PIPELINE
|
| 64 |
+
# -----------------------------
|
| 65 |
+
def main():
|
| 66 |
+
print("📥 Loading Hugging Face dataset...")
|
| 67 |
+
dataset = load_dataset(DATASET_NAME, split=SPLIT)
|
| 68 |
+
|
| 69 |
+
records = []
|
| 70 |
+
|
| 71 |
+
for idx, sample in enumerate(dataset):
|
| 72 |
+
image = sample["image"]
|
| 73 |
+
findings = sample["findings"]
|
| 74 |
+
impression = sample["impression"]
|
| 75 |
+
|
| 76 |
+
if image is None or findings is None or impression is None:
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# Save image locally (important for PyTorch Dataset later)
|
| 80 |
+
image_path = IMAGE_DIR / f"img_{idx}.png"
|
| 81 |
+
if not image_path.exists():
|
| 82 |
+
image.save(image_path)
|
| 83 |
+
|
| 84 |
+
label = assign_label(impression)
|
| 85 |
+
|
| 86 |
+
records.append({
|
| 87 |
+
"image_path": str(image_path),
|
| 88 |
+
"findings": findings,
|
| 89 |
+
"impression": impression,
|
| 90 |
+
"label": label
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
if idx % 1000 == 0:
|
| 94 |
+
print(f"Processed {idx} samples...")
|
| 95 |
+
|
| 96 |
+
df = pd.DataFrame(records)
|
| 97 |
+
df.to_csv(OUTPUT_CSV, index=False)
|
| 98 |
+
|
| 99 |
+
print("\n✅ patient_index.csv created")
|
| 100 |
+
print(df["label"].value_counts())
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
requirements.txt
CHANGED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
altair
|
| 2 |
-
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|