FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
from collections import OrderedDict
from typing import Union, Tuple, List
import numpy as np
import pandas as pd
import torch
from batchgenerators.augmentations.utils import resize_segmentation
from scipy.ndimage.interpolation import map_coordinates
from skimage.transform import resize
from nnunetv2.configuration import ANISO_THRESHOLD
def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
return do_separate_z
def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
return axis
def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
assert len(old_spacing) == len(old_shape)
assert len(old_shape) == len(new_spacing)
new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
return new_shape
def resample_data_or_seg_to_spacing(data: np.ndarray,
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
axis = get_lowres_axis(current_spacing)
else:
axis = None
else:
if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(current_spacing)
elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(new_spacing)
else:
do_separate_z = False
axis = None
if axis is not None:
if len(axis) == 3:
# every axis has the same spacing, this should never happen, why is this code here?
do_separate_z = False
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
else:
pass
if data is not None:
assert data.ndim == 4, "data must be c x y z"
shape = np.array(data[0].shape)
new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
return data_reshaped
def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
"""
needed for segmentation export. Stupid, I know
"""
if isinstance(data, torch.Tensor):
data = data.cpu().numpy()
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
axis = get_lowres_axis(current_spacing)
else:
axis = None
else:
if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(current_spacing)
elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(new_spacing)
else:
do_separate_z = False
axis = None
if axis is not None:
if len(axis) == 3:
# every axis has the same spacing, this should never happen, why is this code here?
do_separate_z = False
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
else:
pass
if data is not None:
assert data.ndim == 4, "data must be c x y z"
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
return data_reshaped
def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
do_separate_z: bool = False, order_z: int = 0):
"""
separate_z=True will resample with order 0 along z
:param data:
:param new_shape:
:param is_seg:
:param axis:
:param order:
:param do_separate_z:
:param order_z: only applies if do_separate_z is True
:return:
"""
assert data.ndim == 4, "data must be (c, x, y, z)"
assert len(new_shape) == data.ndim - 1
if is_seg:
resize_fn = resize_segmentation
kwargs = OrderedDict()
else:
resize_fn = resize
kwargs = {'mode': 'edge', 'anti_aliasing': False}
dtype_data = data.dtype
shape = np.array(data[0].shape)
new_shape = np.array(new_shape)
if np.any(shape != new_shape):
data = data.astype(float)
if do_separate_z:
# print("separate z, order in z is", order_z, "order inplane is", order)
assert len(axis) == 1, "only one anisotropic axis supported"
axis = axis[0]
if axis == 0:
new_shape_2d = new_shape[1:]
elif axis == 1:
new_shape_2d = new_shape[[0, 2]]
else:
new_shape_2d = new_shape[:-1]
reshaped_final_data = []
for c in range(data.shape[0]):
reshaped_data = []
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs))
elif axis == 1:
reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs))
else:
reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs))
reshaped_data = np.stack(reshaped_data, axis)
if shape[axis] != new_shape[axis]:
# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_data.shape
row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
dim_scale = float(orig_dim) / dim
map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
map_rows = row_scale * (map_rows + 0.5) - 0.5
map_cols = col_scale * (map_cols + 0.5) - 0.5
map_dims = dim_scale * (map_dims + 0.5) - 0.5
coord_map = np.array([map_rows, map_cols, map_dims])
if not is_seg or order_z == 0:
reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z,
mode='nearest')[None])
else:
unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data)
reshaped = np.zeros(new_shape, dtype=dtype_data)
for i, cl in enumerate(unique_labels):
reshaped_multihot = np.round(
map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
mode='nearest'))
reshaped[reshaped_multihot > 0.5] = cl
reshaped_final_data.append(reshaped[None])
else:
reshaped_final_data.append(reshaped_data[None])
reshaped_final_data = np.vstack(reshaped_final_data)
else:
# print("no separate z, order", order)
reshaped = []
for c in range(data.shape[0]):
reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None])
reshaped_final_data = np.vstack(reshaped)
return reshaped_final_data.astype(dtype_data)
else:
# print("no resampling necessary")
return data