Spaces:
Sleeping
Sleeping
| from concurrent.futures import ThreadPoolExecutor | |
| from os import cpu_count | |
| from time import time | |
| from typing import Union, List, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, crop_and_pad_nd | |
| from batchgenerators.utilities.file_and_folder_operations import load_json, join, subdirs | |
| from nnunetv2.utilities.find_class_by_name import recursive_find_python_class | |
| from nnunetv2.utilities.helpers import dummy_context, empty_cache | |
| from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels | |
| from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager | |
| from torch import nn | |
| from torch._dynamo import OptimizedModule | |
| from torch.nn.functional import interpolate | |
| import nnInteractive | |
| from nnInteractive.interaction.point import PointInteraction_stub | |
| from nnInteractive.trainer.nnInteractiveTrainer import nnInteractiveTrainer_stub | |
| from nnInteractive.utils.bboxes import generate_bounding_boxes | |
| from nnInteractive.utils.crop import crop_and_pad_into_buffer, paste_tensor, pad_cropped, crop_to_valid | |
| from nnInteractive.utils.erosion_dilation import iterative_3x3_same_padding_pool3d | |
| from nnInteractive.utils.rounding import round_to_nearest_odd | |
| class nnInteractiveInferenceSession(): | |
| def __init__(self, | |
| device: torch.device = torch.device('cuda'), | |
| use_torch_compile: bool = False, | |
| verbose: bool = False, | |
| torch_n_threads: int = 8, | |
| do_autozoom: bool = True, | |
| use_pinned_memory: bool = True, | |
| ): | |
| """ | |
| Only intended to work with nnInteractiveTrainerV2 and its derivatives | |
| """ | |
| # set as part of initialization | |
| assert use_torch_compile is False, ('This implementation places the preprocessed image and the interactions ' | |
| 'into pinned memory for speed reasons. This is incompatible with ' | |
| 'torch.compile because of inconsistent strides in the memory layout. ' | |
| 'Note to self: .contiguous() on GPU could be a solution. Unclear whether ' | |
| 'that will yield a benefit though.') | |
| self.network = None | |
| self.label_manager = None | |
| self.dataset_json = None | |
| self.trainer_name = None | |
| self.configuration_manager = None | |
| self.plans_manager = None | |
| self.use_pinned_memory = use_pinned_memory | |
| self.device = device | |
| self.use_torch_compile = use_torch_compile | |
| self.interaction_decay = None | |
| # image specific | |
| self.interactions: torch.Tensor = None | |
| self.preprocessed_image: torch.Tensor = None | |
| self.preprocessed_props = None | |
| self.target_buffer: Union[np.ndarray, torch.Tensor] = None | |
| # this will be set when loading the model (initialize_from_trained_model_folder) | |
| self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None | |
| self.verbose = verbose | |
| self.do_autozoom: bool = do_autozoom | |
| torch.set_num_threads(min(torch_n_threads, cpu_count())) | |
| self.original_image_shape = None | |
| self.new_interaction_zoom_out_factors: List[float] = [] | |
| self.new_interaction_centers = [] | |
| self.has_positive_bbox = False | |
| # Create a thread pool executor for background tasks. | |
| # this only takes care of preprocessing and interaction memory initialization so there is no need to give it | |
| # more than 2 workers | |
| self.executor = ThreadPoolExecutor(max_workers=2) | |
| self.preprocess_future = None | |
| self.interactions_future = None | |
| def set_image(self, image: np.ndarray, image_properties: dict = None): | |
| """ | |
| Image must be 4D to satisfy nnU-Net needs: [c, x, y, z] | |
| Offload the processing to a background thread. | |
| """ | |
| if image_properties is None: | |
| image_properties = {} | |
| self._reset_session() | |
| assert image.ndim == 4, f'expected a 4d image as input, got {image.ndim}d. Shape {image.shape}' | |
| if self.verbose: | |
| print(f'Initialize with raw image shape {image.shape}') | |
| # Offload all image preprocessing to a background thread. | |
| self.preprocess_future = self.executor.submit(self._background_set_image, image, image_properties) | |
| self.original_image_shape = image.shape | |
| def _finish_preprocessing_and_initialize_interactions(self): | |
| """ | |
| Block until both the image preprocessing and the interactions tensor initialization | |
| are finished. | |
| """ | |
| if self.preprocess_future is not None: | |
| # Wait for image preprocessing to complete. | |
| self.preprocess_future.result() | |
| del self.preprocess_future | |
| self.preprocess_future = None | |
| def set_target_buffer(self, target_buffer: Union[np.ndarray, torch.Tensor]): | |
| """ | |
| Must be 3d numpy array or torch.Tensor | |
| """ | |
| self.target_buffer = target_buffer | |
| def set_do_autozoom(self, do_propagation: bool, max_num_patches: Optional[int] = None): | |
| self.do_autozoom = do_propagation | |
| def _reset_session(self): | |
| self.interactions_future = None | |
| self.preprocess_future = None | |
| del self.preprocessed_image | |
| del self.target_buffer | |
| del self.interactions | |
| del self.preprocessed_props | |
| self.preprocessed_image = None | |
| self.target_buffer = None | |
| self.interactions = None | |
| self.preprocessed_props = None | |
| empty_cache(self.device) | |
| self.original_image_shape = None | |
| self.has_positive_bbox = False | |
| def _initialize_interactions(self, image_torch: torch.Tensor): | |
| if self.verbose: | |
| print(f'Initialize interactions. Pinned: {self.use_pinned_memory}') | |
| # Create the interaction tensor based on the target shape. | |
| self.interactions = torch.zeros( | |
| (7, *image_torch.shape[1:]), | |
| device='cpu', | |
| dtype=torch.float16, | |
| pin_memory=(self.device.type == 'cuda' and self.use_pinned_memory) | |
| ) | |
| def _background_set_image(self, image: np.ndarray, image_properties: dict): | |
| # Convert and clone the image tensor. | |
| image_torch = torch.clone(torch.from_numpy(image)) | |
| # Crop to nonzero region. | |
| if self.verbose: | |
| print('Cropping input image to nonzero region') | |
| nonzero_idx = torch.where(image_torch != 0) | |
| # Create bounding box: for each dimension, get the min and max (plus one) of the nonzero indices. | |
| bbox = [[i.min().item(), i.max().item() + 1] for i in nonzero_idx] | |
| del nonzero_idx | |
| slicer = bounding_box_to_slice(bbox) # Assuming this returns a tuple of slices. | |
| image_torch = image_torch[slicer].float() | |
| if self.verbose: | |
| print(f'Cropped image shape: {image_torch.shape}') | |
| # As soon as we have the target shape, start initializing the interaction tensor in its own thread. | |
| self.interactions_future = self.executor.submit(self._initialize_interactions, image_torch) | |
| # Normalize the cropped image. | |
| if self.verbose: | |
| print('Normalizing cropped image') | |
| image_torch -= image_torch.mean() | |
| image_torch /= image_torch.std() | |
| self.preprocessed_image = image_torch | |
| if self.use_pinned_memory and self.device.type == 'cuda': | |
| if self.verbose: | |
| print('Pin memory: image') | |
| # Note: pin_memory() in PyTorch typically returns a new tensor. | |
| self.preprocessed_image = self.preprocessed_image.pin_memory() | |
| self.preprocessed_props = {'bbox_used_for_cropping': bbox[1:]} | |
| # we need to wait for this here I believe | |
| self.interactions_future.result() | |
| del self.interactions_future | |
| self.interactions_future = None | |
| def reset_interactions(self): | |
| """ | |
| Use this to reset all interactions and start from scratch for the current image. This includes the initial | |
| segmentation! | |
| """ | |
| if self.interactions is not None: | |
| self.interactions.fill_(0) | |
| if self.target_buffer is not None: | |
| if isinstance(self.target_buffer, np.ndarray): | |
| self.target_buffer.fill(0) | |
| elif isinstance(self.target_buffer, torch.Tensor): | |
| self.target_buffer.zero_() | |
| empty_cache(self.device) | |
| self.has_positive_bbox = False | |
| def add_bbox_interaction(self, bbox_coords, include_interaction: bool, run_prediction: bool = True) -> np.ndarray: | |
| if include_interaction: | |
| self.has_positive_bbox = True | |
| self._finish_preprocessing_and_initialize_interactions() | |
| lbs_transformed = [round(i) for i in transform_coordinates_noresampling([i[0] for i in bbox_coords], | |
| self.preprocessed_props['bbox_used_for_cropping'])] | |
| ubs_transformed = [round(i) for i in transform_coordinates_noresampling([i[1] for i in bbox_coords], | |
| self.preprocessed_props['bbox_used_for_cropping'])] | |
| transformed_bbox_coordinates = [[i, j] for i, j in zip(lbs_transformed, ubs_transformed)] | |
| if self.verbose: | |
| print(f'Added bounding box coordinates.\n' | |
| f'Raw: {bbox_coords}\n' | |
| f'Transformed: {transformed_bbox_coordinates}\n' | |
| f"Crop Bbox: {self.preprocessed_props['bbox_used_for_cropping']}") | |
| # Prevent collapsed bounding boxes and clip to image shape | |
| image_shape = self.preprocessed_image.shape # Assuming shape is (C, H, W, D) or similar | |
| for dim in range(len(transformed_bbox_coordinates)): | |
| transformed_start, transformed_end = transformed_bbox_coordinates[dim] | |
| # Clip to image boundaries | |
| transformed_start = max(0, transformed_start) | |
| transformed_end = min(image_shape[dim + 1], transformed_end) # +1 to skip channel dim | |
| # Ensure the bounding box does not collapse to a single point | |
| if transformed_end <= transformed_start: | |
| if transformed_start == 0: | |
| transformed_end = min(1, image_shape[dim + 1]) | |
| else: | |
| transformed_start = max(transformed_start - 1, 0) | |
| transformed_bbox_coordinates[dim] = [transformed_start, transformed_end] | |
| if self.verbose: | |
| print(f'Bbox coordinates after clip to image boundaries and preventing dim collapse:\n' | |
| f'Bbox: {transformed_bbox_coordinates}\n' | |
| f'Internal image shape: {self.preprocessed_image.shape}') | |
| self._add_patch_for_bbox_interaction(transformed_bbox_coordinates) | |
| # decay old interactions | |
| self.interactions[-6:-4] *= self.interaction_decay | |
| # place bbox | |
| slicer = tuple([slice(*i) for i in transformed_bbox_coordinates]) | |
| channel = -6 if include_interaction else -5 | |
| self.interactions[(channel, *slicer)] = 1 | |
| # forward pass | |
| if run_prediction: | |
| self._predict() | |
| def add_point_interaction(self, coordinates: Tuple[int, ...], include_interaction: bool, run_prediction: bool = True): | |
| self._finish_preprocessing_and_initialize_interactions() | |
| transformed_coordinates = [round(i) for i in transform_coordinates_noresampling(coordinates, | |
| self.preprocessed_props['bbox_used_for_cropping'])] | |
| self._add_patch_for_point_interaction(transformed_coordinates) | |
| # decay old interactions | |
| self.interactions[-4:-2] *= self.interaction_decay | |
| interaction_channel = -4 if include_interaction else -3 | |
| self.interactions[interaction_channel] = self.point_interaction.place_point( | |
| transformed_coordinates, self.interactions[interaction_channel]) | |
| if run_prediction: | |
| self._predict() | |
| def add_scribble_interaction(self, scribble_image: np.ndarray, include_interaction: bool, run_prediction: bool = True): | |
| assert all([i == j for i, j in zip(self.original_image_shape[1:], scribble_image.shape)]), f'Given scribble image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {scribble_image.shape}' | |
| self._finish_preprocessing_and_initialize_interactions() | |
| scribble_image = torch.from_numpy(scribble_image) | |
| # crop (as in preprocessing) | |
| scribble_image = crop_and_pad_nd(scribble_image, self.preprocessed_props['bbox_used_for_cropping']) | |
| self._add_patch_for_scribble_interaction(scribble_image) | |
| # decay old interactions | |
| self.interactions[-2:] *= self.interaction_decay | |
| interaction_channel = -2 if include_interaction else -1 | |
| torch.maximum(self.interactions[interaction_channel], scribble_image.to(self.interactions.device), | |
| out=self.interactions[interaction_channel]) | |
| del scribble_image | |
| empty_cache(self.device) | |
| if run_prediction: | |
| self._predict() | |
| def add_lasso_interaction(self, lasso_image: np.ndarray, include_interaction: bool, run_prediction: bool = True): | |
| assert all([i == j for i, j in zip(self.original_image_shape[1:], lasso_image.shape)]), f'Given lasso image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {lasso_image.shape}' | |
| self._finish_preprocessing_and_initialize_interactions() | |
| lasso_image = torch.from_numpy(lasso_image) | |
| # crop (as in preprocessing) | |
| lasso_image = crop_and_pad_nd(lasso_image, self.preprocessed_props['bbox_used_for_cropping']) | |
| self._add_patch_for_lasso_interaction(lasso_image) | |
| # decay old interactions | |
| self.interactions[-6:-4] *= self.interaction_decay | |
| # lasso is written into bbox channel | |
| interaction_channel = -6 if include_interaction else -5 | |
| torch.maximum(self.interactions[interaction_channel], lasso_image.to(self.interactions.device), | |
| out=self.interactions[interaction_channel]) | |
| del lasso_image | |
| empty_cache(self.device) | |
| if run_prediction: | |
| self._predict() | |
| def add_initial_seg_interaction(self, initial_seg: np.ndarray, run_prediction: bool = False): | |
| """ | |
| WARNING THIS WILL RESET INTERACTIONS! | |
| """ | |
| assert all([i == j for i, j in zip(self.original_image_shape[1:], initial_seg.shape)]), f'Given initial seg must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {initial_seg.shape}' | |
| self._finish_preprocessing_and_initialize_interactions() | |
| self.reset_interactions() | |
| if isinstance(self.target_buffer, np.ndarray): | |
| self.target_buffer[:] = initial_seg | |
| initial_seg = torch.from_numpy(initial_seg) | |
| if isinstance(self.target_buffer, torch.Tensor): | |
| self.target_buffer[:] = initial_seg | |
| # crop (as in preprocessing) | |
| initial_seg = crop_and_pad_nd(initial_seg, self.preprocessed_props['bbox_used_for_cropping']) | |
| # initial seg is written into initial seg buffer | |
| interaction_channel = -7 | |
| self.interactions[interaction_channel] = initial_seg | |
| empty_cache(self.device) | |
| if run_prediction: | |
| self._add_patch_for_initial_seg_interaction(initial_seg) | |
| del initial_seg | |
| self._predict() | |
| else: | |
| del initial_seg | |
| def _predict(self): | |
| """ | |
| This function is a smoking mess to read. This is deliberate. Initially it was super pretty and easy to | |
| understand. Then the run time optimization began. | |
| If it feels like we are excessively transferring tensors between CPU and GPU, this is deliberate as well. | |
| Our goal is to keep this tool usable even for people with smaller GPUs (8-10GB VRAM). In an ideal world | |
| everyone would have 24GB+ of VRAM and all tensors would like on GPU all the time. | |
| The amount of hours spent optimizing this function is substantial. Almost every line was turned and twisted | |
| multiple times. If something appears odd, it is probably so for a reason. Don't change things all willy nilly | |
| without first understanding what is going on. And don't make changes without verifying that the run time or | |
| VRAM consumption is not adversely affected. | |
| Returns: | |
| """ | |
| assert self.pad_mode_data == 'constant', 'pad modes other than constant are not implemented here' | |
| assert len(self.new_interaction_centers) == len(self.new_interaction_zoom_out_factors) | |
| if len(self.new_interaction_centers) > 1: | |
| print('It seems like more than one interaction was added since the last prediction. This is not ' | |
| 'recommended and may cause unexpected behavior or inefficient predictions') | |
| start_predict = time() | |
| with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): | |
| for prediction_center, initial_zoom_out_factor in zip(self.new_interaction_centers, self.new_interaction_zoom_out_factors): | |
| # make a prediction at initial zoom out factor. If more zoom out is required, do this until the | |
| # entire object fits the FOV. Then go back to original resolution and refine. | |
| # we need this later. | |
| previous_prediction = torch.clone(self.interactions[0]) | |
| if not self.do_autozoom: | |
| initial_zoom_out_factor = 1 | |
| initial_zoom_out_factor = min(initial_zoom_out_factor, 4) | |
| zoom_out_factor = initial_zoom_out_factor | |
| max_zoom_out_factor = initial_zoom_out_factor | |
| start_autozoom = time() | |
| while zoom_out_factor is not None and zoom_out_factor <= 4: | |
| print('Performing prediction at zoom out factor', zoom_out_factor) | |
| max_zoom_out_factor = max(max_zoom_out_factor, zoom_out_factor) | |
| # initial prediction at initial_zoom_out_factor | |
| scaled_patch_size = [round(i * zoom_out_factor) for i in self.configuration_manager.patch_size] | |
| scaled_bbox = [[c - p // 2, c + p // 2 + p % 2] for c, p in zip(prediction_center, scaled_patch_size)] | |
| crop_img, pad = crop_to_valid(self.preprocessed_image, scaled_bbox) | |
| crop_img = crop_img.to(self.device, non_blocking=self.device.type == 'cuda') | |
| crop_interactions, pad_interaction = crop_to_valid(self.interactions, scaled_bbox) | |
| # resize input_for_predict (which may be larger than patch size) to patch size | |
| # this implementation may not seem straightforward but it does save VRAM which is crucial here | |
| if not all([i == j for i, j in zip(self.configuration_manager.patch_size, scaled_patch_size)]): | |
| crop_interactions_resampled_gpu = torch.empty((7, *self.configuration_manager.patch_size), dtype=torch.float16, device=self.device) | |
| # previous seg, bbox+, bbox- | |
| for i in range(0, 3): | |
| # this is area for a reason but I aint telling ya why | |
| if any([x for y in pad_interaction for x in y]): | |
| tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction) | |
| else: | |
| tmp = crop_interactions[i].to(self.device) | |
| crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0] | |
| empty_cache(self.device) | |
| max_pool_ks = round_to_nearest_odd(zoom_out_factor * 2 - 1) | |
| # point+, point-, scribble+, scribble- | |
| for i in range(3, 7): | |
| if any([x for y in pad_interaction for x in y]): | |
| tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction) | |
| else: | |
| tmp = crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda') | |
| if max_pool_ks > 1: | |
| # dilate to preserve interactions after downsampling | |
| tmp = iterative_3x3_same_padding_pool3d(tmp[None, None], max_pool_ks)[0, 0] | |
| # this is 'area' for a reason but I aint telling ya why | |
| crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0] | |
| del tmp | |
| crop_img = interpolate(pad_cropped(crop_img, pad)[None] if any([x for y in pad_interaction for x in y]) else crop_img[None], self.configuration_manager.patch_size, mode='trilinear')[0] | |
| crop_interactions = crop_interactions_resampled_gpu | |
| del crop_interactions_resampled_gpu | |
| empty_cache(self.device) | |
| else: | |
| # crop_img is already on device | |
| crop_img = pad_cropped(crop_img, pad) if any([x for y in pad_interaction for x in y]) else crop_img | |
| crop_interactions = pad_cropped(crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction) if any([x for y in pad_interaction for x in y]) else crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda') | |
| input_for_predict = torch.cat((crop_img, crop_interactions)) | |
| del crop_img, crop_interactions | |
| pred = self.network(input_for_predict[None])[0].argmax(0).detach() | |
| del input_for_predict | |
| # detect changes at borders | |
| previous_zoom_prediction = crop_and_pad_nd(self.interactions[0], scaled_bbox).to(self.device, non_blocking=self.device.type == 'cuda') | |
| if not all([i == j for i, j in zip(pred.shape, previous_zoom_prediction.shape)]): | |
| previous_zoom_prediction = interpolate(previous_zoom_prediction[None, None].to(float), pred.shape, mode='nearest')[0, 0] | |
| abs_pxl_change_threshold = 1500 | |
| rel_pxl_change_threshold = 0.2 | |
| min_pxl_change_threshold = 100 | |
| continue_zoom = False | |
| if zoom_out_factor < 4 and self.do_autozoom: | |
| for dim in range(len(scaled_bbox)): | |
| if continue_zoom: | |
| break | |
| for idx in [0, pred.shape[dim] - 1]: | |
| slice_prev = previous_zoom_prediction.index_select(dim, torch.tensor(idx, device=self.device)) | |
| slice_curr = pred.index_select(dim, torch.tensor(idx, device=self.device)) | |
| pixels_prev = torch.sum(slice_prev) | |
| pixels_current = torch.sum(slice_curr) | |
| pixels_diff = torch.sum(slice_prev != slice_curr) | |
| rel_change = max(pixels_prev, pixels_current) / max(min(pixels_prev, pixels_current), | |
| 1e-5) - 1 | |
| if pixels_diff > abs_pxl_change_threshold: | |
| continue_zoom = True | |
| if self.verbose: | |
| print(f'continue zooming because change at borders of {pixels_diff} > {abs_pxl_change_threshold}') | |
| break | |
| if pixels_diff > min_pxl_change_threshold and rel_change > rel_pxl_change_threshold: | |
| continue_zoom = True | |
| if self.verbose: | |
| print(f'continue zooming because relative change of {rel_change} > {rel_pxl_change_threshold} and n_pixels {pixels_diff} > {min_pxl_change_threshold}') | |
| break | |
| del slice_prev, slice_curr, pixels_prev, pixels_current, pixels_diff | |
| del previous_zoom_prediction | |
| # resize prediction to correct size and place in target buffer + interactions | |
| if not all([i == j for i, j in zip(pred.shape, scaled_patch_size)]): | |
| pred = (interpolate(pred[None, None].to(float), scaled_patch_size, mode='trilinear')[0, 0] >= 0.5).to(torch.uint8) | |
| # if we do not continue zooming we need a difference map for sampling patches | |
| if not continue_zoom and zoom_out_factor > 1: | |
| # wow this circus saves ~30ms relative to naive implementation | |
| previous_prediction = previous_prediction.to(self.device, non_blocking=self.device.type == 'cuda') | |
| seen_bbox = [[max(0, i[0]), min(i[1], s)] for i, s in zip(scaled_bbox, previous_prediction.shape)] | |
| bbox_tmp = [[i[0] - s[0], i[1] - s[0]] for i, s in zip(seen_bbox, scaled_bbox)] | |
| bbox_tmp = [[max(0, i[0]), min(i[1], s)] for i, s in zip(bbox_tmp, scaled_patch_size)] | |
| slicer = bounding_box_to_slice(seen_bbox) | |
| slicer2 = bounding_box_to_slice(bbox_tmp) | |
| diff_map = pred[slicer2] != previous_prediction[slicer] | |
| # dont allocate new memory, just reuse previous_prediction. We don't need it anymore | |
| previous_prediction.zero_() | |
| diff_map = paste_tensor(previous_prediction, diff_map, seen_bbox) | |
| # open the difference map to keep computational load in check (fewer refinement boxes) | |
| # open distance map | |
| diff_map[slicer] = iterative_3x3_same_padding_pool3d(diff_map[slicer][None, None], kernel_size=5, use_min_pool=True)[0, 0] | |
| diff_map[slicer] = iterative_3x3_same_padding_pool3d(diff_map[slicer][None, None], kernel_size=5, use_min_pool=False)[0, 0] | |
| has_diff = torch.any(diff_map[slicer]) | |
| del previous_prediction | |
| else: | |
| has_diff = False | |
| if zoom_out_factor == 1 or (not continue_zoom and has_diff): # rare case where no changes are needed because of useless interaction. Need to check for not continue_zoom because otherwise diff_map wint exist | |
| pred = pred.cpu() | |
| if zoom_out_factor == 1: | |
| paste_tensor(self.interactions[0], pred.half(), scaled_bbox) | |
| else: | |
| seen_bbox = [[max(0, i[0]), min(i[1], s)] for i, s in | |
| zip(scaled_bbox, diff_map.shape)] | |
| bbox_tmp = [[i[0] - s[0], i[1] - s[0]] for i, s in zip(seen_bbox, scaled_bbox)] | |
| bbox_tmp = [[max(0, i[0]), min(i[1], s)] for i, s in zip(bbox_tmp, scaled_patch_size)] | |
| slicer = bounding_box_to_slice(seen_bbox) | |
| slicer2 = bounding_box_to_slice(bbox_tmp) | |
| mask = (diff_map[slicer] > 0).cpu() | |
| self.interactions[0][slicer][mask] = pred[slicer2][mask].half() | |
| # place into target buffer | |
| bbox = [[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in | |
| zip(scaled_bbox, self.preprocessed_props['bbox_used_for_cropping'])] | |
| paste_tensor(self.target_buffer, pred, bbox) | |
| del pred | |
| empty_cache(self.device) | |
| if continue_zoom: | |
| zoom_out_factor *= 1.5 | |
| zoom_out_factor = min(4, zoom_out_factor) | |
| else: | |
| zoom_out_factor = None | |
| end = time() | |
| print(f'Auto zoom stage took {round(end - start_autozoom, ndigits=3)}s. Max zoom out factor was {max_zoom_out_factor}') | |
| if max_zoom_out_factor > 1 and has_diff: | |
| start_refinement = time() | |
| # only use the region that was previously looked at. Use last scaled_bbox | |
| if self.has_positive_bbox: | |
| # mask positive bbox channel with current segmentation to avoid bbox nonsense. | |
| # Basically convert bbox to pseudo lasso | |
| pos_bbox_idx = -6 | |
| self.interactions[pos_bbox_idx][(~(self.interactions[0] > 0.5)).cpu()] = 0 | |
| self.has_positive_bbox = False | |
| bboxes_ordered = generate_bounding_boxes(diff_map, self.configuration_manager.patch_size, stride='auto', margin=(10, 10, 10), max_depth=3) | |
| del diff_map | |
| empty_cache(self.device) | |
| if self.verbose: | |
| print(f'Using {len(bboxes_ordered)} bounding boxes for refinement') | |
| preallocated_input = torch.zeros((8, *self.configuration_manager.patch_size), device=self.device, dtype=torch.float) | |
| for nref, refinement_bbox in enumerate(bboxes_ordered): | |
| assert self.pad_mode_data == 'constant' | |
| crop_and_pad_into_buffer(preallocated_input[0], refinement_bbox, self.preprocessed_image[0]) | |
| crop_and_pad_into_buffer(preallocated_input[1:], refinement_bbox, self.interactions) | |
| pred = self.network(preallocated_input[None])[0].argmax(0).detach().cpu() | |
| paste_tensor(self.interactions[0], pred, refinement_bbox) | |
| # place into target buffer | |
| bbox = [[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in zip(refinement_bbox, self.preprocessed_props['bbox_used_for_cropping'])] | |
| paste_tensor(self.target_buffer, pred, bbox) | |
| del pred | |
| preallocated_input.zero_() | |
| del preallocated_input | |
| empty_cache(self.device) | |
| end_refinement = time() | |
| print(f'Took {round(end_refinement - start_refinement, 3)} s for refining the segmentation with {len(bboxes_ordered)} bounding boxes') | |
| else: | |
| print('No refinement necessary') | |
| print(f'Done. Total time {round(time() - start_predict, 3)}s') | |
| self.new_interaction_centers = [] | |
| self.new_interaction_zoom_out_factors = [] | |
| empty_cache(self.device) | |
| def _add_patch_for_point_interaction(self, coordinates): | |
| self.new_interaction_zoom_out_factors.append(1) | |
| self.new_interaction_centers.append(coordinates) | |
| print(f'Added new point interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}') | |
| def _add_patch_for_bbox_interaction(self, bbox): | |
| bbox_center = [round((i[0] + i[1]) / 2) for i in bbox] | |
| bbox_size = [i[1]-i[0] for i in bbox] | |
| # we want to see some context, so the crop we see for the initial prediction should be patch_size / 3 larger | |
| requested_size = [i + j // 3 for i, j in zip(bbox_size, self.configuration_manager.patch_size)] | |
| self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))) | |
| self.new_interaction_centers.append(bbox_center) | |
| print(f'Added new bbox interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}') | |
| def _add_patch_for_scribble_interaction(self, scribble_image): | |
| return self._generic_add_patch_from_image(scribble_image) | |
| def _add_patch_for_lasso_interaction(self, lasso_image): | |
| return self._generic_add_patch_from_image(lasso_image) | |
| def _add_patch_for_initial_seg_interaction(self, initial_seg): | |
| return self._generic_add_patch_from_image(initial_seg) | |
| def _generic_add_patch_from_image(self, image: torch.Tensor): | |
| if not torch.any(image): | |
| print('Received empty image prompt. Cannot add patches for prediction') | |
| return | |
| nonzero_indices = torch.nonzero(image, as_tuple=False) | |
| mn = torch.min(nonzero_indices, dim=0)[0] | |
| mx = torch.max(nonzero_indices, dim=0)[0] | |
| roi = [[i.item(), x.item() + 1] for i, x in zip(mn, mx)] | |
| roi_center = [round((i[0] + i[1]) / 2) for i in roi] | |
| roi_size = [i[1]- i[0] for i in roi] | |
| requested_size = [i + j // 3 for i, j in zip(roi_size, self.configuration_manager.patch_size)] | |
| self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))) | |
| self.new_interaction_centers.append(roi_center) | |
| print(f'Added new image interaction: scale {self.new_interaction_zoom_out_factors[-1]}, center {self.new_interaction_centers}') | |
| def initialize_from_trained_model_folder(self, model_training_output_dir: str, | |
| use_fold: Union[int, str] = None, | |
| checkpoint_name: str = 'checkpoint_final.pth'): | |
| """ | |
| This is used when making predictions with a trained model | |
| """ | |
| # load trainer specific settings | |
| expected_json_file = join(model_training_output_dir, 'inference_session_class.json') | |
| json_content = load_json(expected_json_file) | |
| if isinstance(json_content, str): | |
| # old convention where we only specified the inference class in this file. Set defaults for stuff | |
| point_interaction_radius = 4 | |
| point_interaction_use_etd = True | |
| self.preferred_scribble_thickness = [2, 2, 2] | |
| self.point_interaction = PointInteraction_stub( | |
| point_interaction_radius, | |
| point_interaction_use_etd) | |
| self.pad_mode_data = "constant" | |
| self.interaction_decay = 0.9 | |
| else: | |
| point_interaction_radius = json_content['point_radius'] | |
| self.preferred_scribble_thickness = json_content['preferred_scribble_thickness'] | |
| if not isinstance(self.preferred_scribble_thickness, (tuple, list)): | |
| self.preferred_scribble_thickness = [self.preferred_scribble_thickness] * 3 | |
| self.interaction_decay = json_content['interaction_decay'] if 'interaction_decay' in json_content.keys() else 0.9 | |
| point_interaction_use_etd = True # so far this is not defined in that file so we stick with default | |
| self.point_interaction = PointInteraction_stub(point_interaction_radius, point_interaction_use_etd) | |
| # padding mode for data. See nnInteractiveTrainerV2_nodelete_reflectpad | |
| self.pad_mode_data = json_content['pad_mode_image'] if 'pad_mode_image' in json_content.keys() else "constant" | |
| dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) | |
| plans = load_json(join(model_training_output_dir, 'plans.json')) | |
| plans_manager = PlansManager(plans) | |
| if use_fold is not None: | |
| use_fold = int(use_fold) if use_fold != 'all' else use_fold | |
| fold_folder = f'fold_{use_fold}' | |
| else: | |
| fldrs = subdirs(model_training_output_dir, prefix='fold_', join=False) | |
| assert len(fldrs) == 1, f'Attempted to infer fold but there is != 1 fold_ folders: {fldrs}' | |
| fold_folder = fldrs[0] | |
| checkpoint = torch.load(join(model_training_output_dir, fold_folder, checkpoint_name), | |
| map_location=self.device, weights_only=False) | |
| trainer_name = checkpoint['trainer_name'] | |
| configuration_name = checkpoint['init_args']['configuration'] | |
| parameters = checkpoint['network_weights'] | |
| configuration_manager = plans_manager.get_configuration(configuration_name) | |
| # restore network | |
| num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) | |
| trainer_class = recursive_find_python_class(join(nnInteractive.__path__[0], "trainer"), | |
| trainer_name, 'nnInteractive.trainer') | |
| if trainer_class is None: | |
| print(f'Unable to locate trainer class {trainer_name} in nnInteractive.trainer. ' | |
| f'Please place it there (in any .py file)!') | |
| print('Attempting to use default nnInteractiveTrainer_stub. If you encounter errors, this is where you need to look!') | |
| trainer_class = nnInteractiveTrainer_stub | |
| network = trainer_class.build_network_architecture( | |
| configuration_manager.network_arch_class_name, | |
| configuration_manager.network_arch_init_kwargs, | |
| configuration_manager.network_arch_init_kwargs_req_import, | |
| num_input_channels, | |
| plans_manager.get_label_manager(dataset_json).num_segmentation_heads, | |
| enable_deep_supervision=False | |
| ).to(self.device) | |
| network.load_state_dict(parameters) | |
| self.plans_manager = plans_manager | |
| self.configuration_manager = configuration_manager | |
| self.network = network | |
| self.dataset_json = dataset_json | |
| self.trainer_name = trainer_name | |
| self.label_manager = plans_manager.get_label_manager(dataset_json) | |
| if self.use_torch_compile and not isinstance(self.network, OptimizedModule): | |
| print('Using torch.compile') | |
| self.network = torch.compile(self.network) | |
| def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, | |
| configuration_manager: ConfigurationManager, | |
| dataset_json: dict, trainer_name: str): | |
| """ | |
| This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation | |
| """ | |
| self.plans_manager = plans_manager | |
| self.configuration_manager = configuration_manager | |
| self.network = network | |
| self.dataset_json = dataset_json | |
| self.trainer_name = trainer_name | |
| self.label_manager = plans_manager.get_label_manager(dataset_json) | |
| if self.use_torch_compile and not isinstance(self.network, OptimizedModule): | |
| print('Using torch.compile') | |
| self.network = torch.compile(self.network) | |
| if not self.use_torch_compile and isinstance(self.network, OptimizedModule): | |
| self.network = self.network._orig_mod | |
| self.network = self.network.to(self.device) | |
| def transform_coordinates_noresampling( | |
| coords_orig: Union[List[int], Tuple[int, ...]], | |
| nnunet_preprocessing_crop_bbox: List[Tuple[int, int]] | |
| ) -> Tuple[int, ...]: | |
| """ | |
| converts coordinates in the original uncropped image to the internal cropped representation. Man I really hate | |
| nnU-Net's crop to nonzero! | |
| """ | |
| return tuple([coords_orig[d] - nnunet_preprocessing_crop_bbox[d][0] for d in range(len(coords_orig))]) | |