|
|
import os |
|
|
import numpy as np |
|
|
import nibabel as nib |
|
|
|
|
|
import torch |
|
|
from torch.nn.functional import conv3d |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from scipy.io.matlab import loadmat |
|
|
|
|
|
|
|
|
import time, datetime |
|
|
|
|
|
from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint |
|
|
from ShapeID.perlin3d import generate_velocity_3d , generate_shape_3d |
|
|
|
|
|
|
|
|
class ConcatDataset(Dataset): |
|
|
def __init__(self,dataset_list, probs=None): |
|
|
self.datasets = dataset_list |
|
|
self.probs = probs if probs else [1/len(self.datasets)] * len(self.datasets) |
|
|
|
|
|
def __getitem__(self, i): |
|
|
chosen_dataset = np.random.choice(self.datasets, 1, p=self.probs)[0] |
|
|
i = i % len(chosen_dataset) |
|
|
return chosen_dataset[i] |
|
|
|
|
|
def __len__(self): |
|
|
return max(len(d) for d in self.datasets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolution_sampler(low_res_only = False): |
|
|
|
|
|
if low_res_only: |
|
|
r = (np.random.rand() * 0.5) + 0.5 |
|
|
else: |
|
|
r = np.random.rand() |
|
|
|
|
|
if r < 0.25: |
|
|
resolution = np.array([1.0, 1.0, 1.0]) |
|
|
thickness = np.array([1.0, 1.0, 1.0]) |
|
|
elif r < 0.5: |
|
|
resolution = np.array([1.0, 1.0, 1.0]) |
|
|
thickness = np.array([1.0, 1.0, 1.0]) |
|
|
idx = np.random.randint(3) |
|
|
resolution[idx] = 2.5 + 6 * np.random.rand() |
|
|
thickness[idx] = np.min([resolution[idx], 4.0 + 2.0 * np.random.rand()]) |
|
|
elif r < 0.75: |
|
|
resolution = np.array([1.3, 1.3, 4.8]) + 0.4 * np.random.rand(3) |
|
|
thickness = resolution.copy() |
|
|
else: |
|
|
resolution = 2.0 + 3.0 * np.random.rand(3) |
|
|
thickness = resolution.copy() |
|
|
|
|
|
return resolution, thickness |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def binarize(p, thres): |
|
|
|
|
|
thres = thres * p.max() |
|
|
|
|
|
bin = p.clone() |
|
|
bin[p < thres] = 0. |
|
|
bin[p >= thres] = 1. |
|
|
return bin |
|
|
|
|
|
def make_gaussian_kernel(sigma, device): |
|
|
|
|
|
sl = int(np.ceil(3 * sigma)) |
|
|
ts = torch.linspace(-sl, sl, 2*sl+1, dtype=torch.float, device=device) |
|
|
gauss = torch.exp((-(ts / sigma)**2 / 2)) |
|
|
kernel = gauss / gauss.sum() |
|
|
|
|
|
return kernel |
|
|
|
|
|
def gaussian_blur_3d(input, stds, device): |
|
|
blurred = input[None, None, :, :, :] |
|
|
if stds[0]>0: |
|
|
kx = make_gaussian_kernel(stds[0], device=device) |
|
|
blurred = conv3d(blurred, kx[None, None, :, None, None], stride=1, padding=(len(kx) // 2, 0, 0)) |
|
|
if stds[1]>0: |
|
|
ky = make_gaussian_kernel(stds[1], device=device) |
|
|
blurred = conv3d(blurred, ky[None, None, None, :, None], stride=1, padding=(0, len(ky) // 2, 0)) |
|
|
if stds[2]>0: |
|
|
kz = make_gaussian_kernel(stds[2], device=device) |
|
|
blurred = conv3d(blurred, kz[None, None, None, None, :], stride=1, padding=(0, 0, len(kz) // 2)) |
|
|
return torch.squeeze(blurred) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_affine_matrix(rot, sh, s): |
|
|
Rx = np.array([[1, 0, 0], [0, np.cos(rot[0]), -np.sin(rot[0])], [0, np.sin(rot[0]), np.cos(rot[0])]]) |
|
|
Ry = np.array([[np.cos(rot[1]), 0, np.sin(rot[1])], [0, 1, 0], [-np.sin(rot[1]), 0, np.cos(rot[1])]]) |
|
|
Rz = np.array([[np.cos(rot[2]), -np.sin(rot[2]), 0], [np.sin(rot[2]), np.cos(rot[2]), 0], [0, 0, 1]]) |
|
|
|
|
|
SHx = np.array([[1, 0, 0], [sh[1], 1, 0], [sh[2], 0, 1]]) |
|
|
SHy = np.array([[1, sh[0], 0], [0, 1, 0], [0, sh[2], 1]]) |
|
|
SHz = np.array([[1, 0, sh[0]], [0, 1, sh[1]], [0, 0, 1]]) |
|
|
|
|
|
A = SHx @ SHy @ SHz @ Rx @ Ry @ Rz |
|
|
A[0, :] = A[0, :] * s[0] |
|
|
A[1, :] = A[1, :] * s[1] |
|
|
A[2, :] = A[2, :] * s[2] |
|
|
|
|
|
return A |
|
|
|
|
|
|
|
|
def fast_3D_interp_torch(X, II, JJ, KK, mode='linear', default_value_linear=0.0): |
|
|
|
|
|
if II is None: |
|
|
return X |
|
|
|
|
|
if mode=='nearest': |
|
|
IIr = torch.round(II).long() |
|
|
JJr = torch.round(JJ).long() |
|
|
KKr = torch.round(KK).long() |
|
|
IIr[IIr < 0] = 0 |
|
|
JJr[JJr < 0] = 0 |
|
|
KKr[KKr < 0] = 0 |
|
|
IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1) |
|
|
JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1) |
|
|
KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1) |
|
|
if len(X.shape)==3: |
|
|
X = X[..., None] |
|
|
Y = X[IIr, JJr, KKr] |
|
|
if Y.shape[3] == 1: |
|
|
Y = Y[:, :, :, 0] |
|
|
|
|
|
elif mode=='linear': |
|
|
ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1) |
|
|
|
|
|
IIv = II[ok] |
|
|
JJv = JJ[ok] |
|
|
KKv = KK[ok] |
|
|
|
|
|
fx = torch.floor(IIv).long() |
|
|
cx = fx + 1 |
|
|
cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1) |
|
|
wcx = (IIv - fx)[..., None] |
|
|
wfx = 1 - wcx |
|
|
|
|
|
fy = torch.floor(JJv).long() |
|
|
cy = fy + 1 |
|
|
cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1) |
|
|
wcy = (JJv - fy)[..., None] |
|
|
wfy = 1 - wcy |
|
|
|
|
|
fz = torch.floor(KKv).long() |
|
|
cz = fz + 1 |
|
|
cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1) |
|
|
wcz = (KKv - fz)[..., None] |
|
|
wfz = 1 - wcz |
|
|
|
|
|
if len(X.shape)==3: |
|
|
X = X[..., None] |
|
|
|
|
|
c000 = X[fx, fy, fz] |
|
|
c100 = X[cx, fy, fz] |
|
|
c010 = X[fx, cy, fz] |
|
|
c110 = X[cx, cy, fz] |
|
|
c001 = X[fx, fy, cz] |
|
|
c101 = X[cx, fy, cz] |
|
|
c011 = X[fx, cy, cz] |
|
|
c111 = X[cx, cy, cz] |
|
|
|
|
|
c00 = c000 * wfx + c100 * wcx |
|
|
c01 = c001 * wfx + c101 * wcx |
|
|
c10 = c010 * wfx + c110 * wcx |
|
|
c11 = c011 * wfx + c111 * wcx |
|
|
|
|
|
c0 = c00 * wfy + c10 * wcy |
|
|
c1 = c01 * wfy + c11 * wcy |
|
|
|
|
|
c = c0 * wfz + c1 * wcz |
|
|
|
|
|
Y = torch.zeros([*II.shape, X.shape[3]], device=X.device) |
|
|
Y[ok] = c.float() |
|
|
Y[~ok] = default_value_linear |
|
|
|
|
|
if Y.shape[-1]==1: |
|
|
Y = Y[...,0] |
|
|
else: |
|
|
raise Exception('mode must be linear or nearest') |
|
|
|
|
|
return Y |
|
|
|
|
|
|
|
|
|
|
|
def myzoom_torch(X, factor, aff=None): |
|
|
|
|
|
if len(X.shape)==3: |
|
|
X = X[..., None] |
|
|
|
|
|
delta = (1.0 - factor) / (2.0 * factor) |
|
|
newsize = np.round(X.shape[:-1] * factor).astype(int) |
|
|
|
|
|
vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=X.device)[:newsize[0]] |
|
|
vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=X.device)[:newsize[1]] |
|
|
vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=X.device)[:newsize[2]] |
|
|
|
|
|
vx[vx < 0] = 0 |
|
|
vy[vy < 0] = 0 |
|
|
vz[vz < 0] = 0 |
|
|
vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) |
|
|
vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) |
|
|
vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) |
|
|
|
|
|
fx = torch.floor(vx).int() |
|
|
cx = fx + 1 |
|
|
cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) |
|
|
wcx = (vx - fx) |
|
|
wfx = 1 - wcx |
|
|
|
|
|
fy = torch.floor(vy).int() |
|
|
cy = fy + 1 |
|
|
cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) |
|
|
wcy = (vy - fy) |
|
|
wfy = 1 - wcy |
|
|
|
|
|
fz = torch.floor(vz).int() |
|
|
cz = fz + 1 |
|
|
cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) |
|
|
wcz = (vz - fz) |
|
|
wfz = 1 - wcz |
|
|
|
|
|
Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=X.device) |
|
|
|
|
|
tmp1 = torch.zeros([newsize[0], X.shape[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) |
|
|
for i in range(newsize[0]): |
|
|
tmp1[i, :, :] = wfx[i] * X[fx[i], :, :] + wcx[i] * X[cx[i], :, :] |
|
|
tmp2 = torch.zeros([newsize[0], newsize[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) |
|
|
for j in range(newsize[1]): |
|
|
tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] |
|
|
for k in range(newsize[2]): |
|
|
Y[:, :, k] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] |
|
|
|
|
|
if Y.shape[3] == 1: |
|
|
Y = Y[:,:,:, 0] |
|
|
|
|
|
if aff is not None: |
|
|
aff_new = aff.copy() |
|
|
aff_new[:-1] = aff_new[:-1] / factor |
|
|
aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3))) |
|
|
return Y, aff_new |
|
|
else: |
|
|
return Y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_image(file_name): |
|
|
img = nib.load(file_name) |
|
|
aff = img.affine |
|
|
res = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0)) |
|
|
return img, aff, res |
|
|
|
|
|
def deform_image(I, deform_dict, device, default_value_linear_mode=None, deform_mode = 'linear'): |
|
|
if I is None: |
|
|
return I |
|
|
|
|
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
|
|
if not isinstance(I, torch.Tensor): |
|
|
I = torch.squeeze(torch.tensor(I.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device)) |
|
|
else: |
|
|
I = torch.squeeze(I[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device) |
|
|
I = torch.nan_to_num(I) |
|
|
|
|
|
if default_value_linear_mode is not None: |
|
|
if default_value_linear_mode == 'max': |
|
|
default_value_linear = torch.max(I) |
|
|
else: |
|
|
raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode) |
|
|
else: |
|
|
default_value_linear = 0. |
|
|
Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear) |
|
|
|
|
|
return Idef |
|
|
|
|
|
|
|
|
def read_and_deform(file_name, dtype, deform_dict, device, mask, default_value_linear_mode=None, deform_mode = 'linear', mean = 0., scale = 1.): |
|
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
|
|
try: |
|
|
Iimg = nib.load(file_name) |
|
|
except: |
|
|
Iimg = nib.load(file_name + '.gz') |
|
|
res = np.sqrt(np.sum(abs(Iimg.affine[:-1, :-1]), axis=0)) |
|
|
I = torch.squeeze(torch.tensor(Iimg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=dtype, device=device)) |
|
|
I = torch.nan_to_num(I) |
|
|
|
|
|
I -= mean |
|
|
I /= scale |
|
|
|
|
|
if mask is not None: |
|
|
I[mask == 0] = 0 |
|
|
|
|
|
if default_value_linear_mode is not None: |
|
|
if default_value_linear_mode == 'max': |
|
|
default_value_linear = torch.max(I) |
|
|
else: |
|
|
raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode) |
|
|
else: |
|
|
default_value_linear = 0. |
|
|
Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear) |
|
|
return Idef, res |
|
|
|
|
|
|
|
|
def read_and_deform_image(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): |
|
|
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask) |
|
|
Idef -= torch.min(Idef) |
|
|
Idef /= torch.max(Idef) |
|
|
if setups['flip']: |
|
|
Idef = torch.flip(Idef, [0]) |
|
|
update_dict = {task_name: Idef[None]} |
|
|
|
|
|
if os.path.isfile(file_name[:-4] + '.defacingmask.nii'): |
|
|
Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask) |
|
|
Idef_DM = torch.clamp(Idef_DM, min = 0.) |
|
|
Idef_DM /= torch.max(Idef_DM) |
|
|
if setups['flip']: |
|
|
Idef = torch.flip(Idef_DM, [0]) |
|
|
update_dict.update({task_name + '_DM': Idef_DM[None]}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return update_dict |
|
|
|
|
|
def read_and_deform_CT(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): |
|
|
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask, scale = 1000) |
|
|
|
|
|
|
|
|
if setups['flip']: |
|
|
Idef = torch.flip(Idef, [0]) |
|
|
update_dict = {'CT': Idef[None]} |
|
|
|
|
|
if os.path.isfile(file_name[:-4] + '.defacingmask.nii'): |
|
|
Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask) |
|
|
Idef_DM = torch.clamp(Idef_DM, min = 0.) |
|
|
Idef_DM /= torch.max(Idef_DM) |
|
|
if setups['flip']: |
|
|
Idef = torch.flip(Idef_DM, [0]) |
|
|
update_dict.update({task_name + '_DM': Idef_DM[None]}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return update_dict |
|
|
|
|
|
def read_and_deform_distance(exist_keys, task_name, file_names, setups, deform_dict, device, mask, cfg, **kwargs): |
|
|
[lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map] = file_names |
|
|
|
|
|
|
|
|
lp, _ = read_and_deform(lp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) |
|
|
lw, _ = read_and_deform(lw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) |
|
|
|
|
|
if mask is not None: |
|
|
Idef = torch.stack([lp, lw], dim = 0) |
|
|
else: |
|
|
rp, _ = read_and_deform(rp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) |
|
|
rw, _ = read_and_deform(rw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) |
|
|
|
|
|
if setups['flip']: |
|
|
aux = torch.flip(lp, [0]) |
|
|
lp = torch.flip(rp, [0]) |
|
|
rp = aux |
|
|
aux = torch.flip(lw, [0]) |
|
|
lw = torch.flip(rw, [0]) |
|
|
rw = aux |
|
|
|
|
|
Idef = torch.stack([lp, lw, rp, rw], dim = 0) |
|
|
|
|
|
Idef /= deform_dict['scaling_factor_distances'] |
|
|
Idef = torch.clamp(Idef, min=-cfg.max_surf_distance, max=cfg.max_surf_distance) |
|
|
|
|
|
return {'distance': Idef} |
|
|
|
|
|
def read_and_deform_segmentation(exist_keys, task_name, file_name, setups, deform_dict, device, mask, cfg, onehotmatrix, lut, vflip, **kwargs): |
|
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
|
|
Simg = nib.load(file_name) |
|
|
S = torch.squeeze(torch.tensor(Simg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int), dtype=torch.int, device=device)) |
|
|
|
|
|
if mask is not None: |
|
|
S[mask == 0] = 0 |
|
|
|
|
|
Sdef = fast_3D_interp_torch(S, xx2, yy2, zz2, 'nearest') |
|
|
if cfg.generator.deform_one_hots: |
|
|
Sonehot = onehotmatrix[lut[S.long()]] |
|
|
Sdef_OneHot = fast_3D_interp_torch(Sonehot, xx2, yy2, zz2) |
|
|
else: |
|
|
Sdef_OneHot = onehotmatrix[lut[Sdef.long()]] |
|
|
|
|
|
if setups['flip']: |
|
|
|
|
|
Sdef_OneHot = torch.flip(Sdef_OneHot, [0])[:, :, :, vflip] |
|
|
|
|
|
|
|
|
Sdef_OneHot = Sdef_OneHot.permute([3, 0, 1, 2]) |
|
|
|
|
|
|
|
|
update_dict = {'segmentation': Sdef_OneHot} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return update_dict |
|
|
|
|
|
|
|
|
|
|
|
def read_and_deform_pathology(exist_keys, task_name, file_name, setups, deform_dict, device, mask = None, |
|
|
augment = False, pde_func = None, t = None, |
|
|
shape_gen_args = None, thres = 0., **kwargs): |
|
|
|
|
|
|
|
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
|
|
if file_name is None: |
|
|
return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)} |
|
|
|
|
|
if file_name == 'random_shape': |
|
|
percentile = np.random.uniform(shape_gen_args.mask_percentile_min, shape_gen_args.mask_percentile_max) |
|
|
_, Pdef = generate_shape_3d(xx2.shape, shape_gen_args.perlin_res, percentile, device) |
|
|
else: |
|
|
Pdef, _ = read_and_deform(file_name, torch.float, deform_dict, device) |
|
|
|
|
|
if augment: |
|
|
Pdef = augment_pathology(Pdef, pde_func, t, shape_gen_args, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
P = binarize(Pdef, thres) |
|
|
if P.mean() <= shape_gen_args.pathol_tol: |
|
|
return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)} |
|
|
|
|
|
|
|
|
return {'pathology': P[None], 'pathology_prob': Pdef[None]} |
|
|
|
|
|
|
|
|
def read_and_deform_registration(exist_keys, task_name, file_names, setups, deform_dict, device, mask, **kwargs): |
|
|
[mni_reg_x, mni_reg_y, mni_reg_z] = file_names |
|
|
regx, _ = read_and_deform(mni_reg_x, torch.float, deform_dict, device, mask, scale = 10000) |
|
|
regy, _ = read_and_deform(mni_reg_y, torch.float, deform_dict, device, mask, scale = 10000) |
|
|
regz, _ = read_and_deform(mni_reg_z, torch.float, deform_dict, device, mask, scale = 10000) |
|
|
|
|
|
if setups['flip']: |
|
|
regx = -torch.flip(regx, [0]) |
|
|
regy = torch.flip(regy, [0]) |
|
|
regz = torch.flip(regz, [0]) |
|
|
|
|
|
Idef = torch.stack([regx, regy, regz], dim = 0) |
|
|
|
|
|
return {'registration': Idef} |
|
|
|
|
|
def read_and_deform_bias_field(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): |
|
|
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, mask, device) |
|
|
if setups['flip']: |
|
|
Idef = torch.flip(Idef, [0]) |
|
|
return {'bias_field': Idef[None]} |
|
|
|
|
|
def read_and_deform_surface(exist_keys, task_name, file_name, setups, deform_dict, device, mask, size): |
|
|
Fneg, A, c2 = deform_dict['Fneg'], deform_dict['A'], deform_dict['c2'] |
|
|
|
|
|
|
|
|
mat = loadmat(file_name.split('.nii')[0] + '.mat') |
|
|
|
|
|
Vlw = torch.tensor(mat['Vlw'], dtype=torch.float, device=device) |
|
|
Flw = torch.tensor(mat['Flw'], dtype=torch.int, device=device) |
|
|
Vrw = torch.tensor(mat['Vrw'], dtype=torch.float, device=device) |
|
|
Frw = torch.tensor(mat['Frw'], dtype=torch.int, device=device) |
|
|
Vlp = torch.tensor(mat['Vlp'], dtype=torch.float, device=device) |
|
|
Flp = torch.tensor(mat['Flp'], dtype=torch.int, device=device) |
|
|
Vrp = torch.tensor(mat['Vrp'], dtype=torch.float, device=device) |
|
|
Frp = torch.tensor(mat['Frp'], dtype=torch.int, device=device) |
|
|
|
|
|
Ainv = torch.inverse(A) |
|
|
Vlw -= c2[None, :] |
|
|
Vlw = Vlw @ torch.transpose(Ainv, 0, 1) |
|
|
Vlw += fast_3D_interp_torch(Fneg, Vlw[:, 0] + c2[0], Vlw[:, 1]+c2[1], Vlw[:, 2] + c2[2]) |
|
|
Vlw += c2[None, :] |
|
|
Vrw -= c2[None, :] |
|
|
Vrw = Vrw @ torch.transpose(Ainv, 0, 1) |
|
|
Vrw += fast_3D_interp_torch(Fneg, Vrw[:, 0] + c2[0], Vrw[:, 1]+c2[1], Vrw[:, 2] + c2[2]) |
|
|
Vrw += c2[None, :] |
|
|
Vlp -= c2[None, :] |
|
|
Vlp = Vlp @ torch.transpose(Ainv, 0, 1) |
|
|
Vlp += fast_3D_interp_torch(Fneg, Vlp[:, 0] + c2[0], Vlp[:, 1] + c2[1], Vlp[:, 2] + c2[2]) |
|
|
Vlp += c2[None, :] |
|
|
Vrp -= c2[None, :] |
|
|
Vrp = Vrp @ torch.transpose(Ainv, 0, 1) |
|
|
Vrp += fast_3D_interp_torch(Fneg, Vrp[:, 0] + c2[0], Vrp[:, 1] + c2[1], Vrp[:, 2] + c2[2]) |
|
|
Vrp += c2[None, :] |
|
|
|
|
|
if setups['flip']: |
|
|
Vlw[:, 0] = size[0] - 1 - Vlw[:, 0] |
|
|
Vrw[:, 0] = size[0] - 1 - Vrw[:, 0] |
|
|
Vlp[:, 0] = size[0] - 1 - Vlp[:, 0] |
|
|
Vrp[:, 0] = size[0] - 1 - Vrp[:, 0] |
|
|
Vlw, Vrw = Vrw, Vlw |
|
|
Vlp, Vrp = Vrp, Vlp |
|
|
Flw, Frw = Frw, Flw |
|
|
Flp, Frp = Frp, Flp |
|
|
|
|
|
print(Vlw.shape) |
|
|
print(Vlp.shape) |
|
|
|
|
|
print(Vrw.shape) |
|
|
print(Vrp.shape) |
|
|
|
|
|
print(Flw.shape) |
|
|
print(Flp.shape) |
|
|
|
|
|
print(Frw.shape) |
|
|
print(Frp.shape) |
|
|
|
|
|
return {'Vlw': Vlw, 'Flw': Flw, 'Vrw': Vrw, 'Frw': Frw, 'Vlp': Vlp, 'Flp': Flp, 'Vrp': Vrp, 'Frp': Frp} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def augment_pathology(Pprob, pde_func, t, shape_gen_args, device): |
|
|
Pprob = torch.squeeze(Pprob) |
|
|
|
|
|
nt = np.random.randint(1, shape_gen_args.max_nt+1) |
|
|
if nt <= 1: |
|
|
return Pprob |
|
|
|
|
|
pde_func.V_dict = generate_velocity_3d(Pprob.shape, shape_gen_args.perlin_res, shape_gen_args.V_multiplier, device) |
|
|
|
|
|
|
|
|
Pprob = odeint(pde_func, Pprob[None], t[:nt], |
|
|
shape_gen_args.dt, |
|
|
method = shape_gen_args.integ_method)[-1, 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Pprob |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_gamma_transform(I, aux_dict, cfg, device, **kwargs): |
|
|
gamma = torch.tensor(np.exp(cfg.gamma_std * np.random.randn(1)[0]), dtype=float, device=device) |
|
|
I_gamma = 300.0 * (I / 300.0) ** gamma |
|
|
|
|
|
return I_gamma, aux_dict |
|
|
|
|
|
def add_bias_field(I, aux_dict, cfg, input_mode, setups, size, device, **kwargs): |
|
|
if input_mode == 'CT': |
|
|
aux_dict.update({'high_res': I}) |
|
|
return I, aux_dict |
|
|
|
|
|
bf_scale = cfg.bf_scale_min + np.random.rand(1) * (cfg.bf_scale_max - cfg.bf_scale_min) |
|
|
size_BF_small = np.round(bf_scale * np.array(size)).astype(int).tolist() |
|
|
if setups['photo_mode']: |
|
|
size_BF_small[1] = np.round(size[1]/setups['spac']).astype(int) |
|
|
BFsmall = torch.tensor(cfg.bf_std_min + (cfg.bf_std_max - cfg.bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \ |
|
|
torch.randn(size_BF_small, dtype=torch.float, device=device) |
|
|
BFlog = myzoom_torch(BFsmall, np.array(size) / size_BF_small) |
|
|
BF = torch.exp(BFlog) |
|
|
I_bf = I * BF |
|
|
aux_dict.update({'BFlog': BFlog, 'high_res': I_bf}) |
|
|
return I_bf, aux_dict |
|
|
|
|
|
def resample_resolution(I, aux_dict, setups, res, size, device, **kwargs): |
|
|
stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res |
|
|
stds[setups['thickness']<=res] = 0.0 |
|
|
I_blur = gaussian_blur_3d(I, stds, device) |
|
|
new_size = (np.array(size) * res / setups['resolution']).astype(int) |
|
|
|
|
|
factors = np.array(new_size) / np.array(size) |
|
|
delta = (1.0 - factors) / (2.0 * factors) |
|
|
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] |
|
|
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] |
|
|
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] |
|
|
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') |
|
|
II = torch.tensor(II, dtype=torch.float, device=device) |
|
|
JJ = torch.tensor(JJ, dtype=torch.float, device=device) |
|
|
KK = torch.tensor(KK, dtype=torch.float, device=device) |
|
|
|
|
|
I_small = fast_3D_interp_torch(I_blur, II, JJ, KK) |
|
|
aux_dict.update({'factors': factors}) |
|
|
return I_small, aux_dict |
|
|
|
|
|
|
|
|
def resample_resolution_photo(I, aux_dict, setups, res, size, device, **kwargs): |
|
|
stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res |
|
|
stds[setups['thickness']<=res] = 0.0 |
|
|
I_blur = gaussian_blur_3d(I, stds, device) |
|
|
new_size = (np.array(size) * res / setups['resolution']).astype(int) |
|
|
|
|
|
factors = np.array(new_size) / np.array(size) |
|
|
delta = (1.0 - factors) / (2.0 * factors) |
|
|
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] |
|
|
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] |
|
|
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] |
|
|
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') |
|
|
II = torch.tensor(II, dtype=torch.float, device=device) |
|
|
JJ = torch.tensor(JJ, dtype=torch.float, device=device) |
|
|
KK = torch.tensor(KK, dtype=torch.float, device=device) |
|
|
|
|
|
I_small = fast_3D_interp_torch(I_blur, II, JJ, KK) |
|
|
aux_dict.update({'factors': factors}) |
|
|
return I_small, aux_dict |
|
|
|
|
|
|
|
|
def add_noise(I, aux_dict, cfg, device, **kwargs): |
|
|
noise_std = torch.tensor(cfg.noise_std_min + (cfg.noise_std_max - cfg.noise_std_min) * np.random.rand(1), dtype=torch.float, device=device) |
|
|
I_noisy = I + noise_std * torch.randn(I.shape, dtype=torch.float, device=device) |
|
|
I_noisy[I_noisy < 0] = 0 |
|
|
|
|
|
return I_noisy, aux_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
right_to_left_dict = { |
|
|
41: 2, |
|
|
42: 3, |
|
|
43: 4, |
|
|
44: 5, |
|
|
46: 7, |
|
|
47: 8, |
|
|
49: 10, |
|
|
50: 11, |
|
|
51: 12, |
|
|
52: 13, |
|
|
53: 17, |
|
|
54: 18, |
|
|
58: 26, |
|
|
60: 28 |
|
|
} |
|
|
|
|
|
|
|
|
ct_brightness_group = { |
|
|
'darker': [4, 5, 14, 15, 24, 31, 72], |
|
|
'dark': [2, 7, 16, 77, 30], |
|
|
'bright': [3, 8, 17, 18, 28, 10, 11, 12, 13, 26], |
|
|
'brighter': [], |
|
|
} |
|
|
|