Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from utils.functions import normalize | |
| def parcellate( | |
| voxel: np.ndarray, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| mode: str, | |
| n_classes: int = 142, | |
| ) -> torch.Tensor: | |
| """ | |
| Perform 2.5D neural network inference for brain parcellation along a specific anatomical plane. | |
| The function processes a 3D volume slice by slice using a 3-slice context window (previous, | |
| current, next). An additional constant-valued fourth channel encodes the orientation mode | |
| (Axial, Coronal, or Sagittal), allowing the network to distinguish the processing plane. | |
| Args: | |
| voxel (numpy.ndarray): 3D voxel data of shape (N, 224, 224), representing a single anatomical view. | |
| model (torch.nn.Module): The trained PyTorch parcellation model. | |
| device (torch.device): Device for inference (CPU, CUDA, or MPS). | |
| mode (str): The anatomical plane used for inference. Must be one of {'Axial', 'Coronal', 'Sagittal'}. | |
| n_classes (int, optional): Number of output anatomical labels. Defaults to 142. | |
| Returns: | |
| torch.Tensor: A tensor of shape (224, n_classes, 224, 224) containing softmax probabilities | |
| for each class at each voxel position. | |
| """ | |
| model.eval() | |
| voxel = voxel.astype(np.float32) | |
| # Set the constant value for the 4th channel to encode plane orientation | |
| if mode == "Axial": | |
| section_value = 1.0 | |
| elif mode == "Coronal": | |
| section_value = -1.0 | |
| elif mode == "Sagittal": | |
| section_value = 0.0 | |
| else: | |
| raise ValueError("mode must be one of {'Axial','Coronal','Sagittal'}") | |
| # Pad one slice on both ends to safely allow 3-slice context | |
| voxel_pad = np.pad( | |
| voxel, | |
| [(1, 1), (0, 0), (0, 0)], | |
| mode="constant", | |
| constant_values=float(voxel.min()), | |
| ) | |
| # Initialize a container for the network outputs (CPU for accumulation) | |
| box = torch.empty((224, n_classes, 224, 224), dtype=torch.float32, device="cpu") | |
| # Inference loop: iterate over slices and feed triplets to the model | |
| with torch.inference_mode(): | |
| for i in range(1, 225): | |
| prev_ = voxel_pad[i - 1] | |
| curr_ = voxel_pad[i] | |
| next_ = voxel_pad[i + 1] | |
| # Build 4-channel input (3 context slices + orientation encoding) | |
| four_ch = np.empty((4, 224, 224), dtype=np.float32) | |
| four_ch[0] = prev_ | |
| four_ch[1] = curr_ | |
| four_ch[2] = next_ | |
| four_ch[3].fill(section_value) | |
| inp = torch.from_numpy(four_ch).unsqueeze(0).to(device) | |
| # Model inference with softmax normalization | |
| logits = model(inp) | |
| probs = torch.softmax(logits, dim=1) | |
| # Store softmax output for this slice | |
| box[i - 1] = probs | |
| return box | |
| def parcellation(voxel, pnet, device): | |
| """ | |
| Perform full 3D brain parcellation by aggregating predictions across multiple anatomical planes. | |
| The function normalizes the input MRI volume, generates three differently oriented representations | |
| (coronal, sagittal, axial), and performs 2.5D inference on each using a shared parcellation network. | |
| The resulting probability maps are fused by summation and converted into a discrete segmentation map | |
| via argmax over anatomical classes. | |
| Args: | |
| voxel (numpy.ndarray): Input 3D brain volume (float array). | |
| pnet (torch.nn.Module): Trained parcellation network (U-Net or similar architecture). | |
| device (torch.device): Device on which inference will be executed (CPU or GPU). | |
| Returns: | |
| numpy.ndarray: Final 3D parcellation map (integer label image) with voxel-wise anatomical labels. | |
| """ | |
| # Normalize input intensities for network inference | |
| voxel = normalize(voxel, "parcellation") | |
| # Prepare three anatomical views for 2.5D inference | |
| coronal = voxel.transpose(1, 2, 0) | |
| sagittal = voxel | |
| axial = voxel.transpose(2, 1, 0) | |
| # ------------------------ | |
| # Coronal view inference | |
| # ------------------------ | |
| out_c = parcellate(coronal, pnet, device, "Coronal").permute(1, 3, 0, 2) | |
| torch.cuda.empty_cache() | |
| # ------------------------ | |
| # Sagittal view inference | |
| # ------------------------ | |
| out_s = parcellate(sagittal, pnet, device, "Sagittal").permute(1, 0, 2, 3) | |
| torch.cuda.empty_cache() | |
| # Fuse coronal and sagittal predictions | |
| out_e = out_c + out_s | |
| del out_c, out_s | |
| # ------------------------ | |
| # Axial view inference | |
| # ------------------------ | |
| out_a = parcellate(axial, pnet, device, "Axial").permute(1, 3, 2, 0) | |
| torch.cuda.empty_cache() | |
| # Combine outputs from all three anatomical orientations | |
| out_e = out_e + out_a | |
| del out_a | |
| # Convert probability maps to final integer labels | |
| parcellated = torch.argmax(out_e, 0).numpy() | |
| return parcellated | |