from typing import Union, Tuple from batchgenerators.dataloading.data_loader import DataLoader import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset from nnunetv2.utilities.label_handling.label_handling import LabelManager class nnUNetDataLoaderBase(DataLoader): def __init__(self, data: nnUNetDataset, batch_size: int, patch_size: Union[List[int], Tuple[int, ...], np.ndarray], final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray], label_manager: LabelManager, oversample_foreground_percent: float = 0.0, sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None, pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None, probabilistic_oversampling: bool = False): super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities) assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data' self.indices = list(data.keys()) self.oversample_foreground_percent = oversample_foreground_percent self.final_patch_size = final_patch_size self.patch_size = patch_size self.list_of_keys = list(self._data.keys()) # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size # (which is what the network will get) these patches will also cover the border of the images self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) if pad_sides is not None: if not isinstance(pad_sides, np.ndarray): pad_sides = np.array(pad_sides) self.need_to_pad += pad_sides self.num_channels = None self.pad_sides = pad_sides self.data_shape, self.seg_shape = self.determine_shapes() self.sampling_probabilities = sampling_probabilities self.annotated_classes_key = tuple(label_manager.all_labels) self.has_ignore = label_manager.has_ignore_label self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \ else self._probabilistic_oversampling def _oversample_last_XX_percent(self, sample_idx: int) -> bool: """ determines whether sample sample_idx in a minibatch needs to be guaranteed foreground """ return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) def _probabilistic_oversampling(self, sample_idx: int) -> bool: # print('YEAH BOIIIIII') return np.random.uniform() < self.oversample_foreground_percent def determine_shapes(self): # load one case data, seg, properties = self._data.load_case(self.indices[0]) num_color_channels = data.shape[0] data_shape = (self.batch_size, num_color_channels, *self.patch_size) seg_shape = (self.batch_size, seg.shape[0], *self.patch_size) return data_shape, seg_shape def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None], overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False): # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have # locations for the given slice need_to_pad = self.need_to_pad.copy() dim = len(data_shape) for d in range(dim): # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides # always if need_to_pad[d] + data_shape[d] < self.patch_size[d]: need_to_pad[d] = self.patch_size[d] - data_shape[d] # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we # define what the upper and lower bound can be to then sample form them with np.random.randint lbs = [- need_to_pad[i] // 2 for i in range(dim)] ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)] # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get # at least one of the foreground classes in the patch if not force_fg and not self.has_ignore: bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] # print('I want a random location') else: if not force_fg and self.has_ignore: selected_class = self.annotated_classes_key if len(class_locations[selected_class]) == 0: # no annotated pixels in this case. Not good. But we can hardly skip it here print('Warning! No annotated pixels in image!') selected_class = None # print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}') elif force_fg: assert class_locations is not None, 'if force_fg is set class_locations cannot be None' if overwrite_class is not None: assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \ 'have class_locations (missing key)' # this saves us a np.unique. Preprocessing already did that for all cases. Neat. # class_locations keys can also be tuple eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0] # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list # strange formulation needed to circumvent # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] if any(tmp): if len(eligible_classes_or_regions) > 1: eligible_classes_or_regions.pop(np.where(tmp)[0][0]) if len(eligible_classes_or_regions) == 0: # this only happens if some image does not contain foreground voxels at all selected_class = None if verbose: print('case does not contain any foreground classes') else: # I hate myself. Future me aint gonna be happy to read this # 2022_11_25: had to read it today. Wasn't too bad selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class # print(f'I want to have foreground, selected class: {selected_class}') else: raise RuntimeError('lol what!?') voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None if voxels_of_that_class is not None and len(voxels_of_that_class) > 0: selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. # Make sure it is within the bounds of lb and ub # i + 1 because we have first dimension 0! bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)] else: # If the image does not contain any foreground classes, we fall back to random cropping bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)] return bbox_lbs, bbox_ubs