西牧慧 commited on
Commit
c9d94d6
·
1 Parent(s): 6275589
src/parcellation.py CHANGED
@@ -94,10 +94,8 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
94
 
95
  # Load the pre-trained models from the fixed "model/" folder
96
  print("Loading models...")
97
- cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
98
  print("Models loaded successfully.")
99
- # cnet, ssnet, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
100
-
101
  # --- Processing Flow (based on the original parcellation.py) ---
102
  # 1. Load the input image, convert to canonical orientation, and remove extra dimensions
103
  print("Loading and preprocessing the input image...")
@@ -115,7 +113,7 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
115
 
116
  # 3. Cropping
117
  print("Cropping the input image...")
118
- cropped, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
119
  print("Cropping completed.")
120
  if only_face_cropping:
121
  pass
@@ -123,18 +121,18 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
123
  else:
124
  # 4. Skull stripping
125
  print("Performing skull stripping...")
126
- stripped, shift, out_filename = stripping(opt.o, basename, cropped, odata, data, ssnet, device)
127
  print("Skull stripping completed.")
128
  if only_skull_stripping:
129
  pass
130
  else:
131
  # 5. Parcellation
132
  print("Starting parcellation...")
133
- parcellated = parcellation(stripped, pnet_c, pnet_s, pnet_a, device)
134
  print("Parcellation completed.")
135
  # 6. Separate into hemispheres
136
  print("Separating hemispheres...")
137
- separated = hemisphere(stripped, hnet_c, hnet_a, device)
138
  print("Hemispheres separated.")
139
  # 7. Postprocessing
140
  print("Postprocessing the parcellated data...")
 
94
 
95
  # Load the pre-trained models from the fixed "model/" folder
96
  print("Loading models...")
97
+ cnet, ssnet, pnet, hnet = load_model("model/", device=device)
98
  print("Models loaded successfully.")
 
 
99
  # --- Processing Flow (based on the original parcellation.py) ---
100
  # 1. Load the input image, convert to canonical orientation, and remove extra dimensions
101
  print("Loading and preprocessing the input image...")
 
113
 
114
  # 3. Cropping
115
  print("Cropping the input image...")
116
+ cropped, shift, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
117
  print("Cropping completed.")
118
  if only_face_cropping:
119
  pass
 
121
  else:
122
  # 4. Skull stripping
123
  print("Performing skull stripping...")
124
+ stripped, out_filename = stripping(opt.o, basename, cropped, odata, data, ssnet, shift, device)
125
  print("Skull stripping completed.")
126
  if only_skull_stripping:
127
  pass
128
  else:
129
  # 5. Parcellation
130
  print("Starting parcellation...")
131
+ parcellated = parcellation(stripped, pnet, device)
132
  print("Parcellation completed.")
133
  # 6. Separate into hemispheres
134
  print("Separating hemispheres...")
135
+ separated = hemisphere(stripped, hnet, device)
136
  print("Hemispheres separated.")
137
  # 7. Postprocessing
138
  print("Postprocessing the parcellated data...")
src/utils/cropping.py CHANGED
@@ -1,31 +1,44 @@
1
  import numpy as np
2
  import torch
3
  from scipy.ndimage import binary_closing
 
4
 
5
  from utils.functions import normalize, reimburse_conform
6
 
7
 
8
  def crop(voxel, model, device):
9
  """
10
- Crops the given voxel data using the provided model and device.
 
 
 
 
11
 
12
  Args:
13
- voxel (numpy.ndarray): The input voxel data to be cropped, expected to be of shape (N, 256, 256).
14
- model (torch.nn.Module): The PyTorch model used for cropping.
15
- device (torch.device): The device (CPU or GPU) on which the computation will be performed.
 
 
16
 
17
  Returns:
18
- torch.Tensor: The cropped output tensor of shape (256, 256, 256).
19
  """
 
 
20
  model.eval()
 
21
  with torch.inference_mode():
22
- output = torch.zeros(256, 256, 256).to(device)
23
- for i, v in enumerate(voxel):
24
- image = v.reshape(1, 1, 256, 256)
25
- image = torch.tensor(image).to(device)
26
- x_out = torch.sigmoid(model(image)).detach()
27
- output[i] = x_out
28
- return output.reshape(256, 256, 256)
 
 
 
29
 
30
 
31
  def closing(voxel):
@@ -59,7 +72,7 @@ def cropping(output_dir, basename, odata, data, cnet, device):
59
  numpy.ndarray: The cropped medical imaging data.
60
  """
61
  voxel = data.get_fdata().astype("float32")
62
- voxel = normalize(voxel)
63
 
64
  coronal = voxel.transpose(1, 2, 0)
65
  sagittal = voxel
@@ -72,4 +85,18 @@ def cropping(output_dir, basename, odata, data, cnet, device):
72
 
73
  out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
74
 
75
- return cropped, out_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  from scipy.ndimage import binary_closing
4
+ from scipy import ndimage
5
 
6
  from utils.functions import normalize, reimburse_conform
7
 
8
 
9
  def crop(voxel, model, device):
10
  """
11
+ Apply a neural network-based cropping operation on 3D voxel data.
12
+
13
+ This function slides a 3-slice window across the input volume along the first axis
14
+ and predicts a binary mask for each slice using the given model. The outputs are then
15
+ aggregated into a full 3D prediction volume.
16
 
