OpenMAP-T1 / src /utils /stripping.py
西牧慧
V3
c9d94d6
import numpy as np
import torch
from scipy import ndimage
from utils.functions import normalize, reimburse_conform
def strip(voxel, model, device):
"""
Perform slice-wise inference using the brain stripping model.
This function processes the input 3D volume slice by slice (along the first axis),
using a three-slice context window for each prediction. The output is a 3D mask
representing the brain region.
Args:
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224), typically
a single anatomical orientation (e.g., coronal or sagittal view).
model (torch.nn.Module): The trained PyTorch brain stripping model.
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
Returns:
torch.Tensor: A tensor of shape (224, 224, 224) representing the predicted
binary brain mask.
"""
model.eval()
# Pad one slice on both ends to ensure valid 3-slice context at the boundaries
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
with torch.inference_mode():
box = torch.zeros(224, 224, 224)
# Perform model inference for each slice using a 3-slice context
for i in range(1, 225):
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
x_out = torch.sigmoid(model(image)).detach().cpu()
box[i - 1] = x_out
# Return as a 3D mask tensor
return box.reshape(224, 224, 224)
def stripping(output_dir, basename, voxel, odata, data, ssnet, shift, device):
"""
Perform full 3D brain stripping using a deep learning model.
This function applies a neural network-based skull-stripping algorithm to
isolate the brain region from a 3D MRI volume. It performs inference along
three anatomical orientations—coronal, sagittal, and axial—and fuses the
predictions to obtain a robust binary mask. The mask is then applied to the
input image, recentred, and saved.
Args:
output_dir (str): Directory where intermediate and final results will be saved.
basename (str): Base name of the current case (used for file naming).
voxel (numpy.ndarray): Input 3D voxel data (preprocessed MRI image).
odata (nibabel.Nifti1Image): Original NIfTI image before preprocessing.
data (nibabel.Nifti1Image): Preprocessed NIfTI image used for model input.
ssnet (torch.nn.Module): Trained brain stripping network.
shift (tuple[int, int, int]): The (x, y, z) offsets applied previously during cropping.
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
Returns:
numpy.ndarray: The skull-stripped 3D brain volume.
"""
# Preserve original intensity data for later restoration
original = voxel.copy()
# Normalize the voxel intensities for model input
voxel = normalize(voxel, "stripping")
# Prepare data in three anatomical orientations
coronal = voxel.transpose(1, 2, 0)
sagittal = voxel
axial = voxel.transpose(2, 1, 0)
# Apply the model along each anatomical plane
out_c = strip(coronal, ssnet, device).permute(2, 0, 1) # coronal → native orientation
out_s = strip(sagittal, ssnet, device) # sagittal
out_a = strip(axial, ssnet, device).permute(2, 1, 0) # axial → native orientation
# Fuse predictions by averaging across the three planes and apply threshold
out_e = ((out_c + out_s + out_a) / 3) > 0.5
out_e = out_e.cpu().numpy()
# Apply the binary mask to extract the brain region
stripped = original * out_e
# Restore the mask to the original conformed geometry
# Pad to original full size and reverse the previously applied shift
out_e = np.pad(out_e, [(16, 16), (16, 16), (16, 16)], "constant", constant_values=0)
out_e = np.roll(out_e, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
return stripped, out_filename