DIVYANSH-TEJA-09 commited on
Commit
06e1b21
Β·
0 Parent(s):

Initial commit with only essential weights

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./
12
+ COPY src/ ./src/
13
+
14
+ RUN pip3 install -r requirements.txt
15
+
16
+ EXPOSE 8501
17
+
18
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
+
20
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Brain Tumor AI Suite
3
+ emoji: πŸš€
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: docker
7
+ app_port: 8501
8
+ tags:
9
+ - streamlit
10
+ pinned: false
11
+ short_description: Streamlit template space
12
+ ---
13
+
14
+ # Welcome to Streamlit!
15
+
16
+ Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
+
18
+ If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
+ forums](https://discuss.streamlit.io).
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Brain Tumor AI Suite",
6
+ page_icon="🧠",
7
+ layout="wide",
8
+ )
9
+
10
+ st.markdown('''
11
+ <style>
12
+ .hero-title {
13
+ font-size: 3rem; font-weight: 800;
14
+ background: linear-gradient(135deg, #667eea, #764ba2, #f093fb);
15
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;
16
+ margin-bottom: 0;
17
+ }
18
+ </style>
19
+ ''', unsafe_allow_html=True)
20
+
21
+ st.markdown('<p class="hero-title">🧠 Brain Tumor AI Suite</p>', unsafe_allow_html=True)
22
+ st.markdown("**Federated Learning Classification & 3D Segmentation**")
23
+ st.markdown("---")
24
+
25
+ col1, col2 = st.columns(2)
26
+
27
+ with col1:
28
+ st.markdown("### πŸ“Š Tumor Classification (Federated Learning)")
29
+ st.markdown("Predict the class of a brain tumor (Glioma, Meningioma, Pituitary, No Tumor) using a SimpleCNN model trained across simulated hospitals with Layer-by-Layer QPSO aggregation.")
30
+ st.page_link("pages/1_Classification.py", label="Open Classification App β†’", icon="πŸ“Š")
31
+
32
+ with col2:
33
+ st.markdown("### πŸ”¬ 3D Tumor Segmentation")
34
+ st.markdown("View MRI slices and 3D volumetric renderings of brain tumors with segmentation overlays (Whole Tumor, Tumor Core, Enhancing Tumor) predicted by a 3D Attention U-Net.")
35
+ st.page_link("pages/2_Slice_Viewer.py", label="Open Slice Viewer β†’", icon="πŸ”¬")
36
+ st.page_link("pages/3_3D_Visualization.py", label="Open 3D Viewer β†’", icon="🌐")
best_metric_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4980fada1cc6fed5b19cd657528845f2a8598b383dcbbba179c1654b6f592c02
3
+ size 23731355
pages/1_Classification.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🧠 Brain Tumor Classification β€” Federated Learning Demo
3
+ =========================================================
4
+ Demonstrates the FL-trained SimpleCNN models (FedAvg, FedProx, QPSO).
5
+ Users can upload images or use sample test images from the dataset.
6
+ Shows predicted class with confidence bars and compares all 3 models.
7
+ """
8
+
9
+ import streamlit as st
10
+ import os
11
+ import sys
12
+ import glob
13
+ import random
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torchvision.transforms as transforms
18
+ from PIL import Image
19
+ import plotly.graph_objects as go
20
+
21
+
22
+
23
+ # ─── paths ───────────────────────────────────────────────────────────────
24
+ FL_ROOT = os.path.abspath(os.path.dirname(__file__))
25
+ RESULTS = os.path.join(FL_ROOT, "..", "results")
26
+ # Use Setup 1 models by default (best performing)
27
+ MODELS_DIR = os.path.join(RESULTS, "Setup_1", "models")
28
+
29
+ NUM_CLASSES = 4
30
+ CLASS_NAMES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
31
+ CLASS_COLORS = ["#E74C3C", "#3498DB", "#2ECC71", "#9B59B6"]
32
+ CLASS_ICONS = ["πŸ”΄", "πŸ”΅", "🟒", "🟣"]
33
+ IMG_SIZE = 112
34
+
35
+ # ─── model ───────────────────────────────────────────────────────────────
36
+ class SimpleCNN(nn.Module):
37
+ def __init__(self, num_classes=4):
38
+ super().__init__()
39
+ self.features = nn.Sequential(
40
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
42
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
43
+ nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
44
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
45
+ nn.BatchNorm2d(64), nn.ReLU(),
46
+ nn.AdaptiveAvgPool2d(4),
47
+ )
48
+ self.classifier = nn.Sequential(
49
+ nn.Dropout(0.3),
50
+ nn.Linear(64 * 4 * 4, 128), nn.ReLU(),
51
+ nn.Dropout(0.3),
52
+ nn.Linear(128, num_classes),
53
+ )
54
+
55
+ def forward(self, x):
56
+ x = self.features(x)
57
+ x = x.view(x.size(0), -1)
58
+ return self.classifier(x)
59
+
60
+
61
+ TRANSFORM = transforms.Compose([
62
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
65
+ ])
66
+
67
+
68
+ @st.cache_resource
69
+ def load_model(path):
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model = SimpleCNN(NUM_CLASSES).to(device)
72
+ state = torch.load(path, map_location=device, weights_only=True)
73
+ model.load_state_dict(state)
74
+ model.eval()
75
+ return model, device
76
+
77
+
78
+ def predict(model, device, image):
79
+ """Run inference on a PIL Image. Returns (class_idx, probabilities)."""
80
+ tensor = TRANSFORM(image).unsqueeze(0).to(device)
81
+ with torch.no_grad():
82
+ logits = model(tensor)
83
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
84
+ return int(np.argmax(probs)), probs
85
+
86
+
87
+ def render_prediction_card(title, color_accent, pred_idx, probs, image):
88
+ """Render a styled prediction result."""
89
+ confidence = probs[pred_idx] * 100
90
+ st.markdown(
91
+ f'<div style="background:rgba(20,20,35,0.9);padding:20px;border-radius:12px;'
92
+ f'border-top:4px solid {color_accent};margin-bottom:16px;">'
93
+ f'<h3 style="color:{color_accent};margin:0 0 8px 0;">{title}</h3>'
94
+ f'<div style="color:white;font-size:28px;font-weight:800;margin:4px 0;">'
95
+ f'{CLASS_ICONS[pred_idx]} {CLASS_NAMES[pred_idx]}</div>'
96
+ f'<div style="color:#aaa;font-size:14px;">Confidence: {confidence:.1f}%</div>'
97
+ f'</div>',
98
+ unsafe_allow_html=True,
99
+ )
100
+
101
+ # Probability bar chart
102
+ fig = go.Figure(go.Bar(
103
+ x=probs * 100,
104
+ y=CLASS_NAMES,
105
+ orientation='h',
106
+ marker_color=CLASS_COLORS,
107
+ text=[f"{p*100:.1f}%" for p in probs],
108
+ textposition='auto',
109
+ ))
110
+ fig.update_layout(
111
+ height=200,
112
+ margin=dict(l=0, r=10, t=10, b=10),
113
+ paper_bgcolor="rgba(0,0,0,0)",
114
+ plot_bgcolor="rgba(20,20,35,0.5)",
115
+ xaxis=dict(range=[0, 100], title="Probability (%)", color="#aaa",
116
+ gridcolor="rgba(100,100,140,0.2)"),
117
+ yaxis=dict(color="white"),
118
+ font=dict(color="white"),
119
+ )
120
+ st.plotly_chart(fig, use_container_width=True)
121
+
122
+
123
+ # ─── CSS ─────────────────────────────────────────────────────────────────
124
+ st.markdown("""
125
+ <style>
126
+ [data-testid="stSidebar"] {
127
+ background: linear-gradient(180deg, #0f0f1a 0%, #1a1a2e 100%);
128
+ }
129
+ .hero-title {
130
+ font-size: 2.4rem; font-weight: 800;
131
+ background: linear-gradient(135deg, #667eea, #764ba2, #f093fb);
132
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;
133
+ }
134
+ </style>
135
+ """, unsafe_allow_html=True)
136
+
137
+ # ─── sidebar ─────────────────────────────────────────────────────────────
138
+ st.sidebar.title("βš™οΈ Classification Controls")
139
+
140
+ # Setup selector
141
+ setup = st.sidebar.radio("Experiment Setup", ["Setup 1 (Natural)", "Setup 2 (Label Skew)"])
142
+ setup_dir = "Setup_1" if "1" in setup else "Setup_2"
143
+ models_dir = os.path.join(RESULTS, setup_dir, "models")
144
+
145
+ # Check available models
146
+ model_files = {}
147
+ for name, fname in [("FedAvg", "fedavg_best.pth"),
148
+ ("FedProx", "fedprox_best.pth"),
149
+ ("QPSO-FL", "qpso_best.pth")]:
150
+ path = os.path.join(models_dir, fname)
151
+ if os.path.exists(path):
152
+ model_files[name] = path
153
+
154
+ if not model_files:
155
+ st.error(f"No model weights found in `{models_dir}`. Please ensure .pth files are present.")
156
+ st.stop()
157
+
158
+ selected_models = st.sidebar.multiselect(
159
+ "Compare Models",
160
+ list(model_files.keys()),
161
+ default=list(model_files.keys()),
162
+ )
163
+
164
+ st.sidebar.markdown("---")
165
+ input_mode = st.sidebar.radio("Image Source", ["Upload Image", "Sample from Dataset"])
166
+
167
+ # ─── title ───────────────────────────────────────────────────────────────
168
+ st.markdown('<p class="hero-title">🧠 Brain Tumor Classification</p>',
169
+ unsafe_allow_html=True)
170
+ st.markdown("**Federated Learning Β· SimpleCNN (~120K params) Β· "
171
+ "FedAvg vs FedProx vs QPSO-FL**")
172
+ st.markdown("---")
173
+
174
+ # ─── image input ─────────────────────────────────────────────────────────
175
+ image = None
176
+
177
+ if input_mode == "Upload Image":
178
+ uploaded = st.file_uploader(
179
+ "Upload a brain MRI image (JPG/PNG)",
180
+ type=["jpg", "jpeg", "png"],
181
+ )
182
+ if uploaded:
183
+ image = Image.open(uploaded).convert("RGB")
184
+
185
+ elif input_mode == "Sample from Dataset":
186
+ # Try to find sample images from the results plots (confusion matrices show sample data)
187
+ # Or look for actual dataset images
188
+ sample_dirs = [
189
+ os.path.join(FL_ROOT, "data"),
190
+ os.path.join(FL_ROOT, "sample_images"),
191
+ ]
192
+
193
+ # Search for any sample images
194
+ sample_images = []
195
+ for d in sample_dirs:
196
+ if os.path.isdir(d):
197
+ for ext in ["*.jpg", "*.jpeg", "*.png"]:
198
+ sample_images.extend(glob.glob(os.path.join(d, "**", ext), recursive=True))
199
+
200
+ if sample_images:
201
+ # Group by class if possible
202
+ selected = st.selectbox("Select a sample image", sample_images,
203
+ format_func=lambda x: os.path.basename(x))
204
+ image = Image.open(selected).convert("RGB")
205
+ else:
206
+ st.info(
207
+ "πŸ’‘ No sample images found locally. You can either:\n"
208
+ "1. **Upload an image** using the sidebar option\n"
209
+ "2. **Add sample images** to `federated_learning/sample_images/` "
210
+ "(one subfolder per class: glioma/, meningioma/, notumor/, pituitary/)"
211
+ )
212
+
213
+ # ─── inference ───────────────────────────────────────────────────────────
214
+ if image is not None:
215
+ # Show the input image
216
+ st.subheader("πŸ“· Input Image")
217
+ col_img, col_info = st.columns([1, 2])
218
+ with col_img:
219
+ st.image(image, caption="Input MRI", use_container_width=True)
220
+ with col_info:
221
+ w, h = image.size
222
+ st.markdown(f"**Resolution:** {w}Γ—{h} β†’ resized to {IMG_SIZE}Γ—{IMG_SIZE}")
223
+ st.markdown(f"**Models:** {', '.join(selected_models)}")
224
+ st.markdown(f"**Setup:** {setup}")
225
+
226
+ st.markdown("---")
227
+ st.subheader("πŸ”¬ Classification Results")
228
+
229
+ if not selected_models:
230
+ st.warning("Select at least one model from the sidebar.")
231
+ else:
232
+ # Run all selected models
233
+ cols = st.columns(len(selected_models))
234
+ model_colors = {"FedAvg": "#1f77b4", "FedProx": "#ff7f0e", "QPSO-FL": "#2ca02c"}
235
+
236
+ results = {}
237
+ for idx, name in enumerate(selected_models):
238
+ with cols[idx]:
239
+ model, device = load_model(model_files[name])
240
+ pred_idx, probs = predict(model, device, image)
241
+ results[name] = (pred_idx, probs)
242
+ render_prediction_card(
243
+ name,
244
+ model_colors.get(name, "#666"),
245
+ pred_idx, probs, image,
246
+ )
247
+
248
+ # Consensus section
249
+ if len(results) > 1:
250
+ st.markdown("---")
251
+ st.subheader("🀝 Model Consensus")
252
+
253
+ predictions = [CLASS_NAMES[r[0]] for r in results.values()]
254
+ unanimous = len(set(predictions)) == 1
255
+
256
+ if unanimous:
257
+ st.success(
258
+ f"βœ… **All {len(results)} models agree:** "
259
+ f"{CLASS_ICONS[list(results.values())[0][0]]} "
260
+ f"**{predictions[0]}**"
261
+ )
262
+ else:
263
+ # Majority vote
264
+ from collections import Counter
265
+ votes = Counter(predictions)
266
+ winner, count = votes.most_common(1)[0]
267
+ winner_idx = CLASS_NAMES.index(winner)
268
+ st.warning(
269
+ f"⚠️ **Models disagree.** Majority vote ({count}/{len(results)}): "
270
+ f"{CLASS_ICONS[winner_idx]} **{winner}**"
271
+ )
272
+
273
+ # Show disagreement details
274
+ for name, (pred_idx, probs) in results.items():
275
+ emoji = "βœ…" if CLASS_NAMES[pred_idx] == winner else "❌"
276
+ st.markdown(
277
+ f" {emoji} **{name}:** {CLASS_NAMES[pred_idx]} "
278
+ f"({probs[pred_idx]*100:.1f}%)"
279
+ )
280
+
281
+ # Average confidence across models
282
+ if len(results) > 1:
283
+ avg_probs = np.mean([r[1] for r in results.values()], axis=0)
284
+ ensemble_pred = int(np.argmax(avg_probs))
285
+
286
+ st.markdown("---")
287
+ st.subheader("πŸ“Š Ensemble Average (All Models)")
288
+ render_prediction_card(
289
+ "Ensemble Average",
290
+ "#E91E63",
291
+ ensemble_pred, avg_probs, image,
292
+ )
293
+
294
+ else:
295
+ # Welcome state
296
+ st.markdown("""
297
+ ### How it works
298
+ 1. **Choose a setup** β€” Natural heterogeneity (Setup 1) or Label Skew (Setup 2)
299
+ 2. **Select models** β€” Compare FedAvg, FedProx, and QPSO-FL side by side
300
+ 3. **Upload or select an image** β€” Any brain MRI (axial slice)
301
+ 4. **See results** β€” Class prediction with confidence bars for each model
302
+
303
+ The models were trained using **federated learning** across 3 simulated hospitals,
304
+ each with different data distributions. The QPSO-FL model uses our novel
305
+ **Layer-by-Layer QPSO aggregation** for fairer global model performance.
306
+ """)
307
+
308
+ # Show model info cards
309
+ st.markdown("---")
310
+ info_cols = st.columns(3)
311
+ model_info = [
312
+ ("FedAvg", "#1f77b4", "Weighted average of client updates. Standard baseline."),
313
+ ("FedProx", "#ff7f0e", "Adds proximal regularization (ΞΌ=0.01) to prevent client drift."),
314
+ ("QPSO-FL", "#2ca02c", "Layer-by-layer quantum PSO with validation-loss fitness. Our contribution."),
315
+ ]
316
+ for col, (name, color, desc) in zip(info_cols, model_info):
317
+ with col:
318
+ st.markdown(
319
+ f'<div style="background:rgba(20,20,35,0.9);padding:20px;border-radius:12px;'
320
+ f'border-top:3px solid {color};">'
321
+ f'<h4 style="color:{color};margin:0 0 8px 0;">{name}</h4>'
322
+ f'<p style="color:#aaa;font-size:13px;margin:0;">{desc}</p></div>',
323
+ unsafe_allow_html=True,
324
+ )
pages/2_Slice_Viewer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Slice-by-Slice Segmentation Viewer
3
+ ====================================
4
+ View MRI slices with ground truth and AI prediction overlay.
5
+ Supports all 4 modalities and 3 tumor sub-regions.
6
+ """
7
+
8
+ import streamlit as st
9
+ import os
10
+ import sys
11
+ import glob
12
+ import nibabel as nib
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.colors as mcolors
16
+
17
+ # st.set_page_config(page_title="Slice Viewer", layout="wide")
18
+
19
+ # ─── paths & inference ───────────────────────────────────────────────────
20
+ APP_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
21
+ if APP_DIR not in sys.path:
22
+ sys.path.insert(0, APP_DIR)
23
+ from utils.inference import ensure_prediction, get_all_patients, DEMO_DIR
24
+
25
+
26
+ def load_nifti(path):
27
+ if not os.path.exists(path):
28
+ return None
29
+ return nib.load(path).get_fdata()
30
+
31
+
32
+ # ─── sidebar ─────────────────────────────────────────────────────────────
33
+ st.sidebar.title("βš™οΈ Slice Viewer Controls")
34
+
35
+ samples = get_all_patients()
36
+ if not samples:
37
+ st.error("No demo data found. Please ensure patient volumes exist in demo_data/")
38
+ st.stop()
39
+
40
+ selected_id = st.sidebar.selectbox("πŸ§‘β€βš•οΈ Patient", samples)
41
+
42
+ modality = st.sidebar.selectbox(
43
+ "MRI Modality",
44
+ ["FLAIR", "T1", "T1ce", "T2"],
45
+ index=0,
46
+ )
47
+ MOD_MAP = {"T1": 0, "T1ce": 1, "T2": 2, "FLAIR": 3}
48
+
49
+ overlay = st.sidebar.radio(
50
+ "Overlay",
51
+ ["AI Prediction", "Ground Truth", "Both (Side-by-Side)", "None"],
52
+ index=0,
53
+ )
54
+
55
+ overlay_alpha = st.sidebar.slider("Overlay Opacity", 0.1, 0.9, 0.5, 0.05)
56
+
57
+ # ─── load data (run inference if needed) ─────────────────────────────────
58
+ ensure_prediction(selected_id)
59
+
60
+ img_path = os.path.join(DEMO_DIR, f"{selected_id}_image.nii.gz")
61
+ pred_path = os.path.join(DEMO_DIR, f"{selected_id}_pred.nii.gz")
62
+ lbl_path = os.path.join(DEMO_DIR, f"{selected_id}_label.nii.gz")
63
+
64
+ img_data = load_nifti(img_path) # (D, H, W, 4)
65
+ pred_data = load_nifti(pred_path) # (D, H, W, 3) channels: 0=TC, 1=WT, 2=ET
66
+ lbl_data = load_nifti(lbl_path) # (D, H, W) labels: 1=NCR, 2=ED, 4=ET
67
+
68
+ if img_data is None:
69
+ st.error("Failed to load MRI volume.")
70
+ st.stop()
71
+
72
+ depth = img_data.shape[0]
73
+
74
+ # ─── title ───────────────────────────────────────────────────────────────
75
+ st.title("πŸ”¬ Slice-by-Slice Segmentation Viewer")
76
+ st.markdown(f"**Patient:** `{selected_id}` Β· **Modality:** {modality} Β· "
77
+ f"**Volume:** {img_data.shape[0]}Γ—{img_data.shape[1]}Γ—{img_data.shape[2]}")
78
+
79
+ slice_idx = st.slider("Z-Axis Slice", 0, depth - 1, depth // 2)
80
+
81
+ # ─── color maps ──────────────────────────────────────────────────────────
82
+ # Tumor overlay: WT=green, TC=red, ET=yellow (matching 3D view)
83
+ TUMOR_COLORS = np.array([
84
+ [0, 0, 0, 0], # background (transparent)
85
+ [0.18, 0.80, 0.44, 1], # WT - green
86
+ [0.91, 0.30, 0.24, 1], # TC - red
87
+ [0.95, 0.77, 0.06, 1], # ET - gold
88
+ ])
89
+
90
+
91
+ def make_overlay_from_pred(pred_slice):
92
+ """Convert (H, W, 3) prediction channels to (H, W, 4) RGBA overlay."""
93
+ h, w = pred_slice.shape[:2]
94
+ overlay_img = np.zeros((h, w, 4))
95
+ # Order: WT first (background), then TC, then ET on top
96
+ wt = pred_slice[:, :, 1] > 0.5
97
+ tc = pred_slice[:, :, 0] > 0.5
98
+ et = pred_slice[:, :, 2] > 0.5
99
+ overlay_img[wt] = TUMOR_COLORS[1]
100
+ overlay_img[tc] = TUMOR_COLORS[2]
101
+ overlay_img[et] = TUMOR_COLORS[3]
102
+ return overlay_img
103
+
104
+
105
+ def make_overlay_from_gt(lbl_slice):
106
+ """Convert integer label slice to (H, W, 4) RGBA overlay."""
107
+ h, w = lbl_slice.shape
108
+ overlay_img = np.zeros((h, w, 4))
109
+ wt = lbl_slice > 0
110
+ tc = (lbl_slice == 1) | (lbl_slice == 4)
111
+ et = lbl_slice == 4
112
+ overlay_img[wt] = TUMOR_COLORS[1]
113
+ overlay_img[tc] = TUMOR_COLORS[2]
114
+ overlay_img[et] = TUMOR_COLORS[3]
115
+ return overlay_img
116
+
117
+
118
+ def render_slice(mri_slice, overlay_img, title, alpha):
119
+ """Render an MRI slice with optional overlay."""
120
+ fig, ax = plt.subplots(figsize=(6, 6))
121
+ ax.imshow(mri_slice, cmap="gray", origin="lower")
122
+ if overlay_img is not None:
123
+ ax.imshow(overlay_img, alpha=alpha, origin="lower")
124
+ ax.set_title(title, fontsize=14, color="white", fontweight="bold")
125
+ ax.axis("off")
126
+ fig.patch.set_facecolor("#0a0a14")
127
+ return fig
128
+
129
+
130
+ # ─── render ──────────────────────────────────────────────────────────────
131
+ mri_slice = img_data[slice_idx, :, :, MOD_MAP[modality]]
132
+
133
+ if overlay == "None":
134
+ fig = render_slice(mri_slice, None, f"{modality} β€” Slice {slice_idx}", 0)
135
+ st.pyplot(fig)
136
+
137
+ elif overlay == "AI Prediction":
138
+ if pred_data is not None:
139
+ ov = make_overlay_from_pred(pred_data[slice_idx])
140
+ fig = render_slice(mri_slice, ov, f"AI Prediction β€” Slice {slice_idx}",
141
+ overlay_alpha)
142
+ st.pyplot(fig)
143
+ else:
144
+ st.warning("Prediction not available for this patient.")
145
+
146
+ elif overlay == "Ground Truth":
147
+ if lbl_data is not None:
148
+ ov = make_overlay_from_gt(lbl_data[slice_idx])
149
+ fig = render_slice(mri_slice, ov, f"Ground Truth β€” Slice {slice_idx}",
150
+ overlay_alpha)
151
+ st.pyplot(fig)
152
+ else:
153
+ st.warning("Ground truth not available.")
154
+
155
+ elif overlay == "Both (Side-by-Side)":
156
+ col1, col2 = st.columns(2)
157
+ with col1:
158
+ if lbl_data is not None:
159
+ ov = make_overlay_from_gt(lbl_data[slice_idx])
160
+ fig = render_slice(mri_slice, ov, "Ground Truth", overlay_alpha)
161
+ st.pyplot(fig)
162
+ else:
163
+ st.info("No ground truth available.")
164
+ with col2:
165
+ if pred_data is not None:
166
+ ov = make_overlay_from_pred(pred_data[slice_idx])
167
+ fig = render_slice(mri_slice, ov, "AI Prediction", overlay_alpha)
168
+ st.pyplot(fig)
169
+ else:
170
+ st.info("No prediction available.")
171
+
172
+ # ─── all modalities row ──────────────────────────────────────────────────
173
+ with st.expander("πŸ“‹ All 4 Modalities (this slice)", expanded=False):
174
+ cols = st.columns(4)
175
+ mod_names = ["T1", "T1ce", "T2", "FLAIR"]
176
+ for i, col in enumerate(cols):
177
+ with col:
178
+ fig, ax = plt.subplots(figsize=(3, 3))
179
+ ax.imshow(img_data[slice_idx, :, :, i], cmap="gray", origin="lower")
180
+ ax.set_title(mod_names[i], fontsize=11, color="white")
181
+ ax.axis("off")
182
+ fig.patch.set_facecolor("#0a0a14")
183
+ st.pyplot(fig)
184
+
185
+ # ─── color legend ────────────────────────────────────────────────────────
186
+ st.markdown("---")
187
+ legend_cols = st.columns(3)
188
+ labels = ["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"]
189
+ colors = ["#2ECC71", "#E74C3C", "#F1C40F"]
190
+ for i, col in enumerate(legend_cols):
191
+ with col:
192
+ st.markdown(
193
+ f'<div style="display:flex;align-items:center;gap:8px;">'
194
+ f'<div style="width:16px;height:16px;background:{colors[i]};'
195
+ f'border-radius:3px;"></div>'
196
+ f'<span>{labels[i]}</span></div>',
197
+ unsafe_allow_html=True,
198
+ )
pages/3_3D_Visualization.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D Interactive Brain Tumor Visualization
3
+ =========================================
4
+ Renders the brain + tumor regions as interactive 3D surfaces.
5
+ Supports side-by-side Ground Truth vs AI Prediction comparison.
6
+ Runs live inference if prediction doesn't exist yet.
7
+ """
8
+
9
+ import streamlit as st
10
+ import os
11
+ import sys
12
+ import glob
13
+ import nibabel as nib
14
+ import numpy as np
15
+ import plotly.graph_objects as go
16
+ from plotly.subplots import make_subplots
17
+ from skimage.measure import marching_cubes
18
+
19
+ # st.set_page_config(page_title="3D Tumor Visualization", layout="wide")
20
+
21
+ # ─── paths & inference ───────────────────────────────────────────────────
22
+ APP_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
23
+ if APP_DIR not in sys.path:
24
+ sys.path.insert(0, APP_DIR)
25
+ from utils.inference import ensure_prediction, get_all_patients, DEMO_DIR
26
+
27
+
28
+ def load_nifti(path):
29
+ if not os.path.exists(path):
30
+ return None
31
+ return nib.load(path).get_fdata()
32
+
33
+
34
+
35
+ def extract_mesh(volume, level=0.5, step_size=2):
36
+ vol = volume[::step_size, ::step_size, ::step_size]
37
+ if vol.sum() == 0:
38
+ return None
39
+ try:
40
+ verts, faces, _, _ = marching_cubes(vol, level=level)
41
+ verts = verts * step_size
42
+ return verts, faces
43
+ except Exception:
44
+ return None
45
+
46
+
47
+ def make_mesh_trace(volume, color, name, opacity, step_size, level=0.5,
48
+ flatshading=True, scene="scene"):
49
+ result = extract_mesh(volume, level=level, step_size=step_size)
50
+ if result is None:
51
+ return None
52
+ verts, faces = result
53
+ x, y, z = verts.T
54
+ i, j, k = faces.T
55
+ return go.Mesh3d(
56
+ x=x, y=y, z=z, i=i, j=j, k=k,
57
+ color=color, opacity=opacity,
58
+ name=name, showlegend=True,
59
+ flatshading=flatshading,
60
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2, roughness=0.6),
61
+ lightposition=dict(x=100, y=200, z=300),
62
+ scene=scene,
63
+ )
64
+
65
+
66
+ # ─── colors ──────────────────────────────────────────────────────────────
67
+ PRED_COLORS = {
68
+ "Whole Tumor (WT)": "#2ECC71", # emerald green
69
+ "Tumor Core (TC)": "#E74C3C", # vivid red
70
+ "Enhancing Tumor (ET)": "#F1C40F", # bright gold
71
+ }
72
+ GT_COLORS = {
73
+ "Whole Tumor (WT)": "#1ABC9C", # turquoise
74
+ "Tumor Core (TC)": "#9B59B6", # amethyst
75
+ "Enhancing Tumor (ET)": "#3498DB", # ocean blue
76
+ }
77
+ PRED_CHANNELS = {"Whole Tumor (WT)": 1, "Tumor Core (TC)": 0, "Enhancing Tumor (ET)": 2}
78
+ BRAIN_COLOR = "#D5D8DC"
79
+
80
+ # ─── sidebar ─────────────────────────────────────────────────────────────
81
+ st.sidebar.title("βš™οΈ 3D Controls")
82
+
83
+ samples = get_all_patients()
84
+ if not samples:
85
+ st.error("No processed prediction volumes found.")
86
+ st.stop()
87
+
88
+ selected_id = st.sidebar.selectbox("πŸ§‘β€βš•οΈ Patient", samples)
89
+
90
+ st.sidebar.markdown("---")
91
+ view_mode = st.sidebar.radio(
92
+ "View Mode",
93
+ ["Prediction Only", "Ground Truth Only", "Side-by-Side Comparison"],
94
+ index=0,
95
+ )
96
+
97
+ st.sidebar.markdown("---")
98
+ st.sidebar.subheader("🧠 Brain Surface")
99
+ show_brain = st.sidebar.checkbox("Show Brain", value=True)
100
+ brain_opacity = st.sidebar.slider("Brain Opacity", 0.02, 0.30, 0.08, 0.02)
101
+
102
+ st.sidebar.subheader("🎯 Tumor")
103
+ region_choice = st.sidebar.multiselect(
104
+ "Regions",
105
+ ["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"],
106
+ default=["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"],
107
+ )
108
+ tumor_opacity = st.sidebar.slider("Tumor Opacity", 0.20, 1.0, 0.70, 0.05)
109
+
110
+ st.sidebar.markdown("---")
111
+ step_size = st.sidebar.select_slider("Mesh Quality", options=[1, 2, 3, 4], value=2)
112
+
113
+ # ─── load data (run inference if needed) ─────────────────────────────────
114
+ ensure_prediction(selected_id)
115
+
116
+ img_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_image.nii.gz"))
117
+ pred_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_pred.nii.gz"))
118
+ lbl_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_label.nii.gz"))
119
+
120
+ # ─── title ───────────────────────────────────────────────────────────────
121
+ st.title("🌐 3D Brain Tumor Visualization")
122
+ st.markdown(f"**Patient:** `{selected_id}` Β· **Drag** to rotate Β· **Scroll** to zoom")
123
+
124
+ # ─── color legend ────────────────────────────────────────────────────────
125
+ legend_cols = st.columns(6)
126
+ for idx, (region, color) in enumerate(PRED_COLORS.items()):
127
+ with legend_cols[idx]:
128
+ st.markdown(
129
+ f'<div style="display:flex;align-items:center;gap:6px;">'
130
+ f'<div style="width:14px;height:14px;background:{color};'
131
+ f'border-radius:3px;"></div><span style="font-size:13px;">Pred: {region}</span></div>',
132
+ unsafe_allow_html=True,
133
+ )
134
+ if view_mode in ["Ground Truth Only", "Side-by-Side Comparison"]:
135
+ for idx, (region, color) in enumerate(GT_COLORS.items()):
136
+ with legend_cols[idx + 3]:
137
+ st.markdown(
138
+ f'<div style="display:flex;align-items:center;gap:6px;">'
139
+ f'<div style="width:14px;height:14px;background:{color};'
140
+ f'border-radius:3px;"></div><span style="font-size:13px;">GT: {region}</span></div>',
141
+ unsafe_allow_html=True,
142
+ )
143
+
144
+
145
+ # ─── helper: build brain trace ───────────────────────────────────────────
146
+ def build_brain_trace(scene_name="scene"):
147
+ if img_data is None:
148
+ return None
149
+ flair = img_data[:, :, :, 3]
150
+ flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8)
151
+ brain_mask = (flair_norm > 0.15).astype(float)
152
+ return make_mesh_trace(
153
+ brain_mask, BRAIN_COLOR, "Brain",
154
+ opacity=brain_opacity,
155
+ step_size=max(step_size, 2),
156
+ flatshading=False, scene=scene_name,
157
+ )
158
+
159
+
160
+ def build_pred_traces(scene_name="scene"):
161
+ traces = []
162
+ if pred_data is None:
163
+ return traces
164
+ for region in region_choice:
165
+ ch = PRED_CHANNELS[region]
166
+ vol = pred_data[:, :, :, ch]
167
+ t = make_mesh_trace(vol, PRED_COLORS[region], f"Pred: {region}",
168
+ tumor_opacity, step_size, scene=scene_name)
169
+ if t:
170
+ traces.append(t)
171
+ return traces
172
+
173
+
174
+ def build_gt_traces(scene_name="scene"):
175
+ traces = []
176
+ if lbl_data is None:
177
+ return traces
178
+ gt_masks = {
179
+ "Whole Tumor (WT)": (lbl_data > 0).astype(float),
180
+ "Tumor Core (TC)": ((lbl_data == 1) | (lbl_data == 4)).astype(float),
181
+ "Enhancing Tumor (ET)": (lbl_data == 4).astype(float),
182
+ }
183
+ for region in region_choice:
184
+ vol = gt_masks[region]
185
+ t = make_mesh_trace(vol, GT_COLORS[region], f"GT: {region}",
186
+ tumor_opacity, step_size, scene=scene_name)
187
+ if t:
188
+ traces.append(t)
189
+ return traces
190
+
191
+
192
+ SCENE_LAYOUT = dict(
193
+ xaxis=dict(visible=False),
194
+ yaxis=dict(visible=False),
195
+ zaxis=dict(visible=False),
196
+ bgcolor="rgb(10, 10, 20)",
197
+ aspectmode="data",
198
+ camera=dict(eye=dict(x=1.6, y=1.0, z=0.8), up=dict(x=0, y=0, z=1)),
199
+ )
200
+
201
+ # ─── render ──────────────────────────────────────────────────────────────
202
+
203
+ if view_mode == "Side-by-Side Comparison":
204
+ # Two 3D plots: GT on the left, Prediction on the right
205
+ col_left, col_right = st.columns(2)
206
+
207
+ with col_left:
208
+ st.markdown("### 🟒 Ground Truth")
209
+ gt_traces = []
210
+ if show_brain:
211
+ bt = build_brain_trace("scene")
212
+ if bt:
213
+ gt_traces.append(bt)
214
+ gt_traces.extend(build_gt_traces("scene"))
215
+
216
+ if gt_traces:
217
+ fig_gt = go.Figure(data=gt_traces)
218
+ fig_gt.update_layout(
219
+ scene=SCENE_LAYOUT,
220
+ margin=dict(l=0, r=0, t=0, b=0),
221
+ height=600,
222
+ paper_bgcolor="rgb(10, 10, 20)",
223
+ legend=dict(font=dict(color="white", size=11),
224
+ bgcolor="rgba(20,20,40,0.8)", x=0.01, y=0.99),
225
+ )
226
+ st.plotly_chart(fig_gt, width="stretch")
227
+ else:
228
+ st.info("No ground truth data available for this patient.")
229
+
230
+ with col_right:
231
+ st.markdown("### πŸ”΄ AI Prediction")
232
+ pred_traces = []
233
+ if show_brain:
234
+ bt = build_brain_trace("scene")
235
+ if bt:
236
+ pred_traces.append(bt)
237
+ pred_traces.extend(build_pred_traces("scene"))
238
+
239
+ if pred_traces:
240
+ fig_pred = go.Figure(data=pred_traces)
241
+ fig_pred.update_layout(
242
+ scene=SCENE_LAYOUT,
243
+ margin=dict(l=0, r=0, t=0, b=0),
244
+ height=600,
245
+ paper_bgcolor="rgb(10, 10, 20)",
246
+ legend=dict(font=dict(color="white", size=11),
247
+ bgcolor="rgba(20,20,40,0.8)", x=0.01, y=0.99),
248
+ )
249
+ st.plotly_chart(fig_pred, width="stretch")
250
+ else:
251
+ st.warning("No prediction data available.")
252
+
253
+ else:
254
+ # Single 3D view
255
+ all_traces = []
256
+ if show_brain:
257
+ bt = build_brain_trace("scene")
258
+ if bt:
259
+ all_traces.append(bt)
260
+
261
+ if view_mode == "Prediction Only":
262
+ all_traces.extend(build_pred_traces("scene"))
263
+ elif view_mode == "Ground Truth Only":
264
+ all_traces.extend(build_gt_traces("scene"))
265
+
266
+ if not all_traces:
267
+ st.warning("Nothing to render. Check that data exists and regions are selected.")
268
+ st.stop()
269
+
270
+ fig = go.Figure(data=all_traces)
271
+ fig.update_layout(
272
+ scene=SCENE_LAYOUT,
273
+ margin=dict(l=0, r=0, t=0, b=0),
274
+ height=750,
275
+ paper_bgcolor="rgb(10, 10, 20)",
276
+ legend=dict(font=dict(color="white", size=13),
277
+ bgcolor="rgba(20,20,40,0.85)",
278
+ bordercolor="rgba(100,100,140,0.5)", borderwidth=1,
279
+ x=0.01, y=0.99),
280
+ )
281
+ st.plotly_chart(fig, width="stretch")
282
+
283
+ # ─── volume stats ────────────────────────────────────────────────────────
284
+ st.markdown("---")
285
+ st.subheader("πŸ“Š Tumor Volume Statistics")
286
+
287
+ if pred_data is not None:
288
+ cols = st.columns(3)
289
+ for idx, (region, ch) in enumerate(PRED_CHANNELS.items()):
290
+ vol = pred_data[:, :, :, ch]
291
+ voxel_count = int(vol.sum())
292
+ volume_cc = voxel_count / 1000.0
293
+ color = PRED_COLORS[region]
294
+ with cols[idx]:
295
+ st.markdown(
296
+ f'<div style="background:rgba(30,30,50,0.8);padding:14px;'
297
+ f'border-radius:10px;border-left:4px solid {color};">'
298
+ f'<div style="color:{color};font-size:13px;font-weight:600;">{region}</div>'
299
+ f'<div style="color:white;font-size:26px;font-weight:700;">{volume_cc:.1f} cmΒ³</div>'
300
+ f'<div style="color:#888;font-size:11px;">{voxel_count:,} voxels</div></div>',
301
+ unsafe_allow_html=True,
302
+ )
303
+
304
+ # ─── dice ────────────────────────────────────────────────────────────────
305
+ if lbl_data is not None and pred_data is not None:
306
+ st.markdown("---")
307
+ st.subheader("πŸ”¬ Dice Scores")
308
+ cols = st.columns(3)
309
+ gt_masks_dice = {
310
+ "Whole Tumor (WT)": (lbl_data > 0).astype(float),
311
+ "Tumor Core (TC)": ((lbl_data == 1) | (lbl_data == 4)).astype(float),
312
+ "Enhancing Tumor (ET)": (lbl_data == 4).astype(float),
313
+ }
314
+ for idx, (region, ch) in enumerate(PRED_CHANNELS.items()):
315
+ p = pred_data[:, :, :, ch]
316
+ g = gt_masks_dice[region]
317
+ dice = (2.0 * (p * g).sum()) / (p.sum() + g.sum() + 1e-8)
318
+ color = PRED_COLORS[region]
319
+ grade = "Excellent" if dice > 0.8 else "Good" if dice > 0.6 else "Fair"
320
+ with cols[idx]:
321
+ st.markdown(
322
+ f'<div style="background:rgba(30,30,50,0.8);padding:14px;'
323
+ f'border-radius:10px;border-left:4px solid {color};">'
324
+ f'<div style="color:{color};font-size:13px;">{region}</div>'
325
+ f'<div style="color:white;font-size:30px;font-weight:700;">{dice:.4f}</div>'
326
+ f'<div style="color:#888;">{grade}</div></div>',
327
+ unsafe_allow_html=True,
328
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
3
+ numpy>=1.21.0
4
+ Pillow>=9.0.0
5
+ plotly>=5.10.0
6
+ streamlit>=1.20.0
7
+ matplotlib>=3.5.0
8
+ nibabel>=4.0.0
9
+ monai>=1.0.0
results/Setup_1/models/fedavg_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d48f93f0f3ed5e4074a24f6bc645bd55d56068467573902be606e94a75363f0
3
+ size 631956
results/Setup_1/models/fedprox_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7b6c27da43d10bbee7275de1d34eb417f3c7867694e8aabdf9ee3deebebbc52
3
+ size 631987
results/Setup_1/models/qpso_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bcd1b0bcaf795d4519a8457dba2e778a25adbda949adf337e35bd6b185f38ef
3
+ size 631894
results/Setup_2/models/fedavg_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d48f93f0f3ed5e4074a24f6bc645bd55d56068467573902be606e94a75363f0
3
+ size 631956
results/Setup_2/models/fedprox_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7b6c27da43d10bbee7275de1d34eb417f3c7867694e8aabdf9ee3deebebbc52
3
+ size 631987
results/Setup_2/models/qpso_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bcd1b0bcaf795d4519a8457dba2e778a25adbda949adf337e35bd6b185f38ef
3
+ size 631894
src/streamlit_app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import altair as alt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import streamlit as st
5
+
6
+ """
7
+ # Welcome to Streamlit!
8
+
9
+ Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
+ If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
+ forums](https://discuss.streamlit.io).
12
+
13
+ In the meantime, below is an example of what you can do with just a few lines of code:
14
+ """
15
+
16
+ num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
+ num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
+
19
+ indices = np.linspace(0, 1, num_points)
20
+ theta = 2 * np.pi * num_turns * indices
21
+ radius = indices
22
+
23
+ x = radius * np.cos(theta)
24
+ y = radius * np.sin(theta)
25
+
26
+ df = pd.DataFrame({
27
+ "x": x,
28
+ "y": y,
29
+ "idx": indices,
30
+ "rand": np.random.randn(num_points),
31
+ })
32
+
33
+ st.altair_chart(alt.Chart(df, height=700, width=700)
34
+ .mark_point(filled=True)
35
+ .encode(
36
+ x=alt.X("x", axis=None),
37
+ y=alt.Y("y", axis=None),
38
+ color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
+ size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
+ ))
utils/inference.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared inference module for 3D brain tumor segmentation.
3
+ Loads the AttentionUnet model and runs sliding_window_inference
4
+ on patients that don't have pre-computed predictions.
5
+ """
6
+
7
+ import os
8
+ import shutil
9
+ import numpy as np
10
+ import nibabel as nib
11
+ import streamlit as st
12
+ import torch
13
+ from monai.inferers import sliding_window_inference
14
+ from monai.networks.nets import AttentionUnet
15
+ from monai.transforms import (
16
+ Compose,
17
+ LoadImaged,
18
+ NormalizeIntensityd,
19
+ Orientationd,
20
+ Spacingd,
21
+ EnsureChannelFirstd,
22
+ EnsureTyped,
23
+ )
24
+
25
+
26
+ # ─── paths ───────────────────────────────────────────────────────────────
27
+ SEG_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "segmentation"))
28
+ DEMO_DIR = os.path.join(SEG_DIR, "demo_data")
29
+ # streamlit_app/ is inside segmentation/, so go up one level to reach segmentation/
30
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
31
+
32
+ # Model checkpoint β€” prefer refined model (better calibration)
33
+ # 1. Check streamlit_app/ (where refined model lives)
34
+ # 2. Check segmentation/ (parent dir, where base model lives)
35
+ _THIS_DIR = os.path.dirname(os.path.dirname(__file__))
36
+ _candidates = [
37
+ os.path.join(_THIS_DIR, "best_metric_model_refined.pth"), # streamlit_app/
38
+ os.path.join(PROJECT_ROOT, "best_metric_model_refined.pth"), # segmentation/
39
+ os.path.join(_THIS_DIR, "best_metric_model.pth"), # streamlit_app/
40
+ os.path.join(PROJECT_ROOT, "best_metric_model.pth"), # segmentation/
41
+ ]
42
+ CKPT_PATH = None
43
+ for _c in _candidates:
44
+ if os.path.exists(_c):
45
+ CKPT_PATH = _c
46
+ break
47
+
48
+ # MONAI transforms β€” must match training exactly
49
+ INFERENCE_TRANSFORMS = Compose([
50
+ LoadImaged(keys=["image", "label"]),
51
+ EnsureChannelFirstd(keys=["image", "label"]),
52
+ EnsureTyped(keys=["image", "label"]),
53
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
54
+ Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
55
+ NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
56
+ ])
57
+
58
+ # Transforms for image-only (no label available)
59
+ INFERENCE_TRANSFORMS_IMG_ONLY = Compose([
60
+ LoadImaged(keys=["image"]),
61
+ EnsureChannelFirstd(keys=["image"]),
62
+ EnsureTyped(keys=["image"]),
63
+ Orientationd(keys=["image"], axcodes="RAS"),
64
+ Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear",)),
65
+ NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
66
+ ])
67
+
68
+
69
+ @st.cache_resource
70
+ def load_seg_model():
71
+ """Load the 3D Attention U-Net model (cached across sessions)."""
72
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
73
+ model = AttentionUnet(
74
+ spatial_dims=3,
75
+ in_channels=4,
76
+ out_channels=3,
77
+ channels=(16, 32, 64, 128, 256),
78
+ strides=(2, 2, 2, 2),
79
+ ).to(device)
80
+
81
+ if os.path.exists(CKPT_PATH):
82
+ try:
83
+ model.load_state_dict(torch.load(CKPT_PATH, map_location=device))
84
+ model.eval()
85
+ return model, device
86
+ except Exception as e:
87
+ st.error(f"Failed to load model weights: {e}")
88
+ return None, None
89
+ else:
90
+ st.error(f"Model checkpoint not found at {CKPT_PATH}")
91
+ return None, None
92
+
93
+
94
+ def ensure_prediction(patient_id):
95
+ """
96
+ Ensure that the prediction volume exists for a patient.
97
+ If _pred.nii.gz already exists, returns True immediately.
98
+ Otherwise, runs live inference using the exact same MONAI
99
+ transforms as the training pipeline.
100
+ """
101
+ pred_path = os.path.join(DEMO_DIR, f"{patient_id}_pred.nii.gz")
102
+ img_path = os.path.join(DEMO_DIR, f"{patient_id}_image.nii.gz")
103
+ lbl_path = os.path.join(DEMO_DIR, f"{patient_id}_label.nii.gz")
104
+
105
+ # Already have prediction β€” skip
106
+ if os.path.exists(pred_path) and os.path.exists(img_path):
107
+ return True
108
+
109
+ # Check if raw MRI modalities exist in patient subfolder
110
+ p_dir = os.path.join(DEMO_DIR, patient_id)
111
+ if not os.path.isdir(p_dir):
112
+ return False
113
+
114
+ # Build file paths (same order as extract_demo_data.py: t1, t1ce, t2, flair)
115
+ mod_paths = {
116
+ "t1": os.path.join(p_dir, f"{patient_id}_t1.nii.gz"),
117
+ "t1ce": os.path.join(p_dir, f"{patient_id}_t1ce.nii.gz"),
118
+ "t2": os.path.join(p_dir, f"{patient_id}_t2.nii.gz"),
119
+ "flair": os.path.join(p_dir, f"{patient_id}_flair.nii.gz"),
120
+ }
121
+ seg_path = os.path.join(p_dir, f"{patient_id}_seg.nii.gz")
122
+
123
+ for m, mp in mod_paths.items():
124
+ if not os.path.exists(mp):
125
+ st.warning(f"Missing modality: {m} at {mp}")
126
+ return False
127
+
128
+ # ─── Run live inference ──────────────────────────────────────────
129
+ st.info(f"🧠 **Running AI Inference** on `{patient_id}`... This may take 30-60 seconds.")
130
+ progress = st.progress(0)
131
+ status = st.empty()
132
+
133
+ try:
134
+ # Build MONAI data dict (image is a list of 4 modality paths)
135
+ has_label = os.path.exists(seg_path)
136
+ data_dict = {
137
+ "image": [mod_paths["t1"], mod_paths["t1ce"], mod_paths["t2"], mod_paths["flair"]],
138
+ }
139
+ if has_label:
140
+ data_dict["label"] = seg_path
141
+
142
+ # Apply MONAI transforms (Orientation, Spacing, Normalize β€” matching training)
143
+ status.text("Loading & preprocessing with MONAI transforms...")
144
+ if has_label:
145
+ sample_data = INFERENCE_TRANSFORMS(data_dict)
146
+ else:
147
+ sample_data = INFERENCE_TRANSFORMS_IMG_ONLY(data_dict)
148
+ progress.progress(30)
149
+
150
+ # Run model inference
151
+ status.text("Running 3D U-Net inference (sliding window)...")
152
+ model, device = load_seg_model()
153
+ if model is None:
154
+ return False
155
+
156
+ inputs = sample_data["image"].unsqueeze(0).to(device) # (1, 4, D, H, W)
157
+ with torch.no_grad():
158
+ outputs = sliding_window_inference(inputs, (96, 96, 96), 4, model)
159
+ outputs = (outputs.sigmoid() > 0.5).float()
160
+ progress.progress(80)
161
+
162
+ # Save processed image volume (D, H, W, 4)
163
+ status.text("Saving results...")
164
+ img_np = inputs[0].cpu().numpy().transpose(1, 2, 3, 0)
165
+ nib.save(nib.Nifti1Image(img_np, affine=np.eye(4)), img_path)
166
+
167
+ # Save prediction (D, H, W, 3)
168
+ pred_np = outputs[0].cpu().numpy().transpose(1, 2, 3, 0)
169
+ nib.save(nib.Nifti1Image(pred_np, affine=np.eye(4)), pred_path)
170
+
171
+ # Save ground truth label (D, H, W)
172
+ if has_label:
173
+ lbl_np = sample_data["label"][0].cpu().numpy()
174
+ nib.save(nib.Nifti1Image(lbl_np.astype(np.float32), affine=np.eye(4)), lbl_path)
175
+ elif not os.path.exists(lbl_path):
176
+ empty = np.zeros(pred_np.shape[:3])
177
+ nib.save(nib.Nifti1Image(empty.astype(np.float32), affine=np.eye(4)), lbl_path)
178
+
179
+ progress.progress(100)
180
+ status.text("βœ… Inference complete!")
181
+ return True
182
+
183
+ except Exception as e:
184
+ st.error(f"Inference failed: {e}")
185
+ import traceback
186
+ st.code(traceback.format_exc())
187
+ return False
188
+
189
+
190
+ def get_all_patients():
191
+ """
192
+ Return all patient IDs that have either pre-computed predictions
193
+ OR raw MRI data (can be inferred on-demand).
194
+ """
195
+ patients = set()
196
+
197
+ # Patients with pre-computed predictions
198
+ import glob
199
+ for f in glob.glob(os.path.join(DEMO_DIR, "*_pred.nii.gz")):
200
+ pid = os.path.basename(f).replace("_pred.nii.gz", "")
201
+ patients.add(pid)
202
+
203
+ # Patients with raw MRI data (subfolder with modality files)
204
+ if os.path.isdir(DEMO_DIR):
205
+ for d in os.listdir(DEMO_DIR):
206
+ full = os.path.join(DEMO_DIR, d)
207
+ if os.path.isdir(full) and d.startswith("BraTS"):
208
+ # Check it has at least the flair file
209
+ if os.path.exists(os.path.join(full, f"{d}_flair.nii.gz")):
210
+ patients.add(d)
211
+
212
+ return sorted(patients)