17
  Args:
18
+ voxel (numpy.ndarray): Input 3D array of shape (N, 256, 256). The first dimension
19
+ corresponds to the slice index (typically coronal or sagittal).
20
+ model (torch.nn.Module): The trained PyTorch model that predicts binary masks
21
+ for each input slice triplet.
22
+ device (torch.device): The device (CPU, CUDA, or MPS) on which inference will run.
23
 
24
  Returns:
25
+ torch.Tensor: The predicted 3D binary mask of shape (256, 256, 256).
26
  """
27
+ # Pad the input volume by one slice at each end to allow 3-slice context
28
+ voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
29
  model.eval()
30
+
31
  with torch.inference_mode():
32
+ box = torch.zeros(256, 256, 256)
33
+
34
+ # Iterate through each target slice and predict using a 3-slice input context
35
+ for i in range(1, 257):
36
+ image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
37
+ image = torch.tensor(image.reshape(1, 3, 256, 256)).to(device)
38
+ x_out = torch.sigmoid(model(image)).detach().cpu()
39
+ box[i - 1] = x_out
40
+
41
+ return box.reshape(256, 256, 256)
42
 
43
 
44
  def closing(voxel):
 
72
  numpy.ndarray: The cropped medical imaging data.
73
  """
74
  voxel = data.get_fdata().astype("float32")
75
+ voxel = normalize(voxel, "cropping")
76
 
77
  coronal = voxel.transpose(1, 2, 0)
78
  sagittal = voxel
 
85
 
86
  out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
87
 
88
+ # Compute center of mass for the masked brain
89
+ x, y, z = map(int, ndimage.center_of_mass(out_e))
90
+
91
+ # Compute shifts required to center the brain
92
+ xd = 128 - x
93
+ yd = 120 - y
94
+ zd = 128 - z
95
+
96
+ # Translate (roll) the image to center the brain region
97
+ cropped = np.roll(cropped, (xd, yd, zd), axis=(0, 1, 2))
98
+
99
+ # Crop out boundary padding to reduce size and focus on the centered brain
100
+ cropped = cropped[16:-16, 16:-16, 16:-16]
101
+
102
+ return cropped, (xd, yd, zd), out_filename
src/utils/functions.py CHANGED
@@ -5,9 +5,13 @@ import numpy as np
5
  from nibabel import processing
6
 
7
 
8
- def normalize(voxel):
9
  nonzero = voxel[voxel > 0]
10
- voxel = np.clip(voxel, 0, np.mean(nonzero) + np.std(nonzero) * 2)
 
 
 
 
11
  voxel = (voxel - np.min(voxel)) / (np.max(voxel) - np.min(voxel))
12
  voxel = (voxel * 2) - 1
13
  return voxel.astype("float32")
 
5
  from nibabel import processing
6
 
7
 
8
+ def normalize(voxel, mode):
9
  nonzero = voxel[voxel > 0]
10
+ if mode in ["cropping", "stripping"]:
11
+ clip = 2
12
+ elif mode in ["parcellation", "hemisphere"]:
13
+ clip = 3
14
+ voxel = np.clip(voxel, 0, np.mean(nonzero) + np.std(nonzero) * clip)
15
  voxel = (voxel - np.min(voxel)) / (np.max(voxel) - np.min(voxel))
16
  voxel = (voxel * 2) - 1
17
  return voxel.astype("float32")
src/utils/hemisphere.py CHANGED
@@ -1,92 +1,104 @@
1
  import torch
2
  from scipy.ndimage import binary_dilation
3
-
4
  from utils.functions import normalize
5
 
6
 
7
- def separate(voxel, model, device, mode):
8
  """
9
- Separates the voxel data based on the specified mode and processes it using the given model.
 
 
 
 
 
 
10
 
11
  Args:
12
- voxel (list or numpy.ndarray): The input voxel data to be processed.
13
- model (torch.nn.Module): The neural network model used for processing the voxel data.
14
- device (torch.device): The device (CPU or GPU) on which the model and data are loaded.
15
- mode (str): The mode of separation, either 'c' for coronal or 'a' for axial.
16
 
17
  Returns:
18
- torch.Tensor: The processed output tensor with shape (stack[0], 3, stack[1], stack[2]).
 
19
  """
20
- if mode == "c":
21
- # Set the stack dimensions for coronal mode
22
- stack = (224, 192, 192)
23
- elif mode == "a":
24
- # Set the stack dimensions for axial mode
25
- stack = (192, 224, 192)
26
-
27
- # Set the model to evaluation mode
28
  model.eval()
29
 
30
- # Disable gradient calculation for inference
 
 
31
  with torch.inference_mode():
32
- # Initialize an output tensor with the specified stack dimensions
33
- output = torch.zeros(stack[0], 3, stack[1], stack[2]).to(device)
34
 
35
- # Iterate over each slice in the voxel data
36
- for i, v in enumerate(voxel):
37
- # Reshape the slice and convert it to a tensor
38
- image = torch.tensor(v.reshape(1, 1, stack[1], stack[2]))
39
- # Move the tensor to the specified device
40
- image = image.to(device)
41
- # Perform a forward pass through the model and apply softmax
42
- x_out = torch.softmax(model(image), 1).detach()
43
- # Store the output in the corresponding slice of the output tensor
44
- output[i] = x_out
45
 
