File size: 8,320 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from typing import Union, Tuple
from batchgenerators.dataloading.data_loader import DataLoader
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
from nnunetv2.utilities.label_handling.label_handling import LabelManager
class nnUNetDataLoaderBase(DataLoader):
def __init__(self,
data: nnUNetDataset,
batch_size: int,
patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
label_manager: LabelManager,
oversample_foreground_percent: float = 0.0,
sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None,
pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None,
probabilistic_oversampling: bool = False):
super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities)
assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data'
self.indices = list(data.keys())
self.oversample_foreground_percent = oversample_foreground_percent
self.final_patch_size = final_patch_size
self.patch_size = patch_size
self.list_of_keys = list(self._data.keys())
# need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
# (which is what the network will get) these patches will also cover the border of the images
self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
if pad_sides is not None:
if not isinstance(pad_sides, np.ndarray):
pad_sides = np.array(pad_sides)
self.need_to_pad += pad_sides
self.num_channels = None
self.pad_sides = pad_sides
self.data_shape, self.seg_shape = self.determine_shapes()
self.sampling_probabilities = sampling_probabilities
self.annotated_classes_key = tuple(label_manager.all_labels)
self.has_ignore = label_manager.has_ignore_label
self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \
else self._probabilistic_oversampling
def _oversample_last_XX_percent(self, sample_idx: int) -> bool:
"""
determines whether sample sample_idx in a minibatch needs to be guaranteed foreground
"""
return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
def _probabilistic_oversampling(self, sample_idx: int) -> bool:
# print('YEAH BOIIIIII')
return np.random.uniform() < self.oversample_foreground_percent
def determine_shapes(self):
# load one case
data, seg, properties = self._data.load_case(self.indices[0])
num_color_channels = data.shape[0]
data_shape = (self.batch_size, num_color_channels, *self.patch_size)
seg_shape = (self.batch_size, seg.shape[0], *self.patch_size)
return data_shape, seg_shape
def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],
overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):
# in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have
# locations for the given slice
need_to_pad = self.need_to_pad.copy()
dim = len(data_shape)
for d in range(dim):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + data_shape[d] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - data_shape[d]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
lbs = [- need_to_pad[i] // 2 for i in range(dim)]
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg and not self.has_ignore:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
# print('I want a random location')
else:
if not force_fg and self.has_ignore:
selected_class = self.annotated_classes_key
if len(class_locations[selected_class]) == 0:
# no annotated pixels in this case. Not good. But we can hardly skip it here
print('Warning! No annotated pixels in image!')
selected_class = None
# print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}')
elif force_fg:
assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
if overwrite_class is not None:
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
'have class_locations (missing key)'
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
# class_locations keys can also be tuple
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
if len(eligible_classes_or_regions) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
if verbose:
print('case does not contain any foreground classes')
else:
# I hate myself. Future me aint gonna be happy to read this
# 2022_11_25: had to read it today. Wasn't too bad
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
# print(f'I want to have foreground, selected class: {selected_class}')
else:
raise RuntimeError('lol what!?')
voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None
if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
# i + 1 because we have first dimension 0!
bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]
return bbox_lbs, bbox_ubs
|