FrancescoLR commited on
Commit
f2ca9b7
·
verified ·
1 Parent(s): 2d911b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -96,11 +96,14 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
96
  # Define half the slice size
97
  half_size = slice_size // 2
98
 
99
- def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None):
100
  """
101
- Extracts slices from a 3D NIfTI image. If a center is provided, it uses it;
102
- otherwise, computes the center of mass of non-zero voxels. Slices are taken
103
- along axial, coronal, and sagittal planes and saved as a single PNG.
 
 
 
104
  """
105
  # Load NIfTI image
106
  img = nib.load(nifti_path)
@@ -110,65 +113,80 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=
110
  # Resample the image to 1 mm isotropic
111
  resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
112
 
 
 
 
 
 
 
 
 
113
  # Compute or reuse the center of mass
114
  if center is None:
115
- com = center_of_mass(resampled_data > 0)
116
  center = np.round(com).astype(int)
117
 
118
  # Define half the slice size
119
  half_size = slice_size // 2
120
 
121
- # Safely extract and pad 2D slices
122
  def extract_2d_slice(data, center, axis):
123
  slices = [slice(None)] * 3
124
- slices[axis] = center[axis] # Fix the axis to extract a single slice
125
  extracted_slice = data[tuple(slices)]
126
 
127
- # Crop the 2D slice around the center in the remaining dimensions
128
  remaining_axes = [i for i in range(3) if i != axis]
129
  cropped_slice = extracted_slice[
130
  max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
131
  max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
132
  ]
133
 
134
- # Pad the slice to ensure 180x180 dimensions
135
  pad_height = slice_size - cropped_slice.shape[0]
136
  pad_width = slice_size - cropped_slice.shape[1]