46
- # Return the processed output tensor
47
- return output
 
48
 
 
 
49
 
50
- def hemisphere(voxel, hnet_c, hnet_a, device):
 
51
  """
52
- Processes a voxel image to separate and dilate hemispheres using neural networks.
 
 
 
 
 
 
53
 
54
  Args:
55
- voxel (torch.Tensor): The input voxel image tensor.
56
- hnet_c (torch.nn.Module): The neural network model for coronal separation.
57
- hnet_a (torch.nn.Module): The neural network model for transverse separation.
58
- device (torch.device): The device to run the neural networks on (e.g., 'cpu' or 'cuda').
59
 
60
  Returns:
61
- numpy.ndarray: The processed and dilated mask of the hemispheres.
 
 
 
62
  """
63
- # Normalize the voxel data
64
- voxel = normalize(voxel)
65
 
66
- # Transpose the voxel data for coronal and transverse views
67
  coronal = voxel.transpose(1, 2, 0)
68
  transverse = voxel.transpose(2, 1, 0)
69
 
70
- # Separate the coronal and transverse views using the respective models
71
- out_c = separate(coronal, hnet_c, device, "c").permute(1, 3, 0, 2)
72
- out_a = separate(transverse, hnet_a, device, "a").permute(1, 3, 2, 0)
73
 
74
- # Combine the outputs from both views
75
  out_e = out_c + out_a
76
 
77
- # Get the final output by taking the argmax along the first dimension
78
- out_e = torch.argmax(out_e, 0).cpu().numpy()
79
 
80
- # Clear the CUDA cache
81
  torch.cuda.empty_cache()
82
 
83
- # Perform binary dilation on the mask for class 1
84
- dilated_mask_1 = binary_dilation(out_e == 1, iterations=5).astype("int16")
 
 
 
 
 
85
  dilated_mask_1[out_e == 2] = 2
86
 
87
- # Perform binary dilation on the mask for class 2
88
- dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=5).astype("int16") * 2
 
89
  dilated_mask_2[dilated_mask_1 == 1] = 1
90
 
91
- # Return the final dilated mask
92
  return dilated_mask_2
 
1
  import torch
2
  from scipy.ndimage import binary_dilation
3
+ import numpy as np
4
  from utils.functions import normalize
5
 
6
 
7
+ def separate(voxel, model, device):
8
  """
9
+ Perform slice-wise inference using a hemisphere separation model.
10
+
11
+ This function runs a 2.5D neural network across slices of a 3D input volume.
12
+ Each slice is processed in the context of its immediate neighbors (previous
13
+ and next slices) to improve spatial coherence. The model outputs a
14
+ three-class probability map distinguishing background, left hemisphere,
15
+ and right hemisphere regions.
16
 
17
  Args:
18
+ voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224).
19
+ model (torch.nn.Module): Trained hemisphere segmentation model (U-Net architecture).
20
+ device (torch.device): Computational device (CPU, CUDA, or MPS).
 
21
 
22
  Returns:
23
+ torch.Tensor: A tensor of shape (224, 3, 224, 224) containing softmax
24
+ probabilities for each class at every voxel.
25
  """
 
 
 
 
 
 
 
 
26
  model.eval()
27
 
28
+ # Pad the volume by one slice on both ends to provide full 3-slice context
29
+ voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
30
+
31
  with torch.inference_mode():
32
+ # Output tensor for storing model predictions (class probabilities)
33
+ box = torch.zeros(224, 3, 224, 224)
34
 
35
+ # Iterate slice-by-slice along the first axis
36
+ for i in range(1, 225):
37
+ image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
38
+ image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
 
 
 
 
 
 
39
 
40
+ # Model inference with softmax normalization across classes
41
+ x_out = torch.softmax(model(image), dim=1).detach().cpu()
42
+ box[i - 1] = x_out
43
 
44
+ # Return complete 3D probability map
45
+ return box.reshape(224, 3, 224, 224)
46
 
47
+
48
+ def hemisphere(voxel, hnet, device):
49
  """
50
+ Perform hemisphere separation on a brain MRI volume using a deep learning model.
51
+
52
+ The function predicts left and right hemisphere regions from a normalized
53
+ 3D MRI volume using multi-view inference (coronal and transverse planes).
54
+ Predictions from both orientations are fused to improve robustness. The final
55
+ label map is post-processed using binary dilation to smooth and expand hemisphere
56
+ boundaries, ensuring anatomical continuity.
57
 
58
  Args:
59
+ voxel (numpy.ndarray): Input 3D brain volume to be separated into hemispheres.
60
+ hnet (torch.nn.Module): Trained hemisphere segmentation model.
61
+ device (torch.device): Target device for computation (e.g., 'cuda', 'cpu').
 
62
 
63
  Returns:
64
+ numpy.ndarray: A 3D integer array representing the hemisphere mask:
65
+ - 0: Background
66
+ - 1: Left hemisphere
67
+ - 2: Right hemisphere
68
  """
69
+ # Normalize voxel intensities for inference
70
+ voxel = normalize(voxel, "hemisphere")
71
 
72
+ # Prepare different anatomical orientations for inference
73
  coronal = voxel.transpose(1, 2, 0)
