Krishwall commited on
Commit
f52734b
·
verified ·
1 Parent(s): 4d1d01a

Upload 6 files

Browse files
Files changed (6) hide show
  1. .env.example +0 -0
  2. .gitignore +48 -0
  3. README.md +0 -19
  4. app.py +178 -0
  5. create_patient_index.py +104 -0
  6. requirements.txt +0 -3
.env.example ADDED
File without changes
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Raw image formats
2
+ *.raw
3
+ *.dng
4
+ *.cr2
5
+ *.cr3
6
+ *.nef
7
+ *.arw
8
+ *.rw2
9
+ *.orf
10
+ *.srw
11
+ *.x3f
12
+ *.raf
13
+ *.dcr
14
+ *.k25
15
+ *.kdc
16
+ *.mrw
17
+
18
+ # Medical imaging formats (common in clinical AI)
19
+ *.dcm
20
+ *.dicom
21
+ *.nii
22
+ *.nii.gz
23
+ *.mha
24
+ *.mhd
25
+
26
+ # Compressed raw formats
27
+ *.tiff
28
+ *.tif
29
+
30
+ # Processed image formats that might be large
31
+ *.png
32
+ *.jpg
33
+ *.jpeg
34
+ *.bmp
35
+ *.gif
36
+ *.webp
37
+
38
+ # Model checkpoints
39
+ *.pt
40
+ *.pth
41
+ *.ckpt
42
+ *.model
43
+ *.h5
44
+ *.pb
45
+ *.onnx
46
+
47
+ # Checkpoint directories
48
+ checkpoints/
README.md CHANGED
@@ -1,19 +0,0 @@
1
- ---
2
- title: ChestX-Ray Diagnosis
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: This demo showcases a multimodal deep learning system that c
12
- ---
13
-
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+
6
+ from demo.utils.load_model import load_fusion_model
7
+ from demo.utils.grad_cam import GradCAM, overlay_cam
8
+ from demo.utils.saliency import (
9
+ compute_text_saliency,
10
+ merge_wordpieces,
11
+ filter_tokens,
12
+ highlight_text,
13
+ )
14
+
15
+ # --------------------------------------------------
16
+ # Page configuration
17
+ # --------------------------------------------------
18
+ st.set_page_config(
19
+ page_title="Multimodal Clinical AI",
20
+ layout="wide",
21
+ initial_sidebar_state="collapsed"
22
+ )
23
+
24
+ # --------------------------------------------------
25
+ # Header
26
+ # --------------------------------------------------
27
+ st.markdown(
28
+ """
29
+ <h2 style="margin-bottom:0">Multimodal Clinical Decision Support</h2>
30
+ <p style="color:gray; margin-top:4px">
31
+ Chest X-ray + Radiology Text → Ranked Diagnoses with Explainability
32
+ </p>
33
+ """,
34
+ unsafe_allow_html=True
35
+ )
36
+
37
+ st.divider()
38
+
39
+ # --------------------------------------------------
40
+ # Load model (cached)
41
+ # --------------------------------------------------
42
+ @st.cache_resource
43
+ def load_all():
44
+ return load_fusion_model(
45
+ "checkpoints/fusion_model/fusion_layer4_tuned.pt"
46
+ )
47
+
48
+ model, tokenizer, image_transform, LABELS, device = load_all()
49
+
50
+ # --------------------------------------------------
51
+ # Input Section
52
+ # --------------------------------------------------
53
+ col1, col2 = st.columns(2)
54
+
55
+ with col1:
56
+ st.subheader("Chest X-ray")
57
+ uploaded_image = st.file_uploader(
58
+ "Upload Chest X-ray",
59
+ type=["png", "jpg", "jpeg"],
60
+ label_visibility="collapsed"
61
+ )
62
+
63
+ with col2:
64
+ st.subheader("Radiology Findings")
65
+ findings = st.text_area(
66
+ "Enter findings",
67
+ height=180,
68
+ placeholder="e.g. Enlarged cardiac silhouette with pulmonary congestion...",
69
+ label_visibility="collapsed"
70
+ )
71
+
72
+ st.markdown("<br>", unsafe_allow_html=True)
73
+ analyze = st.button("Analyze Case", use_container_width=True)
74
+ st.markdown("<br>", unsafe_allow_html=True)
75
+
76
+ # --------------------------------------------------
77
+ # Inference + Explainability
78
+ # --------------------------------------------------
79
+ if analyze and uploaded_image and findings:
80
+
81
+ # ---- Preprocess inputs ----
82
+ image = Image.open(uploaded_image).convert("RGB")
83
+ image_tensor = image_transform(image).unsqueeze(0).to(device)
84
+
85
+ enc = tokenizer(
86
+ findings,
87
+ padding="max_length",
88
+ truncation=True,
89
+ max_length=256,
90
+ return_tensors="pt"
91
+ )
92
+ input_ids = enc["input_ids"].to(device)
93
+ attention_mask = enc["attention_mask"].to(device)
94
+
95
+ # ---- Forward pass ----
96
+ with torch.no_grad():
97
+ logits = model(image_tensor, input_ids, attention_mask)
98
+ probs = F.softmax(logits, dim=1)
99
+
100
+ top2_prob, top2_idx = torch.topk(probs, k=2, dim=1)
101
+ primary_idx = top2_idx[0, 0].item()
102
+ secondary_idx = top2_idx[0, 1].item()
103
+
104
+ # --------------------------------------------------
105
+ # Diagnosis Output
106
+ # --------------------------------------------------
107
+ col1, col2 = st.columns(2)
108
+
109
+ with col1:
110
+ st.markdown("### 🩺 Primary Diagnosis")
111
+ st.success(
112
+ f"{LABELS[primary_idx]} \nConfidence: {top2_prob[0,0]:.2f}"
113
+ )
114
+
115
+ with col2:
116
+ st.markdown("### 🔍 Secondary Diagnosis")
117
+ st.info(
118
+ f"{LABELS[secondary_idx]} \nConfidence: {top2_prob[0,1]:.2f}"
119
+ )
120
+
121
+ # --------------------------------------------------
122
+ # Explainability
123
+ # --------------------------------------------------
124
+ st.divider()
125
+ st.markdown("## Explainability")
126
+
127
+ col1, col2 = st.columns(2)
128
+
129
+ # ---- Grad-CAM ----
130
+ with col1:
131
+ st.markdown("#### Image Evidence (Grad-CAM)")
132
+
133
+ gradcam = GradCAM(model, model.image_encoder.layer4)
134
+ cam = gradcam.generate(
135
+ image_tensor,
136
+ input_ids,
137
+ attention_mask,
138
+ class_idx=primary_idx
139
+ )
140
+
141
+ overlay = overlay_cam(image_tensor, cam)
142
+ st.image(
143
+ overlay,
144
+ use_column_width=True,
145
+ caption="Regions influencing the primary diagnosis"
146
+ )
147
+
148
+ # ---- Text Saliency ----
149
+ with col2:
150
+ st.markdown("#### Text Evidence (Important Terms)")
151
+
152
+ saliency, attn_mask = compute_text_saliency(
153
+ model,
154
+ input_ids,
155
+ attention_mask,
156
+ target_class=primary_idx
157
+ )
158
+
159
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
160
+
161
+ # Clean tokens
162
+ tokens, scores = filter_tokens(tokens, saliency, attn_mask)
163
+
164
+ # Merge wordpieces
165
+ tokens, scores = merge_wordpieces(tokens, scores)
166
+
167
+ # Highlight text
168
+ html_text = highlight_text(tokens, scores)
169
+ st.markdown(html_text, unsafe_allow_html=True)
170
+
171
+ # --------------------------------------------------
172
+ # Footer / Disclaimer
173
+ # --------------------------------------------------
174
+ st.divider()
175
+ st.caption(
176
+ "⚠️ For educational and research purposes only. "
177
+ "Not intended for clinical use."
178
+ )
create_patient_index.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ import re
5
+ from PIL import Image
6
+
7
+ # -----------------------------
8
+ # CONFIG
9
+ # -----------------------------
10
+ DATASET_NAME = "itsanmolgupta/mimic-cxr-dataset" # change this
11
+ SPLIT = "train"
12
+
13
+ IMAGE_DIR = Path("data/raw/images")
14
+ OUTPUT_CSV = Path("data/metadata/patient_index.csv")
15
+
16
+ IMAGE_DIR.mkdir(parents=True, exist_ok=True)
17
+ OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
18
+
19
+ # -----------------------------
20
+ # LABEL DEFINITIONS
21
+ # -----------------------------
22
+ LABEL_KEYWORDS = {
23
+ "PNEUMOTHORAX": ["pneumothorax"],
24
+ "PNEUMONIA": ["pneumonia", "consolidation", "airspace disease"],
25
+ "EDEMA": ["pulmonary edema", "vascular congestion"],
26
+ "EFFUSION": ["pleural effusion"],
27
+ "CARDIOMEGALY": ["cardiomegaly", "enlarged heart"],
28
+ "NORMAL": [
29
+ "no acute cardiopulmonary",
30
+ "no acute abnormality",
31
+ "no acute disease",
32
+ "normal chest",
33
+ "unremarkable"
34
+ ]
35
+ }
36
+
37
+ PRIORITY = [
38
+ "PNEUMOTHORAX",
39
+ "PNEUMONIA",
40
+ "EDEMA",
41
+ "EFFUSION",
42
+ "CARDIOMEGALY",
43
+ "NORMAL"
44
+ ]
45
+
46
+
47
+ def assign_label(impression: str) -> str:
48
+ if not isinstance(impression, str):
49
+ return "OTHER"
50
+
51
+ text = impression.lower()
52
+ text = re.sub(r"[^\w\s]", " ", text)
53
+
54
+ for label in PRIORITY:
55
+ for kw in LABEL_KEYWORDS[label]:
56
+ if kw in text:
57
+ return label
58
+
59
+ return "OTHER"
60
+
61
+
62
+ # -----------------------------
63
+ # MAIN PIPELINE
64
+ # -----------------------------
65
+ def main():
66
+ print("📥 Loading Hugging Face dataset...")
67
+ dataset = load_dataset(DATASET_NAME, split=SPLIT)
68
+
69
+ records = []
70
+
71
+ for idx, sample in enumerate(dataset):
72
+ image = sample["image"]
73
+ findings = sample["findings"]
74
+ impression = sample["impression"]
75
+
76
+ if image is None or findings is None or impression is None:
77
+ continue
78
+
79
+ # Save image locally (important for PyTorch Dataset later)
80
+ image_path = IMAGE_DIR / f"img_{idx}.png"
81
+ if not image_path.exists():
82
+ image.save(image_path)
83
+
84
+ label = assign_label(impression)
85
+
86
+ records.append({
87
+ "image_path": str(image_path),
88
+ "findings": findings,
89
+ "impression": impression,
90
+ "label": label
91
+ })
92
+
93
+ if idx % 1000 == 0:
94
+ print(f"Processed {idx} samples...")
95
+
96
+ df = pd.DataFrame(records)
97
+ df.to_csv(OUTPUT_CSV, index=False)
98
+
99
+ print("\n✅ patient_index.csv created")
100
+ print(df["label"].value_counts())
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
requirements.txt CHANGED
@@ -1,3 +0,0 @@
1
- altair
2
- pandas
3
- streamlit