Spaces:
Sleeping
Sleeping
File size: 4,132 Bytes
03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import numpy as np
import torch
from scipy import ndimage
from utils.functions import normalize, reimburse_conform
def strip(voxel, model, device):
"""
Perform slice-wise inference using the brain stripping model.
This function processes the input 3D volume slice by slice (along the first axis),
using a three-slice context window for each prediction. The output is a 3D mask
representing the brain region.
Args:
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224), typically
a single anatomical orientation (e.g., coronal or sagittal view).
model (torch.nn.Module): The trained PyTorch brain stripping model.
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
Returns:
torch.Tensor: A tensor of shape (224, 224, 224) representing the predicted
binary brain mask.
"""
model.eval()
# Pad one slice on both ends to ensure valid 3-slice context at the boundaries
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
with torch.inference_mode():
box = torch.zeros(224, 224, 224)
# Perform model inference for each slice using a 3-slice context
for i in range(1, 225):
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
x_out = torch.sigmoid(model(image)).detach().cpu()
box[i - 1] = x_out
# Return as a 3D mask tensor
return box.reshape(224, 224, 224)
def stripping(output_dir, basename, voxel, odata, data, ssnet, shift, device):
"""
Perform full 3D brain stripping using a deep learning model.
This function applies a neural network-based skull-stripping algorithm to
isolate the brain region from a 3D MRI volume. It performs inference along
three anatomical orientations—coronal, sagittal, and axial—and fuses the
predictions to obtain a robust binary mask. The mask is then applied to the
input image, recentred, and saved.
Args:
output_dir (str): Directory where intermediate and final results will be saved.
basename (str): Base name of the current case (used for file naming).
voxel (numpy.ndarray): Input 3D voxel data (preprocessed MRI image).
odata (nibabel.Nifti1Image): Original NIfTI image before preprocessing.
data (nibabel.Nifti1Image): Preprocessed NIfTI image used for model input.
ssnet (torch.nn.Module): Trained brain stripping network.
shift (tuple[int, int, int]): The (x, y, z) offsets applied previously during cropping.
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
Returns:
numpy.ndarray: The skull-stripped 3D brain volume.
"""
# Preserve original intensity data for later restoration
original = voxel.copy()
# Normalize the voxel intensities for model input
voxel = normalize(voxel, "stripping")
# Prepare data in three anatomical orientations
coronal = voxel.transpose(1, 2, 0)
sagittal = voxel
axial = voxel.transpose(2, 1, 0)
# Apply the model along each anatomical plane
out_c = strip(coronal, ssnet, device).permute(2, 0, 1) # coronal → native orientation
out_s = strip(sagittal, ssnet, device) # sagittal
out_a = strip(axial, ssnet, device).permute(2, 1, 0) # axial → native orientation
# Fuse predictions by averaging across the three planes and apply threshold
out_e = ((out_c + out_s + out_a) / 3) > 0.5
out_e = out_e.cpu().numpy()
# Apply the binary mask to extract the brain region
stripped = original * out_e
# Restore the mask to the original conformed geometry
# Pad to original full size and reverse the previously applied shift
out_e = np.pad(out_e, [(16, 16), (16, 16), (16, 16)], "constant", constant_values=0)
out_e = np.roll(out_e, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
return stripped, out_filename
|