74
  transverse = voxel.transpose(2, 1, 0)
75
 
76
+ # Perform inference for both coronal and transverse orientations
77
+ out_c = separate(coronal, hnet, device).permute(1, 3, 0, 2)
78
+ out_a = separate(transverse, hnet, device).permute(1, 3, 2, 0)
79
 
80
+ # Fuse both outputs by summing class probabilities
81
  out_e = out_c + out_a
82
 
83
+ # Determine final class labels (0, 1, or 2) by selecting the most probable class
84
+ out_e = torch.argmax(out_e, dim=0).cpu().numpy()
85
 
86
+ # Release any residual GPU memory
87
  torch.cuda.empty_cache()
88
 
89
+ # --------------------------
90
+ # Post-processing step: binary dilation
91
+ # --------------------------
92
+
93
+ # First, dilate the left hemisphere (class 1)
94
+ dilated_mask_1 = binary_dilation(out_e == 1, iterations=1).astype("int16")
95
+ # Preserve right hemisphere voxels from the original prediction
96
  dilated_mask_1[out_e == 2] = 2
97
 
98
+ # Then, dilate the right hemisphere (class 2) symmetrically
99
+ dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=1).astype("int16") * 2
100
+ # Restore left hemisphere voxels to prevent overwriting
101
  dilated_mask_2[dilated_mask_1 == 1] = 1
102
 
103
+ # Return the final dilated and fused hemisphere mask
104
  return dilated_mask_2
src/utils/load_model.py CHANGED
@@ -8,70 +8,65 @@ from utils.network import UNet
8
 
9
  def load_model(model_dir, device):
10
  """
11
- This function loads multiple pre-trained models and sets them to evaluation mode.
12
- The models loaded are:
13
- 1. CNet: A U-Net model for some specific task.
14
- 2. SSNet: Another U-Net model for a different task.
15
- 3. PNet coronal: A U-Net model for coronal plane predictions.
16
- 4. PNet sagittal: A U-Net model for sagittal plane predictions.
17
- 5. PNet axial: A U-Net model for axial plane predictions.
18
- 6. HNet coronal: A U-Net model for coronal plane predictions with different input/output channels.
19
- 7. HNet axial: A U-Net model for axial plane predictions with different input/output channels.
20
 
21
- Parameters:
22
- opt (object): An options object containing model paths.
23
- device (torch.device): The device on which to load the models (CPU or GPU).
 
 
 
 
 
 
 
 
 
24
 
25
  Returns:
26
- tuple: A tuple containing all the loaded models.
 
 
27
  """
28
- # Unzip the Model.zip file
29
  model_zip_path = os.path.join(model_dir, "model.zip")
30
  with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
31
  zip_ref.extractall(model_dir)
32
-
33
- # Load CNet model
34
- cnet = UNet(1, 1)
35
- cnet.load_state_dict(torch.load(os.path.join(model_dir, "CNet", "CNet.pth"), weights_only=True))
 
 
 
36
  cnet.to(device)
37
  cnet.eval()
38
 
39
- # Load SSNet model
40
- ssnet = UNet(1, 1)
41
- ssnet.load_state_dict(torch.load(os.path.join(model_dir, "SSNet", "SSNet.pth"), weights_only=True))
 
 
 
42
  ssnet.to(device)
43
  ssnet.eval()
44
 
45
- # Load PNet coronal model
46
- pnet_c = UNet(3, 142)
47
- pnet_c.load_state_dict(torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True))
48
- pnet_c.to(device)
49
- pnet_c.eval()
50
-
51
- # Load PNet sagittal model
52
- pnet_s = UNet(3, 142)
53
- pnet_s.load_state_dict(torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True))
54
- pnet_s.to(device)
55
- pnet_s.eval()
56
-
57
- # Load PNet axial model
58
- pnet_a = UNet(3, 142)
59
- pnet_a.load_state_dict(torch.load(os.path.join(model_dir, "PNet", "axial.pth"), weights_only=True))
60
- pnet_a.to(device)
61
- pnet_a.eval()
62
-
63
- # Load HNet coronal model
64
- hnet_c = UNet(1, 3)
65
- hnet_c.load_state_dict(torch.load(os.path.join(model_dir, "HNet", "coronal.pth"), weights_only=True))
66
- hnet_c.to(device)
67
- hnet_c.eval()
68
 
69
- # Load HNet axial model
70
- hnet_a = UNet(1, 3)
71
- hnet_a.load_state_dict(torch.load(os.path.join(model_dir, "HNet", "axial.pth"), weights_only=True))
72
- hnet_a.to(device)
73
- hnet_a.eval()
 
 
 
74
 
75
- # Return all loaded models
76
- return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
77
- # return cnet, ssnet, pnet_a, hnet_c, hnet_a
 
8
 
9
  def load_model(model_dir, device):
