OpenMAP-T1 / src /utils /hemisphere.py
西牧慧
V3
c9d94d6
import torch
from scipy.ndimage import binary_dilation
import numpy as np
from utils.functions import normalize
def separate(voxel, model, device):
"""
Perform slice-wise inference using a hemisphere separation model.
This function runs a 2.5D neural network across slices of a 3D input volume.
Each slice is processed in the context of its immediate neighbors (previous
and next slices) to improve spatial coherence. The model outputs a
three-class probability map distinguishing background, left hemisphere,
and right hemisphere regions.
Args:
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224).
model (torch.nn.Module): Trained hemisphere segmentation model (U-Net architecture).
device (torch.device): Computational device (CPU, CUDA, or MPS).
Returns:
torch.Tensor: A tensor of shape (224, 3, 224, 224) containing softmax
probabilities for each class at every voxel.
"""
model.eval()
# Pad the volume by one slice on both ends to provide full 3-slice context
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
with torch.inference_mode():
# Output tensor for storing model predictions (class probabilities)
box = torch.zeros(224, 3, 224, 224)
# Iterate slice-by-slice along the first axis
for i in range(1, 225):
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
# Model inference with softmax normalization across classes
x_out = torch.softmax(model(image), dim=1).detach().cpu()
box[i - 1] = x_out
# Return complete 3D probability map
return box.reshape(224, 3, 224, 224)
def hemisphere(voxel, hnet, device):
"""
Perform hemisphere separation on a brain MRI volume using a deep learning model.
The function predicts left and right hemisphere regions from a normalized
3D MRI volume using multi-view inference (coronal and transverse planes).
Predictions from both orientations are fused to improve robustness. The final
label map is post-processed using binary dilation to smooth and expand hemisphere
boundaries, ensuring anatomical continuity.
Args:
voxel (numpy.ndarray): Input 3D brain volume to be separated into hemispheres.
hnet (torch.nn.Module): Trained hemisphere segmentation model.
device (torch.device): Target device for computation (e.g., 'cuda', 'cpu').
Returns:
numpy.ndarray: A 3D integer array representing the hemisphere mask:
- 0: Background
- 1: Left hemisphere
- 2: Right hemisphere
"""
# Normalize voxel intensities for inference
voxel = normalize(voxel, "hemisphere")
# Prepare different anatomical orientations for inference
coronal = voxel.transpose(1, 2, 0)
transverse = voxel.transpose(2, 1, 0)
# Perform inference for both coronal and transverse orientations
out_c = separate(coronal, hnet, device).permute(1, 3, 0, 2)
out_a = separate(transverse, hnet, device).permute(1, 3, 2, 0)
# Fuse both outputs by summing class probabilities
out_e = out_c + out_a
# Determine final class labels (0, 1, or 2) by selecting the most probable class
out_e = torch.argmax(out_e, dim=0).cpu().numpy()
# Release any residual GPU memory
torch.cuda.empty_cache()
# --------------------------
# Post-processing step: binary dilation
# --------------------------
# First, dilate the left hemisphere (class 1)
dilated_mask_1 = binary_dilation(out_e == 1, iterations=1).astype("int16")
# Preserve right hemisphere voxels from the original prediction
dilated_mask_1[out_e == 2] = 2
# Then, dilate the right hemisphere (class 2) symmetrically
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=1).astype("int16") * 2
# Restore left hemisphere voxels to prevent overwriting
dilated_mask_2[dilated_mask_1 == 1] = 1
# Return the final dilated and fused hemisphere mask
return dilated_mask_2