Spaces:
Sleeping
Sleeping
DIVYANSH-TEJA-09 commited on
Commit Β·
06e1b21
0
Parent(s):
Initial commit with only essential weights
Browse files- .gitattributes +35 -0
- Dockerfile +20 -0
- README.md +19 -0
- app.py +36 -0
- best_metric_model.pth +3 -0
- pages/1_Classification.py +324 -0
- pages/2_Slice_Viewer.py +198 -0
- pages/3_3D_Visualization.py +328 -0
- requirements.txt +9 -0
- results/Setup_1/models/fedavg_best.pth +3 -0
- results/Setup_1/models/fedprox_best.pth +3 -0
- results/Setup_1/models/qpso_best.pth +3 -0
- results/Setup_2/models/fedavg_best.pth +3 -0
- results/Setup_2/models/fedprox_best.pth +3 -0
- results/Setup_2/models/qpso_best.pth +3 -0
- src/streamlit_app.py +40 -0
- utils/inference.py +212 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13.5-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
curl \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt ./
|
| 12 |
+
COPY src/ ./src/
|
| 13 |
+
|
| 14 |
+
RUN pip3 install -r requirements.txt
|
| 15 |
+
|
| 16 |
+
EXPOSE 8501
|
| 17 |
+
|
| 18 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
+
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Brain Tumor AI Suite
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 8501
|
| 8 |
+
tags:
|
| 9 |
+
- streamlit
|
| 10 |
+
pinned: false
|
| 11 |
+
short_description: Streamlit template space
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Welcome to Streamlit!
|
| 15 |
+
|
| 16 |
+
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
| 17 |
+
|
| 18 |
+
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 19 |
+
forums](https://discuss.streamlit.io).
|
app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
st.set_page_config(
|
| 5 |
+
page_title="Brain Tumor AI Suite",
|
| 6 |
+
page_icon="π§ ",
|
| 7 |
+
layout="wide",
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
st.markdown('''
|
| 11 |
+
<style>
|
| 12 |
+
.hero-title {
|
| 13 |
+
font-size: 3rem; font-weight: 800;
|
| 14 |
+
background: linear-gradient(135deg, #667eea, #764ba2, #f093fb);
|
| 15 |
+
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
|
| 16 |
+
margin-bottom: 0;
|
| 17 |
+
}
|
| 18 |
+
</style>
|
| 19 |
+
''', unsafe_allow_html=True)
|
| 20 |
+
|
| 21 |
+
st.markdown('<p class="hero-title">π§ Brain Tumor AI Suite</p>', unsafe_allow_html=True)
|
| 22 |
+
st.markdown("**Federated Learning Classification & 3D Segmentation**")
|
| 23 |
+
st.markdown("---")
|
| 24 |
+
|
| 25 |
+
col1, col2 = st.columns(2)
|
| 26 |
+
|
| 27 |
+
with col1:
|
| 28 |
+
st.markdown("### π Tumor Classification (Federated Learning)")
|
| 29 |
+
st.markdown("Predict the class of a brain tumor (Glioma, Meningioma, Pituitary, No Tumor) using a SimpleCNN model trained across simulated hospitals with Layer-by-Layer QPSO aggregation.")
|
| 30 |
+
st.page_link("pages/1_Classification.py", label="Open Classification App β", icon="π")
|
| 31 |
+
|
| 32 |
+
with col2:
|
| 33 |
+
st.markdown("### π¬ 3D Tumor Segmentation")
|
| 34 |
+
st.markdown("View MRI slices and 3D volumetric renderings of brain tumors with segmentation overlays (Whole Tumor, Tumor Core, Enhancing Tumor) predicted by a 3D Attention U-Net.")
|
| 35 |
+
st.page_link("pages/2_Slice_Viewer.py", label="Open Slice Viewer β", icon="π¬")
|
| 36 |
+
st.page_link("pages/3_3D_Visualization.py", label="Open 3D Viewer β", icon="π")
|
best_metric_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4980fada1cc6fed5b19cd657528845f2a8598b383dcbbba179c1654b6f592c02
|
| 3 |
+
size 23731355
|
pages/1_Classification.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
π§ Brain Tumor Classification β Federated Learning Demo
|
| 3 |
+
=========================================================
|
| 4 |
+
Demonstrates the FL-trained SimpleCNN models (FedAvg, FedProx, QPSO).
|
| 5 |
+
Users can upload images or use sample test images from the dataset.
|
| 6 |
+
Shows predicted class with confidence bars and compares all 3 models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import streamlit as st
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import glob
|
| 13 |
+
import random
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import plotly.graph_objects as go
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# βββ paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
FL_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 25 |
+
RESULTS = os.path.join(FL_ROOT, "..", "results")
|
| 26 |
+
# Use Setup 1 models by default (best performing)
|
| 27 |
+
MODELS_DIR = os.path.join(RESULTS, "Setup_1", "models")
|
| 28 |
+
|
| 29 |
+
NUM_CLASSES = 4
|
| 30 |
+
CLASS_NAMES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
|
| 31 |
+
CLASS_COLORS = ["#E74C3C", "#3498DB", "#2ECC71", "#9B59B6"]
|
| 32 |
+
CLASS_ICONS = ["π΄", "π΅", "π’", "π£"]
|
| 33 |
+
IMG_SIZE = 112
|
| 34 |
+
|
| 35 |
+
# βββ model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
class SimpleCNN(nn.Module):
|
| 37 |
+
def __init__(self, num_classes=4):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.features = nn.Sequential(
|
| 40 |
+
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
| 41 |
+
nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
|
| 42 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
| 43 |
+
nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
|
| 44 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 45 |
+
nn.BatchNorm2d(64), nn.ReLU(),
|
| 46 |
+
nn.AdaptiveAvgPool2d(4),
|
| 47 |
+
)
|
| 48 |
+
self.classifier = nn.Sequential(
|
| 49 |
+
nn.Dropout(0.3),
|
| 50 |
+
nn.Linear(64 * 4 * 4, 128), nn.ReLU(),
|
| 51 |
+
nn.Dropout(0.3),
|
| 52 |
+
nn.Linear(128, num_classes),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = self.features(x)
|
| 57 |
+
x = x.view(x.size(0), -1)
|
| 58 |
+
return self.classifier(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
TRANSFORM = transforms.Compose([
|
| 62 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| 63 |
+
transforms.ToTensor(),
|
| 64 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@st.cache_resource
|
| 69 |
+
def load_model(path):
|
| 70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
model = SimpleCNN(NUM_CLASSES).to(device)
|
| 72 |
+
state = torch.load(path, map_location=device, weights_only=True)
|
| 73 |
+
model.load_state_dict(state)
|
| 74 |
+
model.eval()
|
| 75 |
+
return model, device
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def predict(model, device, image):
|
| 79 |
+
"""Run inference on a PIL Image. Returns (class_idx, probabilities)."""
|
| 80 |
+
tensor = TRANSFORM(image).unsqueeze(0).to(device)
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
logits = model(tensor)
|
| 83 |
+
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
|
| 84 |
+
return int(np.argmax(probs)), probs
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def render_prediction_card(title, color_accent, pred_idx, probs, image):
|
| 88 |
+
"""Render a styled prediction result."""
|
| 89 |
+
confidence = probs[pred_idx] * 100
|
| 90 |
+
st.markdown(
|
| 91 |
+
f'<div style="background:rgba(20,20,35,0.9);padding:20px;border-radius:12px;'
|
| 92 |
+
f'border-top:4px solid {color_accent};margin-bottom:16px;">'
|
| 93 |
+
f'<h3 style="color:{color_accent};margin:0 0 8px 0;">{title}</h3>'
|
| 94 |
+
f'<div style="color:white;font-size:28px;font-weight:800;margin:4px 0;">'
|
| 95 |
+
f'{CLASS_ICONS[pred_idx]} {CLASS_NAMES[pred_idx]}</div>'
|
| 96 |
+
f'<div style="color:#aaa;font-size:14px;">Confidence: {confidence:.1f}%</div>'
|
| 97 |
+
f'</div>',
|
| 98 |
+
unsafe_allow_html=True,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Probability bar chart
|
| 102 |
+
fig = go.Figure(go.Bar(
|
| 103 |
+
x=probs * 100,
|
| 104 |
+
y=CLASS_NAMES,
|
| 105 |
+
orientation='h',
|
| 106 |
+
marker_color=CLASS_COLORS,
|
| 107 |
+
text=[f"{p*100:.1f}%" for p in probs],
|
| 108 |
+
textposition='auto',
|
| 109 |
+
))
|
| 110 |
+
fig.update_layout(
|
| 111 |
+
height=200,
|
| 112 |
+
margin=dict(l=0, r=10, t=10, b=10),
|
| 113 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 114 |
+
plot_bgcolor="rgba(20,20,35,0.5)",
|
| 115 |
+
xaxis=dict(range=[0, 100], title="Probability (%)", color="#aaa",
|
| 116 |
+
gridcolor="rgba(100,100,140,0.2)"),
|
| 117 |
+
yaxis=dict(color="white"),
|
| 118 |
+
font=dict(color="white"),
|
| 119 |
+
)
|
| 120 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# βββ CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
st.markdown("""
|
| 125 |
+
<style>
|
| 126 |
+
[data-testid="stSidebar"] {
|
| 127 |
+
background: linear-gradient(180deg, #0f0f1a 0%, #1a1a2e 100%);
|
| 128 |
+
}
|
| 129 |
+
.hero-title {
|
| 130 |
+
font-size: 2.4rem; font-weight: 800;
|
| 131 |
+
background: linear-gradient(135deg, #667eea, #764ba2, #f093fb);
|
| 132 |
+
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
|
| 133 |
+
}
|
| 134 |
+
</style>
|
| 135 |
+
""", unsafe_allow_html=True)
|
| 136 |
+
|
| 137 |
+
# βββ sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
st.sidebar.title("βοΈ Classification Controls")
|
| 139 |
+
|
| 140 |
+
# Setup selector
|
| 141 |
+
setup = st.sidebar.radio("Experiment Setup", ["Setup 1 (Natural)", "Setup 2 (Label Skew)"])
|
| 142 |
+
setup_dir = "Setup_1" if "1" in setup else "Setup_2"
|
| 143 |
+
models_dir = os.path.join(RESULTS, setup_dir, "models")
|
| 144 |
+
|
| 145 |
+
# Check available models
|
| 146 |
+
model_files = {}
|
| 147 |
+
for name, fname in [("FedAvg", "fedavg_best.pth"),
|
| 148 |
+
("FedProx", "fedprox_best.pth"),
|
| 149 |
+
("QPSO-FL", "qpso_best.pth")]:
|
| 150 |
+
path = os.path.join(models_dir, fname)
|
| 151 |
+
if os.path.exists(path):
|
| 152 |
+
model_files[name] = path
|
| 153 |
+
|
| 154 |
+
if not model_files:
|
| 155 |
+
st.error(f"No model weights found in `{models_dir}`. Please ensure .pth files are present.")
|
| 156 |
+
st.stop()
|
| 157 |
+
|
| 158 |
+
selected_models = st.sidebar.multiselect(
|
| 159 |
+
"Compare Models",
|
| 160 |
+
list(model_files.keys()),
|
| 161 |
+
default=list(model_files.keys()),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
st.sidebar.markdown("---")
|
| 165 |
+
input_mode = st.sidebar.radio("Image Source", ["Upload Image", "Sample from Dataset"])
|
| 166 |
+
|
| 167 |
+
# βββ title βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
+
st.markdown('<p class="hero-title">π§ Brain Tumor Classification</p>',
|
| 169 |
+
unsafe_allow_html=True)
|
| 170 |
+
st.markdown("**Federated Learning Β· SimpleCNN (~120K params) Β· "
|
| 171 |
+
"FedAvg vs FedProx vs QPSO-FL**")
|
| 172 |
+
st.markdown("---")
|
| 173 |
+
|
| 174 |
+
# βββ image input βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
image = None
|
| 176 |
+
|
| 177 |
+
if input_mode == "Upload Image":
|
| 178 |
+
uploaded = st.file_uploader(
|
| 179 |
+
"Upload a brain MRI image (JPG/PNG)",
|
| 180 |
+
type=["jpg", "jpeg", "png"],
|
| 181 |
+
)
|
| 182 |
+
if uploaded:
|
| 183 |
+
image = Image.open(uploaded).convert("RGB")
|
| 184 |
+
|
| 185 |
+
elif input_mode == "Sample from Dataset":
|
| 186 |
+
# Try to find sample images from the results plots (confusion matrices show sample data)
|
| 187 |
+
# Or look for actual dataset images
|
| 188 |
+
sample_dirs = [
|
| 189 |
+
os.path.join(FL_ROOT, "data"),
|
| 190 |
+
os.path.join(FL_ROOT, "sample_images"),
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
# Search for any sample images
|
| 194 |
+
sample_images = []
|
| 195 |
+
for d in sample_dirs:
|
| 196 |
+
if os.path.isdir(d):
|
| 197 |
+
for ext in ["*.jpg", "*.jpeg", "*.png"]:
|
| 198 |
+
sample_images.extend(glob.glob(os.path.join(d, "**", ext), recursive=True))
|
| 199 |
+
|
| 200 |
+
if sample_images:
|
| 201 |
+
# Group by class if possible
|
| 202 |
+
selected = st.selectbox("Select a sample image", sample_images,
|
| 203 |
+
format_func=lambda x: os.path.basename(x))
|
| 204 |
+
image = Image.open(selected).convert("RGB")
|
| 205 |
+
else:
|
| 206 |
+
st.info(
|
| 207 |
+
"π‘ No sample images found locally. You can either:\n"
|
| 208 |
+
"1. **Upload an image** using the sidebar option\n"
|
| 209 |
+
"2. **Add sample images** to `federated_learning/sample_images/` "
|
| 210 |
+
"(one subfolder per class: glioma/, meningioma/, notumor/, pituitary/)"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# βββ inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
if image is not None:
|
| 215 |
+
# Show the input image
|
| 216 |
+
st.subheader("π· Input Image")
|
| 217 |
+
col_img, col_info = st.columns([1, 2])
|
| 218 |
+
with col_img:
|
| 219 |
+
st.image(image, caption="Input MRI", use_container_width=True)
|
| 220 |
+
with col_info:
|
| 221 |
+
w, h = image.size
|
| 222 |
+
st.markdown(f"**Resolution:** {w}Γ{h} β resized to {IMG_SIZE}Γ{IMG_SIZE}")
|
| 223 |
+
st.markdown(f"**Models:** {', '.join(selected_models)}")
|
| 224 |
+
st.markdown(f"**Setup:** {setup}")
|
| 225 |
+
|
| 226 |
+
st.markdown("---")
|
| 227 |
+
st.subheader("π¬ Classification Results")
|
| 228 |
+
|
| 229 |
+
if not selected_models:
|
| 230 |
+
st.warning("Select at least one model from the sidebar.")
|
| 231 |
+
else:
|
| 232 |
+
# Run all selected models
|
| 233 |
+
cols = st.columns(len(selected_models))
|
| 234 |
+
model_colors = {"FedAvg": "#1f77b4", "FedProx": "#ff7f0e", "QPSO-FL": "#2ca02c"}
|
| 235 |
+
|
| 236 |
+
results = {}
|
| 237 |
+
for idx, name in enumerate(selected_models):
|
| 238 |
+
with cols[idx]:
|
| 239 |
+
model, device = load_model(model_files[name])
|
| 240 |
+
pred_idx, probs = predict(model, device, image)
|
| 241 |
+
results[name] = (pred_idx, probs)
|
| 242 |
+
render_prediction_card(
|
| 243 |
+
name,
|
| 244 |
+
model_colors.get(name, "#666"),
|
| 245 |
+
pred_idx, probs, image,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Consensus section
|
| 249 |
+
if len(results) > 1:
|
| 250 |
+
st.markdown("---")
|
| 251 |
+
st.subheader("π€ Model Consensus")
|
| 252 |
+
|
| 253 |
+
predictions = [CLASS_NAMES[r[0]] for r in results.values()]
|
| 254 |
+
unanimous = len(set(predictions)) == 1
|
| 255 |
+
|
| 256 |
+
if unanimous:
|
| 257 |
+
st.success(
|
| 258 |
+
f"β
**All {len(results)} models agree:** "
|
| 259 |
+
f"{CLASS_ICONS[list(results.values())[0][0]]} "
|
| 260 |
+
f"**{predictions[0]}**"
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
# Majority vote
|
| 264 |
+
from collections import Counter
|
| 265 |
+
votes = Counter(predictions)
|
| 266 |
+
winner, count = votes.most_common(1)[0]
|
| 267 |
+
winner_idx = CLASS_NAMES.index(winner)
|
| 268 |
+
st.warning(
|
| 269 |
+
f"β οΈ **Models disagree.** Majority vote ({count}/{len(results)}): "
|
| 270 |
+
f"{CLASS_ICONS[winner_idx]} **{winner}**"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Show disagreement details
|
| 274 |
+
for name, (pred_idx, probs) in results.items():
|
| 275 |
+
emoji = "β
" if CLASS_NAMES[pred_idx] == winner else "β"
|
| 276 |
+
st.markdown(
|
| 277 |
+
f" {emoji} **{name}:** {CLASS_NAMES[pred_idx]} "
|
| 278 |
+
f"({probs[pred_idx]*100:.1f}%)"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Average confidence across models
|
| 282 |
+
if len(results) > 1:
|
| 283 |
+
avg_probs = np.mean([r[1] for r in results.values()], axis=0)
|
| 284 |
+
ensemble_pred = int(np.argmax(avg_probs))
|
| 285 |
+
|
| 286 |
+
st.markdown("---")
|
| 287 |
+
st.subheader("π Ensemble Average (All Models)")
|
| 288 |
+
render_prediction_card(
|
| 289 |
+
"Ensemble Average",
|
| 290 |
+
"#E91E63",
|
| 291 |
+
ensemble_pred, avg_probs, image,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
else:
|
| 295 |
+
# Welcome state
|
| 296 |
+
st.markdown("""
|
| 297 |
+
### How it works
|
| 298 |
+
1. **Choose a setup** β Natural heterogeneity (Setup 1) or Label Skew (Setup 2)
|
| 299 |
+
2. **Select models** β Compare FedAvg, FedProx, and QPSO-FL side by side
|
| 300 |
+
3. **Upload or select an image** β Any brain MRI (axial slice)
|
| 301 |
+
4. **See results** β Class prediction with confidence bars for each model
|
| 302 |
+
|
| 303 |
+
The models were trained using **federated learning** across 3 simulated hospitals,
|
| 304 |
+
each with different data distributions. The QPSO-FL model uses our novel
|
| 305 |
+
**Layer-by-Layer QPSO aggregation** for fairer global model performance.
|
| 306 |
+
""")
|
| 307 |
+
|
| 308 |
+
# Show model info cards
|
| 309 |
+
st.markdown("---")
|
| 310 |
+
info_cols = st.columns(3)
|
| 311 |
+
model_info = [
|
| 312 |
+
("FedAvg", "#1f77b4", "Weighted average of client updates. Standard baseline."),
|
| 313 |
+
("FedProx", "#ff7f0e", "Adds proximal regularization (ΞΌ=0.01) to prevent client drift."),
|
| 314 |
+
("QPSO-FL", "#2ca02c", "Layer-by-layer quantum PSO with validation-loss fitness. Our contribution."),
|
| 315 |
+
]
|
| 316 |
+
for col, (name, color, desc) in zip(info_cols, model_info):
|
| 317 |
+
with col:
|
| 318 |
+
st.markdown(
|
| 319 |
+
f'<div style="background:rgba(20,20,35,0.9);padding:20px;border-radius:12px;'
|
| 320 |
+
f'border-top:3px solid {color};">'
|
| 321 |
+
f'<h4 style="color:{color};margin:0 0 8px 0;">{name}</h4>'
|
| 322 |
+
f'<p style="color:#aaa;font-size:13px;margin:0;">{desc}</p></div>',
|
| 323 |
+
unsafe_allow_html=True,
|
| 324 |
+
)
|
pages/2_Slice_Viewer.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Slice-by-Slice Segmentation Viewer
|
| 3 |
+
====================================
|
| 4 |
+
View MRI slices with ground truth and AI prediction overlay.
|
| 5 |
+
Supports all 4 modalities and 3 tumor sub-regions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import glob
|
| 12 |
+
import nibabel as nib
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import matplotlib.colors as mcolors
|
| 16 |
+
|
| 17 |
+
# st.set_page_config(page_title="Slice Viewer", layout="wide")
|
| 18 |
+
|
| 19 |
+
# βββ paths & inference βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
APP_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
if APP_DIR not in sys.path:
|
| 22 |
+
sys.path.insert(0, APP_DIR)
|
| 23 |
+
from utils.inference import ensure_prediction, get_all_patients, DEMO_DIR
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_nifti(path):
|
| 27 |
+
if not os.path.exists(path):
|
| 28 |
+
return None
|
| 29 |
+
return nib.load(path).get_fdata()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# βββ sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
st.sidebar.title("βοΈ Slice Viewer Controls")
|
| 34 |
+
|
| 35 |
+
samples = get_all_patients()
|
| 36 |
+
if not samples:
|
| 37 |
+
st.error("No demo data found. Please ensure patient volumes exist in demo_data/")
|
| 38 |
+
st.stop()
|
| 39 |
+
|
| 40 |
+
selected_id = st.sidebar.selectbox("π§ββοΈ Patient", samples)
|
| 41 |
+
|
| 42 |
+
modality = st.sidebar.selectbox(
|
| 43 |
+
"MRI Modality",
|
| 44 |
+
["FLAIR", "T1", "T1ce", "T2"],
|
| 45 |
+
index=0,
|
| 46 |
+
)
|
| 47 |
+
MOD_MAP = {"T1": 0, "T1ce": 1, "T2": 2, "FLAIR": 3}
|
| 48 |
+
|
| 49 |
+
overlay = st.sidebar.radio(
|
| 50 |
+
"Overlay",
|
| 51 |
+
["AI Prediction", "Ground Truth", "Both (Side-by-Side)", "None"],
|
| 52 |
+
index=0,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
overlay_alpha = st.sidebar.slider("Overlay Opacity", 0.1, 0.9, 0.5, 0.05)
|
| 56 |
+
|
| 57 |
+
# βββ load data (run inference if needed) βββββββββββββββββββββββββββββββββ
|
| 58 |
+
ensure_prediction(selected_id)
|
| 59 |
+
|
| 60 |
+
img_path = os.path.join(DEMO_DIR, f"{selected_id}_image.nii.gz")
|
| 61 |
+
pred_path = os.path.join(DEMO_DIR, f"{selected_id}_pred.nii.gz")
|
| 62 |
+
lbl_path = os.path.join(DEMO_DIR, f"{selected_id}_label.nii.gz")
|
| 63 |
+
|
| 64 |
+
img_data = load_nifti(img_path) # (D, H, W, 4)
|
| 65 |
+
pred_data = load_nifti(pred_path) # (D, H, W, 3) channels: 0=TC, 1=WT, 2=ET
|
| 66 |
+
lbl_data = load_nifti(lbl_path) # (D, H, W) labels: 1=NCR, 2=ED, 4=ET
|
| 67 |
+
|
| 68 |
+
if img_data is None:
|
| 69 |
+
st.error("Failed to load MRI volume.")
|
| 70 |
+
st.stop()
|
| 71 |
+
|
| 72 |
+
depth = img_data.shape[0]
|
| 73 |
+
|
| 74 |
+
# βββ title βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
st.title("π¬ Slice-by-Slice Segmentation Viewer")
|
| 76 |
+
st.markdown(f"**Patient:** `{selected_id}` Β· **Modality:** {modality} Β· "
|
| 77 |
+
f"**Volume:** {img_data.shape[0]}Γ{img_data.shape[1]}Γ{img_data.shape[2]}")
|
| 78 |
+
|
| 79 |
+
slice_idx = st.slider("Z-Axis Slice", 0, depth - 1, depth // 2)
|
| 80 |
+
|
| 81 |
+
# βββ color maps ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
# Tumor overlay: WT=green, TC=red, ET=yellow (matching 3D view)
|
| 83 |
+
TUMOR_COLORS = np.array([
|
| 84 |
+
[0, 0, 0, 0], # background (transparent)
|
| 85 |
+
[0.18, 0.80, 0.44, 1], # WT - green
|
| 86 |
+
[0.91, 0.30, 0.24, 1], # TC - red
|
| 87 |
+
[0.95, 0.77, 0.06, 1], # ET - gold
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def make_overlay_from_pred(pred_slice):
|
| 92 |
+
"""Convert (H, W, 3) prediction channels to (H, W, 4) RGBA overlay."""
|
| 93 |
+
h, w = pred_slice.shape[:2]
|
| 94 |
+
overlay_img = np.zeros((h, w, 4))
|
| 95 |
+
# Order: WT first (background), then TC, then ET on top
|
| 96 |
+
wt = pred_slice[:, :, 1] > 0.5
|
| 97 |
+
tc = pred_slice[:, :, 0] > 0.5
|
| 98 |
+
et = pred_slice[:, :, 2] > 0.5
|
| 99 |
+
overlay_img[wt] = TUMOR_COLORS[1]
|
| 100 |
+
overlay_img[tc] = TUMOR_COLORS[2]
|
| 101 |
+
overlay_img[et] = TUMOR_COLORS[3]
|
| 102 |
+
return overlay_img
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def make_overlay_from_gt(lbl_slice):
|
| 106 |
+
"""Convert integer label slice to (H, W, 4) RGBA overlay."""
|
| 107 |
+
h, w = lbl_slice.shape
|
| 108 |
+
overlay_img = np.zeros((h, w, 4))
|
| 109 |
+
wt = lbl_slice > 0
|
| 110 |
+
tc = (lbl_slice == 1) | (lbl_slice == 4)
|
| 111 |
+
et = lbl_slice == 4
|
| 112 |
+
overlay_img[wt] = TUMOR_COLORS[1]
|
| 113 |
+
overlay_img[tc] = TUMOR_COLORS[2]
|
| 114 |
+
overlay_img[et] = TUMOR_COLORS[3]
|
| 115 |
+
return overlay_img
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def render_slice(mri_slice, overlay_img, title, alpha):
|
| 119 |
+
"""Render an MRI slice with optional overlay."""
|
| 120 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 121 |
+
ax.imshow(mri_slice, cmap="gray", origin="lower")
|
| 122 |
+
if overlay_img is not None:
|
| 123 |
+
ax.imshow(overlay_img, alpha=alpha, origin="lower")
|
| 124 |
+
ax.set_title(title, fontsize=14, color="white", fontweight="bold")
|
| 125 |
+
ax.axis("off")
|
| 126 |
+
fig.patch.set_facecolor("#0a0a14")
|
| 127 |
+
return fig
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# βββ render ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
+
mri_slice = img_data[slice_idx, :, :, MOD_MAP[modality]]
|
| 132 |
+
|
| 133 |
+
if overlay == "None":
|
| 134 |
+
fig = render_slice(mri_slice, None, f"{modality} β Slice {slice_idx}", 0)
|
| 135 |
+
st.pyplot(fig)
|
| 136 |
+
|
| 137 |
+
elif overlay == "AI Prediction":
|
| 138 |
+
if pred_data is not None:
|
| 139 |
+
ov = make_overlay_from_pred(pred_data[slice_idx])
|
| 140 |
+
fig = render_slice(mri_slice, ov, f"AI Prediction β Slice {slice_idx}",
|
| 141 |
+
overlay_alpha)
|
| 142 |
+
st.pyplot(fig)
|
| 143 |
+
else:
|
| 144 |
+
st.warning("Prediction not available for this patient.")
|
| 145 |
+
|
| 146 |
+
elif overlay == "Ground Truth":
|
| 147 |
+
if lbl_data is not None:
|
| 148 |
+
ov = make_overlay_from_gt(lbl_data[slice_idx])
|
| 149 |
+
fig = render_slice(mri_slice, ov, f"Ground Truth β Slice {slice_idx}",
|
| 150 |
+
overlay_alpha)
|
| 151 |
+
st.pyplot(fig)
|
| 152 |
+
else:
|
| 153 |
+
st.warning("Ground truth not available.")
|
| 154 |
+
|
| 155 |
+
elif overlay == "Both (Side-by-Side)":
|
| 156 |
+
col1, col2 = st.columns(2)
|
| 157 |
+
with col1:
|
| 158 |
+
if lbl_data is not None:
|
| 159 |
+
ov = make_overlay_from_gt(lbl_data[slice_idx])
|
| 160 |
+
fig = render_slice(mri_slice, ov, "Ground Truth", overlay_alpha)
|
| 161 |
+
st.pyplot(fig)
|
| 162 |
+
else:
|
| 163 |
+
st.info("No ground truth available.")
|
| 164 |
+
with col2:
|
| 165 |
+
if pred_data is not None:
|
| 166 |
+
ov = make_overlay_from_pred(pred_data[slice_idx])
|
| 167 |
+
fig = render_slice(mri_slice, ov, "AI Prediction", overlay_alpha)
|
| 168 |
+
st.pyplot(fig)
|
| 169 |
+
else:
|
| 170 |
+
st.info("No prediction available.")
|
| 171 |
+
|
| 172 |
+
# βββ all modalities row ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
with st.expander("π All 4 Modalities (this slice)", expanded=False):
|
| 174 |
+
cols = st.columns(4)
|
| 175 |
+
mod_names = ["T1", "T1ce", "T2", "FLAIR"]
|
| 176 |
+
for i, col in enumerate(cols):
|
| 177 |
+
with col:
|
| 178 |
+
fig, ax = plt.subplots(figsize=(3, 3))
|
| 179 |
+
ax.imshow(img_data[slice_idx, :, :, i], cmap="gray", origin="lower")
|
| 180 |
+
ax.set_title(mod_names[i], fontsize=11, color="white")
|
| 181 |
+
ax.axis("off")
|
| 182 |
+
fig.patch.set_facecolor("#0a0a14")
|
| 183 |
+
st.pyplot(fig)
|
| 184 |
+
|
| 185 |
+
# βββ color legend ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
st.markdown("---")
|
| 187 |
+
legend_cols = st.columns(3)
|
| 188 |
+
labels = ["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"]
|
| 189 |
+
colors = ["#2ECC71", "#E74C3C", "#F1C40F"]
|
| 190 |
+
for i, col in enumerate(legend_cols):
|
| 191 |
+
with col:
|
| 192 |
+
st.markdown(
|
| 193 |
+
f'<div style="display:flex;align-items:center;gap:8px;">'
|
| 194 |
+
f'<div style="width:16px;height:16px;background:{colors[i]};'
|
| 195 |
+
f'border-radius:3px;"></div>'
|
| 196 |
+
f'<span>{labels[i]}</span></div>',
|
| 197 |
+
unsafe_allow_html=True,
|
| 198 |
+
)
|
pages/3_3D_Visualization.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
3D Interactive Brain Tumor Visualization
|
| 3 |
+
=========================================
|
| 4 |
+
Renders the brain + tumor regions as interactive 3D surfaces.
|
| 5 |
+
Supports side-by-side Ground Truth vs AI Prediction comparison.
|
| 6 |
+
Runs live inference if prediction doesn't exist yet.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import streamlit as st
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import glob
|
| 13 |
+
import nibabel as nib
|
| 14 |
+
import numpy as np
|
| 15 |
+
import plotly.graph_objects as go
|
| 16 |
+
from plotly.subplots import make_subplots
|
| 17 |
+
from skimage.measure import marching_cubes
|
| 18 |
+
|
| 19 |
+
# st.set_page_config(page_title="3D Tumor Visualization", layout="wide")
|
| 20 |
+
|
| 21 |
+
# βββ paths & inference βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
APP_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 23 |
+
if APP_DIR not in sys.path:
|
| 24 |
+
sys.path.insert(0, APP_DIR)
|
| 25 |
+
from utils.inference import ensure_prediction, get_all_patients, DEMO_DIR
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_nifti(path):
|
| 29 |
+
if not os.path.exists(path):
|
| 30 |
+
return None
|
| 31 |
+
return nib.load(path).get_fdata()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def extract_mesh(volume, level=0.5, step_size=2):
|
| 36 |
+
vol = volume[::step_size, ::step_size, ::step_size]
|
| 37 |
+
if vol.sum() == 0:
|
| 38 |
+
return None
|
| 39 |
+
try:
|
| 40 |
+
verts, faces, _, _ = marching_cubes(vol, level=level)
|
| 41 |
+
verts = verts * step_size
|
| 42 |
+
return verts, faces
|
| 43 |
+
except Exception:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def make_mesh_trace(volume, color, name, opacity, step_size, level=0.5,
|
| 48 |
+
flatshading=True, scene="scene"):
|
| 49 |
+
result = extract_mesh(volume, level=level, step_size=step_size)
|
| 50 |
+
if result is None:
|
| 51 |
+
return None
|
| 52 |
+
verts, faces = result
|
| 53 |
+
x, y, z = verts.T
|
| 54 |
+
i, j, k = faces.T
|
| 55 |
+
return go.Mesh3d(
|
| 56 |
+
x=x, y=y, z=z, i=i, j=j, k=k,
|
| 57 |
+
color=color, opacity=opacity,
|
| 58 |
+
name=name, showlegend=True,
|
| 59 |
+
flatshading=flatshading,
|
| 60 |
+
lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2, roughness=0.6),
|
| 61 |
+
lightposition=dict(x=100, y=200, z=300),
|
| 62 |
+
scene=scene,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# βββ colors ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
PRED_COLORS = {
|
| 68 |
+
"Whole Tumor (WT)": "#2ECC71", # emerald green
|
| 69 |
+
"Tumor Core (TC)": "#E74C3C", # vivid red
|
| 70 |
+
"Enhancing Tumor (ET)": "#F1C40F", # bright gold
|
| 71 |
+
}
|
| 72 |
+
GT_COLORS = {
|
| 73 |
+
"Whole Tumor (WT)": "#1ABC9C", # turquoise
|
| 74 |
+
"Tumor Core (TC)": "#9B59B6", # amethyst
|
| 75 |
+
"Enhancing Tumor (ET)": "#3498DB", # ocean blue
|
| 76 |
+
}
|
| 77 |
+
PRED_CHANNELS = {"Whole Tumor (WT)": 1, "Tumor Core (TC)": 0, "Enhancing Tumor (ET)": 2}
|
| 78 |
+
BRAIN_COLOR = "#D5D8DC"
|
| 79 |
+
|
| 80 |
+
# βββ sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
st.sidebar.title("βοΈ 3D Controls")
|
| 82 |
+
|
| 83 |
+
samples = get_all_patients()
|
| 84 |
+
if not samples:
|
| 85 |
+
st.error("No processed prediction volumes found.")
|
| 86 |
+
st.stop()
|
| 87 |
+
|
| 88 |
+
selected_id = st.sidebar.selectbox("π§ββοΈ Patient", samples)
|
| 89 |
+
|
| 90 |
+
st.sidebar.markdown("---")
|
| 91 |
+
view_mode = st.sidebar.radio(
|
| 92 |
+
"View Mode",
|
| 93 |
+
["Prediction Only", "Ground Truth Only", "Side-by-Side Comparison"],
|
| 94 |
+
index=0,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
st.sidebar.markdown("---")
|
| 98 |
+
st.sidebar.subheader("π§ Brain Surface")
|
| 99 |
+
show_brain = st.sidebar.checkbox("Show Brain", value=True)
|
| 100 |
+
brain_opacity = st.sidebar.slider("Brain Opacity", 0.02, 0.30, 0.08, 0.02)
|
| 101 |
+
|
| 102 |
+
st.sidebar.subheader("π― Tumor")
|
| 103 |
+
region_choice = st.sidebar.multiselect(
|
| 104 |
+
"Regions",
|
| 105 |
+
["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"],
|
| 106 |
+
default=["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"],
|
| 107 |
+
)
|
| 108 |
+
tumor_opacity = st.sidebar.slider("Tumor Opacity", 0.20, 1.0, 0.70, 0.05)
|
| 109 |
+
|
| 110 |
+
st.sidebar.markdown("---")
|
| 111 |
+
step_size = st.sidebar.select_slider("Mesh Quality", options=[1, 2, 3, 4], value=2)
|
| 112 |
+
|
| 113 |
+
# βββ load data (run inference if needed) βββββββββββββββββββββββββββββββββ
|
| 114 |
+
ensure_prediction(selected_id)
|
| 115 |
+
|
| 116 |
+
img_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_image.nii.gz"))
|
| 117 |
+
pred_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_pred.nii.gz"))
|
| 118 |
+
lbl_data = load_nifti(os.path.join(DEMO_DIR, f"{selected_id}_label.nii.gz"))
|
| 119 |
+
|
| 120 |
+
# βββ title βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
st.title("π 3D Brain Tumor Visualization")
|
| 122 |
+
st.markdown(f"**Patient:** `{selected_id}` Β· **Drag** to rotate Β· **Scroll** to zoom")
|
| 123 |
+
|
| 124 |
+
# βββ color legend ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
legend_cols = st.columns(6)
|
| 126 |
+
for idx, (region, color) in enumerate(PRED_COLORS.items()):
|
| 127 |
+
with legend_cols[idx]:
|
| 128 |
+
st.markdown(
|
| 129 |
+
f'<div style="display:flex;align-items:center;gap:6px;">'
|
| 130 |
+
f'<div style="width:14px;height:14px;background:{color};'
|
| 131 |
+
f'border-radius:3px;"></div><span style="font-size:13px;">Pred: {region}</span></div>',
|
| 132 |
+
unsafe_allow_html=True,
|
| 133 |
+
)
|
| 134 |
+
if view_mode in ["Ground Truth Only", "Side-by-Side Comparison"]:
|
| 135 |
+
for idx, (region, color) in enumerate(GT_COLORS.items()):
|
| 136 |
+
with legend_cols[idx + 3]:
|
| 137 |
+
st.markdown(
|
| 138 |
+
f'<div style="display:flex;align-items:center;gap:6px;">'
|
| 139 |
+
f'<div style="width:14px;height:14px;background:{color};'
|
| 140 |
+
f'border-radius:3px;"></div><span style="font-size:13px;">GT: {region}</span></div>',
|
| 141 |
+
unsafe_allow_html=True,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# βββ helper: build brain trace βββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
def build_brain_trace(scene_name="scene"):
|
| 147 |
+
if img_data is None:
|
| 148 |
+
return None
|
| 149 |
+
flair = img_data[:, :, :, 3]
|
| 150 |
+
flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8)
|
| 151 |
+
brain_mask = (flair_norm > 0.15).astype(float)
|
| 152 |
+
return make_mesh_trace(
|
| 153 |
+
brain_mask, BRAIN_COLOR, "Brain",
|
| 154 |
+
opacity=brain_opacity,
|
| 155 |
+
step_size=max(step_size, 2),
|
| 156 |
+
flatshading=False, scene=scene_name,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def build_pred_traces(scene_name="scene"):
|
| 161 |
+
traces = []
|
| 162 |
+
if pred_data is None:
|
| 163 |
+
return traces
|
| 164 |
+
for region in region_choice:
|
| 165 |
+
ch = PRED_CHANNELS[region]
|
| 166 |
+
vol = pred_data[:, :, :, ch]
|
| 167 |
+
t = make_mesh_trace(vol, PRED_COLORS[region], f"Pred: {region}",
|
| 168 |
+
tumor_opacity, step_size, scene=scene_name)
|
| 169 |
+
if t:
|
| 170 |
+
traces.append(t)
|
| 171 |
+
return traces
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def build_gt_traces(scene_name="scene"):
|
| 175 |
+
traces = []
|
| 176 |
+
if lbl_data is None:
|
| 177 |
+
return traces
|
| 178 |
+
gt_masks = {
|
| 179 |
+
"Whole Tumor (WT)": (lbl_data > 0).astype(float),
|
| 180 |
+
"Tumor Core (TC)": ((lbl_data == 1) | (lbl_data == 4)).astype(float),
|
| 181 |
+
"Enhancing Tumor (ET)": (lbl_data == 4).astype(float),
|
| 182 |
+
}
|
| 183 |
+
for region in region_choice:
|
| 184 |
+
vol = gt_masks[region]
|
| 185 |
+
t = make_mesh_trace(vol, GT_COLORS[region], f"GT: {region}",
|
| 186 |
+
tumor_opacity, step_size, scene=scene_name)
|
| 187 |
+
if t:
|
| 188 |
+
traces.append(t)
|
| 189 |
+
return traces
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
SCENE_LAYOUT = dict(
|
| 193 |
+
xaxis=dict(visible=False),
|
| 194 |
+
yaxis=dict(visible=False),
|
| 195 |
+
zaxis=dict(visible=False),
|
| 196 |
+
bgcolor="rgb(10, 10, 20)",
|
| 197 |
+
aspectmode="data",
|
| 198 |
+
camera=dict(eye=dict(x=1.6, y=1.0, z=0.8), up=dict(x=0, y=0, z=1)),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# βββ render ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
|
| 203 |
+
if view_mode == "Side-by-Side Comparison":
|
| 204 |
+
# Two 3D plots: GT on the left, Prediction on the right
|
| 205 |
+
col_left, col_right = st.columns(2)
|
| 206 |
+
|
| 207 |
+
with col_left:
|
| 208 |
+
st.markdown("### π’ Ground Truth")
|
| 209 |
+
gt_traces = []
|
| 210 |
+
if show_brain:
|
| 211 |
+
bt = build_brain_trace("scene")
|
| 212 |
+
if bt:
|
| 213 |
+
gt_traces.append(bt)
|
| 214 |
+
gt_traces.extend(build_gt_traces("scene"))
|
| 215 |
+
|
| 216 |
+
if gt_traces:
|
| 217 |
+
fig_gt = go.Figure(data=gt_traces)
|
| 218 |
+
fig_gt.update_layout(
|
| 219 |
+
scene=SCENE_LAYOUT,
|
| 220 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
| 221 |
+
height=600,
|
| 222 |
+
paper_bgcolor="rgb(10, 10, 20)",
|
| 223 |
+
legend=dict(font=dict(color="white", size=11),
|
| 224 |
+
bgcolor="rgba(20,20,40,0.8)", x=0.01, y=0.99),
|
| 225 |
+
)
|
| 226 |
+
st.plotly_chart(fig_gt, width="stretch")
|
| 227 |
+
else:
|
| 228 |
+
st.info("No ground truth data available for this patient.")
|
| 229 |
+
|
| 230 |
+
with col_right:
|
| 231 |
+
st.markdown("### π΄ AI Prediction")
|
| 232 |
+
pred_traces = []
|
| 233 |
+
if show_brain:
|
| 234 |
+
bt = build_brain_trace("scene")
|
| 235 |
+
if bt:
|
| 236 |
+
pred_traces.append(bt)
|
| 237 |
+
pred_traces.extend(build_pred_traces("scene"))
|
| 238 |
+
|
| 239 |
+
if pred_traces:
|
| 240 |
+
fig_pred = go.Figure(data=pred_traces)
|
| 241 |
+
fig_pred.update_layout(
|
| 242 |
+
scene=SCENE_LAYOUT,
|
| 243 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
| 244 |
+
height=600,
|
| 245 |
+
paper_bgcolor="rgb(10, 10, 20)",
|
| 246 |
+
legend=dict(font=dict(color="white", size=11),
|
| 247 |
+
bgcolor="rgba(20,20,40,0.8)", x=0.01, y=0.99),
|
| 248 |
+
)
|
| 249 |
+
st.plotly_chart(fig_pred, width="stretch")
|
| 250 |
+
else:
|
| 251 |
+
st.warning("No prediction data available.")
|
| 252 |
+
|
| 253 |
+
else:
|
| 254 |
+
# Single 3D view
|
| 255 |
+
all_traces = []
|
| 256 |
+
if show_brain:
|
| 257 |
+
bt = build_brain_trace("scene")
|
| 258 |
+
if bt:
|
| 259 |
+
all_traces.append(bt)
|
| 260 |
+
|
| 261 |
+
if view_mode == "Prediction Only":
|
| 262 |
+
all_traces.extend(build_pred_traces("scene"))
|
| 263 |
+
elif view_mode == "Ground Truth Only":
|
| 264 |
+
all_traces.extend(build_gt_traces("scene"))
|
| 265 |
+
|
| 266 |
+
if not all_traces:
|
| 267 |
+
st.warning("Nothing to render. Check that data exists and regions are selected.")
|
| 268 |
+
st.stop()
|
| 269 |
+
|
| 270 |
+
fig = go.Figure(data=all_traces)
|
| 271 |
+
fig.update_layout(
|
| 272 |
+
scene=SCENE_LAYOUT,
|
| 273 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
| 274 |
+
height=750,
|
| 275 |
+
paper_bgcolor="rgb(10, 10, 20)",
|
| 276 |
+
legend=dict(font=dict(color="white", size=13),
|
| 277 |
+
bgcolor="rgba(20,20,40,0.85)",
|
| 278 |
+
bordercolor="rgba(100,100,140,0.5)", borderwidth=1,
|
| 279 |
+
x=0.01, y=0.99),
|
| 280 |
+
)
|
| 281 |
+
st.plotly_chart(fig, width="stretch")
|
| 282 |
+
|
| 283 |
+
# βββ volume stats ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 284 |
+
st.markdown("---")
|
| 285 |
+
st.subheader("π Tumor Volume Statistics")
|
| 286 |
+
|
| 287 |
+
if pred_data is not None:
|
| 288 |
+
cols = st.columns(3)
|
| 289 |
+
for idx, (region, ch) in enumerate(PRED_CHANNELS.items()):
|
| 290 |
+
vol = pred_data[:, :, :, ch]
|
| 291 |
+
voxel_count = int(vol.sum())
|
| 292 |
+
volume_cc = voxel_count / 1000.0
|
| 293 |
+
color = PRED_COLORS[region]
|
| 294 |
+
with cols[idx]:
|
| 295 |
+
st.markdown(
|
| 296 |
+
f'<div style="background:rgba(30,30,50,0.8);padding:14px;'
|
| 297 |
+
f'border-radius:10px;border-left:4px solid {color};">'
|
| 298 |
+
f'<div style="color:{color};font-size:13px;font-weight:600;">{region}</div>'
|
| 299 |
+
f'<div style="color:white;font-size:26px;font-weight:700;">{volume_cc:.1f} cmΒ³</div>'
|
| 300 |
+
f'<div style="color:#888;font-size:11px;">{voxel_count:,} voxels</div></div>',
|
| 301 |
+
unsafe_allow_html=True,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# βββ dice ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 305 |
+
if lbl_data is not None and pred_data is not None:
|
| 306 |
+
st.markdown("---")
|
| 307 |
+
st.subheader("π¬ Dice Scores")
|
| 308 |
+
cols = st.columns(3)
|
| 309 |
+
gt_masks_dice = {
|
| 310 |
+
"Whole Tumor (WT)": (lbl_data > 0).astype(float),
|
| 311 |
+
"Tumor Core (TC)": ((lbl_data == 1) | (lbl_data == 4)).astype(float),
|
| 312 |
+
"Enhancing Tumor (ET)": (lbl_data == 4).astype(float),
|
| 313 |
+
}
|
| 314 |
+
for idx, (region, ch) in enumerate(PRED_CHANNELS.items()):
|
| 315 |
+
p = pred_data[:, :, :, ch]
|
| 316 |
+
g = gt_masks_dice[region]
|
| 317 |
+
dice = (2.0 * (p * g).sum()) / (p.sum() + g.sum() + 1e-8)
|
| 318 |
+
color = PRED_COLORS[region]
|
| 319 |
+
grade = "Excellent" if dice > 0.8 else "Good" if dice > 0.6 else "Fair"
|
| 320 |
+
with cols[idx]:
|
| 321 |
+
st.markdown(
|
| 322 |
+
f'<div style="background:rgba(30,30,50,0.8);padding:14px;'
|
| 323 |
+
f'border-radius:10px;border-left:4px solid {color};">'
|
| 324 |
+
f'<div style="color:{color};font-size:13px;">{region}</div>'
|
| 325 |
+
f'<div style="color:white;font-size:30px;font-weight:700;">{dice:.4f}</div>'
|
| 326 |
+
f'<div style="color:#888;">{grade}</div></div>',
|
| 327 |
+
unsafe_allow_html=True,
|
| 328 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.12.0
|
| 2 |
+
torchvision>=0.13.0
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
Pillow>=9.0.0
|
| 5 |
+
plotly>=5.10.0
|
| 6 |
+
streamlit>=1.20.0
|
| 7 |
+
matplotlib>=3.5.0
|
| 8 |
+
nibabel>=4.0.0
|
| 9 |
+
monai>=1.0.0
|
results/Setup_1/models/fedavg_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d48f93f0f3ed5e4074a24f6bc645bd55d56068467573902be606e94a75363f0
|
| 3 |
+
size 631956
|
results/Setup_1/models/fedprox_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7b6c27da43d10bbee7275de1d34eb417f3c7867694e8aabdf9ee3deebebbc52
|
| 3 |
+
size 631987
|
results/Setup_1/models/qpso_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bcd1b0bcaf795d4519a8457dba2e778a25adbda949adf337e35bd6b185f38ef
|
| 3 |
+
size 631894
|
results/Setup_2/models/fedavg_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d48f93f0f3ed5e4074a24f6bc645bd55d56068467573902be606e94a75363f0
|
| 3 |
+
size 631956
|
results/Setup_2/models/fedprox_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7b6c27da43d10bbee7275de1d34eb417f3c7867694e8aabdf9ee3deebebbc52
|
| 3 |
+
size 631987
|
results/Setup_2/models/qpso_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bcd1b0bcaf795d4519a8457dba2e778a25adbda949adf337e35bd6b185f38ef
|
| 3 |
+
size 631894
|
src/streamlit_app.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import altair as alt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
# Welcome to Streamlit!
|
| 8 |
+
|
| 9 |
+
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
+
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
+
forums](https://discuss.streamlit.io).
|
| 12 |
+
|
| 13 |
+
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
+
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
+
|
| 19 |
+
indices = np.linspace(0, 1, num_points)
|
| 20 |
+
theta = 2 * np.pi * num_turns * indices
|
| 21 |
+
radius = indices
|
| 22 |
+
|
| 23 |
+
x = radius * np.cos(theta)
|
| 24 |
+
y = radius * np.sin(theta)
|
| 25 |
+
|
| 26 |
+
df = pd.DataFrame({
|
| 27 |
+
"x": x,
|
| 28 |
+
"y": y,
|
| 29 |
+
"idx": indices,
|
| 30 |
+
"rand": np.random.randn(num_points),
|
| 31 |
+
})
|
| 32 |
+
|
| 33 |
+
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
+
.mark_point(filled=True)
|
| 35 |
+
.encode(
|
| 36 |
+
x=alt.X("x", axis=None),
|
| 37 |
+
y=alt.Y("y", axis=None),
|
| 38 |
+
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
+
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
+
))
|
utils/inference.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared inference module for 3D brain tumor segmentation.
|
| 3 |
+
Loads the AttentionUnet model and runs sliding_window_inference
|
| 4 |
+
on patients that don't have pre-computed predictions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import numpy as np
|
| 10 |
+
import nibabel as nib
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import torch
|
| 13 |
+
from monai.inferers import sliding_window_inference
|
| 14 |
+
from monai.networks.nets import AttentionUnet
|
| 15 |
+
from monai.transforms import (
|
| 16 |
+
Compose,
|
| 17 |
+
LoadImaged,
|
| 18 |
+
NormalizeIntensityd,
|
| 19 |
+
Orientationd,
|
| 20 |
+
Spacingd,
|
| 21 |
+
EnsureChannelFirstd,
|
| 22 |
+
EnsureTyped,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# βββ paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
SEG_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "segmentation"))
|
| 28 |
+
DEMO_DIR = os.path.join(SEG_DIR, "demo_data")
|
| 29 |
+
# streamlit_app/ is inside segmentation/, so go up one level to reach segmentation/
|
| 30 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 31 |
+
|
| 32 |
+
# Model checkpoint β prefer refined model (better calibration)
|
| 33 |
+
# 1. Check streamlit_app/ (where refined model lives)
|
| 34 |
+
# 2. Check segmentation/ (parent dir, where base model lives)
|
| 35 |
+
_THIS_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 36 |
+
_candidates = [
|
| 37 |
+
os.path.join(_THIS_DIR, "best_metric_model_refined.pth"), # streamlit_app/
|
| 38 |
+
os.path.join(PROJECT_ROOT, "best_metric_model_refined.pth"), # segmentation/
|
| 39 |
+
os.path.join(_THIS_DIR, "best_metric_model.pth"), # streamlit_app/
|
| 40 |
+
os.path.join(PROJECT_ROOT, "best_metric_model.pth"), # segmentation/
|
| 41 |
+
]
|
| 42 |
+
CKPT_PATH = None
|
| 43 |
+
for _c in _candidates:
|
| 44 |
+
if os.path.exists(_c):
|
| 45 |
+
CKPT_PATH = _c
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
# MONAI transforms β must match training exactly
|
| 49 |
+
INFERENCE_TRANSFORMS = Compose([
|
| 50 |
+
LoadImaged(keys=["image", "label"]),
|
| 51 |
+
EnsureChannelFirstd(keys=["image", "label"]),
|
| 52 |
+
EnsureTyped(keys=["image", "label"]),
|
| 53 |
+
Orientationd(keys=["image", "label"], axcodes="RAS"),
|
| 54 |
+
Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
|
| 55 |
+
NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
# Transforms for image-only (no label available)
|
| 59 |
+
INFERENCE_TRANSFORMS_IMG_ONLY = Compose([
|
| 60 |
+
LoadImaged(keys=["image"]),
|
| 61 |
+
EnsureChannelFirstd(keys=["image"]),
|
| 62 |
+
EnsureTyped(keys=["image"]),
|
| 63 |
+
Orientationd(keys=["image"], axcodes="RAS"),
|
| 64 |
+
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear",)),
|
| 65 |
+
NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@st.cache_resource
|
| 70 |
+
def load_seg_model():
|
| 71 |
+
"""Load the 3D Attention U-Net model (cached across sessions)."""
|
| 72 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
model = AttentionUnet(
|
| 74 |
+
spatial_dims=3,
|
| 75 |
+
in_channels=4,
|
| 76 |
+
out_channels=3,
|
| 77 |
+
channels=(16, 32, 64, 128, 256),
|
| 78 |
+
strides=(2, 2, 2, 2),
|
| 79 |
+
).to(device)
|
| 80 |
+
|
| 81 |
+
if os.path.exists(CKPT_PATH):
|
| 82 |
+
try:
|
| 83 |
+
model.load_state_dict(torch.load(CKPT_PATH, map_location=device))
|
| 84 |
+
model.eval()
|
| 85 |
+
return model, device
|
| 86 |
+
except Exception as e:
|
| 87 |
+
st.error(f"Failed to load model weights: {e}")
|
| 88 |
+
return None, None
|
| 89 |
+
else:
|
| 90 |
+
st.error(f"Model checkpoint not found at {CKPT_PATH}")
|
| 91 |
+
return None, None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def ensure_prediction(patient_id):
|
| 95 |
+
"""
|
| 96 |
+
Ensure that the prediction volume exists for a patient.
|
| 97 |
+
If _pred.nii.gz already exists, returns True immediately.
|
| 98 |
+
Otherwise, runs live inference using the exact same MONAI
|
| 99 |
+
transforms as the training pipeline.
|
| 100 |
+
"""
|
| 101 |
+
pred_path = os.path.join(DEMO_DIR, f"{patient_id}_pred.nii.gz")
|
| 102 |
+
img_path = os.path.join(DEMO_DIR, f"{patient_id}_image.nii.gz")
|
| 103 |
+
lbl_path = os.path.join(DEMO_DIR, f"{patient_id}_label.nii.gz")
|
| 104 |
+
|
| 105 |
+
# Already have prediction β skip
|
| 106 |
+
if os.path.exists(pred_path) and os.path.exists(img_path):
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
# Check if raw MRI modalities exist in patient subfolder
|
| 110 |
+
p_dir = os.path.join(DEMO_DIR, patient_id)
|
| 111 |
+
if not os.path.isdir(p_dir):
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
# Build file paths (same order as extract_demo_data.py: t1, t1ce, t2, flair)
|
| 115 |
+
mod_paths = {
|
| 116 |
+
"t1": os.path.join(p_dir, f"{patient_id}_t1.nii.gz"),
|
| 117 |
+
"t1ce": os.path.join(p_dir, f"{patient_id}_t1ce.nii.gz"),
|
| 118 |
+
"t2": os.path.join(p_dir, f"{patient_id}_t2.nii.gz"),
|
| 119 |
+
"flair": os.path.join(p_dir, f"{patient_id}_flair.nii.gz"),
|
| 120 |
+
}
|
| 121 |
+
seg_path = os.path.join(p_dir, f"{patient_id}_seg.nii.gz")
|
| 122 |
+
|
| 123 |
+
for m, mp in mod_paths.items():
|
| 124 |
+
if not os.path.exists(mp):
|
| 125 |
+
st.warning(f"Missing modality: {m} at {mp}")
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
# βββ Run live inference ββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
st.info(f"π§ **Running AI Inference** on `{patient_id}`... This may take 30-60 seconds.")
|
| 130 |
+
progress = st.progress(0)
|
| 131 |
+
status = st.empty()
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
# Build MONAI data dict (image is a list of 4 modality paths)
|
| 135 |
+
has_label = os.path.exists(seg_path)
|
| 136 |
+
data_dict = {
|
| 137 |
+
"image": [mod_paths["t1"], mod_paths["t1ce"], mod_paths["t2"], mod_paths["flair"]],
|
| 138 |
+
}
|
| 139 |
+
if has_label:
|
| 140 |
+
data_dict["label"] = seg_path
|
| 141 |
+
|
| 142 |
+
# Apply MONAI transforms (Orientation, Spacing, Normalize β matching training)
|
| 143 |
+
status.text("Loading & preprocessing with MONAI transforms...")
|
| 144 |
+
if has_label:
|
| 145 |
+
sample_data = INFERENCE_TRANSFORMS(data_dict)
|
| 146 |
+
else:
|
| 147 |
+
sample_data = INFERENCE_TRANSFORMS_IMG_ONLY(data_dict)
|
| 148 |
+
progress.progress(30)
|
| 149 |
+
|
| 150 |
+
# Run model inference
|
| 151 |
+
status.text("Running 3D U-Net inference (sliding window)...")
|
| 152 |
+
model, device = load_seg_model()
|
| 153 |
+
if model is None:
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
inputs = sample_data["image"].unsqueeze(0).to(device) # (1, 4, D, H, W)
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
outputs = sliding_window_inference(inputs, (96, 96, 96), 4, model)
|
| 159 |
+
outputs = (outputs.sigmoid() > 0.5).float()
|
| 160 |
+
progress.progress(80)
|
| 161 |
+
|
| 162 |
+
# Save processed image volume (D, H, W, 4)
|
| 163 |
+
status.text("Saving results...")
|
| 164 |
+
img_np = inputs[0].cpu().numpy().transpose(1, 2, 3, 0)
|
| 165 |
+
nib.save(nib.Nifti1Image(img_np, affine=np.eye(4)), img_path)
|
| 166 |
+
|
| 167 |
+
# Save prediction (D, H, W, 3)
|
| 168 |
+
pred_np = outputs[0].cpu().numpy().transpose(1, 2, 3, 0)
|
| 169 |
+
nib.save(nib.Nifti1Image(pred_np, affine=np.eye(4)), pred_path)
|
| 170 |
+
|
| 171 |
+
# Save ground truth label (D, H, W)
|
| 172 |
+
if has_label:
|
| 173 |
+
lbl_np = sample_data["label"][0].cpu().numpy()
|
| 174 |
+
nib.save(nib.Nifti1Image(lbl_np.astype(np.float32), affine=np.eye(4)), lbl_path)
|
| 175 |
+
elif not os.path.exists(lbl_path):
|
| 176 |
+
empty = np.zeros(pred_np.shape[:3])
|
| 177 |
+
nib.save(nib.Nifti1Image(empty.astype(np.float32), affine=np.eye(4)), lbl_path)
|
| 178 |
+
|
| 179 |
+
progress.progress(100)
|
| 180 |
+
status.text("β
Inference complete!")
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
st.error(f"Inference failed: {e}")
|
| 185 |
+
import traceback
|
| 186 |
+
st.code(traceback.format_exc())
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_all_patients():
|
| 191 |
+
"""
|
| 192 |
+
Return all patient IDs that have either pre-computed predictions
|
| 193 |
+
OR raw MRI data (can be inferred on-demand).
|
| 194 |
+
"""
|
| 195 |
+
patients = set()
|
| 196 |
+
|
| 197 |
+
# Patients with pre-computed predictions
|
| 198 |
+
import glob
|
| 199 |
+
for f in glob.glob(os.path.join(DEMO_DIR, "*_pred.nii.gz")):
|
| 200 |
+
pid = os.path.basename(f).replace("_pred.nii.gz", "")
|
| 201 |
+
patients.add(pid)
|
| 202 |
+
|
| 203 |
+
# Patients with raw MRI data (subfolder with modality files)
|
| 204 |
+
if os.path.isdir(DEMO_DIR):
|
| 205 |
+
for d in os.listdir(DEMO_DIR):
|
| 206 |
+
full = os.path.join(DEMO_DIR, d)
|
| 207 |
+
if os.path.isdir(full) and d.startswith("BraTS"):
|
| 208 |
+
# Check it has at least the flair file
|
| 209 |
+
if os.path.exists(os.path.join(full, f"{d}_flair.nii.gz")):
|
| 210 |
+
patients.add(d)
|
| 211 |
+
|
| 212 |
+
return sorted(patients)
|