Spaces:
Sleeping
Sleeping
西牧慧 commited on
Commit ·
c9d94d6
1
Parent(s): 6275589
V3
Browse files- src/parcellation.py +5 -7
- src/utils/cropping.py +41 -14
- src/utils/functions.py +6 -2
- src/utils/hemisphere.py +65 -53
- src/utils/load_model.py +47 -52
- src/utils/parcellation.py +92 -61
- src/utils/stripping.py +58 -62
src/parcellation.py
CHANGED
|
@@ -94,10 +94,8 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
|
| 94 |
|
| 95 |
# Load the pre-trained models from the fixed "model/" folder
|
| 96 |
print("Loading models...")
|
| 97 |
-
cnet, ssnet,
|
| 98 |
print("Models loaded successfully.")
|
| 99 |
-
# cnet, ssnet, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
|
| 100 |
-
|
| 101 |
# --- Processing Flow (based on the original parcellation.py) ---
|
| 102 |
# 1. Load the input image, convert to canonical orientation, and remove extra dimensions
|
| 103 |
print("Loading and preprocessing the input image...")
|
|
@@ -115,7 +113,7 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
|
| 115 |
|
| 116 |
# 3. Cropping
|
| 117 |
print("Cropping the input image...")
|
| 118 |
-
cropped, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
|
| 119 |
print("Cropping completed.")
|
| 120 |
if only_face_cropping:
|
| 121 |
pass
|
|
@@ -123,18 +121,18 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
|
| 123 |
else:
|
| 124 |
# 4. Skull stripping
|
| 125 |
print("Performing skull stripping...")
|
| 126 |
-
stripped,
|
| 127 |
print("Skull stripping completed.")
|
| 128 |
if only_skull_stripping:
|
| 129 |
pass
|
| 130 |
else:
|
| 131 |
# 5. Parcellation
|
| 132 |
print("Starting parcellation...")
|
| 133 |
-
parcellated = parcellation(stripped,
|
| 134 |
print("Parcellation completed.")
|
| 135 |
# 6. Separate into hemispheres
|
| 136 |
print("Separating hemispheres...")
|
| 137 |
-
separated = hemisphere(stripped,
|
| 138 |
print("Hemispheres separated.")
|
| 139 |
# 7. Postprocessing
|
| 140 |
print("Postprocessing the parcellated data...")
|
|
|
|
| 94 |
|
| 95 |
# Load the pre-trained models from the fixed "model/" folder
|
| 96 |
print("Loading models...")
|
| 97 |
+
cnet, ssnet, pnet, hnet = load_model("model/", device=device)
|
| 98 |
print("Models loaded successfully.")
|
|
|
|
|
|
|
| 99 |
# --- Processing Flow (based on the original parcellation.py) ---
|
| 100 |
# 1. Load the input image, convert to canonical orientation, and remove extra dimensions
|
| 101 |
print("Loading and preprocessing the input image...")
|
|
|
|
| 113 |
|
| 114 |
# 3. Cropping
|
| 115 |
print("Cropping the input image...")
|
| 116 |
+
cropped, shift, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
|
| 117 |
print("Cropping completed.")
|
| 118 |
if only_face_cropping:
|
| 119 |
pass
|
|
|
|
| 121 |
else:
|
| 122 |
# 4. Skull stripping
|
| 123 |
print("Performing skull stripping...")
|
| 124 |
+
stripped, out_filename = stripping(opt.o, basename, cropped, odata, data, ssnet, shift, device)
|
| 125 |
print("Skull stripping completed.")
|
| 126 |
if only_skull_stripping:
|
| 127 |
pass
|
| 128 |
else:
|
| 129 |
# 5. Parcellation
|
| 130 |
print("Starting parcellation...")
|
| 131 |
+
parcellated = parcellation(stripped, pnet, device)
|
| 132 |
print("Parcellation completed.")
|
| 133 |
# 6. Separate into hemispheres
|
| 134 |
print("Separating hemispheres...")
|
| 135 |
+
separated = hemisphere(stripped, hnet, device)
|
| 136 |
print("Hemispheres separated.")
|
| 137 |
# 7. Postprocessing
|
| 138 |
print("Postprocessing the parcellated data...")
|
src/utils/cropping.py
CHANGED
|
@@ -1,31 +1,44 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
from scipy.ndimage import binary_closing
|
|
|
|
| 4 |
|
| 5 |
from utils.functions import normalize, reimburse_conform
|
| 6 |
|
| 7 |
|
| 8 |
def crop(voxel, model, device):
|
| 9 |
"""
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
Args:
|
| 13 |
-
voxel (numpy.ndarray):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
Returns:
|
| 18 |
-
torch.Tensor: The
|
| 19 |
"""
|
|
|
|
|
|
|
| 20 |
model.eval()
|
|
|
|
| 21 |
with torch.inference_mode():
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def closing(voxel):
|
|
@@ -59,7 +72,7 @@ def cropping(output_dir, basename, odata, data, cnet, device):
|
|
| 59 |
numpy.ndarray: The cropped medical imaging data.
|
| 60 |
"""
|
| 61 |
voxel = data.get_fdata().astype("float32")
|
| 62 |
-
voxel = normalize(voxel)
|
| 63 |
|
| 64 |
coronal = voxel.transpose(1, 2, 0)
|
| 65 |
sagittal = voxel
|
|
@@ -72,4 +85,18 @@ def cropping(output_dir, basename, odata, data, cnet, device):
|
|
| 72 |
|
| 73 |
out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
from scipy.ndimage import binary_closing
|
| 4 |
+
from scipy import ndimage
|
| 5 |
|
| 6 |
from utils.functions import normalize, reimburse_conform
|
| 7 |
|
| 8 |
|
| 9 |
def crop(voxel, model, device):
|
| 10 |
"""
|
| 11 |
+
Apply a neural network-based cropping operation on 3D voxel data.
|
| 12 |
+
|
| 13 |
+
This function slides a 3-slice window across the input volume along the first axis
|
| 14 |
+
and predicts a binary mask for each slice using the given model. The outputs are then
|
| 15 |
+
aggregated into a full 3D prediction volume.
|
| 16 |
|
| 17 |
Args:
|
| 18 |
+
voxel (numpy.ndarray): Input 3D array of shape (N, 256, 256). The first dimension
|
| 19 |
+
corresponds to the slice index (typically coronal or sagittal).
|
| 20 |
+
model (torch.nn.Module): The trained PyTorch model that predicts binary masks
|
| 21 |
+
for each input slice triplet.
|
| 22 |
+
device (torch.device): The device (CPU, CUDA, or MPS) on which inference will run.
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
+
torch.Tensor: The predicted 3D binary mask of shape (256, 256, 256).
|
| 26 |
"""
|
| 27 |
+
# Pad the input volume by one slice at each end to allow 3-slice context
|
| 28 |
+
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
|
| 29 |
model.eval()
|
| 30 |
+
|
| 31 |
with torch.inference_mode():
|
| 32 |
+
box = torch.zeros(256, 256, 256)
|
| 33 |
+
|
| 34 |
+
# Iterate through each target slice and predict using a 3-slice input context
|
| 35 |
+
for i in range(1, 257):
|
| 36 |
+
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
|
| 37 |
+
image = torch.tensor(image.reshape(1, 3, 256, 256)).to(device)
|
| 38 |
+
x_out = torch.sigmoid(model(image)).detach().cpu()
|
| 39 |
+
box[i - 1] = x_out
|
| 40 |
+
|
| 41 |
+
return box.reshape(256, 256, 256)
|
| 42 |
|
| 43 |
|
| 44 |
def closing(voxel):
|
|
|
|
| 72 |
numpy.ndarray: The cropped medical imaging data.
|
| 73 |
"""
|
| 74 |
voxel = data.get_fdata().astype("float32")
|
| 75 |
+
voxel = normalize(voxel, "cropping")
|
| 76 |
|
| 77 |
coronal = voxel.transpose(1, 2, 0)
|
| 78 |
sagittal = voxel
|
|
|
|
| 85 |
|
| 86 |
out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
|
| 87 |
|
| 88 |
+
# Compute center of mass for the masked brain
|
| 89 |
+
x, y, z = map(int, ndimage.center_of_mass(out_e))
|
| 90 |
+
|
| 91 |
+
# Compute shifts required to center the brain
|
| 92 |
+
xd = 128 - x
|
| 93 |
+
yd = 120 - y
|
| 94 |
+
zd = 128 - z
|
| 95 |
+
|
| 96 |
+
# Translate (roll) the image to center the brain region
|
| 97 |
+
cropped = np.roll(cropped, (xd, yd, zd), axis=(0, 1, 2))
|
| 98 |
+
|
| 99 |
+
# Crop out boundary padding to reduce size and focus on the centered brain
|
| 100 |
+
cropped = cropped[16:-16, 16:-16, 16:-16]
|
| 101 |
+
|
| 102 |
+
return cropped, (xd, yd, zd), out_filename
|
src/utils/functions.py
CHANGED
|
@@ -5,9 +5,13 @@ import numpy as np
|
|
| 5 |
from nibabel import processing
|
| 6 |
|
| 7 |
|
| 8 |
-
def normalize(voxel):
|
| 9 |
nonzero = voxel[voxel > 0]
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
voxel = (voxel - np.min(voxel)) / (np.max(voxel) - np.min(voxel))
|
| 12 |
voxel = (voxel * 2) - 1
|
| 13 |
return voxel.astype("float32")
|
|
|
|
| 5 |
from nibabel import processing
|
| 6 |
|
| 7 |
|
| 8 |
+
def normalize(voxel, mode):
|
| 9 |
nonzero = voxel[voxel > 0]
|
| 10 |
+
if mode in ["cropping", "stripping"]:
|
| 11 |
+
clip = 2
|
| 12 |
+
elif mode in ["parcellation", "hemisphere"]:
|
| 13 |
+
clip = 3
|
| 14 |
+
voxel = np.clip(voxel, 0, np.mean(nonzero) + np.std(nonzero) * clip)
|
| 15 |
voxel = (voxel - np.min(voxel)) / (np.max(voxel) - np.min(voxel))
|
| 16 |
voxel = (voxel * 2) - 1
|
| 17 |
return voxel.astype("float32")
|
src/utils/hemisphere.py
CHANGED
|
@@ -1,92 +1,104 @@
|
|
| 1 |
import torch
|
| 2 |
from scipy.ndimage import binary_dilation
|
| 3 |
-
|
| 4 |
from utils.functions import normalize
|
| 5 |
|
| 6 |
|
| 7 |
-
def separate(voxel, model, device
|
| 8 |
"""
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
Args:
|
| 12 |
-
voxel (
|
| 13 |
-
model (torch.nn.Module):
|
| 14 |
-
device (torch.device):
|
| 15 |
-
mode (str): The mode of separation, either 'c' for coronal or 'a' for axial.
|
| 16 |
|
| 17 |
Returns:
|
| 18 |
-
torch.Tensor:
|
|
|
|
| 19 |
"""
|
| 20 |
-
if mode == "c":
|
| 21 |
-
# Set the stack dimensions for coronal mode
|
| 22 |
-
stack = (224, 192, 192)
|
| 23 |
-
elif mode == "a":
|
| 24 |
-
# Set the stack dimensions for axial mode
|
| 25 |
-
stack = (192, 224, 192)
|
| 26 |
-
|
| 27 |
-
# Set the model to evaluation mode
|
| 28 |
model.eval()
|
| 29 |
|
| 30 |
-
#
|
|
|
|
|
|
|
| 31 |
with torch.inference_mode():
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
|
| 35 |
-
# Iterate
|
| 36 |
-
for i
|
| 37 |
-
|
| 38 |
-
image = torch.tensor(
|
| 39 |
-
# Move the tensor to the specified device
|
| 40 |
-
image = image.to(device)
|
| 41 |
-
# Perform a forward pass through the model and apply softmax
|
| 42 |
-
x_out = torch.softmax(model(image), 1).detach()
|
| 43 |
-
# Store the output in the corresponding slice of the output tensor
|
| 44 |
-
output[i] = x_out
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
| 48 |
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
|
|
|
| 51 |
"""
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
Args:
|
| 55 |
-
voxel (
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
device (torch.device): The device to run the neural networks on (e.g., 'cpu' or 'cuda').
|
| 59 |
|
| 60 |
Returns:
|
| 61 |
-
numpy.ndarray:
|
|
|
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
-
# Normalize
|
| 64 |
-
voxel = normalize(voxel)
|
| 65 |
|
| 66 |
-
#
|
| 67 |
coronal = voxel.transpose(1, 2, 0)
|
| 68 |
transverse = voxel.transpose(2, 1, 0)
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
out_c = separate(coronal,
|
| 72 |
-
out_a = separate(transverse,
|
| 73 |
|
| 74 |
-
#
|
| 75 |
out_e = out_c + out_a
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
out_e = torch.argmax(out_e, 0).cpu().numpy()
|
| 79 |
|
| 80 |
-
#
|
| 81 |
torch.cuda.empty_cache()
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
dilated_mask_1[out_e == 2] = 2
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=
|
|
|
|
| 89 |
dilated_mask_2[dilated_mask_1 == 1] = 1
|
| 90 |
|
| 91 |
-
# Return the final dilated mask
|
| 92 |
return dilated_mask_2
|
|
|
|
| 1 |
import torch
|
| 2 |
from scipy.ndimage import binary_dilation
|
| 3 |
+
import numpy as np
|
| 4 |
from utils.functions import normalize
|
| 5 |
|
| 6 |
|
| 7 |
+
def separate(voxel, model, device):
|
| 8 |
"""
|
| 9 |
+
Perform slice-wise inference using a hemisphere separation model.
|
| 10 |
+
|
| 11 |
+
This function runs a 2.5D neural network across slices of a 3D input volume.
|
| 12 |
+
Each slice is processed in the context of its immediate neighbors (previous
|
| 13 |
+
and next slices) to improve spatial coherence. The model outputs a
|
| 14 |
+
three-class probability map distinguishing background, left hemisphere,
|
| 15 |
+
and right hemisphere regions.
|
| 16 |
|
| 17 |
Args:
|
| 18 |
+
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224).
|
| 19 |
+
model (torch.nn.Module): Trained hemisphere segmentation model (U-Net architecture).
|
| 20 |
+
device (torch.device): Computational device (CPU, CUDA, or MPS).
|
|
|
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
+
torch.Tensor: A tensor of shape (224, 3, 224, 224) containing softmax
|
| 24 |
+
probabilities for each class at every voxel.
|
| 25 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
model.eval()
|
| 27 |
|
| 28 |
+
# Pad the volume by one slice on both ends to provide full 3-slice context
|
| 29 |
+
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
|
| 30 |
+
|
| 31 |
with torch.inference_mode():
|
| 32 |
+
# Output tensor for storing model predictions (class probabilities)
|
| 33 |
+
box = torch.zeros(224, 3, 224, 224)
|
| 34 |
|
| 35 |
+
# Iterate slice-by-slice along the first axis
|
| 36 |
+
for i in range(1, 225):
|
| 37 |
+
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
|
| 38 |
+
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Model inference with softmax normalization across classes
|
| 41 |
+
x_out = torch.softmax(model(image), dim=1).detach().cpu()
|
| 42 |
+
box[i - 1] = x_out
|
| 43 |
|
| 44 |
+
# Return complete 3D probability map
|
| 45 |
+
return box.reshape(224, 3, 224, 224)
|
| 46 |
|
| 47 |
+
|
| 48 |
+
def hemisphere(voxel, hnet, device):
|
| 49 |
"""
|
| 50 |
+
Perform hemisphere separation on a brain MRI volume using a deep learning model.
|
| 51 |
+
|
| 52 |
+
The function predicts left and right hemisphere regions from a normalized
|
| 53 |
+
3D MRI volume using multi-view inference (coronal and transverse planes).
|
| 54 |
+
Predictions from both orientations are fused to improve robustness. The final
|
| 55 |
+
label map is post-processed using binary dilation to smooth and expand hemisphere
|
| 56 |
+
boundaries, ensuring anatomical continuity.
|
| 57 |
|
| 58 |
Args:
|
| 59 |
+
voxel (numpy.ndarray): Input 3D brain volume to be separated into hemispheres.
|
| 60 |
+
hnet (torch.nn.Module): Trained hemisphere segmentation model.
|
| 61 |
+
device (torch.device): Target device for computation (e.g., 'cuda', 'cpu').
|
|
|
|
| 62 |
|
| 63 |
Returns:
|
| 64 |
+
numpy.ndarray: A 3D integer array representing the hemisphere mask:
|
| 65 |
+
- 0: Background
|
| 66 |
+
- 1: Left hemisphere
|
| 67 |
+
- 2: Right hemisphere
|
| 68 |
"""
|
| 69 |
+
# Normalize voxel intensities for inference
|
| 70 |
+
voxel = normalize(voxel, "hemisphere")
|
| 71 |
|
| 72 |
+
# Prepare different anatomical orientations for inference
|
| 73 |
coronal = voxel.transpose(1, 2, 0)
|
| 74 |
transverse = voxel.transpose(2, 1, 0)
|
| 75 |
|
| 76 |
+
# Perform inference for both coronal and transverse orientations
|
| 77 |
+
out_c = separate(coronal, hnet, device).permute(1, 3, 0, 2)
|
| 78 |
+
out_a = separate(transverse, hnet, device).permute(1, 3, 2, 0)
|
| 79 |
|
| 80 |
+
# Fuse both outputs by summing class probabilities
|
| 81 |
out_e = out_c + out_a
|
| 82 |
|
| 83 |
+
# Determine final class labels (0, 1, or 2) by selecting the most probable class
|
| 84 |
+
out_e = torch.argmax(out_e, dim=0).cpu().numpy()
|
| 85 |
|
| 86 |
+
# Release any residual GPU memory
|
| 87 |
torch.cuda.empty_cache()
|
| 88 |
|
| 89 |
+
# --------------------------
|
| 90 |
+
# Post-processing step: binary dilation
|
| 91 |
+
# --------------------------
|
| 92 |
+
|
| 93 |
+
# First, dilate the left hemisphere (class 1)
|
| 94 |
+
dilated_mask_1 = binary_dilation(out_e == 1, iterations=1).astype("int16")
|
| 95 |
+
# Preserve right hemisphere voxels from the original prediction
|
| 96 |
dilated_mask_1[out_e == 2] = 2
|
| 97 |
|
| 98 |
+
# Then, dilate the right hemisphere (class 2) symmetrically
|
| 99 |
+
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=1).astype("int16") * 2
|
| 100 |
+
# Restore left hemisphere voxels to prevent overwriting
|
| 101 |
dilated_mask_2[dilated_mask_1 == 1] = 1
|
| 102 |
|
| 103 |
+
# Return the final dilated and fused hemisphere mask
|
| 104 |
return dilated_mask_2
|
src/utils/load_model.py
CHANGED
|
@@ -8,70 +8,65 @@ from utils.network import UNet
|
|
| 8 |
|
| 9 |
def load_model(model_dir, device):
|
| 10 |
"""
|
| 11 |
-
|
| 12 |
-
The models loaded are:
|
| 13 |
-
1. CNet: A U-Net model for some specific task.
|
| 14 |
-
2. SSNet: Another U-Net model for a different task.
|
| 15 |
-
3. PNet coronal: A U-Net model for coronal plane predictions.
|
| 16 |
-
4. PNet sagittal: A U-Net model for sagittal plane predictions.
|
| 17 |
-
5. PNet axial: A U-Net model for axial plane predictions.
|
| 18 |
-
6. HNet coronal: A U-Net model for coronal plane predictions with different input/output channels.
|
| 19 |
-
7. HNet axial: A U-Net model for axial plane predictions with different input/output channels.
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
Returns:
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
# Unzip the Model.zip file
|
| 29 |
model_zip_path = os.path.join(model_dir, "model.zip")
|
| 30 |
with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
|
| 31 |
zip_ref.extractall(model_dir)
|
| 32 |
-
|
| 33 |
-
# Load CNet
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
cnet.to(device)
|
| 37 |
cnet.eval()
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
ssnet.to(device)
|
| 43 |
ssnet.eval()
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
pnet_s.load_state_dict(torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True))
|
| 54 |
-
pnet_s.to(device)
|
| 55 |
-
pnet_s.eval()
|
| 56 |
-
|
| 57 |
-
# Load PNet axial model
|
| 58 |
-
pnet_a = UNet(3, 142)
|
| 59 |
-
pnet_a.load_state_dict(torch.load(os.path.join(model_dir, "PNet", "axial.pth"), weights_only=True))
|
| 60 |
-
pnet_a.to(device)
|
| 61 |
-
pnet_a.eval()
|
| 62 |
-
|
| 63 |
-
# Load HNet coronal model
|
| 64 |
-
hnet_c = UNet(1, 3)
|
| 65 |
-
hnet_c.load_state_dict(torch.load(os.path.join(model_dir, "HNet", "coronal.pth"), weights_only=True))
|
| 66 |
-
hnet_c.to(device)
|
| 67 |
-
hnet_c.eval()
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
# Return all loaded models
|
| 76 |
-
return cnet, ssnet,
|
| 77 |
-
# return cnet, ssnet, pnet_a, hnet_c, hnet_a
|
|
|
|
| 8 |
|
| 9 |
def load_model(model_dir, device):
|
| 10 |
"""
|
| 11 |
+
Load and initialize the pretrained neural network models required for the OpenMAP-T1 pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
This function loads four U-Net–based models from the specified pretrained model directory.
|
| 14 |
+
Each model is moved to the target device (CPU, CUDA, or MPS) and set to evaluation mode.
|
| 15 |
+
|
| 16 |
+
Models loaded:
|
| 17 |
+
1. **CNet (Cropping Network)** — Performs face cropping and brain localization.
|
| 18 |
+
2. **SSNet (Skull Stripping Network)** — Removes non-brain tissues from MRI scans.
|
| 19 |
+
3. **PNet (Parcellation Network)** — Predicts fine-grained anatomical labels across 142 regions.
|
| 20 |
+
4. **HNet (Hemisphere Network)** — Segments the brain into hemispheric masks (left/right/other).
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
opt (argparse.Namespace): Parsed command-line arguments containing the pretrained model directory path (`opt.m`).
|
| 24 |
+
device (torch.device): Target device on which to load models (e.g., `torch.device('cuda')`).
|
| 25 |
|
| 26 |
Returns:
|
| 27 |
+
tuple:
|
| 28 |
+
A tuple containing four initialized and evaluation-ready models:
|
| 29 |
+
(cnet, ssnet, pnet, hnet).
|
| 30 |
"""
|
|
|
|
| 31 |
model_zip_path = os.path.join(model_dir, "model.zip")
|
| 32 |
with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
|
| 33 |
zip_ref.extractall(model_dir)
|
| 34 |
+
# --------------------------
|
| 35 |
+
# Load CNet (Cropping Network)
|
| 36 |
+
# --------------------------
|
| 37 |
+
# Input: 3-channel (neighboring slices), Output: 1-channel binary mask
|
| 38 |
+
cnet = UNet(3, 1)
|
| 39 |
+
print(os.path.join(model_dir, "model", "CNet", "CNet.pth"))
|
| 40 |
+
cnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "CNet", "CNet.pth"), weights_only=True))
|
| 41 |
cnet.to(device)
|
| 42 |
cnet.eval()
|
| 43 |
|
| 44 |
+
# ------------------------------
|
| 45 |
+
# Load SSNet (Skull Stripping Network)
|
| 46 |
+
# ------------------------------
|
| 47 |
+
# Input: 3-channel (neighboring slices), Output: 1-channel brain mask
|
| 48 |
+
ssnet = UNet(3, 1)
|
| 49 |
+
ssnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "SSNet", "SSNet.pth"), weights_only=True))
|
| 50 |
ssnet.to(device)
|
| 51 |
ssnet.eval()
|
| 52 |
|
| 53 |
+
# -----------------------------
|
| 54 |
+
# Load PNet (Parcellation Network)
|
| 55 |
+
# -----------------------------
|
| 56 |
+
# Input: 4 channels (multi-modal or augmented context), Output: 142 anatomical regions
|
| 57 |
+
pnet = UNet(4, 142)
|
| 58 |
+
pnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "PNet", "PNet.pth"), weights_only=True))
|
| 59 |
+
pnet.to(device)
|
| 60 |
+
pnet.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# -----------------------------
|
| 63 |
+
# Load HNet (Hemisphere Network)
|
| 64 |
+
# -----------------------------
|
| 65 |
+
# Input: 3 channels, Output: 3-class hemisphere mask (left, right, background)
|
| 66 |
+
hnet = UNet(3, 3)
|
| 67 |
+
hnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "HNet", "HNet.pth"), weights_only=True))
|
| 68 |
+
hnet.to(device)
|
| 69 |
+
hnet.eval()
|
| 70 |
|
| 71 |
+
# Return all loaded, device-initialized, and evaluation-ready models
|
| 72 |
+
return cnet, ssnet, pnet, hnet
|
|
|
src/utils/parcellation.py
CHANGED
|
@@ -4,102 +4,133 @@ import torch
|
|
| 4 |
from utils.functions import normalize
|
| 5 |
|
| 6 |
|
| 7 |
-
def parcellate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
Args:
|
| 12 |
-
voxel (numpy.ndarray):
|
| 13 |
-
model (torch.nn.Module): The
|
| 14 |
-
device (torch.device):
|
| 15 |
-
mode (str): The
|
|
|
|
| 16 |
|
| 17 |
Returns:
|
| 18 |
-
torch.Tensor:
|
|
|
|
| 19 |
"""
|
| 20 |
-
if mode == "c":
|
| 21 |
-
stack = (224, 192, 192)
|
| 22 |
-
elif mode == "s":
|
| 23 |
-
stack = (192, 224, 192)
|
| 24 |
-
elif mode == "a":
|
| 25 |
-
stack = (192, 224, 192)
|
| 26 |
-
|
| 27 |
-
# Set the model to evaluation mode
|
| 28 |
model.eval()
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
with torch.inference_mode():
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
| 46 |
-
x_out = torch.softmax(model(image), 1).detach().cpu()
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
"""
|
| 57 |
-
Perform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
Args:
|
| 60 |
-
voxel (
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
pnet_a (torch.nn.Module): The neural network model for axial view parcellation.
|
| 64 |
-
device (torch.device): The device (CPU or GPU) to perform computations on.
|
| 65 |
|
| 66 |
Returns:
|
| 67 |
-
numpy.ndarray:
|
| 68 |
"""
|
| 69 |
-
# Normalize
|
| 70 |
-
voxel = normalize(voxel)
|
| 71 |
|
| 72 |
-
# Prepare
|
| 73 |
coronal = voxel.transpose(1, 2, 0)
|
| 74 |
sagittal = voxel
|
| 75 |
axial = voxel.transpose(2, 1, 0)
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
torch.cuda.empty_cache()
|
| 81 |
-
print("Parcellation for coronal view completed.")
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
|
| 85 |
-
|
|
|
|
| 86 |
torch.cuda.empty_cache()
|
| 87 |
-
print("Parcellation for sagittal view completed.")
|
| 88 |
|
| 89 |
-
#
|
| 90 |
out_e = out_c + out_s
|
| 91 |
del out_c, out_s
|
| 92 |
-
|
| 93 |
-
#
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
torch.cuda.empty_cache()
|
| 96 |
-
print("Parcellation for axial view completed.")
|
| 97 |
|
| 98 |
-
# Combine
|
| 99 |
-
out_e =
|
| 100 |
del out_a
|
| 101 |
|
| 102 |
-
#
|
| 103 |
parcellated = torch.argmax(out_e, 0).numpy()
|
| 104 |
|
| 105 |
return parcellated
|
|
|
|
| 4 |
from utils.functions import normalize
|
| 5 |
|
| 6 |
|
| 7 |
+
def parcellate(
|
| 8 |
+
voxel: np.ndarray,
|
| 9 |
+
model: torch.nn.Module,
|
| 10 |
+
device: torch.device,
|
| 11 |
+
mode: str,
|
| 12 |
+
n_classes: int = 142,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
"""
|
| 15 |
+
Perform 2.5D neural network inference for brain parcellation along a specific anatomical plane.
|
| 16 |
+
|
| 17 |
+
The function processes a 3D volume slice by slice using a 3-slice context window (previous,
|
| 18 |
+
current, next). An additional constant-valued fourth channel encodes the orientation mode
|
| 19 |
+
(Axial, Coronal, or Sagittal), allowing the network to distinguish the processing plane.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
+
voxel (numpy.ndarray): 3D voxel data of shape (N, 224, 224), representing a single anatomical view.
|
| 23 |
+
model (torch.nn.Module): The trained PyTorch parcellation model.
|
| 24 |
+
device (torch.device): Device for inference (CPU, CUDA, or MPS).
|
| 25 |
+
mode (str): The anatomical plane used for inference. Must be one of {'Axial', 'Coronal', 'Sagittal'}.
|
| 26 |
+
n_classes (int, optional): Number of output anatomical labels. Defaults to 142.
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
+
torch.Tensor: A tensor of shape (224, n_classes, 224, 224) containing softmax probabilities
|
| 30 |
+
for each class at each voxel position.
|
| 31 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
model.eval()
|
| 33 |
+
voxel = voxel.astype(np.float32)
|
| 34 |
+
|
| 35 |
+
# Set the constant value for the 4th channel to encode plane orientation
|
| 36 |
+
if mode == "Axial":
|
| 37 |
+
section_value = 1.0
|
| 38 |
+
elif mode == "Coronal":
|
| 39 |
+
section_value = -1.0
|
| 40 |
+
elif mode == "Sagittal":
|
| 41 |
+
section_value = 0.0
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError("mode must be one of {'Axial','Coronal','Sagittal'}")
|
| 44 |
+
|
| 45 |
+
# Pad one slice on both ends to safely allow 3-slice context
|
| 46 |
+
voxel_pad = np.pad(
|
| 47 |
+
voxel,
|
| 48 |
+
[(1, 1), (0, 0), (0, 0)],
|
| 49 |
+
mode="constant",
|
| 50 |
+
constant_values=float(voxel.min()),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Initialize a container for the network outputs (CPU for accumulation)
|
| 54 |
+
box = torch.empty((224, n_classes, 224, 224), dtype=torch.float32, device="cpu")
|
| 55 |
+
|
| 56 |
+
# Inference loop: iterate over slices and feed triplets to the model
|
| 57 |
with torch.inference_mode():
|
| 58 |
+
for i in range(1, 225):
|
| 59 |
+
prev_ = voxel_pad[i - 1]
|
| 60 |
+
curr_ = voxel_pad[i]
|
| 61 |
+
next_ = voxel_pad[i + 1]
|
| 62 |
|
| 63 |
+
# Build 4-channel input (3 context slices + orientation encoding)
|
| 64 |
+
four_ch = np.empty((4, 224, 224), dtype=np.float32)
|
| 65 |
+
four_ch[0] = prev_
|
| 66 |
+
four_ch[1] = curr_
|
| 67 |
+
four_ch[2] = next_
|
| 68 |
+
four_ch[3].fill(section_value)
|
| 69 |
|
| 70 |
+
inp = torch.from_numpy(four_ch).unsqueeze(0).to(device)
|
|
|
|
| 71 |
|
| 72 |
+
# Model inference with softmax normalization
|
| 73 |
+
logits = model(inp)
|
| 74 |
+
probs = torch.softmax(logits, dim=1)
|
| 75 |
|
| 76 |
+
# Store softmax output for this slice
|
| 77 |
+
box[i - 1] = probs
|
| 78 |
|
| 79 |
+
return box
|
| 80 |
|
| 81 |
+
|
| 82 |
+
def parcellation(voxel, pnet, device):
|
| 83 |
"""
|
| 84 |
+
Perform full 3D brain parcellation by aggregating predictions across multiple anatomical planes.
|
| 85 |
+
|
| 86 |
+
The function normalizes the input MRI volume, generates three differently oriented representations
|
| 87 |
+
(coronal, sagittal, axial), and performs 2.5D inference on each using a shared parcellation network.
|
| 88 |
+
The resulting probability maps are fused by summation and converted into a discrete segmentation map
|
| 89 |
+
via argmax over anatomical classes.
|
| 90 |
|
| 91 |
Args:
|
| 92 |
+
voxel (numpy.ndarray): Input 3D brain volume (float array).
|
| 93 |
+
pnet (torch.nn.Module): Trained parcellation network (U-Net or similar architecture).
|
| 94 |
+
device (torch.device): Device on which inference will be executed (CPU or GPU).
|
|
|
|
|
|
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
+
numpy.ndarray: Final 3D parcellation map (integer label image) with voxel-wise anatomical labels.
|
| 98 |
"""
|
| 99 |
+
# Normalize input intensities for network inference
|
| 100 |
+
voxel = normalize(voxel, "parcellation")
|
| 101 |
|
| 102 |
+
# Prepare three anatomical views for 2.5D inference
|
| 103 |
coronal = voxel.transpose(1, 2, 0)
|
| 104 |
sagittal = voxel
|
| 105 |
axial = voxel.transpose(2, 1, 0)
|
| 106 |
|
| 107 |
+
# ------------------------
|
| 108 |
+
# Coronal view inference
|
| 109 |
+
# ------------------------
|
| 110 |
+
out_c = parcellate(coronal, pnet, device, "Coronal").permute(1, 3, 0, 2)
|
| 111 |
torch.cuda.empty_cache()
|
|
|
|
| 112 |
|
| 113 |
+
# ------------------------
|
| 114 |
+
# Sagittal view inference
|
| 115 |
+
# ------------------------
|
| 116 |
+
out_s = parcellate(sagittal, pnet, device, "Sagittal").permute(1, 0, 2, 3)
|
| 117 |
torch.cuda.empty_cache()
|
|
|
|
| 118 |
|
| 119 |
+
# Fuse coronal and sagittal predictions
|
| 120 |
out_e = out_c + out_s
|
| 121 |
del out_c, out_s
|
| 122 |
+
|
| 123 |
+
# ------------------------
|
| 124 |
+
# Axial view inference
|
| 125 |
+
# ------------------------
|
| 126 |
+
out_a = parcellate(axial, pnet, device, "Axial").permute(1, 3, 2, 0)
|
| 127 |
torch.cuda.empty_cache()
|
|
|
|
| 128 |
|
| 129 |
+
# Combine outputs from all three anatomical orientations
|
| 130 |
+
out_e = out_e + out_a
|
| 131 |
del out_a
|
| 132 |
|
| 133 |
+
# Convert probability maps to final integer labels
|
| 134 |
parcellated = torch.argmax(out_e, 0).numpy()
|
| 135 |
|
| 136 |
return parcellated
|
src/utils/stripping.py
CHANGED
|
@@ -7,96 +7,92 @@ from utils.functions import normalize, reimburse_conform
|
|
| 7 |
|
| 8 |
def strip(voxel, model, device):
|
| 9 |
"""
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
Args:
|
| 13 |
-
voxel (numpy.ndarray):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
Returns:
|
| 18 |
-
torch.Tensor: A
|
|
|
|
| 19 |
"""
|
| 20 |
-
# Set the model to evaluation mode
|
| 21 |
model.eval()
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
# Initialize an empty tensor to store the output
|
| 26 |
-
output = torch.zeros(256, 256, 256).to(device)
|
| 27 |
-
|
| 28 |
-
# Iterate over each slice in the voxel data
|
| 29 |
-
for i, v in enumerate(voxel):
|
| 30 |
-
# Reshape the slice to match the model's input dimensions
|
| 31 |
-
image = v.reshape(1, 1, 256, 256)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# Apply the model to the input image and apply the sigmoid activation function
|
| 37 |
-
x_out = torch.sigmoid(model(image)).detach()
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
return
|
| 44 |
|
| 45 |
|
| 46 |
-
def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
|
| 47 |
"""
|
| 48 |
-
Perform brain stripping
|
| 49 |
|
| 50 |
-
This function
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
Args:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
Returns:
|
| 61 |
-
|
| 62 |
-
- stripped (numpy.ndarray): The stripped and processed brain image.
|
| 63 |
-
- (xd, yd, zd) (tuple of int): The shifts applied to center the brain image in the x, y, and z directions.
|
| 64 |
"""
|
| 65 |
-
#
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
# Prepare
|
| 69 |
coronal = voxel.transpose(1, 2, 0)
|
| 70 |
sagittal = voxel
|
| 71 |
axial = voxel.transpose(2, 1, 0)
|
| 72 |
|
| 73 |
-
# Apply the
|
| 74 |
-
out_c = strip(coronal, ssnet, device).permute(2, 0, 1)
|
| 75 |
-
out_s = strip(sagittal, ssnet, device)
|
| 76 |
-
out_a = strip(axial, ssnet, device).permute(2, 1, 0)
|
| 77 |
|
| 78 |
-
#
|
| 79 |
out_e = ((out_c + out_s + out_a) / 3) > 0.5
|
| 80 |
out_e = out_e.cpu().numpy()
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
stripped =
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# Calculate the shifts needed to center the brain image
|
| 91 |
-
xd = 128 - x
|
| 92 |
-
yd = 120 - y
|
| 93 |
-
zd = 128 - z
|
| 94 |
|
| 95 |
-
|
| 96 |
-
stripped = np.roll(stripped, (xd, yd, zd), axis=(0, 1, 2))
|
| 97 |
-
|
| 98 |
-
# Crop the centered brain image
|
| 99 |
-
stripped = stripped[32:-32, 16:-16, 32:-32]
|
| 100 |
|
| 101 |
-
|
| 102 |
-
return stripped, (xd, yd, zd), out_filename
|
|
|
|
| 7 |
|
| 8 |
def strip(voxel, model, device):
|
| 9 |
"""
|
| 10 |
+
Perform slice-wise inference using the brain stripping model.
|
| 11 |
+
|
| 12 |
+
This function processes the input 3D volume slice by slice (along the first axis),
|
| 13 |
+
using a three-slice context window for each prediction. The output is a 3D mask
|
| 14 |
+
representing the brain region.
|
| 15 |
|
| 16 |
Args:
|
| 17 |
+
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224), typically
|
| 18 |
+
a single anatomical orientation (e.g., coronal or sagittal view).
|
| 19 |
+
model (torch.nn.Module): The trained PyTorch brain stripping model.
|
| 20 |
+
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
+
torch.Tensor: A tensor of shape (224, 224, 224) representing the predicted
|
| 24 |
+
binary brain mask.
|
| 25 |
"""
|
|
|
|
| 26 |
model.eval()
|
| 27 |
|
| 28 |
+
# Pad one slice on both ends to ensure valid 3-slice context at the boundaries
|
| 29 |
+
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
with torch.inference_mode():
|
| 32 |
+
box = torch.zeros(224, 224, 224)
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# Perform model inference for each slice using a 3-slice context
|
| 35 |
+
for i in range(1, 225):
|
| 36 |
+
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
|
| 37 |
+
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
|
| 38 |
+
x_out = torch.sigmoid(model(image)).detach().cpu()
|
| 39 |
+
box[i - 1] = x_out
|
| 40 |
|
| 41 |
+
# Return as a 3D mask tensor
|
| 42 |
+
return box.reshape(224, 224, 224)
|
| 43 |
|
| 44 |
|
| 45 |
+
def stripping(output_dir, basename, voxel, odata, data, ssnet, shift, device):
|
| 46 |
"""
|
| 47 |
+
Perform full 3D brain stripping using a deep learning model.
|
| 48 |
|
| 49 |
+
This function applies a neural network-based skull-stripping algorithm to
|
| 50 |
+
isolate the brain region from a 3D MRI volume. It performs inference along
|
| 51 |
+
three anatomical orientations—coronal, sagittal, and axial—and fuses the
|
| 52 |
+
predictions to obtain a robust binary mask. The mask is then applied to the
|
| 53 |
+
input image, recentred, and saved.
|
| 54 |
|
| 55 |
Args:
|
| 56 |
+
output_dir (str): Directory where intermediate and final results will be saved.
|
| 57 |
+
basename (str): Base name of the current case (used for file naming).
|
| 58 |
+
voxel (numpy.ndarray): Input 3D voxel data (preprocessed MRI image).
|
| 59 |
+
odata (nibabel.Nifti1Image): Original NIfTI image before preprocessing.
|
| 60 |
+
data (nibabel.Nifti1Image): Preprocessed NIfTI image used for model input.
|
| 61 |
+
ssnet (torch.nn.Module): Trained brain stripping network.
|
| 62 |
+
shift (tuple[int, int, int]): The (x, y, z) offsets applied previously during cropping.
|
| 63 |
+
device (torch.device): Device used for inference (CPU, CUDA, or MPS).
|
| 64 |
|
| 65 |
Returns:
|
| 66 |
+
numpy.ndarray: The skull-stripped 3D brain volume.
|
|
|
|
|
|
|
| 67 |
"""
|
| 68 |
+
# Preserve original intensity data for later restoration
|
| 69 |
+
original = voxel.copy()
|
| 70 |
+
|
| 71 |
+
# Normalize the voxel intensities for model input
|
| 72 |
+
voxel = normalize(voxel, "stripping")
|
| 73 |
|
| 74 |
+
# Prepare data in three anatomical orientations
|
| 75 |
coronal = voxel.transpose(1, 2, 0)
|
| 76 |
sagittal = voxel
|
| 77 |
axial = voxel.transpose(2, 1, 0)
|
| 78 |
|
| 79 |
+
# Apply the model along each anatomical plane
|
| 80 |
+
out_c = strip(coronal, ssnet, device).permute(2, 0, 1) # coronal → native orientation
|
| 81 |
+
out_s = strip(sagittal, ssnet, device) # sagittal
|
| 82 |
+
out_a = strip(axial, ssnet, device).permute(2, 1, 0) # axial → native orientation
|
| 83 |
|
| 84 |
+
# Fuse predictions by averaging across the three planes and apply threshold
|
| 85 |
out_e = ((out_c + out_s + out_a) / 3) > 0.5
|
| 86 |
out_e = out_e.cpu().numpy()
|
| 87 |
|
| 88 |
+
# Apply the binary mask to extract the brain region
|
| 89 |
+
stripped = original * out_e
|
| 90 |
|
| 91 |
+
# Restore the mask to the original conformed geometry
|
| 92 |
+
# Pad to original full size and reverse the previously applied shift
|
| 93 |
+
out_e = np.pad(out_e, [(16, 16), (16, 16), (16, 16)], "constant", constant_values=0)
|
| 94 |
+
out_e = np.roll(out_e, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
return stripped, out_filename
|
|
|