10
  """
11
+ Load and initialize the pretrained neural network models required for the OpenMAP-T1 pipeline.
 
 
 
 
 
 
 
 
12
 
13
+ This function loads four U-Net–based models from the specified pretrained model directory.
14
+ Each model is moved to the target device (CPU, CUDA, or MPS) and set to evaluation mode.
15
+
16
+ Models loaded:
17
+ 1. **CNet (Cropping Network)** — Performs face cropping and brain localization.
18
+ 2. **SSNet (Skull Stripping Network)** — Removes non-brain tissues from MRI scans.
19
+ 3. **PNet (Parcellation Network)** — Predicts fine-grained anatomical labels across 142 regions.
20
+ 4. **HNet (Hemisphere Network)** — Segments the brain into hemispheric masks (left/right/other).
21
+
22
+ Args:
23
+ opt (argparse.Namespace): Parsed command-line arguments containing the pretrained model directory path (`opt.m`).
24
+ device (torch.device): Target device on which to load models (e.g., `torch.device('cuda')`).
25
 
26
  Returns:
27
+ tuple:
28
+ A tuple containing four initialized and evaluation-ready models:
29
+ (cnet, ssnet, pnet, hnet).
30
  """
 
31
  model_zip_path = os.path.join(model_dir, "model.zip")
32
  with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
33
  zip_ref.extractall(model_dir)
34
+ # --------------------------
35
+ # Load CNet (Cropping Network)
36
+ # --------------------------
37
+ # Input: 3-channel (neighboring slices), Output: 1-channel binary mask
38
+ cnet = UNet(3, 1)
39
+ print(os.path.join(model_dir, "model", "CNet", "CNet.pth"))
40
+ cnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "CNet", "CNet.pth"), weights_only=True))
41
  cnet.to(device)
42
  cnet.eval()
43
 
44
+ # ------------------------------
45
+ # Load SSNet (Skull Stripping Network)
46
+ # ------------------------------
47
+ # Input: 3-channel (neighboring slices), Output: 1-channel brain mask
48
+ ssnet = UNet(3, 1)
49
+ ssnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "SSNet", "SSNet.pth"), weights_only=True))
50
  ssnet.to(device)
51
  ssnet.eval()
52
 
53
+ # -----------------------------
54
+ # Load PNet (Parcellation Network)
55
+ # -----------------------------
56
+ # Input: 4 channels (multi-modal or augmented context), Output: 142 anatomical regions
57
+ pnet = UNet(4, 142)
58
+ pnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "PNet", "PNet.pth"), weights_only=True))
59
+ pnet.to(device)
60
+ pnet.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # -----------------------------
63
+ # Load HNet (Hemisphere Network)
64
+ # -----------------------------
65
+ # Input: 3 channels, Output: 3-class hemisphere mask (left, right, background)
66
+ hnet = UNet(3, 3)
67
+ hnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "HNet", "HNet.pth"), weights_only=True))
68
+ hnet.to(device)
69
+ hnet.eval()
70
 
71
+ # Return all loaded, device-initialized, and evaluation-ready models
72
+ return cnet, ssnet, pnet, hnet
 
src/utils/parcellation.py CHANGED
@@ -4,102 +4,133 @@ import torch
4
  from utils.functions import normalize
5
 
6
 
7
- def parcellate(voxel, model, device, mode):
 
 
 
 
 
 
8
  """
9
- Parcellates a given voxel volume using a specified model and mode.
 
 
 
 
10
 
11
  Args:
12
- voxel (numpy.ndarray): The input voxel volume to be parcellated.
13
- model (torch.nn.Module): The neural network model used for parcellation.
14
- device (torch.device): The device (CPU or GPU) on which the model is run.
15
- mode (str): The mode of parcellation. Can be 'c', 's', or 'a', which determines the stack dimensions.
 
16
 
17
  Returns:
18
- torch.Tensor: The parcellated voxel volume.
 
19
  """
20
- if mode == "c":
21
- stack = (224, 192, 192)
22
- elif mode == "s":
23
- stack = (192, 224, 192)
24
- elif mode == "a":
25
- stack = (192, 224, 192)
26
-
27
- # Set the model to evaluation mode
28
  model.eval()
29
-
30
- # Pad the voxel volume to handle edge cases
31
- voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
32
-
33
- # Disable gradient calculation for inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  with torch.inference_mode():
35
- # Initialize an empty tensor to store the parcellation results
36
- box = torch.zeros(stack[0], 142, stack[1], stack[2])
 
 
37
 
38
- # Iterate over each slice in the stack dimension
39
- for i in range(1, stack[0] + 1):
40
- # Stack three consecutive slices to form the input image
41
- image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
42
- image = torch.tensor(image.reshape(1, 3, stack[1], stack[2]))
43
- image = image.to(device)
44
 
45
- # Perform the forward pass through the model and apply softmax
46
- x_out = torch.softmax(model(image), 1).detach().cpu()
47
 
48
- # Store the output in the corresponding slice of the box tensor
49
- box[i - 1] = x_out
 
50
 
51
- # Reshape the box tensor to the desired output shape
52
- return box.reshape(stack[0], 142, stack[1], stack[2])
53
 
 
54
 
55
- def parcellation(voxel, pnet_c, pnet_s, pnet_a, device):
 
56
  """
57
- Perform parcellation on the given voxel data using provided neural networks for coronal, sagittal, and axial views.
 
 
 
 
 
58
 
59
  Args:
60
- voxel (torch.Tensor): The input 3D voxel data to be parcellated.
61
- pnet_c (torch.nn.Module): The neural network model for coronal view parcellation.
62
- pnet_s (torch.nn.Module): The neural network model for sagittal view parcellation.
63
- pnet_a (torch.nn.Module): The neural network model for axial view parcellation.
64
- device (torch.device): The device (CPU or GPU) to perform computations on.
65
 
66
  Returns:
67
- numpy.ndarray: The parcellated output as a numpy array.
68
  """
