Spaces:
Sleeping
Sleeping
| # --- Patch-based Inference Helper --- | |
| def run_patch_inference(model, tensor, patch_depth=32): | |
| """ | |
| Run model inference on 3D tensor in patches along the depth axis. | |
| Args: | |
| model: The 3D segmentation model. | |
| tensor: Input tensor of shape [1, 4, D, H, W]. | |
| patch_depth: Depth of each patch. | |
| Returns: | |
| Output tensor stitched together. | |
| """ | |
| device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu') | |
| _, c, d, h, w = tensor.shape | |
| output = [] | |
| for start in range(0, d, patch_depth): | |
| end = min(start + patch_depth, d) | |
| patch = tensor[:, :, start:end, :, :] | |
| with torch.no_grad(): | |
| patch_out = model(patch.to(device)) | |
| output.append(patch_out.cpu()) | |
| # Concatenate along the depth axis | |
| return torch.cat(output, dim=2) | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import nibabel as nib | |
| import numpy as np | |
| import os | |
| import io | |
| import tempfile | |
| from utils import preprocess_nifti, postprocess_mask, visualize_prediction, combine_nifti_files | |
| # --- Page Configuration --- | |
| st.set_page_config( | |
| page_title="Brain Tumor Segmentation App", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- App Title and Description --- | |
| st.title("Brain Tumor Segmentation") | |
| st.write("Upload the four 3D NIfTI brain scans (.nii or .nii.gz) for each modality to get a segmentation mask of the tumor.") | |
| st.markdown("---") | |
| # --- Model Architecture --- | |
| # A single block in the U-Net architecture. | |
| class DoubleConv(nn.Module): | |
| """(convolution => GroupNorm => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| # 3D convolutional layers, GroupNorm for stable training, and ReLU activation. | |
| self.double_conv = nn.Sequential( | |
| nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| # The downsampling part of the U-Net. | |
| class Down(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.MaxPool3d(2), | |
| DoubleConv(in_channels, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.encoder(x) | |
| # The upsampling part of the U-Net. | |
| class Up(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| # Use bilinear upsampling and then a convolution layer | |
| self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| # Pad x1 to match the size of x2 for concatenation | |
| diffX = x2.size()[2] - x1.size()[2] | |
| diffY = x2.size()[3] - x1.size()[3] | |
| diffZ = x2.size()[4] - x1.size()[4] | |
| x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2, | |
| diffY // 2, diffY - diffY // 2, | |
| diffX // 2, diffX - diffX // 2]) | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| # The final output convolutional layer. | |
| class Out(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(Out, self).__init__() | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| # The complete 3D U-Net model. | |
| class UNet3d(nn.Module): | |
| def __init__(self, n_channels=4, n_classes=3): | |
| super().__init__() | |
| # The number of classes is 3 (tumor core, edema, enhancing tumor). | |
| self.n_channels = n_channels | |
| self.n_classes = n_classes | |
| # Contracting path | |
| self.conv = DoubleConv(n_channels, 16) | |
| self.enc1 = Down(16, 32) | |
| self.enc2 = Down(32, 64) | |
| self.enc3 = Down(64, 128) | |
| self.enc4 = Down(128, 256) | |
| # Expansive path | |
| self.dec1 = Up(256 + 128, 128) | |
| self.dec2 = Up(128 + 64, 64) | |
| self.dec3 = Up(64 + 32, 32) | |
| self.dec4 = Up(32 + 16, 16) | |
| self.out = Out(16, n_classes) | |
| def forward(self, x): | |
| x1 = self.conv(x) | |
| x2 = self.enc1(x1) | |
| x3 = self.enc2(x2) | |
| x4 = self.enc3(x3) | |
| x5 = self.enc4(x4) | |
| x = self.dec1(x5, x4) | |
| x = self.dec2(x, x3) | |
| x = self.dec3(x, x2) | |
| x = self.dec4(x, x1) | |
| logits = self.out(x) | |
| return logits | |
| # --- Model Loading --- | |
| def load_model(model_path): | |
| """Loads the trained PyTorch model from a .pth file.""" | |
| try: | |
| # FIX: Directly load the model object, which is what was saved. | |
| # The weights_only=False argument is needed for custom classes. | |
| model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False) | |
| model.eval() | |
| st.success("Model loaded successfully!") | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return None | |
| # --- Main App Logic --- | |
| model_file_path = "unet3d_model.pth" | |
| if not os.path.exists(model_file_path): | |
| st.warning("Model file 'unet3d_model.pth' not found. Please ensure it is in the same directory.") | |
| model = None | |
| else: | |
| model = load_model(model_file_path) | |
| st.sidebar.header("Upload NIfTI Files") | |
| t1_file = st.sidebar.file_uploader("Choose a T1 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1") | |
| t1ce_file = st.sidebar.file_uploader("Choose a T1ce scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1ce") | |
| t2_file = st.sidebar.file_uploader("Choose a T2 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t2") | |
| flair_file = st.sidebar.file_uploader("Choose a FLAIR scan (.nii or .nii.gz)", type=["nii", "gz"], key="flair") | |
| if t1_file and t1ce_file and t2_file and flair_file and model is not None: | |
| st.info("All files uploaded successfully. Processing...") | |
| # temp_combined_file_path is now defined at the start of the block | |
| temp_combined_file_path = None | |
| with st.spinner("Combining NIfTI files and making prediction..."): | |
| try: | |
| # Create temporary files for each uploaded file | |
| with tempfile.NamedTemporaryFile(suffix=f"_{t1_file.name}") as t1_temp, \ | |
| tempfile.NamedTemporaryFile(suffix=f"_{t1ce_file.name}") as t1ce_temp, \ | |
| tempfile.NamedTemporaryFile(suffix=f"_{t2_file.name}") as t2_temp, \ | |
| tempfile.NamedTemporaryFile(suffix=f"_{flair_file.name}") as flair_temp: | |
| t1_temp.write(t1_file.getvalue()) | |
| t1ce_temp.write(t1ce_file.getvalue()) | |
| t2_temp.write(t2_file.getvalue()) | |
| flair_temp.write(flair_file.getvalue()) | |
| # Pass the temporary file paths to the combine function | |
| combined_nifti_img = combine_nifti_files(t1_temp.name, t1ce_temp.name, t2_temp.name, flair_temp.name) | |
| original_data = combined_nifti_img.get_fdata() | |
| # Preprocess the combined image | |
| # We need to save the combined NIfTI object to a file for nibabel to load it properly | |
| temp_combined_file_path = "combined_4d.nii.gz" | |
| nib.save(combined_nifti_img, temp_combined_file_path) | |
| _, processed_tensor = preprocess_nifti(temp_combined_file_path) | |
| if original_data is not None and processed_tensor is not None: | |
| st.success("Preprocessing complete!") | |
| # --- Patch-based Model Prediction --- | |
| st.info("Running patch-based model inference...") | |
| try: | |
| prediction_tensor = run_patch_inference(model, processed_tensor, patch_depth=32) | |
| st.success("Prediction complete!") | |
| except Exception as e: | |
| st.error(f"Error during patch-based inference: {e}") | |
| raise | |
| # Post-process the prediction to get a mask, resizing back to original size | |
| predicted_mask = postprocess_mask(prediction_tensor, original_data.shape) | |
| if predicted_mask is not None: | |
| st.header("Results") | |
| # Ensure mask is int and shape matches for visualization | |
| max_slices = original_data.shape[2] | |
| slice_index = st.slider("Select an axial slice to view", 0, max_slices - 1, max_slices // 2) | |
| fig = visualize_prediction(original_data, predicted_mask.astype(int), slice_index=slice_index) | |
| st.pyplot(fig) | |
| else: | |
| st.error("Could not post-process the model's prediction.") | |
| except Exception as e: | |
| st.error(f"An error occurred during processing: {e}") | |
| st.error("Please ensure the uploaded files are valid NIfTI files with the same dimensions.") | |
| finally: | |
| # Clean up temporary files | |
| if os.path.exists(temp_combined_file_path): | |
| os.remove(temp_combined_file_path) | |
| # --- Footer --- | |
| st.markdown("---") | |
| st.markdown("Developed with PyTorch and Streamlit.") | |