File size: 10,574 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import numpy as np
from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
class nnUNetDataLoader2D(nnUNetDataLoaderBase):
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.int16)
case_properties = []
for j, current_key in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(current_key)
case_properties.append(properties)
# select a class/region first, then a slice where this class is present, then crop to that area
if not force_fg:
if self.has_ignore:
selected_class_or_region = self.annotated_classes_key
else:
selected_class_or_region = None
else:
# filter out all classes that are not present here
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
len(eligible_classes_or_regions) > 0 else None
if selected_class_or_region is not None:
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1])
else:
selected_slice = np.random.choice(len(data[0]))
data = data[:, selected_slice]
seg = seg[:, selected_slice]
# the line of death lol
# this needs to be a separate variable because we could otherwise permanently overwrite
# properties['class_locations']
# selected_class_or_region is:
# - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image
# - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False
# - a class or region if force_fg is True
class_locations = {
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)]
} if (selected_class_or_region is not None) else None
# print(properties)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None,
class_locations, overwrite_class=selected_class_or_region)
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
class nnUNetDataLoader2D_MRCT(nnUNetDataLoaderBase): #float32 and background 0
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, current_key in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(current_key)
case_properties.append(properties)
# select a class/region first, then a slice where this class is present, then crop to that area
if not force_fg:
if self.has_ignore:
selected_class_or_region = self.annotated_classes_key
else:
selected_class_or_region = None
else:
# filter out all classes that are not present here
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
len(eligible_classes_or_regions) > 0 else None
if selected_class_or_region is not None:
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1])
else:
selected_slice = np.random.choice(len(data[0]))
data = data[:, selected_slice]
seg = seg[:, selected_slice]
# the line of death lol
# this needs to be a separate variable because we could otherwise permanently overwrite
# properties['class_locations']
# selected_class_or_region is:
# - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image
# - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False
# - a class or region if force_fg is True
class_locations = {
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)]
} if (selected_class_or_region is not None) else None
# print(properties)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None,
class_locations, overwrite_class=selected_class_or_region)
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=0)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
|