Spaces:
Sleeping
Sleeping
File size: 4,178 Bytes
03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
from scipy.ndimage import binary_dilation
import numpy as np
from utils.functions import normalize
def separate(voxel, model, device):
"""
Perform slice-wise inference using a hemisphere separation model.
This function runs a 2.5D neural network across slices of a 3D input volume.
Each slice is processed in the context of its immediate neighbors (previous
and next slices) to improve spatial coherence. The model outputs a
three-class probability map distinguishing background, left hemisphere,
and right hemisphere regions.
Args:
voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224).
model (torch.nn.Module): Trained hemisphere segmentation model (U-Net architecture).
device (torch.device): Computational device (CPU, CUDA, or MPS).
Returns:
torch.Tensor: A tensor of shape (224, 3, 224, 224) containing softmax
probabilities for each class at every voxel.
"""
model.eval()
# Pad the volume by one slice on both ends to provide full 3-slice context
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
with torch.inference_mode():
# Output tensor for storing model predictions (class probabilities)
box = torch.zeros(224, 3, 224, 224)
# Iterate slice-by-slice along the first axis
for i in range(1, 225):
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
# Model inference with softmax normalization across classes
x_out = torch.softmax(model(image), dim=1).detach().cpu()
box[i - 1] = x_out
# Return complete 3D probability map
return box.reshape(224, 3, 224, 224)
def hemisphere(voxel, hnet, device):
"""
Perform hemisphere separation on a brain MRI volume using a deep learning model.
The function predicts left and right hemisphere regions from a normalized
3D MRI volume using multi-view inference (coronal and transverse planes).
Predictions from both orientations are fused to improve robustness. The final
label map is post-processed using binary dilation to smooth and expand hemisphere
boundaries, ensuring anatomical continuity.
Args:
voxel (numpy.ndarray): Input 3D brain volume to be separated into hemispheres.
hnet (torch.nn.Module): Trained hemisphere segmentation model.
device (torch.device): Target device for computation (e.g., 'cuda', 'cpu').
Returns:
numpy.ndarray: A 3D integer array representing the hemisphere mask:
- 0: Background
- 1: Left hemisphere
- 2: Right hemisphere
"""
# Normalize voxel intensities for inference
voxel = normalize(voxel, "hemisphere")
# Prepare different anatomical orientations for inference
coronal = voxel.transpose(1, 2, 0)
transverse = voxel.transpose(2, 1, 0)
# Perform inference for both coronal and transverse orientations
out_c = separate(coronal, hnet, device).permute(1, 3, 0, 2)
out_a = separate(transverse, hnet, device).permute(1, 3, 2, 0)
# Fuse both outputs by summing class probabilities
out_e = out_c + out_a
# Determine final class labels (0, 1, or 2) by selecting the most probable class
out_e = torch.argmax(out_e, dim=0).cpu().numpy()
# Release any residual GPU memory
torch.cuda.empty_cache()
# --------------------------
# Post-processing step: binary dilation
# --------------------------
# First, dilate the left hemisphere (class 1)
dilated_mask_1 = binary_dilation(out_e == 1, iterations=1).astype("int16")
# Preserve right hemisphere voxels from the original prediction
dilated_mask_1[out_e == 2] = 2
# Then, dilate the right hemisphere (class 2) symmetrically
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=1).astype("int16") * 2
# Restore left hemisphere voxels to prevent overwriting
dilated_mask_2[dilated_mask_1 == 1] = 1
# Return the final dilated and fused hemisphere mask
return dilated_mask_2
|