FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import numpy as np
from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
class nnUNetDataLoader2D(nnUNetDataLoaderBase):
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.int16)
case_properties = []
for j, current_key in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(current_key)
case_properties.append(properties)
# select a class/region first, then a slice where this class is present, then crop to that area
if not force_fg:
if self.has_ignore:
selected_class_or_region = self.annotated_classes_key
else:
selected_class_or_region = None
else:
# filter out all classes that are not present here
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
len(eligible_classes_or_regions) > 0 else None
if selected_class_or_region is not None:
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1])
else:
selected_slice = np.random.choice(len(data[0]))
data = data[:, selected_slice]
seg = seg[:, selected_slice]
# the line of death lol
# this needs to be a separate variable because we could otherwise permanently overwrite
# properties['class_locations']
# selected_class_or_region is:
# - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image
# - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False
# - a class or region if force_fg is True
class_locations = {
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)]
} if (selected_class_or_region is not None) else None
# print(properties)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None,
class_locations, overwrite_class=selected_class_or_region)
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
class nnUNetDataLoader2D_MRCT(nnUNetDataLoaderBase): #float32 and background 0
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, current_key in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(current_key)
case_properties.append(properties)
# select a class/region first, then a slice where this class is present, then crop to that area
if not force_fg:
if self.has_ignore:
selected_class_or_region = self.annotated_classes_key
else:
selected_class_or_region = None
else:
# filter out all classes that are not present here
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
len(eligible_classes_or_regions) > 0 else None
if selected_class_or_region is not None:
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1])
else:
selected_slice = np.random.choice(len(data[0]))
data = data[:, selected_slice]
seg = seg[:, selected_slice]
# the line of death lol
# this needs to be a separate variable because we could otherwise permanently overwrite
# properties['class_locations']
# selected_class_or_region is:
# - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image
# - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False
# - a class or region if force_fg is True
class_locations = {
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)]
} if (selected_class_or_region is not None) else None
# print(properties)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None,
class_locations, overwrite_class=selected_class_or_region)
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=0)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}