# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union, List import numpy as np from nnunetv2.imageio.base_reader_writer import BaseReaderWriter import SimpleITK as sitk class SimpleITKIO(BaseReaderWriter): supported_file_endings = [ '.nii.gz', '.nrrd', '.mha' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: images = [] spacings = [] origins = [] directions = [] spacings_for_nnunet = [] for f in image_fnames: itk_image = sitk.ReadImage(f) spacings.append(itk_image.GetSpacing()) origins.append(itk_image.GetOrigin()) directions.append(itk_image.GetDirection()) npy_image = sitk.GetArrayFromImage(itk_image) if npy_image.ndim == 2: # 2d npy_image = npy_image[None, None] max_spacing = max(spacings[-1]) spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) elif npy_image.ndim == 3: # 3d, as in original nnunet npy_image = npy_image[None] spacings_for_nnunet.append(list(spacings[-1])[::-1]) elif npy_image.ndim == 4: # 4d, multiple modalities in one file spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) pass else: raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}") images.append(npy_image) spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(spacings): print('ERROR! Not all input images have the same spacing!') print('Spacings:') print(spacings) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(origins): print('WARNING! Not all input images have the same origin!') print('Origins:') print(origins) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(directions): print('WARNING! Not all input images have the same direction!') print('Directions:') print(directions) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(spacings_for_nnunet): print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' 'bug. Please report!') print('spacings_for_nnunet:') print(spacings_for_nnunet) print('Image files:') print(image_fnames) raise RuntimeError() stacked_images = np.vstack(images) dict = { 'sitk_stuff': { # this saves the sitk geometry information. This part is NOT used by nnU-Net! 'spacing': spacings[0], 'origin': origins[0], 'direction': directions[0] }, # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays # are returned x,y,z but spacing is returned z,y,x. Duh. 'spacing': spacings_for_nnunet[0] } return stacked_images.astype(np.float32), dict def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: return self.read_images((seg_fname, )) def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' output_dimension = len(properties['sitk_stuff']['spacing']) assert 1 < output_dimension < 4 if output_dimension == 2: seg = seg[0] if seg.dtype=="float16": itk_image = sitk.GetImageFromArray(seg.astype(np.float32)) #TODO: to improve else: itk_image = sitk.GetImageFromArray(seg.astype(np.uint8)) itk_image.SetSpacing(properties['sitk_stuff']['spacing']) itk_image.SetOrigin(properties['sitk_stuff']['origin']) itk_image.SetDirection(properties['sitk_stuff']['direction']) sitk.WriteImage(itk_image, output_fname, True)