File size: 4,423 Bytes
fa7fb3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import streamlit as st
import io
import tempfile
from skimage.transform import resize

def preprocess_nifti(nifti_file):
    """
    Loads a NIfTI file, preprocesses it, and returns a PyTorch tensor.
    
    Args:
        nifti_file (str or io.BytesIO): Path to the NIfTI file or a file-like object.
    
    Returns:
        tuple: A tuple containing the original image data and a preprocessed tensor.
    """
    try:
        nifti_img = nib.load(nifti_file)
        img_data = nifti_img.get_fdata()
        
        if len(img_data.shape) != 4 or img_data.shape[-1] != 4:
            st.error("The uploaded NIfTI file must be a 4D image with 4 channels.")
            return None, None
            
        for i in range(img_data.shape[-1]):
            channel_data = img_data[..., i]
            if np.max(channel_data) > 0:
                img_data[..., i] = channel_data / np.max(channel_data)
        
        img_data = np.transpose(img_data, (3, 2, 0, 1))
        
        tensor_data = torch.from_numpy(img_data).float()
        tensor_data = torch.unsqueeze(tensor_data, 0)
        
        return nifti_img.get_fdata(), tensor_data
        
    except Exception as e:
        st.error(f"Error during NIfTI preprocessing: {e}")
        return None, None

def postprocess_mask(prediction_tensor, original_shape):
    """
    Converts model output (tensor) into a visualizable mask (numpy array)
    and resizes it to the original image dimensions.
    """
    try:
        probabilities = F.softmax(prediction_tensor, dim=1)
        
        mask = torch.argmax(probabilities, dim=1)
        
        mask = mask.detach().cpu().numpy()
        mask = np.squeeze(mask)

        mask = np.transpose(mask, (1, 2, 0))

        resized_mask = resize(mask, original_shape[:3], order=0, preserve_range=True, anti_aliasing=False)
        
        return resized_mask
    except Exception as e:
        st.error(f"Error during mask post-processing: {e}")
        return None

def visualize_prediction(original_image, predicted_mask, slice_index=75):
    """
    Creates a 2-panel visualization of the original image and the predicted mask.
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    # Show FLAIR channel for the original image
    axes[0].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
    axes[0].set_title('Original Image (FLAIR)', fontsize=16)
    axes[0].axis('off')

    axes[1].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
    # Overlay the predicted mask directly
    mask_slice = np.rot90(predicted_mask[:, :, slice_index])
    axes[1].imshow(np.ma.masked_where(mask_slice == 0, mask_slice), cmap='jet', alpha=0.5)
    axes[1].set_title('Predicted Tumor Mask', fontsize=16)
    axes[1].axis('off')
    return fig

def combine_nifti_files(t1_file_path, t1ce_file_path, t2_file_path, flair_file_path):
    """
    Combines four 3D NIfTI files from given paths into a single 4D NIfTI file object.
    
    Args:
        t1_file_path, t1ce_file_path, t2_file_path, flair_file_path (str): Paths to the temporary NIfTI files.
        
    Returns:
        nib.Nifti1Image: A 4D NIfTI image object.
    """
    try:
        # Load the four 3D NIfTI files from file paths
        t1_img = nib.load(t1_file_path)
        t1ce_img = nib.load(t1ce_file_path)
        t2_img = nib.load(t2_file_path)
        flair_img = nib.load(flair_file_path)

        # Get the image data as NumPy arrays
        t1_data = t1_img.get_fdata()
        t1ce_data = t1ce_img.get_fdata()
        t2_data = t2_img.get_fdata()
        flair_data = flair_img.get_fdata()

        # Ensure all files have the same shape
        if not (t1_data.shape == t1ce_data.shape == t2_data.shape == flair_data.shape):
            st.error("Error: Input NIfTI files do not have matching dimensions.")
            return None

        # Stack the 3D arrays along a new (4th) dimension to create a 4D array
        combined_data = np.stack([t1_data, t1ce_data, t2_data, flair_data], axis=-1)

        # Create a new 4D NIfTI image object
        combined_img = nib.Nifti1Image(combined_data, t1_img.affine)

        return combined_img
    except Exception as e:
        st.error(f"Error combining NIfTI files: {e}")
        return None