137
- padded_slice = np.pad(cropped_slice,
138
- ((pad_height // 2, pad_height - pad_height // 2),
139
  (pad_width // 2, pad_width - pad_width // 2)),
140
  mode='constant', constant_values=0)
141
  return padded_slice
142
 
143
- # Extract slices in axial, coronal, and sagittal planes
144
- axial_slice = extract_2d_slice(resampled_data, center, axis=2) # Axial (z-axis)
145
- coronal_slice = extract_2d_slice(resampled_data, center, axis=1) # Coronal (y-axis)
146
- sagittal_slice = extract_2d_slice(resampled_data, center, axis=0) # Sagittal (x-axis)
147
 
148
- # Apply rotations to each slice
149
- axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise
150
- coronal_slice = np.rot90(coronal_slice, k=1) # 90 degrees anticlockwise
151
- coronal_slice = np.rot90(coronal_slice, k=2) # Additional 180 degrees
152
- sagittal_slice = np.rot90(sagittal_slice, k=1) # 90 degrees anticlockwise
153
- sagittal_slice = np.rot90(sagittal_slice, k=2) # Additional 180 degrees
154
 
155
  # Create subplots
156
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
157
 
158
- # Plot each padded and rotated slice
159
- axes[0].imshow(axial_slice, cmap="gray", origin="lower")
160
- axes[0].axis("off")
161
-
162
- axes[1].imshow(coronal_slice, cmap="gray", origin="lower")
163
- axes[1].axis("off")
 
 
 
164
 
165
- axes[2].imshow(sagittal_slice, cmap="gray", origin="lower")
166
- axes[2].axis("off")
 
 
 
167
 
168
- # Save the figure
169
  plt.tight_layout()
170
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
171
  plt.close()
 
 
 
172
 
173
  # Function to run nnUNet inference
174
  @spaces.GPU(duration=90) # Decorate the function to allocate GPU for its execution
@@ -239,8 +257,8 @@ def run_nnunet_predict(nifti_file,hd_bet=False):
239
  # Extract and save 2D slices
240
  input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
241
  output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
242
- extract_middle_slices(input_path, input_slice_path, center=center)
243
- extract_middle_slices(new_output_file, output_slice_path, center=center)
244
 
245
  # Return paths for the Gradio interface
246
  return new_output_file, input_slice_path, output_slice_path
 
96
  # Define half the slice size
97
  half_size = slice_size // 2
98
 
99
+ def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False):
100
  """
101
+ Extracts slices from a 3D NIfTI image.
102
+ If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity)
103
+ and returns the labeled 3D mask.
104
+
105
+ Returns:
106
+ labeled_data (np.ndarray): The 3D array (either labeled or original).
107
  """
108
  # Load NIfTI image
109
  img = nib.load(nifti_path)
 
113
  # Resample the image to 1 mm isotropic
114
  resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
115
 
116
+ # Optionally label connected components
117
+ if label_components:
118
+ structure = generate_binary_structure(3, 3) # 3D, 26-connectivity
119
+ labeled_data, num_features = label(resampled_data > 0, structure=structure)
120
+ else:
121
+ labeled_data = resampled_data
122
+ num_features = None # Not needed if we're not labeling
123
+
124
  # Compute or reuse the center of mass
125
  if center is None:
126
+ com = center_of_mass(labeled_data > 0)
127
  center = np.round(com).astype(int)
128
 
129
  # Define half the slice size
130
  half_size = slice_size // 2
131
 
132
+ # Function to extract and pad slices
133
  def extract_2d_slice(data, center, axis):
134
  slices = [slice(None)] * 3
135
+ slices[axis] = center[axis]
136
  extracted_slice = data[tuple(slices)]
137
 
 
138
  remaining_axes = [i for i in range(3) if i != axis]
139
  cropped_slice = extracted_slice[
140
  max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
141
  max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
142
  ]
143
 
 
144
  pad_height = slice_size - cropped_slice.shape[0]
145
  pad_width = slice_size - cropped_slice.shape[1]
146
+ padded_slice = np.pad(cropped_slice,
147
+ ((pad_height // 2, pad_height - pad_height // 2),
148
  (pad_width // 2, pad_width - pad_width // 2)),
149
  mode='constant', constant_values=0)
150
  return padded_slice
151
 
152
+ # Extract slices
153
+ axial_slice = extract_2d_slice(labeled_data, center, axis=2)
154
+ coronal_slice = extract_2d_slice(labeled_data, center, axis=1)
155
+ sagittal_slice = extract_2d_slice(labeled_data, center, axis=0)
156
 
157
+ # Apply rotations
158
+ axial_slice = np.rot90(axial_slice, k=-1)
159
+ coronal_slice = np.rot90(coronal_slice, k=1)
160
+ coronal_slice = np.rot90(coronal_slice, k=2)
161
+ sagittal_slice = np.rot90(sagittal_slice, k=1)
162
+ sagittal_slice = np.rot90(sagittal_slice, k=2)
163
 
164
  # Create subplots
165
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
166
 
167
+ # Choose colormap
168
+ if label_components:
169
+ cmap = plt.cm.nipy_spectral # Colorful
170
+ vmin = 0
171
+ vmax = num_features
172
+ else:
173
+ cmap = "gray" # Normal
174
+ vmin = None
175
+ vmax = None
176
 
177
+ # Plot slices
178
+ for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]):
179
+ ax = axes[idx]
180
+ im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
181
+ ax.axis("off")
182
 
183
+ # Save figure
184
  plt.tight_layout()
185
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
186
  plt.close()
187
+
188
+ # Return the labeled mask
189
+ return labeled_data
190
 
191
  # Function to run nnUNet inference
192
  @spaces.GPU(duration=90) # Decorate the function to allocate GPU for its execution
 
257
  # Extract and save 2D slices
258
  input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
259
  output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
260
+ image = extract_middle_slices(input_path, input_slice_path, center=center)
261
+ labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
262
 
263
  # Return paths for the Gradio interface
264
  return new_output_file, input_slice_path, output_slice_path