69
- # Normalize the voxel data
70
- voxel = normalize(voxel)
71
 
72
- # Prepare the voxel data for different views
73
  coronal = voxel.transpose(1, 2, 0)
74
  sagittal = voxel
75
  axial = voxel.transpose(2, 1, 0)
76
 
77
- # Perform parcellation for the coronal view
78
- print("Performing parcellation for coronal view...")
79
- out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
 
80
  torch.cuda.empty_cache()
81
- print("Parcellation for coronal view completed.")
82
 
83
- # Perform parcellation for the sagittal view
84
- print("Performing parcellation for sagittal view...")
85
- out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
 
86
  torch.cuda.empty_cache()
87
- print("Parcellation for sagittal view completed.")
88
 
89
- # Combine the results from coronal and sagittal views
90
  out_e = out_c + out_s
91
  del out_c, out_s
92
- print("Combining results from coronal and sagittal views...")
93
- # Perform parcellation for the axial view
94
- out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
 
 
95
  torch.cuda.empty_cache()
96
- print("Parcellation for axial view completed.")
97
 
98
- # Combine the results from all views
99
- out_e = out_a + out_e
100
  del out_a
101
 
102
- # Get the final parcellated output by taking the argmax
103
  parcellated = torch.argmax(out_e, 0).numpy()
104
 
105
  return parcellated
 
4
  from utils.functions import normalize
5
 
6
 
7
+ def parcellate(
8
+ voxel: np.ndarray,
9
+ model: torch.nn.Module,
10
+ device: torch.device,
11
+ mode: str,
12
+ n_classes: int = 142,
13
+ ) -> torch.Tensor:
14
  """
15
+ Perform 2.5D neural network inference for brain parcellation along a specific anatomical plane.
16
+
17
+ The function processes a 3D volume slice by slice using a 3-slice context window (previous,
18
+ current, next). An additional constant-valued fourth channel encodes the orientation mode
19
+ (Axial, Coronal, or Sagittal), allowing the network to distinguish the processing plane.
20
 
21
  Args:
22
+ voxel (numpy.ndarray): 3D voxel data of shape (N, 224, 224), representing a single anatomical view.
23
+ model (torch.nn.Module): The trained PyTorch parcellation model.
24
+ device (torch.device): Device for inference (CPU, CUDA, or MPS).
25
+ mode (str): The anatomical plane used for inference. Must be one of {'Axial', 'Coronal', 'Sagittal'}.
26
+ n_classes (int, optional): Number of output anatomical labels. Defaults to 142.
27
 
28
  Returns:
29
+ torch.Tensor: A tensor of shape (224, n_classes, 224, 224) containing softmax probabilities
30
+ for each class at each voxel position.
31
  """
 
 
 
 
 
 
 
 
32
  model.eval()
33
+ voxel = voxel.astype(np.float32)
34
+
35
+ # Set the constant value for the 4th channel to encode plane orientation
36
+ if mode == "Axial":
37
+ section_value = 1.0
38
+ elif mode == "Coronal":
39
+ section_value = -1.0
40
+ elif mode == "Sagittal":
41
+ section_value = 0.0
42
+ else:
43
+ raise ValueError("mode must be one of {'Axial','Coronal','Sagittal'}")
44
+
45
+ # Pad one slice on both ends to safely allow 3-slice context
46
+ voxel_pad = np.pad(
47
+ voxel,
48
+ [(1, 1), (0, 0), (0, 0)],
49
+ mode="constant",
50
+ constant_values=float(voxel.min()),
51
+ )
52
+
53
+ # Initialize a container for the network outputs (CPU for accumulation)
54
+ box = torch.empty((224, n_classes, 224, 224), dtype=torch.float32, device="cpu")
55
+
56
+ # Inference loop: iterate over slices and feed triplets to the model
57
  with torch.inference_mode():
58
+ for i in range(1, 225):
59
+ prev_ = voxel_pad[i - 1]
60
+ curr_ = voxel_pad[i]
61
+ next_ = voxel_pad[i + 1]
62
 
63
+ # Build 4-channel input (3 context slices + orientation encoding)
64
+ four_ch = np.empty((4, 224, 224), dtype=np.float32)
65
+ four_ch[0] = prev_
66
+ four_ch[1] = curr_
67
+ four_ch[2] = next_
68
+ four_ch[3].fill(section_value)
69
 
70
+ inp = torch.from_numpy(four_ch).unsqueeze(0).to(device)
 
71
 
72
+ # Model inference with softmax normalization
73
+ logits = model(inp)
74
+ probs = torch.softmax(logits, dim=1)
75
 
76
+ # Store softmax output for this slice
77
+ box[i - 1] = probs
78
 
79
+ return box
80
 
