|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union, List |
|
|
import numpy as np |
|
|
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter |
|
|
import SimpleITK as sitk |
|
|
|
|
|
|
|
|
class SimpleITKIO(BaseReaderWriter): |
|
|
supported_file_endings = [ |
|
|
'.nii.gz', |
|
|
'.nrrd', |
|
|
'.mha' |
|
|
] |
|
|
|
|
|
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: |
|
|
images = [] |
|
|
spacings = [] |
|
|
origins = [] |
|
|
directions = [] |
|
|
|
|
|
spacings_for_nnunet = [] |
|
|
for f in image_fnames: |
|
|
itk_image = sitk.ReadImage(f) |
|
|
spacings.append(itk_image.GetSpacing()) |
|
|
origins.append(itk_image.GetOrigin()) |
|
|
directions.append(itk_image.GetDirection()) |
|
|
npy_image = sitk.GetArrayFromImage(itk_image) |
|
|
if npy_image.ndim == 2: |
|
|
|
|
|
npy_image = npy_image[None, None] |
|
|
max_spacing = max(spacings[-1]) |
|
|
spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) |
|
|
elif npy_image.ndim == 3: |
|
|
|
|
|
npy_image = npy_image[None] |
|
|
spacings_for_nnunet.append(list(spacings[-1])[::-1]) |
|
|
elif npy_image.ndim == 4: |
|
|
|
|
|
spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) |
|
|
pass |
|
|
else: |
|
|
raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}") |
|
|
|
|
|
images.append(npy_image) |
|
|
spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) |
|
|
|
|
|
if not self._check_all_same([i.shape for i in images]): |
|
|
print('ERROR! Not all input images have the same shape!') |
|
|
print('Shapes:') |
|
|
print([i.shape for i in images]) |
|
|
print('Image files:') |
|
|
print(image_fnames) |
|
|
raise RuntimeError() |
|
|
if not self._check_all_same(spacings): |
|
|
print('ERROR! Not all input images have the same spacing!') |
|
|
print('Spacings:') |
|
|
print(spacings) |
|
|
print('Image files:') |
|
|
print(image_fnames) |
|
|
raise RuntimeError() |
|
|
if not self._check_all_same(origins): |
|
|
print('WARNING! Not all input images have the same origin!') |
|
|
print('Origins:') |
|
|
print(origins) |
|
|
print('Image files:') |
|
|
print(image_fnames) |
|
|
print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' |
|
|
'that segmentations and data overlap.') |
|
|
if not self._check_all_same(directions): |
|
|
print('WARNING! Not all input images have the same direction!') |
|
|
print('Directions:') |
|
|
print(directions) |
|
|
print('Image files:') |
|
|
print(image_fnames) |
|
|
print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' |
|
|
'that segmentations and data overlap.') |
|
|
if not self._check_all_same(spacings_for_nnunet): |
|
|
print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' |
|
|
'bug. Please report!') |
|
|
print('spacings_for_nnunet:') |
|
|
print(spacings_for_nnunet) |
|
|
print('Image files:') |
|
|
print(image_fnames) |
|
|
raise RuntimeError() |
|
|
|
|
|
stacked_images = np.vstack(images) |
|
|
dict = { |
|
|
'sitk_stuff': { |
|
|
|
|
|
'spacing': spacings[0], |
|
|
'origin': origins[0], |
|
|
'direction': directions[0] |
|
|
}, |
|
|
|
|
|
|
|
|
'spacing': spacings_for_nnunet[0] |
|
|
} |
|
|
return stacked_images.astype(np.float32), dict |
|
|
|
|
|
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: |
|
|
return self.read_images((seg_fname, )) |
|
|
|
|
|
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: |
|
|
assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' |
|
|
output_dimension = len(properties['sitk_stuff']['spacing']) |
|
|
assert 1 < output_dimension < 4 |
|
|
if output_dimension == 2: |
|
|
seg = seg[0] |
|
|
|
|
|
if seg.dtype=="float16": |
|
|
itk_image = sitk.GetImageFromArray(seg.astype(np.float32)) |
|
|
else: |
|
|
itk_image = sitk.GetImageFromArray(seg.astype(np.uint8)) |
|
|
|
|
|
itk_image.SetSpacing(properties['sitk_stuff']['spacing']) |
|
|
itk_image.SetOrigin(properties['sitk_stuff']['origin']) |
|
|
itk_image.SetDirection(properties['sitk_stuff']['direction']) |
|
|
|
|
|
sitk.WriteImage(itk_image, output_fname, True) |
|
|
|