|
|
""" |
|
|
Misc functions, including distributed helpers. |
|
|
|
|
|
Mostly copy-paste from torchvision references. |
|
|
""" |
|
|
import os |
|
|
import shutil |
|
|
import numpy as np |
|
|
import nibabel as nib |
|
|
from pathlib import Path |
|
|
import SimpleITK as sitk |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
'''if float(torchvision.__version__[:3]) < 0.7: |
|
|
from torchvision.ops import _new_empty_tensor |
|
|
from torchvision.ops.misc import _output_size''' |
|
|
|
|
|
|
|
|
def make_dir(dir_name, parents = True, exist_ok = True, reset = False): |
|
|
if reset and os.path.isdir(dir_name): |
|
|
shutil.rmtree(dir_name) |
|
|
dir_name = Path(dir_name) |
|
|
dir_name.mkdir(parents=parents, exist_ok=exist_ok) |
|
|
return dir_name |
|
|
|
|
|
|
|
|
def read_image(img_path, save_path = None): |
|
|
img = nib.load(img_path) |
|
|
nda = img.get_fdata() |
|
|
affine = img.affine |
|
|
if save_path: |
|
|
ni_img = nib.Nifti1Image(nda, affine) |
|
|
nib.save(ni_img, save_path) |
|
|
return np.squeeze(nda), affine |
|
|
|
|
|
def save_image(nda, affine, save_path): |
|
|
ni_img = nib.Nifti1Image(nda, affine) |
|
|
nib.save(ni_img, save_path) |
|
|
return save_path |
|
|
|
|
|
def img2nda(img_path, save_path = None): |
|
|
img = sitk.ReadImage(img_path) |
|
|
nda = sitk.GetArrayFromImage(img) |
|
|
if save_path: |
|
|
np.save(save_path, nda) |
|
|
return nda, img.GetOrigin(), img.GetSpacing(), img.GetDirection() |
|
|
|
|
|
def to3d(img_path, save_path = None): |
|
|
nda, o, s, d = img2nda(img_path) |
|
|
save_path = img_path if save_path is None else save_path |
|
|
if len(o) > 3: |
|
|
nda2img(nda, o[:3], s[:3], d[:3] + d[4:7] + d[8:11], save_path) |
|
|
return save_path |
|
|
|
|
|
def nda2img(nda, origin = None, spacing = None, direction = None, save_path = None, isVector = None): |
|
|
if type(nda) == torch.Tensor: |
|
|
nda = nda.cpu().detach().numpy() |
|
|
nda = np.squeeze(np.array(nda)) |
|
|
isVector = isVector if isVector else len(nda.shape) > 3 |
|
|
img = sitk.GetImageFromArray(nda, isVector = isVector) |
|
|
if origin: |
|
|
img.SetOrigin(origin) |
|
|
if spacing: |
|
|
img.SetSpacing(spacing) |
|
|
if direction: |
|
|
img.SetDirection(direction) |
|
|
if save_path: |
|
|
sitk.WriteImage(img, save_path) |
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
def cropping(img_path, tol = 0, crop_range_lst = None, spare = 0, save_path = None): |
|
|
|
|
|
img = sitk.ReadImage(img_path) |
|
|
orig_nda = sitk.GetArrayFromImage(img) |
|
|
if len(orig_nda.shape) > 3: |
|
|
nda = orig_nda[..., 0] |
|
|
else: |
|
|
nda = np.copy(orig_nda) |
|
|
|
|
|
if crop_range_lst is None: |
|
|
|
|
|
mask = nda > tol |
|
|
|
|
|
coords = np.argwhere(mask) |
|
|
|
|
|
x0, y0, z0 = coords.min(axis=0) |
|
|
x1, y1, z1 = coords.max(axis=0) + 1 |
|
|
|
|
|
x0 = x0 - spare if x0 > spare else x0 |
|
|
y0 = y0 - spare if y0 > spare else y0 |
|
|
z0 = z0 - spare if z0 > spare else z0 |
|
|
x1 = x1 + spare if x1 < orig_nda.shape[0] - spare else x1 |
|
|
y1 = y1 + spare if y1 < orig_nda.shape[1] - spare else y1 |
|
|
z1 = z1 + spare if z1 < orig_nda.shape[2] - spare else z1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
[[x0, y0, z0], [x1, y1, z1]] = crop_range_lst |
|
|
|
|
|
|
|
|
cropped_nda = orig_nda[x0 : x1, y0 : y1, z0 : z1] |
|
|
new_origin = [img.GetOrigin()[0] + img.GetSpacing()[0] * z0,\ |
|
|
img.GetOrigin()[1] + img.GetSpacing()[1] * y0,\ |
|
|
img.GetOrigin()[2] + img.GetSpacing()[2] * x0] |
|
|
cropped_img = sitk.GetImageFromArray(cropped_nda, isVector = len(orig_nda.shape) > 3) |
|
|
cropped_img.SetOrigin(new_origin) |
|
|
|
|
|
cropped_img.SetSpacing(img.GetSpacing()) |
|
|
cropped_img.SetDirection(img.GetDirection()) |
|
|
if save_path: |
|
|
sitk.WriteImage(cropped_img, save_path) |
|
|
|
|
|
return cropped_img, [[x0, y0, z0], [x1, y1, z1]], new_origin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def crop_and_pad(orig_nda, crop_idx = [], tol = 1e-7, pad_size = [224, 224, 224], to_print = True): |
|
|
if len(crop_idx) < 2: |
|
|
[[x0, y0, z0], [x1, y1, z1]] = crop(orig_nda, to_print = to_print) |
|
|
else: |
|
|
[[x0, y0, z0], [x1, y1, z1]] = crop_idx |
|
|
nda = orig_nda[x0:x1, y0:y1, z0:z1] |
|
|
nda = pad(nda, pad_size, to_print = to_print) |
|
|
return nda, [[x0, y0, z0], [x1, y1, z1]] |
|
|
|
|
|
|
|
|
def crop(orig_nda, tol = 1e-7, to_print = True): |
|
|
|
|
|
if len(orig_nda.shape) > 3: |
|
|
nda = orig_nda[..., 0] |
|
|
else: |
|
|
nda = np.copy(orig_nda) |
|
|
|
|
|
|
|
|
mask = nda > tol |
|
|
|
|
|
|
|
|
coords = np.argwhere(mask) |
|
|
|
|
|
|
|
|
x0, y0, z0 = coords.min(axis=0) |
|
|
x1, y1, z1 = coords.max(axis=0) + 1 |
|
|
|
|
|
if to_print: |
|
|
|
|
|
print(' Cropping Slice [%d, %d)' % (x0, x1)) |
|
|
print(' Cropping Row [%d, %d)' % (y0, y1)) |
|
|
print(' Cropping Column [%d, %d)' % (z0, z1)) |
|
|
|
|
|
return [[x0, y0, z0], [x1, y1, z1]] |
|
|
|
|
|
def pad(orig_nda, pad_size = [224, 224, 224], to_print = True): |
|
|
orig_shape = orig_nda.shape |
|
|
to_pad_start = [int((pad_size[i] - orig_shape[i])/2) for i in range(3)] |
|
|
|
|
|
if to_print: |
|
|
print(' orig shape:', orig_shape) |
|
|
print(' pad start:', to_pad_start) |
|
|
|
|
|
new_nda = np.zeros(pad_size) |
|
|
new_nda[to_pad_start[0]:to_pad_start[0]+orig_shape[0], |
|
|
to_pad_start[1]:to_pad_start[1]+orig_shape[1], |
|
|
to_pad_start[2]:to_pad_start[2]+orig_shape[2]] = orig_nda |
|
|
|
|
|
return new_nda |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def viewVolume(x, aff=None, prefix='', postfix='', names=[], ext='.nii.gz', save_dir='/tmp'): |
|
|
|
|
|
if aff is None: |
|
|
aff = np.eye(4) |
|
|
else: |
|
|
if type(aff) == torch.Tensor: |
|
|
aff = aff.cpu().detach().numpy() |
|
|
|
|
|
if type(x) is dict: |
|
|
names = list(x.keys()) |
|
|
x = [x[k] for k in x] |
|
|
|
|
|
if type(x) is not list: |
|
|
x = [x] |
|
|
|
|
|
|
|
|
|
|
|
for n in range(len(x)): |
|
|
vol = x[n] |
|
|
if vol is not None: |
|
|
if type(vol) == torch.Tensor: |
|
|
vol = vol.cpu().detach().numpy() |
|
|
vol = np.squeeze(np.array(vol)) |
|
|
try: |
|
|
save_path = os.path.join(save_dir, prefix + names[n] + postfix + ext) |
|
|
except: |
|
|
save_path = os.path.join(save_dir, prefix + str(n) + postfix + ext) |
|
|
MRIwrite(vol, aff, save_path) |
|
|
|
|
|
|
|
|
|
|
|
return save_path |
|
|
|
|
|
|
|
|
|
|
|
def MRIwrite(volume, aff, filename, dtype=None): |
|
|
|
|
|
if dtype is not None: |
|
|
volume = volume.astype(dtype=dtype) |
|
|
|
|
|
if aff is None: |
|
|
aff = np.eye(4) |
|
|
header = nib.Nifti1Header() |
|
|
nifty = nib.Nifti1Image(volume, aff, header) |
|
|
|
|
|
nib.save(nifty, filename) |
|
|
|
|
|
|
|
|
|
|
|
def MRIread(filename, dtype=None, im_only=False): |
|
|
|
|
|
assert filename.endswith(('.nii', '.nii.gz', '.mgz')), 'Unknown data file: %s' % filename |
|
|
|
|
|
x = nib.load(filename) |
|
|
volume = x.get_fdata() |
|
|
aff = x.affine |
|
|
|
|
|
if dtype is not None: |
|
|
volume = volume.astype(dtype=dtype) |
|
|
|
|
|
if im_only: |
|
|
return volume |
|
|
else: |
|
|
return volume, aff |
|
|
|
|
|
|
|
|
|