81
+
82
+ def parcellation(voxel, pnet, device):
83
  """
84
+ Perform full 3D brain parcellation by aggregating predictions across multiple anatomical planes.
85
+
86
+ The function normalizes the input MRI volume, generates three differently oriented representations
87
+ (coronal, sagittal, axial), and performs 2.5D inference on each using a shared parcellation network.
88
+ The resulting probability maps are fused by summation and converted into a discrete segmentation map
89
+ via argmax over anatomical classes.
90
 
91
  Args:
92
+ voxel (numpy.ndarray): Input 3D brain volume (float array).
93
+ pnet (torch.nn.Module): Trained parcellation network (U-Net or similar architecture).
94
+ device (torch.device): Device on which inference will be executed (CPU or GPU).
 
 
95
 
96
  Returns:
97
+ numpy.ndarray: Final 3D parcellation map (integer label image) with voxel-wise anatomical labels.
98
  """
99
+ # Normalize input intensities for network inference
100
+ voxel = normalize(voxel, "parcellation")
101
 
102
+ # Prepare three anatomical views for 2.5D inference
103
  coronal = voxel.transpose(1, 2, 0)
104
  sagittal = voxel
105
  axial = voxel.transpose(2, 1, 0)
106
 
107
+ # ------------------------
108
+ # Coronal view inference
109
+ # ------------------------
110
+ out_c = parcellate(coronal, pnet, device, "Coronal").permute(1, 3, 0, 2)
111
  torch.cuda.empty_cache()
 
112
 
113
+ # ------------------------
114
+ # Sagittal view inference
115
+ # ------------------------
116
+ out_s = parcellate(sagittal, pnet, device, "Sagittal").permute(1, 0, 2, 3)
117
  torch.cuda.empty_cache()
 
118
 
119
+ # Fuse coronal and sagittal predictions
120
  out_e = out_c + out_s
121
  del out_c, out_s
122
+
123
+ # ------------------------
124
+ # Axial view inference
125
+ # ------------------------
126
+ out_a = parcellate(axial, pnet, device, "Axial").permute(1, 3, 2, 0)
127
  torch.cuda.empty_cache()
 
128
 
129
+ # Combine outputs from all three anatomical orientations
130
+ out_e = out_e + out_a
131
  del out_a
132
 
133
+ # Convert probability maps to final integer labels
134
  parcellated = torch.argmax(out_e, 0).numpy()
135
 
136
  return parcellated
src/utils/stripping.py CHANGED
@@ -7,96 +7,92 @@ from utils.functions import normalize, reimburse_conform
7
 
8
  def strip(voxel, model, device):
9
  """
10
- Applies a given model to a 3D voxel array and returns the processed output.
 
 
 
 
11
 
12
  Args:
13
- voxel (numpy.ndarray): A 3D numpy array of shape (256, 256, 256) representing the input voxel data.
14
- model (torch.nn.Module): A PyTorch model to be used for processing the voxel data.
15
- device (torch.device): The device (CPU or GPU) on which the model and data should be loaded.
 
16
 
17
  Returns:
18
- torch.Tensor: A 3D tensor of shape (256, 256, 256) containing the processed output.
 
19
  """
20
- # Set the model to evaluation mode
21
  model.eval()
22
 
23
- # Disable gradient calculation for inference
24
- with torch.inference_mode():
25
- # Initialize an empty tensor to store the output
26
- output = torch.zeros(256, 256, 256).to(device)
27
-
28
- # Iterate over each slice in the voxel data
29
- for i, v in enumerate(voxel):
30
- # Reshape the slice to match the model's input dimensions
31
- image = v.reshape(1, 1, 256, 256)
32
 
33
- # Convert the numpy array to a PyTorch tensor and move it to the specified device
34
- image = torch.tensor(image).to(device)
35
-
36
- # Apply the model to the input image and apply the sigmoid activation function
37
- x_out = torch.sigmoid(model(image)).detach()
38
 
39
- # Store the output in the corresponding slice of the output tensor
40
- output[i] = x_out
 
 
 
 
41
 
42
- # Reshape the output tensor to the original voxel dimensions and return it
43
- return output.reshape(256, 256, 256)
44
 
45
 
46
- def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
47
  """
48
- Perform brain stripping on a given voxel using a specified neural network.
49
 
50
- This function normalizes the input voxel, applies brain stripping in three anatomical planes
51
- (coronal, sagittal, and axial), and combines the results to produce a final stripped brain image.
52
- The stripped image is then centered and cropped.
 
 
53
 
54
  Args:
55
- voxel (numpy.ndarray): The input 3D voxel data to be stripped.
56
- data (nibabel.Nifti1Image): The original neuroimaging data.
57
- ssnet (torch.nn.Module): The neural network model used for brain stripping.
58
- device (torch.device): The device on which the neural network model is loaded (e.g., CPU or GPU).
 
 
 
 
59
 
60
  Returns:
61
- tuple: A tuple containing:
62
- - stripped (numpy.ndarray): The stripped and processed brain image.
63
- - (xd, yd, zd) (tuple of int): The shifts applied to center the brain image in the x, y, and z directions.
64
  """
65
- # Normalize the input voxel data
66
- voxel = normalize(voxel)
 
 
 
67
 
68
- # Prepare the voxel data in three anatomical planes: coronal, sagittal, and axial
69
  coronal = voxel.transpose(1, 2, 0)
70
  sagittal = voxel
71
  axial = voxel.transpose(2, 1, 0)
72
 
73
- # Apply the brain stripping model to each plane
74
- out_c = strip(coronal, ssnet, device).permute(2, 0, 1)
75
- out_s = strip(sagittal, ssnet, device)
76
- out_a = strip(axial, ssnet, device).permute(2, 1, 0)
77
 
