hanjang commited on
Commit
fbbcdcd
·
verified ·
1 Parent(s): 9588088

Upload /Interactive_MEN_RT_predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Interactive_MEN_RT_predictor.py +1012 -0
Interactive_MEN_RT_predictor.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from typing import Union, List, Tuple, Optional, Dict
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from time import time
7
+ import sys
8
+ import importlib
9
+ import math
10
+
11
+ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
12
+ from nnunetv2.utilities.helpers import empty_cache, dummy_context
13
+ from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
14
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
15
+ from batchgenerators.utilities.file_and_folder_operations import load_json, join, subdirs
16
+ from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, crop_and_pad_nd
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.functional import interpolate
20
+ import nnunetv2
21
+
22
+ import nnInteractive
23
+ from nnInteractive.interaction.point import PointInteraction_stub
24
+ from nnInteractive.utils.bboxes import generate_bounding_boxes
25
+ from nnInteractive.utils.crop import crop_and_pad_into_buffer, paste_tensor, pad_cropped, crop_to_valid
26
+ from nnInteractive.utils.erosion_dilation import iterative_3x3_same_padding_pool3d
27
+ from nnInteractive.utils.rounding import round_to_nearest_odd
28
+
29
+
30
+ class InteractiveMENRTPredictor:
31
+ """
32
+ Interactive MEN RT Predictor for interactive segmentation with point, bbox, scribble, and lasso interactions.
33
+ """
34
+
35
+ def __init__(self,
36
+ device: torch.device = torch.device('cuda'),
37
+ use_torch_compile: bool = False,
38
+ verbose: bool = False,
39
+ torch_n_threads: int = 8,
40
+ do_autozoom: bool = True,
41
+ use_pinned_memory: bool = True
42
+ ):
43
+ """
44
+ Only intended to work with nnInteractiveTrainerV2 and its derivatives
45
+ """
46
+ # set as part of initialization
47
+ assert use_torch_compile is False, ('This implementation places the preprocessed image and the interactions '
48
+ 'into pinned memory for speed reasons. This is incompatible with '
49
+ 'torch.compile because of inconsistent strides in the memory layout. '
50
+ 'Note to self: .contiguous() on GPU could be a solution. Unclear whether '
51
+ 'that will yield a benefit though.')
52
+ self.network = None
53
+ self.label_manager = None
54
+ self.dataset_json = None
55
+ self.trainer_name = None
56
+ self.configuration_manager = None
57
+ self.plans_manager = None
58
+ self.use_pinned_memory = use_pinned_memory
59
+ self.device = device
60
+ self.use_torch_compile = use_torch_compile
61
+
62
+ # Interactive session state
63
+ self.interactions: torch.Tensor = None
64
+ self.preprocessed_image: torch.Tensor = None
65
+ self.preprocessed_props = None
66
+ self.target_buffer: Union[np.ndarray, torch.Tensor] = None
67
+
68
+ self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
69
+ self.verbose = verbose
70
+
71
+ self.do_autozoom: bool = do_autozoom
72
+ torch.set_num_threads(min(torch_n_threads, os.cpu_count()))
73
+
74
+ self.original_image_shape = None
75
+
76
+ self.new_interaction_zoom_out_factors: List[float] = []
77
+ self.new_interaction_centers = []
78
+ self.has_positive_bbox = False
79
+
80
+ # Create a thread pool executor for background tasks.
81
+ # this only takes care of preprocessing and interaction memory initialization so there is no need to give it
82
+ # more than 2 workers
83
+ self.executor = ThreadPoolExecutor(max_workers=2)
84
+ self.preprocess_future = None
85
+ self.interactions_future = None
86
+
87
+ def set_image(self, image: np.ndarray, image_properties: dict = None):
88
+ """
89
+ Image must be 4D to satisfy nnU-Net needs: [c, x, y, z]
90
+ Offload the processing to a background thread.
91
+ """
92
+ if image_properties is None:
93
+ image_properties = {}
94
+ self._reset_session()
95
+ assert image.ndim == 4, f'expected a 4d image as input, got {image.ndim}d. Shape {image.shape}'
96
+ if self.verbose:
97
+ print(f'Initialize with raw image shape {image.shape}')
98
+
99
+ # Offload all image preprocessing to a background thread.
100
+ self.preprocess_future = self.executor.submit(self._background_set_image, image, image_properties)
101
+ self.original_image_shape = image.shape
102
+
103
+ def _finish_preprocessing_and_initialize_interactions(self):
104
+ """
105
+ Block until both the image preprocessing and the interactions tensor initialization
106
+ are finished.
107
+ """
108
+ if self.preprocess_future is not None:
109
+ # Wait for image preprocessing to complete.
110
+ self.preprocess_future.result()
111
+ del self.preprocess_future
112
+ self.preprocess_future = None
113
+
114
+ def set_target_buffer(self, target_buffer: Union[np.ndarray, torch.Tensor]):
115
+ """
116
+ Must be 3d numpy array or torch.Tensor
117
+ """
118
+ self.target_buffer = target_buffer
119
+
120
+ def set_do_autozoom(self, do_propagation: bool, max_num_patches: Optional[int] = None):
121
+ self.do_autozoom = do_propagation
122
+
123
+ def _reset_session(self):
124
+ self.interactions_future = None
125
+ self.preprocess_future = None
126
+
127
+ del self.preprocessed_image
128
+ del self.target_buffer
129
+ del self.interactions
130
+ del self.preprocessed_props
131
+ self.preprocessed_image = None
132
+ self.target_buffer = None
133
+ self.interactions = None
134
+ self.preprocessed_props = None
135
+ empty_cache(self.device)
136
+ self.original_image_shape = None
137
+ self.has_positive_bbox = False
138
+
139
+ def _initialize_interactions(self, image_torch: torch.Tensor):
140
+ if self.verbose:
141
+ print(f'Initialize interactions. Pinned: {self.use_pinned_memory}')
142
+ # Create the interaction tensor based on the target shape.
143
+ self.interactions = torch.zeros(
144
+ (7, *image_torch.shape[1:]),
145
+ device='cpu',
146
+ dtype=torch.float16,
147
+ pin_memory=(self.device.type == 'cuda' and self.use_pinned_memory)
148
+ )
149
+
150
+ def _background_set_image(self, image: np.ndarray, image_properties: dict):
151
+ """Background preprocessing of the image"""
152
+ # Convert to torch tensor
153
+ image_torch = torch.clone(torch.from_numpy(image))
154
+
155
+ # Crop to nonzero region
156
+ if self.verbose:
157
+ print('Cropping input image to nonzero region')
158
+ nonzero_idx = torch.where(image_torch != 0)
159
+ bbox = [[i.min().item(), i.max().item() + 1] for i in nonzero_idx]
160
+
161
+ # Ensure bbox is larger than patch_size
162
+ if hasattr(self, 'configuration_manager') and self.configuration_manager is not None:
163
+ patch_size = self.configuration_manager.patch_size
164
+ for dim in range(1, len(bbox)):
165
+ bbox_size = bbox[dim][1] - bbox[dim][0]
166
+ if bbox_size < patch_size[dim - 1]:
167
+ # Center the bbox and extend it to patch_size
168
+ center = (bbox[dim][0] + bbox[dim][1]) // 2
169
+ bbox[dim][0] = max(0, center - patch_size[dim - 1] // 2)
170
+ bbox[dim][1] = min(image_torch.shape[dim], center + patch_size[dim - 1] // 2 + patch_size[dim - 1] % 2)
171
+
172
+ del nonzero_idx
173
+ slicer = bounding_box_to_slice(bbox)
174
+ image_torch = image_torch[slicer].float()
175
+
176
+ if self.verbose:
177
+ print(f'Cropped image shape: {image_torch.shape}')
178
+
179
+ # Initialize interactions tensor
180
+ self._initialize_interactions(image_torch)
181
+
182
+ # Normalize the image
183
+ if self.verbose:
184
+ print('Normalizing cropped image')
185
+ image_torch -= image_torch.mean()
186
+ image_torch /= image_torch.std()
187
+
188
+ self.preprocessed_image = image_torch
189
+ if self.use_pinned_memory and self.device.type == 'cuda':
190
+ if self.verbose:
191
+ print('Pin memory: image')
192
+ self.preprocessed_image = self.preprocessed_image.pin_memory()
193
+
194
+ self.preprocessed_props = {'bbox_used_for_cropping': bbox[1:]}
195
+
196
+ def reset_interactions(self):
197
+ """
198
+ Use this to reset all interactions and start from scratch for the current image. This includes the initial
199
+ segmentation!
200
+ """
201
+ if self.interactions is not None:
202
+ self.interactions.fill_(0)
203
+
204
+ if self.target_buffer is not None:
205
+ if isinstance(self.target_buffer, np.ndarray):
206
+ self.target_buffer.fill(0)
207
+ elif isinstance(self.target_buffer, torch.Tensor):
208
+ self.target_buffer.zero_()
209
+ empty_cache(self.device)
210
+ self.has_positive_bbox = False
211
+
212
+ def add_bbox_interaction(self, bbox_coords, include_interaction: bool, run_prediction: bool = True) -> np.ndarray:
213
+ if include_interaction:
214
+ self.has_positive_bbox = True
215
+
216
+ self._finish_preprocessing_and_initialize_interactions()
217
+
218
+ lbs_transformed = [round(i) for i in transform_coordinates_noresampling([i[0] for i in bbox_coords],
219
+ self.preprocessed_props['bbox_used_for_cropping'])]
220
+ ubs_transformed = [round(i) for i in transform_coordinates_noresampling([i[1] for i in bbox_coords],
221
+ self.preprocessed_props['bbox_used_for_cropping'])]
222
+ transformed_bbox_coordinates = [[i, j] for i, j in zip(lbs_transformed, ubs_transformed)]
223
+
224
+ if self.verbose:
225
+ print(f'Added bounding box coordinates.\n'
226
+ f'Raw: {bbox_coords}\n'
227
+ f'Transformed: {transformed_bbox_coordinates}\n'
228
+ f"Crop Bbox: {self.preprocessed_props['bbox_used_for_cropping']}")
229
+
230
+ # Prevent collapsed bounding boxes and clip to image shape
231
+ image_shape = self.preprocessed_image.shape # Assuming shape is (C, H, W, D) or similar
232
+
233
+ for dim in range(len(transformed_bbox_coordinates)):
234
+ transformed_start, transformed_end = transformed_bbox_coordinates[dim]
235
+
236
+ # Clip to image boundaries
237
+ transformed_start = max(0, transformed_start)
238
+ transformed_end = min(image_shape[dim + 1], transformed_end) # +1 to skip channel dim
239
+
240
+ # Ensure the bounding box does not collapse to a single point
241
+ if transformed_end <= transformed_start:
242
+ if transformed_start == 0:
243
+ transformed_end = min(1, image_shape[dim + 1])
244
+ else:
245
+ transformed_start = max(transformed_start - 1, 0)
246
+
247
+ transformed_bbox_coordinates[dim] = [transformed_start, transformed_end]
248
+
249
+ if self.verbose:
250
+ print(f'Bbox coordinates after clip to image boundaries and preventing dim collapse:\n'
251
+ f'Bbox: {transformed_bbox_coordinates}\n'
252
+ f'Internal image shape: {self.preprocessed_image.shape}')
253
+
254
+ self._add_patch_for_bbox_interaction(transformed_bbox_coordinates)
255
+
256
+ # decay old interactions
257
+ self.interactions[-6:-4] *= self.interaction_decay
258
+
259
+ # place bbox
260
+ slicer = tuple([slice(*i) for i in transformed_bbox_coordinates])
261
+ channel = -6 if include_interaction else -5
262
+ self.interactions[(channel, *slicer)] = 1
263
+
264
+ # forward pass
265
+ if run_prediction:
266
+ self._predict()
267
+
268
+ def add_point_interaction(self, coordinates: Tuple[int, ...], include_interaction: bool, run_prediction: bool = True):
269
+ self._finish_preprocessing_and_initialize_interactions()
270
+
271
+ transformed_coordinates = [round(i) for i in transform_coordinates_noresampling(coordinates,
272
+ self.preprocessed_props['bbox_used_for_cropping'])]
273
+
274
+ self._add_patch_for_point_interaction(transformed_coordinates)
275
+
276
+ # decay old interactions
277
+ self.interactions[-4:-2] *= self.interaction_decay
278
+
279
+ interaction_channel = -4 if include_interaction else -3
280
+ self.interactions[interaction_channel] = self.point_interaction.place_point(
281
+ transformed_coordinates, self.interactions[interaction_channel])
282
+ if run_prediction:
283
+ self._predict()
284
+
285
+ def add_scribble_interaction(self, scribble_image: np.ndarray, include_interaction: bool, run_prediction: bool = True):
286
+ assert all([i == j for i, j in zip(self.original_image_shape[1:], scribble_image.shape)]), f'Given scribble image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {scribble_image.shape}'
287
+ self._finish_preprocessing_and_initialize_interactions()
288
+
289
+ scribble_image = torch.from_numpy(scribble_image)
290
+
291
+ # crop (as in preprocessing)
292
+ scribble_image = crop_and_pad_nd(scribble_image, self.preprocessed_props['bbox_used_for_cropping'])
293
+
294
+ self._add_patch_for_scribble_interaction(scribble_image)
295
+
296
+ # decay old interactions
297
+ self.interactions[-2:] *= self.interaction_decay
298
+
299
+ interaction_channel = -2 if include_interaction else -1
300
+ torch.maximum(self.interactions[interaction_channel], scribble_image.to(self.interactions.device),
301
+ out=self.interactions[interaction_channel])
302
+ del scribble_image
303
+ empty_cache(self.device)
304
+ if run_prediction:
305
+ self._predict()
306
+
307
+ def add_lasso_interaction(self, lasso_image: np.ndarray, include_interaction: bool, run_prediction: bool = True):
308
+ assert all([i == j for i, j in zip(self.original_image_shape[1:], lasso_image.shape)]), f'Given lasso image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {lasso_image.shape}'
309
+ self._finish_preprocessing_and_initialize_interactions()
310
+
311
+ lasso_image = torch.from_numpy(lasso_image)
312
+
313
+ # crop (as in preprocessing)
314
+ lasso_image = crop_and_pad_nd(lasso_image, self.preprocessed_props['bbox_used_for_cropping'])
315
+
316
+ self._add_patch_for_lasso_interaction(lasso_image)
317
+
318
+ # decay old interactions
319
+ self.interactions[-6:-4] *= self.interaction_decay
320
+
321
+ # lasso is written into bbox channel
322
+ interaction_channel = -6 if include_interaction else -5
323
+ torch.maximum(self.interactions[interaction_channel], lasso_image.to(self.interactions.device),
324
+ out=self.interactions[interaction_channel])
325
+ del lasso_image
326
+ empty_cache(self.device)
327
+ if run_prediction:
328
+ self._predict()
329
+
330
+ def add_initial_seg_interaction(self, initial_seg: np.ndarray, run_prediction: bool = False):
331
+ """
332
+ WARNING THIS WILL RESET INTERACTIONS!
333
+ """
334
+ assert all([i == j for i, j in zip(self.original_image_shape[1:], initial_seg.shape)]), f'Given initial seg must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {initial_seg.shape}'
335
+
336
+ self._finish_preprocessing_and_initialize_interactions()
337
+
338
+ self.reset_interactions()
339
+
340
+ if isinstance(self.target_buffer, np.ndarray):
341
+ self.target_buffer[:] = initial_seg
342
+
343
+ initial_seg = torch.from_numpy(initial_seg)
344
+
345
+ if isinstance(self.target_buffer, torch.Tensor):
346
+ self.target_buffer[:] = initial_seg
347
+
348
+ # crop (as in preprocessing)
349
+ initial_seg = crop_and_pad_nd(initial_seg, self.preprocessed_props['bbox_used_for_cropping'])
350
+
351
+ # initial seg is written into initial seg buffer
352
+ interaction_channel = -7
353
+ self.interactions[interaction_channel] = initial_seg
354
+ empty_cache(self.device)
355
+ if run_prediction:
356
+ self._add_patch_for_initial_seg_interaction(initial_seg)
357
+ del initial_seg
358
+ self._predict()
359
+ else:
360
+ del initial_seg
361
+
362
+ @torch.inference_mode()
363
+ def _predict(self):
364
+ """
365
+ Perform prediction with interactions. The process follows the training procedure:
366
+ 1. Make initial prediction with current interactions
367
+ 2. Generate new interactions based on prediction errors
368
+ 3. Make final prediction with updated interactions
369
+ """
370
+ assert self.pad_mode_data == 'constant', 'pad modes other than constant are not implemented here'
371
+
372
+ start_predict = time()
373
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
374
+ # Find the region containing all interactions
375
+ interaction_mask = torch.any(self.interactions[1:] > 0, dim=0) # Combine all interaction channels
376
+ if not torch.any(interaction_mask):
377
+ print('No interactions found, skipping prediction')
378
+ return
379
+
380
+ # Get bounding box of interaction region
381
+ nonzero_indices = torch.nonzero(interaction_mask)
382
+ min_coords = torch.min(nonzero_indices, dim=0)[0]
383
+ max_coords = torch.max(nonzero_indices, dim=0)[0]
384
+
385
+ # Initialize bbox with interaction region
386
+ patch_size = self.configuration_manager.patch_size
387
+ half_patch_size = [p // 2 for p in patch_size]
388
+ image_shape = self.preprocessed_image.shape[1:]
389
+
390
+ # For each dimension, calculate bbox ensuring:
391
+ # 1. bbox start >= 0
392
+ # 2. bbox end <= image_shape
393
+ # 3. bbox size >= patch_size
394
+ bbox = []
395
+ for i, (min_c, max_c, h, p) in enumerate(zip(min_coords, max_coords, half_patch_size, patch_size)):
396
+ start = max(0, min(image_shape[i] - p, (min_c + max_c) // 2 - p // 2))
397
+ end = min(image_shape[i], start + p)
398
+ bbox.append([start, end])
399
+
400
+ # Calculate number of patches needed
401
+ overlap = [64, 64, 64] # [O_z, O_y, O_x]
402
+ num_patches = [
403
+ 1 if (b1 - b0) <= P
404
+ else math.ceil(((b1 - b0) - P) / (P - O)) + 1
405
+ for (b0, b1), P, O in zip(bbox, patch_size, overlap)
406
+ ]
407
+
408
+ # Initialize prediction tensors for soft merging
409
+ final_pred_soft = torch.zeros((2, *self.preprocessed_image.shape[1:]), dtype=torch.float32, device='cpu')
410
+ prediction_count = torch.zeros(self.preprocessed_image.shape[1:], dtype=torch.float32, device='cpu')
411
+
412
+ # Process each patch
413
+ for x in range(num_patches[0]):
414
+ for y in range(num_patches[1]):
415
+ for z in range(num_patches[2]):
416
+ # Calculate patch boundaries
417
+ step_index = [x, y, z]
418
+ start_coords = [bbox[i][0] + step_index[i] * p for i, p in zip([0, 1, 2], patch_size)]
419
+ end_coords = [min(bbox[i][1], start_coords[i] + p) for i, p in zip([0, 1, 2], patch_size)]
420
+
421
+ for i in range(len(patch_size)):
422
+ if end_coords[i] - start_coords[i] < patch_size[i]:
423
+ if end_coords[i] >= bbox[i][1]:
424
+ start_coords[i] = bbox[i][1] - patch_size[i]
425
+
426
+ # Extract image patch
427
+ image_patch = self.preprocessed_image[:, start_coords[0]:end_coords[0],
428
+ start_coords[1]:end_coords[1],
429
+ start_coords[2]:end_coords[2]]
430
+
431
+ # Extract interaction patches
432
+ interaction_patch = self.interactions[:, start_coords[0]:end_coords[0],
433
+ start_coords[1]:end_coords[1],
434
+ start_coords[2]:end_coords[2]]
435
+
436
+ # Pad to patch_size if necessary
437
+ if not all([e - s == p for s, e, p in zip(start_coords, end_coords, patch_size)]):
438
+ pad_size = [(0, p - (e - s)) for s, e, p in zip(start_coords, end_coords, patch_size)]
439
+ image_patch = F.pad(image_patch, [item for sublist in reversed(pad_size) for item in sublist])
440
+ interaction_patch = F.pad(interaction_patch, [item for sublist in reversed(pad_size) for item in sublist])
441
+
442
+ # Move to device
443
+ image_patch = image_patch.to(self.device, non_blocking=self.device.type == 'cuda')
444
+ interaction_patch = interaction_patch.to(self.device, non_blocking=self.device.type == 'cuda')
445
+
446
+ # Concatenate image and interaction channels
447
+ input_for_predict = torch.cat((image_patch, interaction_patch))
448
+
449
+ # Make prediction
450
+ pred_raw = self.network(input_for_predict[None])[0]
451
+ pred_prob = F.softmax(pred_raw, dim=0)
452
+
453
+ del input_for_predict, pred_raw, image_patch, interaction_patch
454
+
455
+ # Resize prediction if needed
456
+ if not all([e - s == p for s, e, p in zip(start_coords, end_coords, patch_size)]):
457
+ pred_prob = interpolate(pred_prob[None],
458
+ [e - s for s, e in zip(start_coords, end_coords)],
459
+ mode='trilinear')[0]
460
+
461
+ # Add to accumulated predictions
462
+ pred_prob = pred_prob.cpu()
463
+ final_pred_soft[:, start_coords[0]:end_coords[0],
464
+ start_coords[1]:end_coords[1],
465
+ start_coords[2]:end_coords[2]] += pred_prob
466
+ prediction_count[start_coords[0]:end_coords[0],
467
+ start_coords[1]:end_coords[1],
468
+ start_coords[2]:end_coords[2]] += 1
469
+
470
+ del pred_prob
471
+ empty_cache(self.device)
472
+
473
+ # Average predictions and convert to binary
474
+ final_pred_soft = final_pred_soft / prediction_count.clamp(min=1)
475
+ # final_pred_soft = self._iterative_adjust_prediction(final_pred_soft, self.interactions)
476
+ final_pred = (final_pred_soft[1] >= 0.5).to(torch.uint8)
477
+
478
+ # Update interactions and target buffer
479
+ self.interactions[0][:] = final_pred
480
+ paste_tensor(self.target_buffer, final_pred, self.preprocessed_props['bbox_used_for_cropping'])
481
+
482
+ print(f'Done. Total time {round(time() - start_predict, 3)}s')
483
+
484
+ self.new_interaction_centers = []
485
+ empty_cache(self.device)
486
+
487
+ @torch.inference_mode()
488
+ def _predict_without_interaction(self):
489
+ """
490
+ Perform prediction with interaction channels but without zooming. This is a simplified version of _predict that:
491
+ 1. Makes prediction on the entire image at once using interaction channels
492
+ 2. No zooming or refinement is performed
493
+ 3. Uses all interaction channels (previous segmentation, bbox, point, scribble)
494
+ """
495
+ assert self.pad_mode_data == 'constant', 'pad modes other than constant are not implemented here'
496
+
497
+ start_predict = time()
498
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
499
+ # Get image dimensions
500
+ image_shape = self.preprocessed_image.shape[1:] # Remove channel dimension
501
+
502
+ # Calculate number of patches needed
503
+ patch_size = self.configuration_manager.patch_size
504
+ bbox = [[0, i] for i in image_shape]
505
+
506
+ # Calculate number of patches needed
507
+ overlap = [64, 64, 64] # [O_z, O_y, O_x]
508
+ num_patches = [
509
+ 1 if (b1 - b0) <= P
510
+ else math.ceil(((b1 - b0) - P) / (P - O)) + 1
511
+ for (b0, b1), P, O in zip(bbox, patch_size, overlap)
512
+ ]
513
+
514
+ # Initialize prediction tensors for soft merging
515
+ pred_soft = torch.zeros((2, *image_shape), dtype=torch.float32, device='cpu') # 2 channels for binary segmentation
516
+ pred_count = torch.zeros(image_shape, dtype=torch.float32, device='cpu')
517
+
518
+ # Process each patch
519
+ for x in range(num_patches[0]):
520
+ for y in range(num_patches[1]):
521
+ for z in range(num_patches[2]):
522
+ # Calculate patch boundaries
523
+ step_index = [x, y, z]
524
+ start_coords = [bbox[i][0] + step_index[i] * p for i, p in zip([0, 1, 2], patch_size)]
525
+ end_coords = [min(bbox[i][1], start_coords[i] + p) for i, p in zip([0, 1, 2], patch_size)]
526
+
527
+ for i in range(len(patch_size)):
528
+ if end_coords[i] - start_coords[i] < patch_size[i]:
529
+ if end_coords[i] >= bbox[i][1]:
530
+ start_coords[i] = bbox[i][1] - patch_size[i]
531
+
532
+ # Extract image patch
533
+ image_patch = self.preprocessed_image[:, start_coords[0]:end_coords[0],
534
+ start_coords[1]:end_coords[1],
535
+ start_coords[2]:end_coords[2]]
536
+
537
+ # Extract interaction patches
538
+ interaction_patch = self.interactions[:, start_coords[0]:end_coords[0],
539
+ start_coords[1]:end_coords[1],
540
+ start_coords[2]:end_coords[2]]
541
+
542
+ # Pad if necessary
543
+ if not all([e - s == p for s, e, p in zip(start_coords, end_coords, patch_size)]):
544
+ pad_size = [(0, p - (e - s)) for s, e, p in zip(start_coords, end_coords, patch_size)]
545
+ image_patch = F.pad(image_patch, [item for sublist in reversed(pad_size) for item in sublist])
546
+ interaction_patch = F.pad(interaction_patch, [item for sublist in reversed(pad_size) for item in sublist])
547
+
548
+ # Move to device
549
+ image_patch = image_patch.to(self.device, non_blocking=self.device.type == 'cuda')
550
+ interaction_patch = interaction_patch.to(self.device, non_blocking=self.device.type == 'cuda')
551
+
552
+ # Concatenate image and interaction channels
553
+ input_for_predict = torch.cat((image_patch, interaction_patch))
554
+
555
+ # Make prediction and get soft probabilities
556
+ patch_pred = self.network(input_for_predict[None])[0]
557
+ patch_prob = F.softmax(patch_pred, dim=0)
558
+
559
+ # Resize prediction to original patch size if necessary
560
+ if not all([e - s == p for s, e, p in zip(start_coords, end_coords, patch_size)]):
561
+ patch_prob = interpolate(patch_prob[None],
562
+ [e - s for s, e in zip(start_coords, end_coords)],
563
+ mode='trilinear')[0]
564
+
565
+ # Add to accumulated predictions
566
+ pred_soft[:, start_coords[0]:end_coords[0],
567
+ start_coords[1]:end_coords[1],
568
+ start_coords[2]:end_coords[2]] += patch_prob.cpu()
569
+ pred_count[start_coords[0]:end_coords[0],
570
+ start_coords[1]:end_coords[1],
571
+ start_coords[2]:end_coords[2]] += 1
572
+
573
+ del image_patch, interaction_patch, input_for_predict, patch_pred, patch_prob
574
+ empty_cache(self.device)
575
+
576
+ # Average predictions and convert to binary
577
+ pred_soft = pred_soft / pred_count.clamp(min=1)
578
+ pred = (pred_soft[1] >= 0.5).to(torch.uint8)
579
+
580
+ # Update interactions and target buffer
581
+ self.interactions[0][:] = pred
582
+ paste_tensor(self.target_buffer, pred, self.preprocessed_props['bbox_used_for_cropping'])
583
+
584
+ print(f'Done. Total time {round(time() - start_predict, 3)}s')
585
+ empty_cache(self.device)
586
+
587
+ def _add_patch_for_point_interaction(self, coordinates):
588
+ self.new_interaction_centers.append(coordinates)
589
+ print(f'Added new point interaction: center {coordinates}')
590
+
591
+ def _add_patch_for_bbox_interaction(self, bbox):
592
+ bbox_center = [round((i[0] + i[1]) / 2) for i in bbox]
593
+ bbox_size = [i[1]-i[0] for i in bbox]
594
+ # we want to see some context, so the crop we see for the initial prediction should be patch_size / 3 larger
595
+ requested_size = [i + j // 3 for i, j in zip(bbox_size, self.configuration_manager.patch_size)]
596
+ self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)])))
597
+ self.new_interaction_centers.append(bbox_center)
598
+ print(f'Added new bbox interaction: center {bbox_center}')
599
+
600
+ def _add_patch_for_scribble_interaction(self, scribble_image):
601
+ return self._generic_add_patch_from_image(scribble_image)
602
+
603
+ def _add_patch_for_lasso_interaction(self, lasso_image):
604
+ return self._generic_add_patch_from_image(lasso_image)
605
+
606
+ def _add_patch_for_initial_seg_interaction(self, initial_seg):
607
+ return self._generic_add_patch_from_image(initial_seg)
608
+
609
+ def _generic_add_patch_from_image(self, image: torch.Tensor):
610
+ if not torch.any(image):
611
+ print('Received empty image prompt. Cannot add patches for prediction')
612
+ return
613
+ nonzero_indices = torch.nonzero(image, as_tuple=False)
614
+ mn = torch.min(nonzero_indices, dim=0)[0]
615
+ mx = torch.max(nonzero_indices, dim=0)[0]
616
+ roi = [[i.item(), x.item() + 1] for i, x in zip(mn, mx)]
617
+ roi_center = [round((i[0] + i[1]) / 2) for i in roi]
618
+ roi_size = [i[1]- i[0] for i in roi]
619
+ requested_size = [i + j // 3 for i, j in zip(roi_size, self.configuration_manager.patch_size)]
620
+ self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)])))
621
+ self.new_interaction_centers.append(roi_center)
622
+ print(f'Added new image interaction: scale {self.new_interaction_zoom_out_factors[-1]}, center {roi_center}')
623
+
624
+ def initialize_from_trained_model_folder(self,
625
+ model_training_output_dir: str,
626
+ use_fold: Union[int, str] = None,
627
+ checkpoint_name: str = 'checkpoint_final.pth'):
628
+ """
629
+ Initialize the predictor from a trained model folder.
630
+ """
631
+ # Determine fold folder
632
+ if use_fold is not None:
633
+ use_fold = int(use_fold) if use_fold != 'all' else use_fold
634
+ fold_folder = f'fold_{use_fold}'
635
+ else:
636
+ fldrs = subdirs(model_training_output_dir, prefix='fold_', join=False)
637
+ assert len(fldrs) == 1, f'Attempted to infer fold but there is != 1 fold_ folders: {fldrs}'
638
+ fold_folder = fldrs[0]
639
+
640
+ # load trainer specific settings
641
+ expected_json_file = join(model_training_output_dir, fold_folder, 'inference_session_class.json')
642
+ json_content = load_json(expected_json_file)
643
+ if isinstance(json_content, str):
644
+ # Old convention where we only specified the inference class in this file
645
+ point_interaction_radius = 4
646
+ point_interaction_use_etd = True
647
+ self.preferred_scribble_thickness = [2, 2, 2]
648
+ self.point_interaction = PointInteraction_stub(
649
+ point_interaction_radius,
650
+ point_interaction_use_etd)
651
+ self.pad_mode_data = "constant"
652
+ self.interaction_decay = 0.9
653
+ else:
654
+ point_interaction_radius = json_content['point_radius']
655
+ self.preferred_scribble_thickness = json_content['preferred_scribble_thickness']
656
+ if not isinstance(self.preferred_scribble_thickness, (tuple, list)):
657
+ self.preferred_scribble_thickness = [self.preferred_scribble_thickness] * 3
658
+ self.interaction_decay = json_content['interaction_decay'] if 'interaction_decay' in json_content.keys() else 0.9
659
+ point_interaction_use_etd = json_content['use_distance_transform'] if 'use_distance_transform' in json_content.keys() else True
660
+ self.point_interaction = PointInteraction_stub(point_interaction_radius, point_interaction_use_etd)
661
+ # padding mode for data. See nnInteractiveTrainerV2_nodelete_reflectpad
662
+ self.pad_mode_data = json_content['pad_mode_image'] if 'pad_mode_image' in json_content.keys() else "constant"
663
+
664
+ # Load dataset and plans
665
+ dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
666
+ plans = load_json(join(model_training_output_dir, 'plans.json'))
667
+ plans_manager = PlansManager(plans)
668
+
669
+ # Load checkpoint
670
+ checkpoint = torch.load(join(model_training_output_dir, fold_folder, checkpoint_name),
671
+ map_location=self.device, weights_only=False)
672
+ trainer_name = checkpoint['trainer_name']
673
+ configuration_name = checkpoint['init_args']['configuration']
674
+ parameters = checkpoint['network_weights']
675
+
676
+ # Get configuration
677
+ configuration_manager = plans_manager.get_configuration(configuration_name)
678
+
679
+ # Restore network
680
+ num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
681
+ network = nnUNetTrainer.build_network_architecture(
682
+ configuration_manager.network_arch_class_name,
683
+ configuration_manager.network_arch_init_kwargs,
684
+ configuration_manager.network_arch_init_kwargs_req_import,
685
+ num_input_channels,
686
+ plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
687
+ enable_deep_supervision=False
688
+ ).to(self.device)
689
+ network.load_state_dict(parameters)
690
+
691
+ # Store necessary information
692
+ self.plans_manager = plans_manager
693
+ self.configuration_manager = configuration_manager
694
+ self.network = network
695
+ self.dataset_json = dataset_json
696
+ self.trainer_name = trainer_name
697
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
698
+
699
+ if self.use_torch_compile:
700
+ print('Using torch.compile')
701
+ self.network = torch.compile(self.network)
702
+
703
+ if self.verbose:
704
+ print(f"Loaded interactive config: point_radius={self.point_interaction.point_radius}, "
705
+ f"scribble_thickness={self.preferred_scribble_thickness}, "
706
+ f"interaction_decay={self.interaction_decay}")
707
+
708
+ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
709
+ configuration_manager: ConfigurationManager,
710
+ dataset_json: dict, trainer_name: str):
711
+ """
712
+ This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
713
+ """
714
+ self.plans_manager = plans_manager
715
+ self.configuration_manager = configuration_manager
716
+ self.network = network
717
+ self.dataset_json = dataset_json
718
+ self.trainer_name = trainer_name
719
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
720
+
721
+ if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
722
+ print('Using torch.compile')
723
+ self.network = torch.compile(self.network)
724
+
725
+ if not self.use_torch_compile and isinstance(self.network, OptimizedModule):
726
+ self.network = self.network._orig_mod
727
+
728
+ self.network = self.network.to(self.device)
729
+
730
+ @torch.inference_mode()
731
+ def _predict_autozoom(self):
732
+ """
733
+ Perform prediction with interactions. The process follows the training procedure:
734
+ 1. Make initial prediction with current interactions
735
+ 2. Generate new interactions based on prediction errors
736
+ 3. Make final prediction with updated interactions
737
+ """
738
+ assert self.pad_mode_data == 'constant', 'pad modes other than constant are not implemented here'
739
+ assert len(self.new_interaction_centers) == len(self.new_interaction_zoom_out_factors)
740
+ if len(self.new_interaction_centers) > 1:
741
+ print('It seems like more than one interaction was added since the last prediction. This is not '
742
+ 'recommended and may cause unexpected behavior or inefficient predictions')
743
+
744
+ start_predict = time()
745
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
746
+ for prediction_center, initial_zoom_out_factor in zip(self.new_interaction_centers, self.new_interaction_zoom_out_factors):
747
+ # Store previous prediction for comparison
748
+ previous_prediction = torch.clone(self.interactions[0])
749
+
750
+ if not self.do_autozoom:
751
+ initial_zoom_out_factor = 1
752
+
753
+ initial_zoom_out_factor = min(initial_zoom_out_factor, 4)
754
+ zoom_out_factor = initial_zoom_out_factor
755
+ max_zoom_out_factor = initial_zoom_out_factor
756
+
757
+ start_autozoom = time()
758
+ while zoom_out_factor is not None and zoom_out_factor <= 4:
759
+ print('Performing prediction at zoom out factor', zoom_out_factor)
760
+ max_zoom_out_factor = max(max_zoom_out_factor, zoom_out_factor)
761
+
762
+ # Calculate patch size and bounding box
763
+ scaled_patch_size = [round(i * zoom_out_factor) for i in self.configuration_manager.patch_size]
764
+ scaled_bbox = [[int(c - p // 2), int(c + p // 2 + p % 2)] for c, p in zip(prediction_center, scaled_patch_size)]
765
+
766
+ # Crop and prepare input
767
+ crop_img, pad = crop_to_valid(self.preprocessed_image, scaled_bbox)
768
+ crop_img = crop_img.to(self.device, non_blocking=self.device.type == 'cuda')
769
+ crop_interactions, pad_interaction = crop_to_valid(self.interactions, scaled_bbox)
770
+
771
+ # Resize if needed
772
+ if not all([i == j for i, j in zip(self.configuration_manager.patch_size, scaled_patch_size)]):
773
+ crop_interactions_resampled_gpu = torch.empty((7, *self.configuration_manager.patch_size), dtype=torch.float16, device=self.device)
774
+
775
+ # Handle previous segmentation and bbox channels
776
+ for i in range(0, 3):
777
+ if any([x for y in pad_interaction for x in y]):
778
+ tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction)
779
+ else:
780
+ tmp = crop_interactions[i].to(self.device)
781
+ crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0]
782
+ empty_cache(self.device)
783
+
784
+ # Handle point and scribble channels with dilation
785
+ max_pool_ks = round_to_nearest_odd(zoom_out_factor * 2 - 1)
786
+ for i in range(3, 7):
787
+ if any([x for y in pad_interaction for x in y]):
788
+ tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction)
789
+ else:
790
+ tmp = crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda')
791
+ if max_pool_ks > 1:
792
+ tmp = iterative_3x3_same_padding_pool3d(tmp[None, None], max_pool_ks)[0, 0]
793
+ crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0]
794
+ del tmp
795
+
796
+ crop_img = interpolate(pad_cropped(crop_img, pad)[None] if any([x for y in pad_interaction for x in y]) else crop_img[None],
797
+ self.configuration_manager.patch_size, mode='trilinear')[0]
798
+ crop_interactions = crop_interactions_resampled_gpu
799
+ del crop_interactions_resampled_gpu
800
+ empty_cache(self.device)
801
+ else:
802
+ crop_img = pad_cropped(crop_img, pad) if any([x for y in pad_interaction for x in y]) else crop_img
803
+ crop_interactions = pad_cropped(crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction) if any([x for y in pad_interaction for x in y]) else crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda')
804
+
805
+ # Make prediction
806
+ input_for_predict = torch.cat((crop_img, crop_interactions))
807
+ del crop_img, crop_interactions
808
+ pred = self.network(input_for_predict[None])[0].argmax(0).detach()
809
+ del input_for_predict
810
+
811
+ # Check for changes at borders
812
+ previous_zoom_prediction = crop_and_pad_nd(self.interactions[0], scaled_bbox).to(self.device, non_blocking=self.device.type == 'cuda')
813
+ if not all([i == j for i, j in zip(pred.shape, previous_zoom_prediction.shape)]):
814
+ previous_zoom_prediction = interpolate(previous_zoom_prediction[None, None].to(float), pred.shape, mode='nearest')[0, 0]
815
+
816
+ # Determine if we need to continue zooming
817
+ continue_zoom = False
818
+ if zoom_out_factor < 4 and self.do_autozoom:
819
+ for dim in range(len(scaled_bbox)):
820
+ if continue_zoom:
821
+ break
822
+ for idx in [0, pred.shape[dim] - 1]:
823
+ slice_prev = previous_zoom_prediction.index_select(dim, torch.tensor(idx, device=self.device))
824
+ slice_curr = pred.index_select(dim, torch.tensor(idx, device=self.device))
825
+ pixels_prev = torch.sum(slice_prev)
826
+ pixels_current = torch.sum(slice_curr)
827
+ pixels_diff = torch.sum(slice_prev != slice_curr)
828
+ rel_change = max(pixels_prev, pixels_current) / max(min(pixels_prev, pixels_current), 1e-5) - 1
829
+
830
+ if pixels_diff > 1500 or (pixels_diff > 100 and rel_change > 0.2):
831
+ continue_zoom = True
832
+ if self.verbose:
833
+ print(f'Continuing zoom due to significant changes at borders')
834
+ break
835
+ del slice_prev, slice_curr, pixels_prev, pixels_current, pixels_diff
836
+ del previous_zoom_prediction
837
+
838
+ # Resize prediction if needed
839
+ if not all([i == j for i, j in zip(pred.shape, scaled_patch_size)]):
840
+ pred = (interpolate(pred[None, None].to(float), scaled_patch_size, mode='trilinear')[0, 0] >= 0.5).to(torch.uint8)
841
+
842
+ # Update interactions and target buffer
843
+ if zoom_out_factor == 1 or not continue_zoom:
844
+ pred = pred.cpu()
845
+ paste_tensor(self.interactions[0], pred.half(), scaled_bbox)
846
+
847
+ # Update target buffer
848
+ bbox = [[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in zip(scaled_bbox, self.preprocessed_props['bbox_used_for_cropping'])]
849
+ paste_tensor(self.target_buffer, pred, bbox)
850
+
851
+ del pred
852
+ empty_cache(self.device)
853
+
854
+ if continue_zoom:
855
+ zoom_out_factor *= 1.5
856
+ zoom_out_factor = min(4, zoom_out_factor)
857
+ else:
858
+ zoom_out_factor = None
859
+
860
+ end = time()
861
+ print(f'Auto zoom stage took {round(end - start_autozoom, ndigits=3)}s. Max zoom out factor was {max_zoom_out_factor}')
862
+
863
+ print(f'Done. Total time {round(time() - start_predict, 3)}s')
864
+
865
+ self.new_interaction_centers = []
866
+ self.new_interaction_zoom_out_factors = []
867
+ empty_cache(self.device)
868
+
869
+ def _iterative_adjust_prediction(self, pred_prob: torch.Tensor, crop_interactions: torch.Tensor,
870
+ max_iterations: int = 15, prob_increase_factor: float = 1.5) -> torch.Tensor:
871
+ """
872
+ Perform iterative prediction adjustment when positive interactions exist.
873
+
874
+ Args:
875
+ pred_prob: Probability prediction tensor [C, H, W, D]
876
+ crop_interactions: Interaction tensor [7, H, W, D]
877
+ max_iterations: Maximum number of iterations to try
878
+ prob_increase_factor: Factor to increase foreground probability by in each iteration
879
+
880
+ Returns:
881
+ Adjusted prediction tensor
882
+ """
883
+ # Check if there are any positive interactions
884
+ crop_interactions_pos = crop_interactions[1:7:2]
885
+ pos_mask = torch.any(crop_interactions_pos > 0, dim=0)
886
+ pos_mask_np = pos_mask.cpu().numpy()
887
+ max_iterations = max_iterations if np.any(pos_mask_np) else 1
888
+
889
+ iteration = 0
890
+ while iteration < max_iterations:
891
+ pred_prob = self._adjust_prediction_with_interactions(pred_prob, crop_interactions)
892
+ pred_np = pred_prob[1].cpu().numpy()
893
+
894
+ # If prediction is all zero, try again with adjusted probabilities
895
+ if not np.any(pred_np):
896
+ # Increase foreground probability for regions with positive interactions
897
+ pred_prob[1, pos_mask] = torch.clamp(pred_prob[1, pos_mask] * prob_increase_factor, 0, 1)
898
+ pred_prob[0, pos_mask] = 1 - pred_prob[1, pos_mask]
899
+ iteration += 1
900
+ else:
901
+ break
902
+
903
+ return pred_prob
904
+
905
+ def _adjust_prediction_with_interactions(self, pred_prob: torch.Tensor, crop_interactions: torch.Tensor) -> torch.Tensor:
906
+ """
907
+ Adjust prediction based on interaction masks using superpixel segmentation.
908
+
909
+ Args:
910
+ pred_prob: Probability prediction tensor [C, H, W, D]
911
+ crop_interactions: Interaction tensor [7, H, W, D]
912
+
913
+ Returns:
914
+ Adjusted prediction tensor
915
+ """
916
+ # Separate positive and negative interactions
917
+ crop_interactions_pos = crop_interactions[1:7:2]
918
+ crop_interactions_neg = crop_interactions[2:7:2]
919
+
920
+ pos_mask = torch.any(crop_interactions_pos > 0, dim=0)
921
+ neg_mask = torch.any(crop_interactions_neg > 0, dim=0)
922
+
923
+ # Separate connected components
924
+ import scipy.ndimage
925
+ from skimage.segmentation import slic
926
+ # Get initial prediction for labeling using threshold
927
+ pred_np = (pred_prob[1].cpu().numpy() > 0.5).astype(np.uint8)
928
+ labeled_pred, num_components = scipy.ndimage.label(pred_np)
929
+
930
+ # Convert masks to numpy for overlap checking
931
+ pos_mask_np = pos_mask.cpu().numpy()
932
+ neg_mask_np = neg_mask.cpu().numpy()
933
+
934
+ # Check overlap for each component and adjust pred_prob
935
+ for comp_id in range(1, num_components + 1):
936
+ comp_mask = (labeled_pred == comp_id).astype(np.uint8)
937
+
938
+ # Check overlap with positive and negative masks
939
+ overlap_pos = np.logical_and(comp_mask, pos_mask_np)
940
+ overlap_neg = np.logical_and(comp_mask, neg_mask_np)
941
+
942
+ # If component overlaps with both positive and negative masks
943
+ if np.any(overlap_pos) and np.any(overlap_neg):
944
+ # Get the bounding box of the component
945
+ bbox = scipy.ndimage.find_objects(comp_mask)[0]
946
+ comp_region = comp_mask[bbox]
947
+ pos_region = overlap_pos[bbox]
948
+ neg_region = overlap_neg[bbox]
949
+
950
+ # Get pred_prob values for the region
951
+ pred_region_prob = pred_prob[:, bbox[0], bbox[1], bbox[2]].cpu().numpy()
952
+
953
+ # Create RGB image from probabilities
954
+ pred_rgb = np.transpose(pred_region_prob, (1, 2, 3, 0)) # [H, W, D, C]
955
+ # pred_rgb = np.mean(pred_rgb, axis=-1, keepdims=True) # Average across channels
956
+ # pred_rgb = np.repeat(pred_rgb, 3, axis=-1) # Repeat for RGB
957
+
958
+ # Create superpixels based on pred_prob values
959
+ n_segments = min(100, np.sum(comp_region)) # Limit number of segments
960
+ segments = slic(pred_rgb, n_segments=n_segments, compactness=10, channel_axis=-1)
961
+
962
+ # Process each superpixel
963
+ for seg_id in range(1, segments.max() + 1):
964
+ seg_mask = (segments == seg_id)
965
+ seg_pos = np.logical_and(seg_mask, pos_region)
966
+ seg_neg = np.logical_and(seg_mask, neg_region)
967
+
968
+ # Get global coordinates for this segment
969
+ seg_coords = np.where(seg_mask)
970
+ global_coords = tuple(c + b for c, b in zip(seg_coords, [b.start for b in bbox]))
971
+
972
+ # Assign values based on interaction overlap
973
+ if np.any(seg_pos) and not np.any(seg_neg):
974
+ pred_prob[0, global_coords] = 0.0
975
+ pred_prob[1, global_coords] = 1.0
976
+ elif np.any(seg_neg) and not np.any(seg_pos):
977
+ pred_prob[0, global_coords] = 1.0
978
+ pred_prob[1, global_coords] = 0.0
979
+ # If segment has both interactions, use the original prediction
980
+ else:
981
+ continue
982
+
983
+ # If component only overlaps with positive mask, force it to foreground
984
+ elif np.any(overlap_pos):
985
+ pred_prob[0, comp_mask > 0] = 0.0 # Set background to 0
986
+ pred_prob[1, comp_mask > 0] = 1.0 # Set foreground to 1
987
+
988
+ # If component only overlaps with negative mask, force it to background
989
+ elif np.any(overlap_neg):
990
+ pred_prob[0, comp_mask > 0] = 1.0 # Set background to 1
991
+ pred_prob[1, comp_mask > 0] = 0.0 # Set foreground to 0
992
+
993
+ # # If component does not overlap with any masks, force it to background
994
+ # else:
995
+ # pred_prob[0, comp_mask > 0] = 1.0 # Set background to 1
996
+ # pred_prob[1, comp_mask > 0] = 0.0 # Set foreground to 0
997
+
998
+ # Return thresholded prediction
999
+ return pred_prob
1000
+
1001
+
1002
+ def transform_coordinates_noresampling(
1003
+ coords_orig: Union[List[int], Tuple[int, ...]],
1004
+ nnunet_preprocessing_crop_bbox: List[Tuple[int, int]]
1005
+ ) -> Tuple[int, ...]:
1006
+ """
1007
+ converts coordinates in the original uncropped image to the internal cropped representation. Man I really hate
1008
+ nnU-Net's crop to nonzero!
1009
+ """
1010
+ return tuple([coords_orig[d] - nnunet_preprocessing_crop_bbox[d][0] for d in range(len(coords_orig))])
1011
+
1012
+