Spaces:
Sleeping
Sleeping
File size: 9,725 Bytes
fa7fb3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | # --- Patch-based Inference Helper ---
def run_patch_inference(model, tensor, patch_depth=32):
"""
Run model inference on 3D tensor in patches along the depth axis.
Args:
model: The 3D segmentation model.
tensor: Input tensor of shape [1, 4, D, H, W].
patch_depth: Depth of each patch.
Returns:
Output tensor stitched together.
"""
device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
_, c, d, h, w = tensor.shape
output = []
for start in range(0, d, patch_depth):
end = min(start + patch_depth, d)
patch = tensor[:, :, start:end, :, :]
with torch.no_grad():
patch_out = model(patch.to(device))
output.append(patch_out.cpu())
# Concatenate along the depth axis
return torch.cat(output, dim=2)
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import numpy as np
import os
import io
import tempfile
from utils import preprocess_nifti, postprocess_mask, visualize_prediction, combine_nifti_files
# --- Page Configuration ---
st.set_page_config(
page_title="Brain Tumor Segmentation App",
layout="wide",
initial_sidebar_state="expanded"
)
# --- App Title and Description ---
st.title("Brain Tumor Segmentation")
st.write("Upload the four 3D NIfTI brain scans (.nii or .nii.gz) for each modality to get a segmentation mask of the tumor.")
st.markdown("---")
# --- Model Architecture ---
# A single block in the U-Net architecture.
class DoubleConv(nn.Module):
"""(convolution => GroupNorm => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 3D convolutional layers, GroupNorm for stable training, and ReLU activation.
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
# The downsampling part of the U-Net.
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoder = nn.Sequential(
nn.MaxPool3d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.encoder(x)
# The upsampling part of the U-Net.
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
# Use bilinear upsampling and then a convolution layer
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Pad x1 to match the size of x2 for concatenation
diffX = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffZ = x2.size()[4] - x1.size()[4]
x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2,
diffY // 2, diffY - diffY // 2,
diffX // 2, diffX - diffX // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# The final output convolutional layer.
class Out(nn.Module):
def __init__(self, in_channels, out_channels):
super(Out, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
# The complete 3D U-Net model.
class UNet3d(nn.Module):
def __init__(self, n_channels=4, n_classes=3):
super().__init__()
# The number of classes is 3 (tumor core, edema, enhancing tumor).
self.n_channels = n_channels
self.n_classes = n_classes
# Contracting path
self.conv = DoubleConv(n_channels, 16)
self.enc1 = Down(16, 32)
self.enc2 = Down(32, 64)
self.enc3 = Down(64, 128)
self.enc4 = Down(128, 256)
# Expansive path
self.dec1 = Up(256 + 128, 128)
self.dec2 = Up(128 + 64, 64)
self.dec3 = Up(64 + 32, 32)
self.dec4 = Up(32 + 16, 16)
self.out = Out(16, n_classes)
def forward(self, x):
x1 = self.conv(x)
x2 = self.enc1(x1)
x3 = self.enc2(x2)
x4 = self.enc3(x3)
x5 = self.enc4(x4)
x = self.dec1(x5, x4)
x = self.dec2(x, x3)
x = self.dec3(x, x2)
x = self.dec4(x, x1)
logits = self.out(x)
return logits
# --- Model Loading ---
@st.cache_resource
def load_model(model_path):
"""Loads the trained PyTorch model from a .pth file."""
try:
# FIX: Directly load the model object, which is what was saved.
# The weights_only=False argument is needed for custom classes.
model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
model.eval()
st.success("Model loaded successfully!")
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# --- Main App Logic ---
model_file_path = "unet3d_model.pth"
if not os.path.exists(model_file_path):
st.warning("Model file 'unet3d_model.pth' not found. Please ensure it is in the same directory.")
model = None
else:
model = load_model(model_file_path)
st.sidebar.header("Upload NIfTI Files")
t1_file = st.sidebar.file_uploader("Choose a T1 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1")
t1ce_file = st.sidebar.file_uploader("Choose a T1ce scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1ce")
t2_file = st.sidebar.file_uploader("Choose a T2 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t2")
flair_file = st.sidebar.file_uploader("Choose a FLAIR scan (.nii or .nii.gz)", type=["nii", "gz"], key="flair")
if t1_file and t1ce_file and t2_file and flair_file and model is not None:
st.info("All files uploaded successfully. Processing...")
# temp_combined_file_path is now defined at the start of the block
temp_combined_file_path = None
with st.spinner("Combining NIfTI files and making prediction..."):
try:
# Create temporary files for each uploaded file
with tempfile.NamedTemporaryFile(suffix=f"_{t1_file.name}") as t1_temp, \
tempfile.NamedTemporaryFile(suffix=f"_{t1ce_file.name}") as t1ce_temp, \
tempfile.NamedTemporaryFile(suffix=f"_{t2_file.name}") as t2_temp, \
tempfile.NamedTemporaryFile(suffix=f"_{flair_file.name}") as flair_temp:
t1_temp.write(t1_file.getvalue())
t1ce_temp.write(t1ce_file.getvalue())
t2_temp.write(t2_file.getvalue())
flair_temp.write(flair_file.getvalue())
# Pass the temporary file paths to the combine function
combined_nifti_img = combine_nifti_files(t1_temp.name, t1ce_temp.name, t2_temp.name, flair_temp.name)
original_data = combined_nifti_img.get_fdata()
# Preprocess the combined image
# We need to save the combined NIfTI object to a file for nibabel to load it properly
temp_combined_file_path = "combined_4d.nii.gz"
nib.save(combined_nifti_img, temp_combined_file_path)
_, processed_tensor = preprocess_nifti(temp_combined_file_path)
if original_data is not None and processed_tensor is not None:
st.success("Preprocessing complete!")
# --- Patch-based Model Prediction ---
st.info("Running patch-based model inference...")
try:
prediction_tensor = run_patch_inference(model, processed_tensor, patch_depth=32)
st.success("Prediction complete!")
except Exception as e:
st.error(f"Error during patch-based inference: {e}")
raise
# Post-process the prediction to get a mask, resizing back to original size
predicted_mask = postprocess_mask(prediction_tensor, original_data.shape)
if predicted_mask is not None:
st.header("Results")
# Ensure mask is int and shape matches for visualization
max_slices = original_data.shape[2]
slice_index = st.slider("Select an axial slice to view", 0, max_slices - 1, max_slices // 2)
fig = visualize_prediction(original_data, predicted_mask.astype(int), slice_index=slice_index)
st.pyplot(fig)
else:
st.error("Could not post-process the model's prediction.")
except Exception as e:
st.error(f"An error occurred during processing: {e}")
st.error("Please ensure the uploaded files are valid NIfTI files with the same dimensions.")
finally:
# Clean up temporary files
if os.path.exists(temp_combined_file_path):
os.remove(temp_combined_file_path)
# --- Footer ---
st.markdown("---")
st.markdown("Developed with PyTorch and Streamlit.")
|