hanjang's picture
Upload folder using huggingface_hub
24e5510 verified
from concurrent.futures import ThreadPoolExecutor
from os import cpu_count
from time import time
from typing import Union, List, Tuple, Optional
import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, crop_and_pad_nd
from batchgenerators.utilities.file_and_folder_operations import load_json, join, subdirs
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import dummy_context, empty_cache
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from torch import nn
from torch._dynamo import OptimizedModule
from torch.nn.functional import interpolate
import nnInteractive
from nnInteractive.interaction.point import PointInteraction_stub
from nnInteractive.trainer.nnInteractiveTrainer import nnInteractiveTrainer_stub
from nnInteractive.utils.bboxes import generate_bounding_boxes
from nnInteractive.utils.crop import crop_and_pad_into_buffer, paste_tensor, pad_cropped, crop_to_valid
from nnInteractive.utils.erosion_dilation import iterative_3x3_same_padding_pool3d
from nnInteractive.utils.rounding import round_to_nearest_odd
class nnInteractiveInferenceSession():
def __init__(self,
device: torch.device = torch.device('cuda'),
use_torch_compile: bool = False,
verbose: bool = False,
torch_n_threads: int = 8,
do_autozoom: bool = True,
use_pinned_memory: bool = True,
):
"""
Only intended to work with nnInteractiveTrainerV2 and its derivatives
"""
# set as part of initialization
assert use_torch_compile is False, ('This implementation places the preprocessed image and the interactions '
'into pinned memory for speed reasons. This is incompatible with '
'torch.compile because of inconsistent strides in the memory layout. '
'Note to self: .contiguous() on GPU could be a solution. Unclear whether '
'that will yield a benefit though.')
self.network = None
self.label_manager = None
self.dataset_json = None
self.trainer_name = None
self.configuration_manager = None
self.plans_manager = None
self.use_pinned_memory = use_pinned_memory
self.device = device
self.use_torch_compile = use_torch_compile
self.interaction_decay = None
# image specific
self.interactions: torch.Tensor = None
self.preprocessed_image: torch.Tensor = None
self.preprocessed_props = None
self.target_buffer: Union[np.ndarray, torch.Tensor] = None
# this will be set when loading the model (initialize_from_trained_model_folder)
self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
self.verbose = verbose
self.do_autozoom: bool = do_autozoom
torch.set_num_threads(min(torch_n_threads, cpu_count()))
self.original_image_shape = None
self.new_interaction_zoom_out_factors: List[float] = []
self.new_interaction_centers = []
self.has_positive_bbox = False
# Create a thread pool executor for background tasks.
# this only takes care of preprocessing and interaction memory initialization so there is no need to give it
# more than 2 workers
self.executor = ThreadPoolExecutor(max_workers=2)
self.preprocess_future = None
self.interactions_future = None
def set_image(self, image: np.ndarray, image_properties: dict = None):
"""
Image must be 4D to satisfy nnU-Net needs: [c, x, y, z]
Offload the processing to a background thread.
"""
if image_properties is None:
image_properties = {}
self._reset_session()
assert image.ndim == 4, f'expected a 4d image as input, got {image.ndim}d. Shape {image.shape}'
if self.verbose:
print(f'Initialize with raw image shape {image.shape}')
# Offload all image preprocessing to a background thread.
self.preprocess_future = self.executor.submit(self._background_set_image, image, image_properties)
self.original_image_shape = image.shape
def _finish_preprocessing_and_initialize_interactions(self):
"""
Block until both the image preprocessing and the interactions tensor initialization
are finished.
"""
if self.preprocess_future is not None:
# Wait for image preprocessing to complete.
self.preprocess_future.result()
del self.preprocess_future
self.preprocess_future = None
def set_target_buffer(self, target_buffer: Union[np.ndarray, torch.Tensor]):
"""
Must be 3d numpy array or torch.Tensor
"""
self.target_buffer = target_buffer
def set_do_autozoom(self, do_propagation: bool, max_num_patches: Optional[int] = None):
self.do_autozoom = do_propagation
def _reset_session(self):
self.interactions_future = None
self.preprocess_future = None
del self.preprocessed_image
del self.target_buffer
del self.interactions
del self.preprocessed_props
self.preprocessed_image = None
self.target_buffer = None
self.interactions = None
self.preprocessed_props = None
empty_cache(self.device)
self.original_image_shape = None
self.has_positive_bbox = False
def _initialize_interactions(self, image_torch: torch.Tensor):
if self.verbose:
print(f'Initialize interactions. Pinned: {self.use_pinned_memory}')
# Create the interaction tensor based on the target shape.
self.interactions = torch.zeros(
(7, *image_torch.shape[1:]),
device='cpu',
dtype=torch.float16,
pin_memory=(self.device.type == 'cuda' and self.use_pinned_memory)
)
def _background_set_image(self, image: np.ndarray, image_properties: dict):
# Convert and clone the image tensor.
image_torch = torch.clone(torch.from_numpy(image))
# Crop to nonzero region.
if self.verbose:
print('Cropping input image to nonzero region')
nonzero_idx = torch.where(image_torch != 0)
# Create bounding box: for each dimension, get the min and max (plus one) of the nonzero indices.
bbox = [[i.min().item(), i.max().item() + 1] for i in nonzero_idx]
del nonzero_idx
slicer = bounding_box_to_slice(bbox) # Assuming this returns a tuple of slices.
image_torch = image_torch[slicer].float()
if self.verbose:
print(f'Cropped image shape: {image_torch.shape}')
# As soon as we have the target shape, start initializing the interaction tensor in its own thread.
self.interactions_future = self.executor.submit(self._initialize_interactions, image_torch)
# Normalize the cropped image.
if self.verbose:
print('Normalizing cropped image')
image_torch -= image_torch.mean()
image_torch /= image_torch.std()
self.preprocessed_image = image_torch
if self.use_pinned_memory and self.device.type == 'cuda':
if self.verbose:
print('Pin memory: image')
# Note: pin_memory() in PyTorch typically returns a new tensor.
self.preprocessed_image = self.preprocessed_image.pin_memory()
self.preprocessed_props = {'bbox_used_for_cropping': bbox[1:]}
# we need to wait for this here I believe
self.interactions_future.result()
del self.interactions_future
self.interactions_future = None
def reset_interactions(self):
"""
Use this to reset all interactions and start from scratch for the current image. This includes the initial
segmentation!
"""
if self.interactions is not None:
self.interactions.fill_(0)
if self.target_buffer is not None:
if isinstance(self.target_buffer, np.ndarray):
self.target_buffer.fill(0)
elif isinstance(self.target_buffer, torch.Tensor):
self.target_buffer.zero_()
empty_cache(self.device)
self.has_positive_bbox = False
def add_bbox_interaction(self, bbox_coords, include_interaction: bool, run_prediction: bool = True) -> np.ndarray:
if include_interaction:
self.has_positive_bbox = True
self._finish_preprocessing_and_initialize_interactions()
lbs_transformed = [round(i) for i in transform_coordinates_noresampling([i[0] for i in bbox_coords],
self.preprocessed_props['bbox_used_for_cropping'])]
ubs_transformed = [round(i) for i in transform_coordinates_noresampling([i[1] for i in bbox_coords],
self.preprocessed_props['bbox_used_for_cropping'])]
transformed_bbox_coordinates = [[i, j] for i, j in zip(lbs_transformed, ubs_transformed)]
if self.verbose:
print(f'Added bounding box coordinates.\n'
f'Raw: {bbox_coords}\n'
f'Transformed: {transformed_bbox_coordinates}\n'
f"Crop Bbox: {self.preprocessed_props['bbox_used_for_cropping']}")
# Prevent collapsed bounding boxes and clip to image shape
image_shape = self.preprocessed_image.shape # Assuming shape is (C, H, W, D) or similar
for dim in range(len(transformed_bbox_coordinates)):
transformed_start, transformed_end = transformed_bbox_coordinates[dim]
# Clip to image boundaries
transformed_start = max(0, transformed_start)
transformed_end = min(image_shape[dim + 1], transformed_end) # +1 to skip channel dim
# Ensure the bounding box does not collapse to a single point
if transformed_end <= transformed_start:
if transformed_start == 0:
transformed_end = min(1, image_shape[dim + 1])
else:
transformed_start = max(transformed_start - 1, 0)
transformed_bbox_coordinates[dim] = [transformed_start, transformed_end]
if self.verbose:
print(f'Bbox coordinates after clip to image boundaries and preventing dim collapse:\n'
f'Bbox: {transformed_bbox_coordinates}\n'
f'Internal image shape: {self.preprocessed_image.shape}')
self._add_patch_for_bbox_interaction(transformed_bbox_coordinates)
# decay old interactions
self.interactions[-6:-4] *= self.interaction_decay
# place bbox
slicer = tuple([slice(*i) for i in transformed_bbox_coordinates])
channel = -6 if include_interaction else -5
self.interactions[(channel, *slicer)] = 1
# forward pass
if run_prediction:
self._predict()
def add_point_interaction(self, coordinates: Tuple[int, ...], include_interaction: bool, run_prediction: bool = True):
self._finish_preprocessing_and_initialize_interactions()
transformed_coordinates = [round(i) for i in transform_coordinates_noresampling(coordinates,
self.preprocessed_props['bbox_used_for_cropping'])]
self._add_patch_for_point_interaction(transformed_coordinates)
# decay old interactions
self.interactions[-4:-2] *= self.interaction_decay
interaction_channel = -4 if include_interaction else -3
self.interactions[interaction_channel] = self.point_interaction.place_point(
transformed_coordinates, self.interactions[interaction_channel])
if run_prediction:
self._predict()
def add_scribble_interaction(self, scribble_image: np.ndarray, include_interaction: bool, run_prediction: bool = True):
assert all([i == j for i, j in zip(self.original_image_shape[1:], scribble_image.shape)]), f'Given scribble image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {scribble_image.shape}'
self._finish_preprocessing_and_initialize_interactions()
scribble_image = torch.from_numpy(scribble_image)
# crop (as in preprocessing)
scribble_image = crop_and_pad_nd(scribble_image, self.preprocessed_props['bbox_used_for_cropping'])
self._add_patch_for_scribble_interaction(scribble_image)
# decay old interactions
self.interactions[-2:] *= self.interaction_decay
interaction_channel = -2 if include_interaction else -1
torch.maximum(self.interactions[interaction_channel], scribble_image.to(self.interactions.device),
out=self.interactions[interaction_channel])
del scribble_image
empty_cache(self.device)
if run_prediction:
self._predict()
def add_lasso_interaction(self, lasso_image: np.ndarray, include_interaction: bool, run_prediction: bool = True):
assert all([i == j for i, j in zip(self.original_image_shape[1:], lasso_image.shape)]), f'Given lasso image must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {lasso_image.shape}'
self._finish_preprocessing_and_initialize_interactions()
lasso_image = torch.from_numpy(lasso_image)
# crop (as in preprocessing)
lasso_image = crop_and_pad_nd(lasso_image, self.preprocessed_props['bbox_used_for_cropping'])
self._add_patch_for_lasso_interaction(lasso_image)
# decay old interactions
self.interactions[-6:-4] *= self.interaction_decay
# lasso is written into bbox channel
interaction_channel = -6 if include_interaction else -5
torch.maximum(self.interactions[interaction_channel], lasso_image.to(self.interactions.device),
out=self.interactions[interaction_channel])
del lasso_image
empty_cache(self.device)
if run_prediction:
self._predict()
def add_initial_seg_interaction(self, initial_seg: np.ndarray, run_prediction: bool = False):
"""
WARNING THIS WILL RESET INTERACTIONS!
"""
assert all([i == j for i, j in zip(self.original_image_shape[1:], initial_seg.shape)]), f'Given initial seg must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {initial_seg.shape}'
self._finish_preprocessing_and_initialize_interactions()
self.reset_interactions()
if isinstance(self.target_buffer, np.ndarray):
self.target_buffer[:] = initial_seg
initial_seg = torch.from_numpy(initial_seg)
if isinstance(self.target_buffer, torch.Tensor):
self.target_buffer[:] = initial_seg
# crop (as in preprocessing)
initial_seg = crop_and_pad_nd(initial_seg, self.preprocessed_props['bbox_used_for_cropping'])
# initial seg is written into initial seg buffer
interaction_channel = -7
self.interactions[interaction_channel] = initial_seg
empty_cache(self.device)
if run_prediction:
self._add_patch_for_initial_seg_interaction(initial_seg)
del initial_seg
self._predict()
else:
del initial_seg
@torch.inference_mode()
def _predict(self):
"""
This function is a smoking mess to read. This is deliberate. Initially it was super pretty and easy to
understand. Then the run time optimization began.
If it feels like we are excessively transferring tensors between CPU and GPU, this is deliberate as well.
Our goal is to keep this tool usable even for people with smaller GPUs (8-10GB VRAM). In an ideal world
everyone would have 24GB+ of VRAM and all tensors would like on GPU all the time.
The amount of hours spent optimizing this function is substantial. Almost every line was turned and twisted
multiple times. If something appears odd, it is probably so for a reason. Don't change things all willy nilly
without first understanding what is going on. And don't make changes without verifying that the run time or
VRAM consumption is not adversely affected.
Returns:
"""
assert self.pad_mode_data == 'constant', 'pad modes other than constant are not implemented here'
assert len(self.new_interaction_centers) == len(self.new_interaction_zoom_out_factors)
if len(self.new_interaction_centers) > 1:
print('It seems like more than one interaction was added since the last prediction. This is not '
'recommended and may cause unexpected behavior or inefficient predictions')
start_predict = time()
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
for prediction_center, initial_zoom_out_factor in zip(self.new_interaction_centers, self.new_interaction_zoom_out_factors):
# make a prediction at initial zoom out factor. If more zoom out is required, do this until the
# entire object fits the FOV. Then go back to original resolution and refine.
# we need this later.
previous_prediction = torch.clone(self.interactions[0])
if not self.do_autozoom:
initial_zoom_out_factor = 1
initial_zoom_out_factor = min(initial_zoom_out_factor, 4)
zoom_out_factor = initial_zoom_out_factor
max_zoom_out_factor = initial_zoom_out_factor
start_autozoom = time()
while zoom_out_factor is not None and zoom_out_factor <= 4:
print('Performing prediction at zoom out factor', zoom_out_factor)
max_zoom_out_factor = max(max_zoom_out_factor, zoom_out_factor)
# initial prediction at initial_zoom_out_factor
scaled_patch_size = [round(i * zoom_out_factor) for i in self.configuration_manager.patch_size]
scaled_bbox = [[c - p // 2, c + p // 2 + p % 2] for c, p in zip(prediction_center, scaled_patch_size)]
crop_img, pad = crop_to_valid(self.preprocessed_image, scaled_bbox)
crop_img = crop_img.to(self.device, non_blocking=self.device.type == 'cuda')
crop_interactions, pad_interaction = crop_to_valid(self.interactions, scaled_bbox)
# resize input_for_predict (which may be larger than patch size) to patch size
# this implementation may not seem straightforward but it does save VRAM which is crucial here
if not all([i == j for i, j in zip(self.configuration_manager.patch_size, scaled_patch_size)]):
crop_interactions_resampled_gpu = torch.empty((7, *self.configuration_manager.patch_size), dtype=torch.float16, device=self.device)
# previous seg, bbox+, bbox-
for i in range(0, 3):
# this is area for a reason but I aint telling ya why
if any([x for y in pad_interaction for x in y]):
tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction)
else:
tmp = crop_interactions[i].to(self.device)
crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0]
empty_cache(self.device)
max_pool_ks = round_to_nearest_odd(zoom_out_factor * 2 - 1)
# point+, point-, scribble+, scribble-
for i in range(3, 7):
if any([x for y in pad_interaction for x in y]):
tmp = pad_cropped(crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction)
else:
tmp = crop_interactions[i].to(self.device, non_blocking=self.device.type == 'cuda')
if max_pool_ks > 1:
# dilate to preserve interactions after downsampling
tmp = iterative_3x3_same_padding_pool3d(tmp[None, None], max_pool_ks)[0, 0]
# this is 'area' for a reason but I aint telling ya why
crop_interactions_resampled_gpu[i] = interpolate(tmp[None, None], self.configuration_manager.patch_size, mode='area')[0][0]
del tmp
crop_img = interpolate(pad_cropped(crop_img, pad)[None] if any([x for y in pad_interaction for x in y]) else crop_img[None], self.configuration_manager.patch_size, mode='trilinear')[0]
crop_interactions = crop_interactions_resampled_gpu
del crop_interactions_resampled_gpu
empty_cache(self.device)
else:
# crop_img is already on device
crop_img = pad_cropped(crop_img, pad) if any([x for y in pad_interaction for x in y]) else crop_img
crop_interactions = pad_cropped(crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda'), pad_interaction) if any([x for y in pad_interaction for x in y]) else crop_interactions.to(self.device, non_blocking=self.device.type == 'cuda')
input_for_predict = torch.cat((crop_img, crop_interactions))
del crop_img, crop_interactions
pred = self.network(input_for_predict[None])[0].argmax(0).detach()
del input_for_predict
# detect changes at borders
previous_zoom_prediction = crop_and_pad_nd(self.interactions[0], scaled_bbox).to(self.device, non_blocking=self.device.type == 'cuda')
if not all([i == j for i, j in zip(pred.shape, previous_zoom_prediction.shape)]):
previous_zoom_prediction = interpolate(previous_zoom_prediction[None, None].to(float), pred.shape, mode='nearest')[0, 0]
abs_pxl_change_threshold = 1500
rel_pxl_change_threshold = 0.2
min_pxl_change_threshold = 100
continue_zoom = False
if zoom_out_factor < 4 and self.do_autozoom:
for dim in range(len(scaled_bbox)):
if continue_zoom:
break
for idx in [0, pred.shape[dim] - 1]:
slice_prev = previous_zoom_prediction.index_select(dim, torch.tensor(idx, device=self.device))
slice_curr = pred.index_select(dim, torch.tensor(idx, device=self.device))
pixels_prev = torch.sum(slice_prev)
pixels_current = torch.sum(slice_curr)
pixels_diff = torch.sum(slice_prev != slice_curr)
rel_change = max(pixels_prev, pixels_current) / max(min(pixels_prev, pixels_current),
1e-5) - 1
if pixels_diff > abs_pxl_change_threshold:
continue_zoom = True
if self.verbose:
print(f'continue zooming because change at borders of {pixels_diff} > {abs_pxl_change_threshold}')
break
if pixels_diff > min_pxl_change_threshold and rel_change > rel_pxl_change_threshold:
continue_zoom = True
if self.verbose:
print(f'continue zooming because relative change of {rel_change} > {rel_pxl_change_threshold} and n_pixels {pixels_diff} > {min_pxl_change_threshold}')
break
del slice_prev, slice_curr, pixels_prev, pixels_current, pixels_diff
del previous_zoom_prediction
# resize prediction to correct size and place in target buffer + interactions
if not all([i == j for i, j in zip(pred.shape, scaled_patch_size)]):
pred = (interpolate(pred[None, None].to(float), scaled_patch_size, mode='trilinear')[0, 0] >= 0.5).to(torch.uint8)
# if we do not continue zooming we need a difference map for sampling patches
if not continue_zoom and zoom_out_factor > 1:
# wow this circus saves ~30ms relative to naive implementation
previous_prediction = previous_prediction.to(self.device, non_blocking=self.device.type == 'cuda')
seen_bbox = [[max(0, i[0]), min(i[1], s)] for i, s in zip(scaled_bbox, previous_prediction.shape)]
bbox_tmp = [[i[0] - s[0], i[1] - s[0]] for i, s in zip(seen_bbox, scaled_bbox)]
bbox_tmp = [[max(0, i[0]), min(i[1], s)] for i, s in zip(bbox_tmp, scaled_patch_size)]
slicer = bounding_box_to_slice(seen_bbox)
slicer2 = bounding_box_to_slice(bbox_tmp)
diff_map = pred[slicer2] != previous_prediction[slicer]
# dont allocate new memory, just reuse previous_prediction. We don't need it anymore
previous_prediction.zero_()
diff_map = paste_tensor(previous_prediction, diff_map, seen_bbox)
# open the difference map to keep computational load in check (fewer refinement boxes)
# open distance map
diff_map[slicer] = iterative_3x3_same_padding_pool3d(diff_map[slicer][None, None], kernel_size=5, use_min_pool=True)[0, 0]
diff_map[slicer] = iterative_3x3_same_padding_pool3d(diff_map[slicer][None, None], kernel_size=5, use_min_pool=False)[0, 0]
has_diff = torch.any(diff_map[slicer])
del previous_prediction
else:
has_diff = False
if zoom_out_factor == 1 or (not continue_zoom and has_diff): # rare case where no changes are needed because of useless interaction. Need to check for not continue_zoom because otherwise diff_map wint exist
pred = pred.cpu()
if zoom_out_factor == 1:
paste_tensor(self.interactions[0], pred.half(), scaled_bbox)
else:
seen_bbox = [[max(0, i[0]), min(i[1], s)] for i, s in
zip(scaled_bbox, diff_map.shape)]
bbox_tmp = [[i[0] - s[0], i[1] - s[0]] for i, s in zip(seen_bbox, scaled_bbox)]
bbox_tmp = [[max(0, i[0]), min(i[1], s)] for i, s in zip(bbox_tmp, scaled_patch_size)]
slicer = bounding_box_to_slice(seen_bbox)
slicer2 = bounding_box_to_slice(bbox_tmp)
mask = (diff_map[slicer] > 0).cpu()
self.interactions[0][slicer][mask] = pred[slicer2][mask].half()
# place into target buffer
bbox = [[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in
zip(scaled_bbox, self.preprocessed_props['bbox_used_for_cropping'])]
paste_tensor(self.target_buffer, pred, bbox)
del pred
empty_cache(self.device)
if continue_zoom:
zoom_out_factor *= 1.5
zoom_out_factor = min(4, zoom_out_factor)
else:
zoom_out_factor = None
end = time()
print(f'Auto zoom stage took {round(end - start_autozoom, ndigits=3)}s. Max zoom out factor was {max_zoom_out_factor}')
if max_zoom_out_factor > 1 and has_diff:
start_refinement = time()
# only use the region that was previously looked at. Use last scaled_bbox
if self.has_positive_bbox:
# mask positive bbox channel with current segmentation to avoid bbox nonsense.
# Basically convert bbox to pseudo lasso
pos_bbox_idx = -6
self.interactions[pos_bbox_idx][(~(self.interactions[0] > 0.5)).cpu()] = 0
self.has_positive_bbox = False
bboxes_ordered = generate_bounding_boxes(diff_map, self.configuration_manager.patch_size, stride='auto', margin=(10, 10, 10), max_depth=3)
del diff_map
empty_cache(self.device)
if self.verbose:
print(f'Using {len(bboxes_ordered)} bounding boxes for refinement')
preallocated_input = torch.zeros((8, *self.configuration_manager.patch_size), device=self.device, dtype=torch.float)
for nref, refinement_bbox in enumerate(bboxes_ordered):
assert self.pad_mode_data == 'constant'
crop_and_pad_into_buffer(preallocated_input[0], refinement_bbox, self.preprocessed_image[0])
crop_and_pad_into_buffer(preallocated_input[1:], refinement_bbox, self.interactions)
pred = self.network(preallocated_input[None])[0].argmax(0).detach().cpu()
paste_tensor(self.interactions[0], pred, refinement_bbox)
# place into target buffer
bbox = [[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in zip(refinement_bbox, self.preprocessed_props['bbox_used_for_cropping'])]
paste_tensor(self.target_buffer, pred, bbox)
del pred
preallocated_input.zero_()
del preallocated_input
empty_cache(self.device)
end_refinement = time()
print(f'Took {round(end_refinement - start_refinement, 3)} s for refining the segmentation with {len(bboxes_ordered)} bounding boxes')
else:
print('No refinement necessary')
print(f'Done. Total time {round(time() - start_predict, 3)}s')
self.new_interaction_centers = []
self.new_interaction_zoom_out_factors = []
empty_cache(self.device)
def _add_patch_for_point_interaction(self, coordinates):
self.new_interaction_zoom_out_factors.append(1)
self.new_interaction_centers.append(coordinates)
print(f'Added new point interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}')
def _add_patch_for_bbox_interaction(self, bbox):
bbox_center = [round((i[0] + i[1]) / 2) for i in bbox]
bbox_size = [i[1]-i[0] for i in bbox]
# we want to see some context, so the crop we see for the initial prediction should be patch_size / 3 larger
requested_size = [i + j // 3 for i, j in zip(bbox_size, self.configuration_manager.patch_size)]
self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)])))
self.new_interaction_centers.append(bbox_center)
print(f'Added new bbox interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}')
def _add_patch_for_scribble_interaction(self, scribble_image):
return self._generic_add_patch_from_image(scribble_image)
def _add_patch_for_lasso_interaction(self, lasso_image):
return self._generic_add_patch_from_image(lasso_image)
def _add_patch_for_initial_seg_interaction(self, initial_seg):
return self._generic_add_patch_from_image(initial_seg)
def _generic_add_patch_from_image(self, image: torch.Tensor):
if not torch.any(image):
print('Received empty image prompt. Cannot add patches for prediction')
return
nonzero_indices = torch.nonzero(image, as_tuple=False)
mn = torch.min(nonzero_indices, dim=0)[0]
mx = torch.max(nonzero_indices, dim=0)[0]
roi = [[i.item(), x.item() + 1] for i, x in zip(mn, mx)]
roi_center = [round((i[0] + i[1]) / 2) for i in roi]
roi_size = [i[1]- i[0] for i in roi]
requested_size = [i + j // 3 for i, j in zip(roi_size, self.configuration_manager.patch_size)]
self.new_interaction_zoom_out_factors.append(max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)])))
self.new_interaction_centers.append(roi_center)
print(f'Added new image interaction: scale {self.new_interaction_zoom_out_factors[-1]}, center {self.new_interaction_centers}')
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_fold: Union[int, str] = None,
checkpoint_name: str = 'checkpoint_final.pth'):
"""
This is used when making predictions with a trained model
"""
# load trainer specific settings
expected_json_file = join(model_training_output_dir, 'inference_session_class.json')
json_content = load_json(expected_json_file)
if isinstance(json_content, str):
# old convention where we only specified the inference class in this file. Set defaults for stuff
point_interaction_radius = 4
point_interaction_use_etd = True
self.preferred_scribble_thickness = [2, 2, 2]
self.point_interaction = PointInteraction_stub(
point_interaction_radius,
point_interaction_use_etd)
self.pad_mode_data = "constant"
self.interaction_decay = 0.9
else:
point_interaction_radius = json_content['point_radius']
self.preferred_scribble_thickness = json_content['preferred_scribble_thickness']
if not isinstance(self.preferred_scribble_thickness, (tuple, list)):
self.preferred_scribble_thickness = [self.preferred_scribble_thickness] * 3
self.interaction_decay = json_content['interaction_decay'] if 'interaction_decay' in json_content.keys() else 0.9
point_interaction_use_etd = True # so far this is not defined in that file so we stick with default
self.point_interaction = PointInteraction_stub(point_interaction_radius, point_interaction_use_etd)
# padding mode for data. See nnInteractiveTrainerV2_nodelete_reflectpad
self.pad_mode_data = json_content['pad_mode_image'] if 'pad_mode_image' in json_content.keys() else "constant"
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
if use_fold is not None:
use_fold = int(use_fold) if use_fold != 'all' else use_fold
fold_folder = f'fold_{use_fold}'
else:
fldrs = subdirs(model_training_output_dir, prefix='fold_', join=False)
assert len(fldrs) == 1, f'Attempted to infer fold but there is != 1 fold_ folders: {fldrs}'
fold_folder = fldrs[0]
checkpoint = torch.load(join(model_training_output_dir, fold_folder, checkpoint_name),
map_location=self.device, weights_only=False)
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
parameters = checkpoint['network_weights']
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnInteractive.__path__[0], "trainer"),
trainer_name, 'nnInteractive.trainer')
if trainer_class is None:
print(f'Unable to locate trainer class {trainer_name} in nnInteractive.trainer. '
f'Please place it there (in any .py file)!')
print('Attempting to use default nnInteractiveTrainer_stub. If you encounter errors, this is where you need to look!')
trainer_class = nnInteractiveTrainer_stub
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
).to(self.device)
network.load_state_dict(parameters)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.label_manager = plans_manager.get_label_manager(dataset_json)
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)
def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
dataset_json: dict, trainer_name: str):
"""
This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
"""
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.label_manager = plans_manager.get_label_manager(dataset_json)
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)
if not self.use_torch_compile and isinstance(self.network, OptimizedModule):
self.network = self.network._orig_mod
self.network = self.network.to(self.device)
def transform_coordinates_noresampling(
coords_orig: Union[List[int], Tuple[int, ...]],
nnunet_preprocessing_crop_bbox: List[Tuple[int, int]]
) -> Tuple[int, ...]:
"""
converts coordinates in the original uncropped image to the internal cropped representation. Man I really hate
nnU-Net's crop to nonzero!
"""
return tuple([coords_orig[d] - nnunet_preprocessing_crop_bbox[d][0] for d in range(len(coords_orig))])