import numpy as np import torch from scipy import ndimage from utils.functions import normalize, reimburse_conform def strip(voxel, model, device): """ Perform slice-wise inference using the brain stripping model. This function processes the input 3D volume slice by slice (along the first axis), using a three-slice context window for each prediction. The output is a 3D mask representing the brain region. Args: voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224), typically a single anatomical orientation (e.g., coronal or sagittal view). model (torch.nn.Module): The trained PyTorch brain stripping model. device (torch.device): Device used for inference (CPU, CUDA, or MPS). Returns: torch.Tensor: A tensor of shape (224, 224, 224) representing the predicted binary brain mask. """ model.eval() # Pad one slice on both ends to ensure valid 3-slice context at the boundaries voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min()) with torch.inference_mode(): box = torch.zeros(224, 224, 224) # Perform model inference for each slice using a 3-slice context for i in range(1, 225): image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]]) image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device) x_out = torch.sigmoid(model(image)).detach().cpu() box[i - 1] = x_out # Return as a 3D mask tensor return box.reshape(224, 224, 224) def stripping(output_dir, basename, voxel, odata, data, ssnet, shift, device): """ Perform full 3D brain stripping using a deep learning model. This function applies a neural network-based skull-stripping algorithm to isolate the brain region from a 3D MRI volume. It performs inference along three anatomical orientations—coronal, sagittal, and axial—and fuses the predictions to obtain a robust binary mask. The mask is then applied to the input image, recentred, and saved. Args: output_dir (str): Directory where intermediate and final results will be saved. basename (str): Base name of the current case (used for file naming). voxel (numpy.ndarray): Input 3D voxel data (preprocessed MRI image). odata (nibabel.Nifti1Image): Original NIfTI image before preprocessing. data (nibabel.Nifti1Image): Preprocessed NIfTI image used for model input. ssnet (torch.nn.Module): Trained brain stripping network. shift (tuple[int, int, int]): The (x, y, z) offsets applied previously during cropping. device (torch.device): Device used for inference (CPU, CUDA, or MPS). Returns: numpy.ndarray: The skull-stripped 3D brain volume. """ # Preserve original intensity data for later restoration original = voxel.copy() # Normalize the voxel intensities for model input voxel = normalize(voxel, "stripping") # Prepare data in three anatomical orientations coronal = voxel.transpose(1, 2, 0) sagittal = voxel axial = voxel.transpose(2, 1, 0) # Apply the model along each anatomical plane out_c = strip(coronal, ssnet, device).permute(2, 0, 1) # coronal → native orientation out_s = strip(sagittal, ssnet, device) # sagittal out_a = strip(axial, ssnet, device).permute(2, 1, 0) # axial → native orientation # Fuse predictions by averaging across the three planes and apply threshold out_e = ((out_c + out_s + out_a) / 3) > 0.5 out_e = out_e.cpu().numpy() # Apply the binary mask to extract the brain region stripped = original * out_e # Restore the mask to the original conformed geometry # Pad to original full size and reverse the previously applied shift out_e = np.pad(out_e, [(16, 16), (16, 16), (16, 16)], "constant", constant_values=0) out_e = np.roll(out_e, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2)) out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e) return stripped, out_filename