|
|
from typing import Union, Tuple |
|
|
|
|
|
from batchgenerators.dataloading.data_loader import DataLoader |
|
|
import numpy as np |
|
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
|
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset |
|
|
from nnunetv2.utilities.label_handling.label_handling import LabelManager |
|
|
|
|
|
|
|
|
class nnUNetDataLoaderBase(DataLoader): |
|
|
def __init__(self, |
|
|
data: nnUNetDataset, |
|
|
batch_size: int, |
|
|
patch_size: Union[List[int], Tuple[int, ...], np.ndarray], |
|
|
final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray], |
|
|
label_manager: LabelManager, |
|
|
oversample_foreground_percent: float = 0.0, |
|
|
sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None, |
|
|
pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None, |
|
|
probabilistic_oversampling: bool = False): |
|
|
super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities) |
|
|
assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data' |
|
|
self.indices = list(data.keys()) |
|
|
|
|
|
self.oversample_foreground_percent = oversample_foreground_percent |
|
|
self.final_patch_size = final_patch_size |
|
|
self.patch_size = patch_size |
|
|
self.list_of_keys = list(self._data.keys()) |
|
|
|
|
|
|
|
|
self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) |
|
|
if pad_sides is not None: |
|
|
if not isinstance(pad_sides, np.ndarray): |
|
|
pad_sides = np.array(pad_sides) |
|
|
self.need_to_pad += pad_sides |
|
|
self.num_channels = None |
|
|
self.pad_sides = pad_sides |
|
|
self.data_shape, self.seg_shape = self.determine_shapes() |
|
|
self.sampling_probabilities = sampling_probabilities |
|
|
self.annotated_classes_key = tuple(label_manager.all_labels) |
|
|
self.has_ignore = label_manager.has_ignore_label |
|
|
self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \ |
|
|
else self._probabilistic_oversampling |
|
|
|
|
|
def _oversample_last_XX_percent(self, sample_idx: int) -> bool: |
|
|
""" |
|
|
determines whether sample sample_idx in a minibatch needs to be guaranteed foreground |
|
|
""" |
|
|
return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) |
|
|
|
|
|
def _probabilistic_oversampling(self, sample_idx: int) -> bool: |
|
|
|
|
|
return np.random.uniform() < self.oversample_foreground_percent |
|
|
|
|
|
def determine_shapes(self): |
|
|
|
|
|
data, seg, properties = self._data.load_case(self.indices[0]) |
|
|
num_color_channels = data.shape[0] |
|
|
|
|
|
data_shape = (self.batch_size, num_color_channels, *self.patch_size) |
|
|
seg_shape = (self.batch_size, seg.shape[0], *self.patch_size) |
|
|
return data_shape, seg_shape |
|
|
|
|
|
def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None], |
|
|
overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False): |
|
|
|
|
|
|
|
|
need_to_pad = self.need_to_pad.copy() |
|
|
dim = len(data_shape) |
|
|
|
|
|
for d in range(dim): |
|
|
|
|
|
|
|
|
if need_to_pad[d] + data_shape[d] < self.patch_size[d]: |
|
|
need_to_pad[d] = self.patch_size[d] - data_shape[d] |
|
|
|
|
|
|
|
|
|
|
|
lbs = [- need_to_pad[i] // 2 for i in range(dim)] |
|
|
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)] |
|
|
|
|
|
|
|
|
|
|
|
if not force_fg and not self.has_ignore: |
|
|
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] |
|
|
|
|
|
else: |
|
|
if not force_fg and self.has_ignore: |
|
|
selected_class = self.annotated_classes_key |
|
|
if len(class_locations[selected_class]) == 0: |
|
|
|
|
|
print('Warning! No annotated pixels in image!') |
|
|
selected_class = None |
|
|
|
|
|
elif force_fg: |
|
|
assert class_locations is not None, 'if force_fg is set class_locations cannot be None' |
|
|
if overwrite_class is not None: |
|
|
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \ |
|
|
'have class_locations (missing key)' |
|
|
|
|
|
|
|
|
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] |
|
|
if any(tmp): |
|
|
if len(eligible_classes_or_regions) > 1: |
|
|
eligible_classes_or_regions.pop(np.where(tmp)[0][0]) |
|
|
|
|
|
if len(eligible_classes_or_regions) == 0: |
|
|
|
|
|
selected_class = None |
|
|
if verbose: |
|
|
print('case does not contain any foreground classes') |
|
|
else: |
|
|
|
|
|
|
|
|
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ |
|
|
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class |
|
|
|
|
|
else: |
|
|
raise RuntimeError('lol what!?') |
|
|
voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None |
|
|
|
|
|
if voxels_of_that_class is not None and len(voxels_of_that_class) > 0: |
|
|
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] |
|
|
|
|
|
|
|
|
|
|
|
bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)] |
|
|
else: |
|
|
|
|
|
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] |
|
|
|
|
|
bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)] |
|
|
|
|
|
return bbox_lbs, bbox_ubs |
|
|
|