|
|
import numpy as np |
|
|
from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase |
|
|
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset |
|
|
|
|
|
|
|
|
class nnUNetDataLoader2D(nnUNetDataLoaderBase): |
|
|
def generate_train_batch(self): |
|
|
selected_keys = self.get_indices() |
|
|
|
|
|
data_all = np.zeros(self.data_shape, dtype=np.float32) |
|
|
seg_all = np.zeros(self.seg_shape, dtype=np.int16) |
|
|
case_properties = [] |
|
|
|
|
|
for j, current_key in enumerate(selected_keys): |
|
|
|
|
|
|
|
|
force_fg = self.get_do_oversample(j) |
|
|
data, seg, properties = self._data.load_case(current_key) |
|
|
case_properties.append(properties) |
|
|
|
|
|
|
|
|
if not force_fg: |
|
|
if self.has_ignore: |
|
|
selected_class_or_region = self.annotated_classes_key |
|
|
else: |
|
|
selected_class_or_region = None |
|
|
else: |
|
|
|
|
|
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] |
|
|
if any(tmp): |
|
|
if len(eligible_classes_or_regions) > 1: |
|
|
eligible_classes_or_regions.pop(np.where(tmp)[0][0]) |
|
|
|
|
|
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ |
|
|
len(eligible_classes_or_regions) > 0 else None |
|
|
if selected_class_or_region is not None: |
|
|
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1]) |
|
|
else: |
|
|
selected_slice = np.random.choice(len(data[0])) |
|
|
|
|
|
data = data[:, selected_slice] |
|
|
seg = seg[:, selected_slice] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_locations = { |
|
|
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)] |
|
|
} if (selected_class_or_region is not None) else None |
|
|
|
|
|
|
|
|
shape = data.shape[1:] |
|
|
dim = len(shape) |
|
|
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None, |
|
|
class_locations, overwrite_class=selected_class_or_region) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] |
|
|
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) |
|
|
data = data[this_slice] |
|
|
|
|
|
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) |
|
|
seg = seg[this_slice] |
|
|
|
|
|
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] |
|
|
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) |
|
|
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) |
|
|
|
|
|
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} |
|
|
|
|
|
|
|
|
class nnUNetDataLoader2D_MRCT(nnUNetDataLoaderBase): |
|
|
def generate_train_batch(self): |
|
|
selected_keys = self.get_indices() |
|
|
|
|
|
data_all = np.zeros(self.data_shape, dtype=np.float32) |
|
|
seg_all = np.zeros(self.seg_shape, dtype=np.float32) |
|
|
case_properties = [] |
|
|
|
|
|
for j, current_key in enumerate(selected_keys): |
|
|
|
|
|
|
|
|
force_fg = self.get_do_oversample(j) |
|
|
data, seg, properties = self._data.load_case(current_key) |
|
|
case_properties.append(properties) |
|
|
|
|
|
|
|
|
if not force_fg: |
|
|
if self.has_ignore: |
|
|
selected_class_or_region = self.annotated_classes_key |
|
|
else: |
|
|
selected_class_or_region = None |
|
|
else: |
|
|
|
|
|
eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] |
|
|
if any(tmp): |
|
|
if len(eligible_classes_or_regions) > 1: |
|
|
eligible_classes_or_regions.pop(np.where(tmp)[0][0]) |
|
|
|
|
|
selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ |
|
|
len(eligible_classes_or_regions) > 0 else None |
|
|
if selected_class_or_region is not None: |
|
|
selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1]) |
|
|
else: |
|
|
selected_slice = np.random.choice(len(data[0])) |
|
|
|
|
|
data = data[:, selected_slice] |
|
|
seg = seg[:, selected_slice] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_locations = { |
|
|
selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)] |
|
|
} if (selected_class_or_region is not None) else None |
|
|
|
|
|
|
|
|
shape = data.shape[1:] |
|
|
dim = len(shape) |
|
|
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None, |
|
|
class_locations, overwrite_class=selected_class_or_region) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] |
|
|
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) |
|
|
data = data[this_slice] |
|
|
|
|
|
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) |
|
|
seg = seg[this_slice] |
|
|
|
|
|
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] |
|
|
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) |
|
|
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=0) |
|
|
|
|
|
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} |
|
|
|