import SimpleITK as sitk import numpy as np from skimage.transform import resize def resize_image(image, old_spacing, new_spacing, order=3): new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] image = sitk.GetArrayFromImage(itk_image).astype(float) assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" if not is_seg: if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): image = resize_image(image, spacing, spacing_target).astype(np.float32) image -= image.mean() image /= image.std() else: new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) image = resize_segmentation(image, new_shape, 1) return image def load_and_preprocess(mri_file): images = {} # t1 images["T1"] = sitk.ReadImage(mri_file) properties_dict = { "spacing": images["T1"].GetSpacing(), "direction": images["T1"].GetDirection(), "size": images["T1"].GetSize(), "origin": images["T1"].GetOrigin() } for k in images.keys(): images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) properties_dict['size_before_cropping'] = images["T1"].shape imgs = [] for seq in ['T1']: imgs.append(images[seq][None]) all_data = np.vstack(imgs) print("image shape after preprocessing: ", str(all_data[0].shape)) return all_data, properties_dict def save_segmentation_nifti(segmentation, dct, out_fname, order=1): ''' segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out of the original image dct: size_before_cropping brain_bbox size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping spacing origin direction :param segmentation: :param dct: :param out_fname: :return: ''' old_size = dct.get('size_before_cropping') bbox = dct.get('brain_bbox') if bbox is not None: seg_old_size = np.zeros(old_size) for c in range(3): bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) seg_old_size[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], bbox[2][0]:bbox[2][1]] = segmentation else: seg_old_size = segmentation if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]): seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) else: seg_old_spacing = seg_old_size seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32)) seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) seg_resized_itk.SetOrigin(dct['origin']) seg_resized_itk.SetDirection(dct['direction']) sitk.WriteImage(seg_resized_itk, out_fname) def resize_segmentation(segmentation, new_shape, order=3, cval=0): ''' Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one hot encoding which is resized and transformed back to a segmentation map. This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2]) :param segmentation: :param new_shape: :param order: :return: ''' tpe = segmentation.dtype unique_labels = np.unique(segmentation) assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) else: reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for i, c in enumerate(unique_labels): reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) reshaped[reshaped_multihot >= 0.5] = c return reshaped