| import warnings |
| import time |
| import numpy as np |
| from tensorflow import keras |
| import os |
| import h5py |
| import random |
| from PIL import Image |
| import nibabel as nib |
| from nilearn.image import resample_img |
| from skimage.exposure import equalize_adapthist |
| from scipy.ndimage import zoom |
| import tensorflow as tf |
|
|
| import ddmr.utils.constants as C |
| from ddmr.utils.operators import min_max_norm |
| from ddmr.utils.misc import segmentation_cardinal_to_ohe |
| from ddmr.utils.thin_plate_splines import ThinPlateSplines |
| from voxelmorph.tf.layers import SpatialTransformer |
| from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT |
|
|
| from tensorflow.python.keras.preprocessing.image import Iterator |
| from tensorflow.python.keras.utils import Sequence |
| import sys |
|
|
| from collections import defaultdict |
|
|
| from Brain_study.format_dataset import SEGMENTATION_LOC |
|
|
| |
| |
| import time |
|
|
| class BatchGenerator: |
| def __init__(self, |
| directory, |
| batch_size, |
| shuffle=True, |
| split=0.7, |
| combine_segmentations=True, |
| labels=['all'], |
| directory_val=None, |
| return_isotropic_shape=False): |
| self.file_directory = directory |
| self.batch_size = batch_size |
| self.combine_segmentations = combine_segmentations |
| self.labels = labels |
| self.shuffle = shuffle |
| self.split = split |
| self.return_isotropic_shape=return_isotropic_shape |
|
|
| if directory_val is None: |
| self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))] |
| random.shuffle(self.file_list) if self.shuffle else self.file_list.sort() |
| self.num_samples = len(self.file_list) |
| training_samples = self.file_list[:int(self.num_samples * self.split)] |
|
|
| self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels, return_isotropic_shape=return_isotropic_shape) |
| if self.split < 1.: |
| validation_samples = list(set(self.file_list) - set(training_samples)) |
| self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'], |
| validation=True, return_isotropic_shape=return_isotropic_shape) |
| else: |
| self.validation_iter = None |
| else: |
| training_samples = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))] |
| random.shuffle(training_samples) if self.shuffle else training_samples.sort() |
|
|
| validation_samples = [os.path.join(directory_val, f) for f in os.listdir(directory_val) if f.endswith(('h5', 'hd5'))] |
| random.shuffle(validation_samples) if self.shuffle else validation_samples.sort() |
|
|
| self.num_samples = len(training_samples) + len(validation_samples) |
| self.file_list = training_samples + validation_samples |
|
|
| self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels) |
| self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, labels, |
| validation=True) |
|
|
| def get_train_generator(self): |
| return self.train_iter |
|
|
| def get_validation_generator(self): |
| if self.validation_iter is not None: |
| return self.validation_iter |
| else: |
| raise ValueError('No validation iterator. Split must be < 1.0') |
|
|
| def get_file_list(self): |
| return self.file_list |
|
|
| def get_data_shape(self): |
| return self.train_iter.get_data_shape() |
|
|
|
|
| ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210., |
| 211., 218., 219., 232., 233., 254., 255.} |
| ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_LABELS)))} |
|
|
|
|
| class BatchIterator(Sequence): |
| def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'], |
| zero_grads=[64, 64, 64, 3], validation=False, sequential_labels=True, |
| return_isotropic_shape=False, **kwargs): |
| |
| |
| |
| |
| |
| self.batch_size = batch_size |
| self.shuffle = shuffle |
| self.file_list = file_list |
| self.combine_segmentations = combine_segmentations |
| self.labels = labels |
| self.zero_grads = np.zeros(zero_grads) |
| self.idx_list = np.arange(0, len(self.file_list)) |
| self.validation = validation |
| self.sequential_labels = sequential_labels |
| self.return_isotropic_shape = return_isotropic_shape |
| self._initialize() |
| self.shuffle_samples() |
|
|
| def _initialize(self): |
| if (isinstance(self.labels[0], str) and self.labels[0].lower() != 'none'): |
| if self.labels[0] != 'all': |
| |
| self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels] |
| if not self.sequential_labels: |
| self.labels = [SEGMENTATION_LOC[lbl] for lbl in self.labels] |
| self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x in self.labels else 0 |
| else: |
| self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0 |
| else: |
| |
| if self.sequential_labels: |
| self.labels = list(set(SEGMENTATION_LOC.values())) |
| self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x else 0 |
| else: |
| self.labels = list(ALL_LABELS) |
| self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0 |
| elif hasattr(self.labels[0], 'lower') and self.labels[0].lower() == 'none': |
| |
| self.labels_dict = dict() |
| else: |
| assert np.all([isinstance(lbl, (int, float)) for lbl in self.labels]), "Labels must be a str, int or float" |
| |
|
|
| self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0) |
| |
| |
|
|
| with h5py.File(self.file_list[0], 'r') as f: |
| self.image_shape = list(f['image'][:].shape) |
| self.segm_shape = self.image_shape.copy() |
| self.segm_shape[-1] = len(self.labels) if not self.combine_segmentations else 1 |
|
|
| self.batch_shape = self.image_shape.copy() |
| self.batch_shape[-1] = self.image_shape[-1] + self.segm_shape[-1] |
|
|
| def shuffle_samples(self): |
| np.random.shuffle(self.idx_list) |
|
|
| def __len__(self): |
| return self.num_steps |
|
|
| def _filter_segmentations(self, segm, segm_labels): |
| if self.combine_segmentations: |
| |
| warnings.warn('Cannot select labels when combine_segmentations options is active') |
| if self.labels[0] != 'all': |
| if set(self.labels).issubset(set(segm_labels)): |
| |
| idx = [ALL_LABELS_LOC[l] for l in self.labels] |
| segm = segm[..., idx] |
| else: |
| |
| idx = [ALL_LABELS_LOC[l] for l in list(set(self.labels).intersection(set(segm_labels)))] |
| aux = segm.copy() |
| segm = np.zeros(self.segm_shape) |
| segm[..., :len(idx)] = aux[..., idx] |
| |
| return segm |
|
|
| def _load_sample(self, file_path): |
| with h5py.File(file_path, 'r') as f: |
| img = f['image'][:] |
| segm = f['segmentation'][:] |
| isot_shape = f['isotropic_shape'][:] |
|
|
| if not self.combine_segmentations: |
| if self.sequential_labels: |
| |
| segm = np.squeeze(np.eye(len(self.labels))[segm]) |
| else: |
| lbls_list = list(ALL_LABELS) if self.labels[0] == 'all' else self.labels |
| segm = segmentation_cardinal_to_ohe(segm, lbls_list) |
| |
| |
| |
| |
| |
|
|
| img = np.asarray(img, dtype=np.float32) |
| segm = np.asarray(segm, dtype=np.float32) |
| if not isinstance(self.labels[0], str) or self.labels[0].lower() != 'none' or self.validation: |
| |
|
|
| if self.validation: |
| ret_val = np.concatenate([img, segm], axis=-1), (img, segm, self.zero_grads), isot_shape |
| else: |
| ret_val = np.concatenate([img, segm], axis=-1), (img, self.zero_grads), isot_shape |
| else: |
| ret_val = img, (img, self.zero_grads), isot_shape |
| return ret_val |
|
|
| def __getitem__(self, idx): |
| in_batch = list() |
| isotropic_shape = list() |
| |
|
|
| batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size] |
| file_list = [self.file_list[i] for i in batch_idxs] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for batch_idx, f in enumerate(file_list): |
| b, i, isot_shape = self._load_sample(f) |
| |
| if self.return_isotropic_shape: |
| isotropic_shape.append(isot_shape) |
| in_batch.append(b) |
| |
|
|
| in_batch = np.asarray(in_batch, dtype=np.float32) |
| ret_val = (in_batch, in_batch) |
| if self.return_isotropic_shape: |
| isotropic_shape = np.asarray(isotropic_shape, dtype=np.int) |
| ret_val += (isotropic_shape,) |
| |
| return ret_val |
|
|
| def __iter__(self): |
| """Create a generator that iterate over the Sequence.""" |
| for item in (self[i] for i in range(len(self))): |
| yield item |
|
|
| def get_data_shape(self): |
| return self.batch_shape, self.image_shape, self.segm_shape |
|
|
| def on_epoch_end(self): |
| self.shuffle_samples() |
|
|
| def get_segmentation_labels(self): |
| if self.combine_segmentations: |
| labels = [1] |
| else: |
| labels = self.labels |
| return labels |
|
|
|
|
|
|
|
|
|
|
|
|
| ''' |
| def get_size(obj, seen=None): |
| """Recursively finds size of objects""" |
| size = sys.getsizeof(obj) |
| if seen is None: |
| seen = set() |
| obj_id = id(obj) |
| if obj_id in seen: |
| return 0 |
| # Important mark as seen *before* entering recursion to gracefully handle |
| # self-referential objects |
| seen.add(obj_id) |
| if isinstance(obj, dict): |
| size += sum([get_size(v, seen) for v in obj.values()]) |
| size += sum([get_size(k, seen) for k in obj.keys()]) |
| elif hasattr(obj, '__dict__'): |
| size += get_size(obj.__dict__, seen) |
| elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): |
| size += sum([get_size(i, seen) for i in obj]) |
| return size |
| |
| |
| class BatchIterator(Iterator): |
| def __init__(self, generator, file_list, input_shape, output_shape, batch_size, shuffle, all_files_in_batch): |
| self.file_list = file_list |
| self.generator = generator |
| self.input_shape = input_shape |
| self.nr_of_inputs = len(input_shape) |
| self.output_shape = output_shape |
| self.nr_of_outputs = len(output_shape) |
| self.all_files_in_batch = all_files_in_batch |
| self.preload_to_memory = False |
| self.file_cache = {} |
| self.max_cache_size = 10*1024 |
| self.verbose = False |
| if self.preload_to_memory: |
| for filename, file_index in self.file_list: |
| file = h5py.File(filename, 'r') |
| inputs = {} |
| for name, data in file['input'].items(): |
| inputs[name] = np.copy(data) |
| self.file_cache[filename] = {'input': inputs, 'output': np.copy(file['output'])} |
| file.close() |
| if get_size(self.file_cache) / (1024*1024) >= self.max_cache_size: |
| print('File cache has reached limit of', self.max_cache_size, 'MBs') |
| break |
| epoch_size = len(file_list) |
| if all_files_in_batch: |
| epoch_size = len(file_list) * 10 |
| super(BatchIterator, self).__init__(epoch_size, batch_size, shuffle, None) |
| |
| def _get_sample(self, index): |
| filename, file_index = self.file_list[index] |
| if filename in self.file_cache: |
| file = self.file_cache[filename] |
| else: |
| file = h5py.File(filename, 'r') |
| inputs = [] |
| outputs = [] |
| for name, data in file['input'].items(): |
| inputs.append(data[file_index, :]) |
| for name, data in file['output'].items(): |
| if len(data.shape) > 1: |
| outputs.append(data[file_index, :]) |
| else: |
| outputs.append(data[file_index]) |
| #outputs.append(file['output'][file_index, :]) # TODO fix |
| if filename not in self.file_cache: |
| file.close() |
| return inputs, outputs |
| |
| def _get_random_sample_in_file(self, file_index): |
| filename = self.file_list[file_index] |
| file = h5py.File(filename, 'r') |
| x = file['output/0'] |
| sample = np.random.randint(0, x.shape[0]) |
| #print('Sampling image', sample, 'from file', filename) |
| inputs = [] |
| outputs = [] |
| for name, data in file['input'].items(): |
| inputs.append(data[sample, :]) |
| for name, data in file['output'].items(): |
| outputs.append(data[file_index, :]) |
| #outputs.append(file['output'][sample, :]) # TODO FIX output |
| file.close() |
| return inputs, outputs |
| |
| def next(self): |
| |
| with self.lock: |
| index_array = next(self.index_generator) |
| |
| #print(len(index_array)) |
| return self._get_batches_of_transformed_samples(index_array) |
| |
| def _get_batches_of_transformed_samples(self, index_array): |
| start_batch = time.time() |
| batches_x = [] |
| batches_y = [] |
| for input_index in range(self.nr_of_inputs): |
| batches_x.append(np.zeros(tuple([len(index_array)] + list(self.input_shape[input_index])))) |
| for output_index in range(self.nr_of_outputs): |
| batches_y.append(np.zeros(tuple([len(index_array)] + list(self.output_shape[output_index])))) |
| |
| timings_sampling = np.zeros((len(index_array,))) |
| timings_transform = np.zeros((len(index_array,))) |
| for batch_index, sample_index in enumerate(index_array): |
| # Have to copy here in order to not modify original data |
| start = time.time() |
| if self.all_files_in_batch: |
| input, output = self._get_random_sample_in_file(batch_index) |
| else: |
| input, output = self._get_sample(sample_index) |
| timings_sampling[batch_index] = time.time() - start |
| start = time.time() |
| input, output = self.generator.transform(input, output) |
| timings_transform[batch_index] = time.time() - start |
| |
| #print('inputs', self.nr_of_inputs, len(input)) |
| for input_index in range(self.nr_of_inputs): |
| batches_x[input_index][batch_index] = input[input_index] |
| for output_index in range(self.nr_of_outputs): |
| batches_y[output_index][batch_index] = output[output_index] |
| |
| elapsed = time.time() - start_batch |
| if self.verbose: |
| print('Time to prepare batch:', round(elapsed,3), 'seconds') |
| print('Sampling mean:', round(timings_sampling.mean(), 3), 'seconds') |
| print('Transform mean:', round(timings_transform.mean(), 3), 'seconds') |
| |
| return batches_x, batches_y |
| |
| |
| CLASSIFICATION = 'classification' |
| SEGMENTATION = 'segmentation' |
| |
| |
| class BatchGenerator(): |
| def __init__(self, filelist, all_files_in_batch=False): |
| self.methods = [] |
| self.args = [] |
| self.crop_width_to = None |
| self.image_list = [] |
| self.input_shape = [] |
| self.output_shape = [] |
| self.all_files_in_batch = all_files_in_batch |
| self.transforms = [] |
| |
| if all_files_in_batch: |
| file = h5py.File(filelist[0], 'r') |
| for name, data in file['input'].items(): |
| self.input_shape.append(data.shape[1:]) |
| for name, data in file['output'].items(): |
| self.output_shape.append(data.shape[1:]) |
| # TODO fix |
| #self.output_shape.append(file['output'].shape[1:]) |
| file.close() |
| self.image_list = filelist |
| return |
| |
| # Go through filelist |
| first = True |
| for filename in filelist: |
| samples = None |
| # Open file to see how many samples it has |
| file = h5py.File(filename, 'r') |
| for name, data in file['input'].items(): |
| if first: |
| self.input_shape.append(data.shape[1:]) |
| samples = data.shape[0] |
| # TODO fix |
| for name, data in file['output'].items(): |
| if first: |
| self.output_shape.append(data.shape[1:]) |
| if samples != data.shape[0]: |
| raise ValueError() |
| #self.output_shape.append(file['output'].shape[1:]) |
| if len(self.output_shape) == 1: |
| self.problem_type = CLASSIFICATION |
| else: |
| self.problem_type = SEGMENTATION |
| |
| file.close() |
| if samples is None: |
| raise ValueError() |
| # Append a tuple to image_list for each image consisting of filename and index |
| print(filename, samples) |
| for i in range(samples): |
| self.image_list.append((filename, i)) |
| first = False |
| |
| print('Image generator with', len(self.image_list), ' image samples created') |
| |
| def flow(self, batch_size, shuffle=True): |
| |
| return BatchIterator(self, self.image_list, self.input_shape, self.output_shape, batch_size, shuffle, self.all_files_in_batch) |
| |
| def transform(self, inputs, outputs): |
| #input = input.astype(np.float32) # TODO |
| #output = output.astype(np.float32) |
| for input_indices, output_indices, transform in self.transforms: |
| transform.randomize() |
| inputs, outputs = transform.transform_all(inputs, outputs, input_indices, output_indices) |
| return inputs, outputs |
| |
| def add_transform(self, input_indices: Union[int, List[int], None], output_indices: Union[int, List[int], None], transform: Transform): |
| if type(input_indices) is int: |
| input_indices = [input_indices] |
| if type(output_indices) is int: |
| output_indices = [output_indices] |
| |
| self.transforms.append(( |
| input_indices, |
| output_indices, |
| transform |
| )) |
| |
| def get_size(self): |
| if self.all_files_in_batch: |
| return 10*len(self.image_list) |
| else: |
| return len(self.image_list) |
| |
| ''' |
|
|