File size: 8,943 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import numpy as np
from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
import time
class nnUNetDataLoader3D(nnUNetDataLoaderBase):
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.int16)
case_properties = []
for j, i in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(i)
case_properties.append(properties)
# If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
# self._data.load_case(i) (see nnUNetDataset.load_case)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
class nnUNetDataLoader3D_MRCT(nnUNetDataLoaderBase):
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, i in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties = self._data.load_case(i)
case_properties.append(properties)
# If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
# self._data.load_case(i) (see nnUNetDataset.load_case)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=0)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
class nnUNetDataLoader3D_MRCT_mask(nnUNetDataLoaderBase):
def determine_shapes(self):
# load one case
data, seg, properties, mask = self._data.load_case(self.indices[0])
num_color_channels = data.shape[0]
data_shape = (self.batch_size, num_color_channels, *self.patch_size)
seg_shape = (self.batch_size, seg.shape[0], *self.patch_size)
mask_shape = (self.batch_size, mask.shape[0], *self.patch_size)
return data_shape, seg_shape
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
seg_all = np.zeros(self.seg_shape, dtype=np.float32)
mask_all = np.zeros(self.data_shape, dtype=np.float32) # bx: todo
case_properties = []
for j, i in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, properties, mask = self._data.load_case(i)
case_properties.append(properties)
# If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
# self._data.load_case(i) (see nnUNetDataset.load_case)
shape = data.shape[1:]
dim = len(shape)
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
data = data[this_slice]
this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
seg = seg[this_slice]
mask = mask[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=0)
mask_all[j] = np.pad(mask, ((0, 0), *padding), 'constant', constant_values=0)
return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys, 'mask': mask_all}
|