|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from Generator.utils import fast_3D_interp_torch, myzoom_torch |
|
|
from Trainer.models import build_model, build_inpaint_model |
|
|
from utils.checkpoint import load_checkpoint |
|
|
import utils.misc as utils |
|
|
|
|
|
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/submit.yaml' |
|
|
default_gen_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_gen.yaml' |
|
|
|
|
|
default_train_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_train.yaml' |
|
|
default_val_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_val.yaml' |
|
|
|
|
|
|
|
|
gen_cfg_dir = '' |
|
|
train_cfg_dir = '' |
|
|
atlas_path = '/autofs/vast/lemon/temp_stuff/brainfm/files/gca.mgz' |
|
|
|
|
|
|
|
|
MNI, aff2 = utils.MRIread(atlas_path) |
|
|
A = np.linalg.inv(aff2) |
|
|
A = torch.tensor(A, device=device, dtype=torch.float32) |
|
|
MNI = torch.tensor(MNI, device = device, dtype = torch.float32) |
|
|
|
|
|
def get_deformed_atlas(brain_labels, regx, regy, regz): |
|
|
M = brain_labels>0 |
|
|
xx = 100 * regx[M] |
|
|
yy = 100 * regy[M] |
|
|
zz = 100 * regz[M] |
|
|
ii = A[0, 0] * xx + A[0, 1] * yy + A[0, 2] * zz + A[0, 3] |
|
|
jj = A[1, 0] * xx + A[1, 1] * yy + A[1, 2] * zz + A[1, 3] |
|
|
kk = A[2, 0] * xx + A[2, 1] * yy + A[2, 2] * zz + A[2, 3] |
|
|
|
|
|
vals = fast_3D_interp_torch(MNI, ii, jj, kk, 'linear') |
|
|
DEF = torch.zeros_like(regx) |
|
|
DEF[M] = vals |
|
|
return DEF |
|
|
|
|
|
|
|
|
def zero_crop(orig, tol = 0, crop_range_lst = None, save_path = None): |
|
|
|
|
|
''' |
|
|
crop_range_lst: [[x0, y0, z0], [x1, y1, z1]] |
|
|
''' |
|
|
|
|
|
if crop_range_lst is None: |
|
|
|
|
|
|
|
|
mask = orig > tol |
|
|
|
|
|
|
|
|
coords = torch.argwhere(mask) |
|
|
|
|
|
|
|
|
|
|
|
x0, y0, z0 = coords.min(dim=0)[0] |
|
|
x1, y1, z1 = coords.max(dim=0)[0] + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
[[x0, y0, z0], [x1, y1, z1]] = crop_range_lst |
|
|
|
|
|
cropped = orig[x0 : x1, y0 : y1, z0 : z1] |
|
|
|
|
|
return cropped |
|
|
|
|
|
|
|
|
def tiling(img, stride = [40, 40, 40], win_size = [160, 160, 160], zero_crop_first = False): |
|
|
|
|
|
orig_shp = img.shape[2:] |
|
|
|
|
|
|
|
|
if zero_crop_first: |
|
|
|
|
|
img = zero_crop(img[0, 0])[None, None] |
|
|
orig_shp = img.shape[2:] |
|
|
print('shape after zero croppping:', orig_shp) |
|
|
|
|
|
|
|
|
x_start, y_start, z_start = 0, 0, 0 |
|
|
x_end = min(x_start + win_size[0], orig_shp[0]) |
|
|
y_end = min(y_start + win_size[1], orig_shp[1]) |
|
|
z_end = min(z_start + win_size[2], orig_shp[2]) |
|
|
x_list, y_list, z_list = [(x_start, x_end)], [(y_start, y_end)], [([z_start, z_end])] |
|
|
|
|
|
while x_end < orig_shp[0]: |
|
|
x_start = min(x_end, orig_shp[0] - stride[0]) |
|
|
x_end = min(x_start + stride[0], orig_shp[0]) |
|
|
x_list.append(([x_start, x_end])) |
|
|
|
|
|
while y_end < orig_shp[1]: |
|
|
y_start = min(y_end, orig_shp[1] - stride[1]) |
|
|
y_end = min(y_start + stride[1], orig_shp[1]) |
|
|
y_list.append(([y_start, y_end])) |
|
|
|
|
|
while z_end < orig_shp[2]: |
|
|
z_start = min(z_end, orig_shp[2] - stride[2]) |
|
|
z_end = min(z_start + stride[2], orig_shp[2]) |
|
|
z_list.append(([z_start, z_end])) |
|
|
|
|
|
img_list = [] |
|
|
cnt = torch.zeros_like(img[0, 0]) |
|
|
for (x_start, x_end) in x_list: |
|
|
for (y_start, y_end) in y_list: |
|
|
for (z_start, z_end) in z_list: |
|
|
curr_img = img[:, :, x_start : x_end, y_start : y_end, z_start : z_end] |
|
|
curr_range = [(x_start, x_end), (y_start, y_end), (z_start, z_end)] |
|
|
img_list.append((curr_img, curr_range)) |
|
|
cnt[x_start : x_end, y_start : y_end, z_start : z_end] += 1 |
|
|
print('Patch #'+str(len(img_list)), curr_img.shape[2:], curr_range) |
|
|
|
|
|
return img_list, cnt |
|
|
|
|
|
|
|
|
|
|
|
def center_crop(img, win_size = [220, 220, 220], zero_crop_first = False, aff=np.eye(4)): |
|
|
|
|
|
if len(img.shape) == 4: |
|
|
img = torch.permute(img, (3, 0, 1, 2)) |
|
|
img = img[None] |
|
|
permuted = True |
|
|
else: |
|
|
assert len(img.shape) == 3 |
|
|
img = img[None, None] |
|
|
permuted = False |
|
|
|
|
|
orig_shp = img.shape[2:] |
|
|
|
|
|
|
|
|
if zero_crop_first: |
|
|
print(' before zero croppping:', orig_shp) |
|
|
img = zero_crop(img[0, 0])[None, None] |
|
|
orig_shp = img.shape[2:] |
|
|
print(' after zero croppping:', orig_shp) |
|
|
|
|
|
|
|
|
if win_size is None: |
|
|
if permuted: |
|
|
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp |
|
|
return img, [0, 0, 0], orig_shp, aff |
|
|
|
|
|
elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]: |
|
|
crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ] |
|
|
aff[:-1,-1] = aff[:-1,-1] + aff[:-1,:-1] @ np.array(crop_start) |
|
|
crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0], |
|
|
crop_start[1] : crop_start[1] + win_size[1], |
|
|
crop_start[2] : crop_start[2] + win_size[2]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if permuted: |
|
|
return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff |
|
|
return crop_img, crop_start, orig_shp, aff |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if permuted: |
|
|
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff |
|
|
return img, [0, 0, 0], orig_shp, aff |
|
|
|
|
|
|
|
|
|
|
|
def add_bias_field(I, bf_scale_min = 0.02, bf_scale_max = 0.04, bf_std_min = 0.1, bf_std_max = 0.6, device = 'cpu'): |
|
|
bf_scale = bf_scale_min + np.random.rand(1) * (bf_scale_max - bf_scale_min) |
|
|
size_BF_small = np.round(bf_scale * np.array(I.shape)).astype(int).tolist() |
|
|
BFsmall = torch.tensor(bf_std_min + (bf_std_max - bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \ |
|
|
torch.randn(size_BF_small, dtype=torch.float, device=device) |
|
|
BFlog = myzoom_torch(BFsmall, np.array(I.shape) / size_BF_small) |
|
|
BF = torch.exp(BFlog) |
|
|
I_bf = I * BF |
|
|
return I_bf, BF |
|
|
|
|
|
def resample(I, orig_res = [1., 1., 1.], new_res = [1., 1., 1.]): |
|
|
if not isinstance(orig_res, list): |
|
|
orig_res = [orig_res, orig_res, orig_res] |
|
|
if not isinstance(new_res, list): |
|
|
new_res = [new_res, new_res, new_res] |
|
|
|
|
|
resolution = np.array(new_res) |
|
|
new_size = (np.array(I.shape) * orig_res / resolution).astype(int) |
|
|
|
|
|
factors = np.array(new_size) / np.array(I.shape) |
|
|
delta = (1.0 - factors) / (2.0 * factors) |
|
|
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] |
|
|
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] |
|
|
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] |
|
|
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') |
|
|
II = torch.tensor(II, dtype=torch.float, device=I.device) |
|
|
JJ = torch.tensor(JJ, dtype=torch.float, device=I.device) |
|
|
KK = torch.tensor(KK, dtype=torch.float, device=I.device) |
|
|
|
|
|
I_resize = fast_3D_interp_torch(I, II, JJ, KK, 'linear') |
|
|
I_new = utils.myzoom_torch(I_resize, 1 / factors) |
|
|
|
|
|
|
|
|
return I_new |
|
|
|
|
|
|
|
|
def read_image(img_path, is_label = False, device = 'cpu'): |
|
|
im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') |
|
|
im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) |
|
|
im = torch.nan_to_num(im) |
|
|
return im |
|
|
|
|
|
|
|
|
def prepare_image(img_path, win_size = None, zero_crop_first = False, spacing = None, |
|
|
add_bf = False, is_CT = False, is_label = False, rescale = True, |
|
|
hemis_mask = None, im_only = False, device = 'cpu'): |
|
|
im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') |
|
|
im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) |
|
|
im = torch.nan_to_num(im) |
|
|
|
|
|
if len(im.shape) > 3: |
|
|
|
|
|
im = im.mean(dim = -1) |
|
|
|
|
|
if is_CT and rescale: |
|
|
im = torch.clamp(im, min = 0., max = 80.) |
|
|
|
|
|
if not is_label and rescale: |
|
|
im -= torch.min(im) |
|
|
im /= torch.max(im) |
|
|
|
|
|
im, aff = utils.torch_resize(im, aff, 1.) |
|
|
|
|
|
orig = im |
|
|
orig, aff_before_crop = utils.align_volume_to_ref(orig, aff, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
|
|
|
|
|
orig, crop_start, orig_shp, aff = center_crop(orig, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
|
|
|
if add_bf and not is_CT: |
|
|
high_res, bf = add_bias_field(im, device = device) |
|
|
bf, _ = utils.align_volume_to_ref(bf, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
|
|
bf, crop_start, orig_shp, _ = center_crop(bf, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
else: |
|
|
high_res, bf = im, None |
|
|
|
|
|
if spacing is not None: |
|
|
final = resample(high_res, new_res = spacing) |
|
|
else: |
|
|
final = high_res |
|
|
|
|
|
high_res, _ = utils.align_volume_to_ref(high_res, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
|
|
high_res, crop_start, orig_shp, _ = center_crop(high_res, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
|
|
|
final, _ = utils.align_volume_to_ref(final, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
|
|
final, crop_start, orig_shp, _ = center_crop(final, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
|
|
|
if hemis_mask is not None: |
|
|
final[hemis_mask ==0] = 0 |
|
|
|
|
|
if im_only: |
|
|
return final |
|
|
|
|
|
return final, orig, high_res, bf, aff, crop_start, orig_shp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate_image(inputs, ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): |
|
|
|
|
|
|
|
|
|
|
|
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
|
|
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
|
|
|
|
|
samples = [ { 'input': inputs } ] |
|
|
|
|
|
|
|
|
gen_args, train_args, feat_model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) |
|
|
load_checkpoint(ckp_path, [feat_model], model_keys = ['model'], to_print = False) |
|
|
outputs, _ = feat_model(samples) |
|
|
|
|
|
for processor in processors: |
|
|
outputs = processor(outputs, samples) |
|
|
if postprocessor is not None: |
|
|
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
|
|
|
if feature_only: |
|
|
return outputs[0]['feat'][-1] |
|
|
else: |
|
|
return outputs[0] |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate_image_twostage(inputs, pathol_ckp_path, task_ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): |
|
|
|
|
|
|
|
|
|
|
|
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
|
|
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
|
|
|
|
|
samples = [ { 'input': inputs } ] |
|
|
|
|
|
|
|
|
gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor = build_inpaint_model(gen_args, train_args, device) |
|
|
load_checkpoint(pathol_ckp_path, [pathol_model], model_keys = ['model'], to_print = False) |
|
|
load_checkpoint(task_ckp_path, [task_model], model_keys = ['model'], to_print = False) |
|
|
|
|
|
|
|
|
outputs_pathol, _ = pathol_model(samples) |
|
|
for processor in pathol_processors: |
|
|
outputs_pathol = processor(outputs_pathol, samples) |
|
|
|
|
|
|
|
|
for i in range(len(samples)): |
|
|
samples[i]['input_masked'] = samples[i]['input'] * (1 - outputs_pathol[i]['pathology']) |
|
|
outputs_task, _ = task_model(samples, input_name = 'input_masked', cond = [o['pathology'] for o in outputs_pathol]) |
|
|
for processor in task_processors: |
|
|
outputs_task = processor(outputs_task, samples) |
|
|
|
|
|
outputs = utils.merge_list_of_dict(outputs_task, outputs_pathol) |
|
|
|
|
|
if postprocessor is not None: |
|
|
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
|
|
|
if feature_only: |
|
|
return outputs[0]['feat_pathol'][-1], outputs[0]['feat_task'][-1] |
|
|
else: |
|
|
return outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate_path(input_paths, save_dir, ckp_path, win_size = [220, 220, 220], |
|
|
save_input = False, aux_paths = {}, save_aux = False, exclude_keys = [], |
|
|
mask_output = False, ext = '.nii.gz', device = 'cpu', |
|
|
gen_cfg = None, model_cfg = None): |
|
|
|
|
|
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
|
|
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
|
|
|
|
|
|
|
|
gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) |
|
|
load_checkpoint(ckp_path, [model], model_keys = ['model'], to_print = False) |
|
|
|
|
|
for i, input_path in enumerate(input_paths): |
|
|
print('Now testing: %s (%d/%d)' % (input_path, i+1, len(input_paths))) |
|
|
print(' ckp:', ckp_path) |
|
|
curr_save_dir = utils.make_dir(os.path.join(save_dir, os.path.basename(input_path).split('.nii')[0])) |
|
|
|
|
|
|
|
|
mask = None |
|
|
im, orig, high_res, bf, aff, crop_start, orig_shp = prepare_image(input_path, win_size, device = device) |
|
|
if save_input: |
|
|
print(' Input: saved in - %s' % (os.path.join(curr_save_dir, 'input' + ext))) |
|
|
utils.viewVolume(im, aff, names = ['input'], ext = ext, save_dir = curr_save_dir) |
|
|
for k in aux_paths.keys(): |
|
|
im_k, _, _, _, _, _, _ = prepare_image(aux_paths[k][i], win_size, is_label = 'label' in k, device = device) |
|
|
if save_aux: |
|
|
print(' Aux input: %s - saved in - %s' % (k, os.path.join(curr_save_dir, k + ext))) |
|
|
utils.viewVolume(im_k, aff, names = [k], ext = ext, save_dir = curr_save_dir) |
|
|
if mask_output and 'mask' in k: |
|
|
mask = im_k.clone() |
|
|
mask[im_k != 0.] = 1. |
|
|
samples = [ { 'input': im } ] |
|
|
|
|
|
|
|
|
outputs, _ = model(samples) |
|
|
|
|
|
for processor in processors: |
|
|
outputs = processor(outputs, samples) |
|
|
if postprocessor is not None: |
|
|
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
|
|
|
out = outputs[0] |
|
|
if mask_output and mask is None: |
|
|
mask = torch.zeros_like(im) |
|
|
mask[im != 0.] = 1. |
|
|
for key in out.keys(): |
|
|
if key not in exclude_keys and isinstance(out[key], torch.Tensor): |
|
|
print(' Output: %s - saved in - %s' % (key, os.path.join(curr_save_dir, 'out_' + key + ext))) |
|
|
out[key][out[key] < 0.] = 0. |
|
|
utils.viewVolume(out[key] * mask if mask_output else out[key], aff, names = ['out_'+key], ext = ext, save_dir = curr_save_dir) |
|
|
|