78
- # Combine the results from the three planes and threshold the output
79
  out_e = ((out_c + out_s + out_a) / 3) > 0.5
80
  out_e = out_e.cpu().numpy()
81
 
82
- # Multiply the original data by the thresholded output to get the stripped brain image
83
- stripped = data.get_fdata().astype("float32") * out_e
84
 
85
- out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
86
-
87
- # Calculate the center of mass of the stripped brain image
88
- x, y, z = map(int, ndimage.center_of_mass(out_e))
89
-
90
- # Calculate the shifts needed to center the brain image
91
- xd = 128 - x
92
- yd = 120 - y
93
- zd = 128 - z
94
 
95
- # Apply the shifts to center the brain image
96
- stripped = np.roll(stripped, (xd, yd, zd), axis=(0, 1, 2))
97
-
98
- # Crop the centered brain image
99
- stripped = stripped[32:-32, 16:-16, 32:-32]
100
 
101
- # Return the stripped brain image and the shifts applied
102
- return stripped, (xd, yd, zd), out_filename
 
7
 
8
  def strip(voxel, model, device):
9
  """
10
+ Perform slice-wise inference using the brain stripping model.
11
+
12
+ This function processes the input 3D volume slice by slice (along the first axis),
13
+ using a three-slice context window for each prediction. The output is a 3D mask
14
+ representing the brain region.
15
 
16
  Args:
17
+ voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224), typically
18
+ a single anatomical orientation (e.g., coronal or sagittal view).
19
+ model (torch.nn.Module): The trained PyTorch brain stripping model.
20
+ device (torch.device): Device used for inference (CPU, CUDA, or MPS).
21
 
22
  Returns:
23
+ torch.Tensor: A tensor of shape (224, 224, 224) representing the predicted
24
+ binary brain mask.
25
  """
 
26
  model.eval()
27
 
28
+ # Pad one slice on both ends to ensure valid 3-slice context at the boundaries
29
+ voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
 
 
 
 
 
 
 
30
 
31
+ with torch.inference_mode():
32
+ box = torch.zeros(224, 224, 224)
 
 
 
33
 
34
+ # Perform model inference for each slice using a 3-slice context
35
+ for i in range(1, 225):
36
+ image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
37
+ image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)
38
+ x_out = torch.sigmoid(model(image)).detach().cpu()
39
+ box[i - 1] = x_out
40
 
41
+ # Return as a 3D mask tensor
42
+ return box.reshape(224, 224, 224)
43
 
44
 
45
+ def stripping(output_dir, basename, voxel, odata, data, ssnet, shift, device):
46
  """
47
+ Perform full 3D brain stripping using a deep learning model.
48
 
49
+ This function applies a neural network-based skull-stripping algorithm to
50
+ isolate the brain region from a 3D MRI volume. It performs inference along
51
+ three anatomical orientations—coronal, sagittal, and axial—and fuses the
52
+ predictions to obtain a robust binary mask. The mask is then applied to the
53
+ input image, recentred, and saved.
54
 
55
  Args:
56
+ output_dir (str): Directory where intermediate and final results will be saved.
57
+ basename (str): Base name of the current case (used for file naming).
58
+ voxel (numpy.ndarray): Input 3D voxel data (preprocessed MRI image).
59
+ odata (nibabel.Nifti1Image): Original NIfTI image before preprocessing.
60
+ data (nibabel.Nifti1Image): Preprocessed NIfTI image used for model input.
61
+ ssnet (torch.nn.Module): Trained brain stripping network.
62
+ shift (tuple[int, int, int]): The (x, y, z) offsets applied previously during cropping.
63
+ device (torch.device): Device used for inference (CPU, CUDA, or MPS).
64
 
65
  Returns:
66
+ numpy.ndarray: The skull-stripped 3D brain volume.
 
 
67
  """
68
+ # Preserve original intensity data for later restoration
69
+ original = voxel.copy()
70
+
71
+ # Normalize the voxel intensities for model input
72
+ voxel = normalize(voxel, "stripping")
73
 
74
+ # Prepare data in three anatomical orientations
75
  coronal = voxel.transpose(1, 2, 0)
76
  sagittal = voxel
77
  axial = voxel.transpose(2, 1, 0)
78
 
79
+ # Apply the model along each anatomical plane
80
+ out_c = strip(coronal, ssnet, device).permute(2, 0, 1) # coronal → native orientation
81
+ out_s = strip(sagittal, ssnet, device) # sagittal
82
+ out_a = strip(axial, ssnet, device).permute(2, 1, 0) # axial → native orientation
83
 
84
+ # Fuse predictions by averaging across the three planes and apply threshold
85
  out_e = ((out_c + out_s + out_a) / 3) > 0.5
86
  out_e = out_e.cpu().numpy()
87
 
88
+ # Apply the binary mask to extract the brain region
89
+ stripped = original * out_e
90
 
91
+ # Restore the mask to the original conformed geometry
92
+ # Pad to original full size and reverse the previously applied shift
93
+ out_e = np.pad(out_e, [(16, 16), (16, 16), (16, 16)], "constant", constant_values=0)
94
+ out_e = np.roll(out_e, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
 
 
 
 
 
95
 
96
+ out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
 
 
 
 
97
 
98
+ return stripped, out_filename