BrainFM / utils /test_utils.py
peirong26's picture
Upload 187 files
2571f24 verified
import os
import numpy as np
import torch
from Generator.utils import fast_3D_interp_torch, myzoom_torch
from Trainer.models import build_model, build_inpaint_model
from utils.checkpoint import load_checkpoint
import utils.misc as utils
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
# default & gpu cfg #
#submit_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/submit.yaml'
#default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml'
#default_train_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_train.yaml'
#default_val_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_val.yaml'
#gen_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test'
#train_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/test'
#atlas_path = '/autofs/vast/lemon/temp_stuff/peirong/data/gca.mgz'
submit_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/submit.yaml'
default_gen_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_gen.yaml'
default_train_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_train.yaml'
default_val_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_val.yaml'
gen_cfg_dir = ''
train_cfg_dir = ''
atlas_path = '/autofs/vast/lemon/temp_stuff/brainfm/files/gca.mgz'
MNI, aff2 = utils.MRIread(atlas_path)
A = np.linalg.inv(aff2)
A = torch.tensor(A, device=device, dtype=torch.float32)
MNI = torch.tensor(MNI, device = device, dtype = torch.float32)
def get_deformed_atlas(brain_labels, regx, regy, regz):
M = brain_labels>0
xx = 100 * regx[M]
yy = 100 * regy[M]
zz = 100 * regz[M]
ii = A[0, 0] * xx + A[0, 1] * yy + A[0, 2] * zz + A[0, 3]
jj = A[1, 0] * xx + A[1, 1] * yy + A[1, 2] * zz + A[1, 3]
kk = A[2, 0] * xx + A[2, 1] * yy + A[2, 2] * zz + A[2, 3]
vals = fast_3D_interp_torch(MNI, ii, jj, kk, 'linear')
DEF = torch.zeros_like(regx)
DEF[M] = vals
return DEF
def zero_crop(orig, tol = 0, crop_range_lst = None, save_path = None):
'''
crop_range_lst: [[x0, y0, z0], [x1, y1, z1]]
'''
if crop_range_lst is None:
# Mask of non-black pixels (assuming image has a single channel).
mask = orig > tol
# Coordinates of non-black pixels.
coords = torch.argwhere(mask)
# Bounding box of non-black pixels.
#print(coords.min(dim=0))
x0, y0, z0 = coords.min(dim=0)[0]
x1, y1, z1 = coords.max(dim=0)[0] + 1 # slices are exclusive at the top
# Check the the bounding box #
#print(' Cropping Slice [%d, %d)' % (x0, x1))
#print(' Cropping Row [%d, %d)' % (y0, y1))
#print(' Cropping Column [%d, %d)' % (z0, z1))
else:
[[x0, y0, z0], [x1, y1, z1]] = crop_range_lst
cropped = orig[x0 : x1, y0 : y1, z0 : z1]
return cropped #, [[x0, y0, z0], [x1, y1, z1]]
def tiling(img, stride = [40, 40, 40], win_size = [160, 160, 160], zero_crop_first = False):
orig_shp = img.shape[2:] # (1, d, s, r, c)
# first, crop all zeros -> get *actual* shape
if zero_crop_first:
#print(' before zero croppping:', orig_shp)
img = zero_crop(img[0, 0])[None, None]
orig_shp = img.shape[2:]
print('shape after zero croppping:', orig_shp)
# tiling
x_start, y_start, z_start = 0, 0, 0
x_end = min(x_start + win_size[0], orig_shp[0])
y_end = min(y_start + win_size[1], orig_shp[1])
z_end = min(z_start + win_size[2], orig_shp[2])
x_list, y_list, z_list = [(x_start, x_end)], [(y_start, y_end)], [([z_start, z_end])]
while x_end < orig_shp[0]:
x_start = min(x_end, orig_shp[0] - stride[0])
x_end = min(x_start + stride[0], orig_shp[0])
x_list.append(([x_start, x_end]))
while y_end < orig_shp[1]:
y_start = min(y_end, orig_shp[1] - stride[1])
y_end = min(y_start + stride[1], orig_shp[1])
y_list.append(([y_start, y_end]))
while z_end < orig_shp[2]:
z_start = min(z_end, orig_shp[2] - stride[2])
z_end = min(z_start + stride[2], orig_shp[2])
z_list.append(([z_start, z_end]))
img_list = []
cnt = torch.zeros_like(img[0, 0])
for (x_start, x_end) in x_list:
for (y_start, y_end) in y_list:
for (z_start, z_end) in z_list:
curr_img = img[:, :, x_start : x_end, y_start : y_end, z_start : z_end]
curr_range = [(x_start, x_end), (y_start, y_end), (z_start, z_end)]
img_list.append((curr_img, curr_range))
cnt[x_start : x_end, y_start : y_end, z_start : z_end] += 1
print('Patch #'+str(len(img_list)), curr_img.shape[2:], curr_range)
return img_list, cnt
def center_crop(img, win_size = [220, 220, 220], zero_crop_first = False, aff=np.eye(4)):
# center crop
if len(img.shape) == 4:
img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first)
img = img[None]
permuted = True
else:
assert len(img.shape) == 3
img = img[None, None]
permuted = False
orig_shp = img.shape[2:] # (1, d, s, r, c)
# first, crop all zeros -> get *actual* shape
if zero_crop_first:
print(' before zero croppping:', orig_shp)
img = zero_crop(img[0, 0])[None, None]
orig_shp = img.shape[2:]
print(' after zero croppping:', orig_shp)
if win_size is None:
if permuted:
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
return img, [0, 0, 0], orig_shp, aff
elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]:
crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ]
aff[:-1,-1] = aff[:-1,-1] + aff[:-1,:-1] @ np.array(crop_start)
crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0],
crop_start[1] : crop_start[1] + win_size[1],
crop_start[2] : crop_start[2] + win_size[2]]
#pad_img = torch.zeros((1, 1, win_size[0], win_size[1], win_size[2]), device = device)
#pad_img[:, :, int((win_size[0] - crop_img.shape[2])/2) : int((win_size[0] - crop_img.shape[2])/2) + crop_img.shape[2], \
# int((win_size[1] - crop_img.shape[3])/2) : int((win_size[1] - crop_img.shape[3])/2) + crop_img.shape[3], \
# int((win_size[2] - crop_img.shape[4])/2) : int((win_size[2] - crop_img.shape[4])/2) + crop_img.shape[4] ] = crop_img
if permuted:
return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff
return crop_img, crop_start, orig_shp, aff
else:
#pad_img = torch.zeros((1, 1, win_size[0], win_size[1], win_size[2]), device = device)
#pad_img[:, :, int((win_size[0] - img.shape[2])/2) : int((win_size[0] - img.shape[2])/2) + img.shape[2], \
# int((win_size[1] - img.shape[3])/2) : int((win_size[1] - img.shape[3])/2) + img.shape[3], \
# int((win_size[2] - img.shape[4])/2) : int((win_size[2] - img.shape[4])/2) + img.shape[4] ] = img
if permuted:
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff
return img, [0, 0, 0], orig_shp, aff
def add_bias_field(I, bf_scale_min = 0.02, bf_scale_max = 0.04, bf_std_min = 0.1, bf_std_max = 0.6, device = 'cpu'):
bf_scale = bf_scale_min + np.random.rand(1) * (bf_scale_max - bf_scale_min)
size_BF_small = np.round(bf_scale * np.array(I.shape)).astype(int).tolist()
BFsmall = torch.tensor(bf_std_min + (bf_std_max - bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \
torch.randn(size_BF_small, dtype=torch.float, device=device)
BFlog = myzoom_torch(BFsmall, np.array(I.shape) / size_BF_small)
BF = torch.exp(BFlog)
I_bf = I * BF
return I_bf, BF
def resample(I, orig_res = [1., 1., 1.], new_res = [1., 1., 1.]):
if not isinstance(orig_res, list):
orig_res = [orig_res, orig_res, orig_res]
if not isinstance(new_res, list):
new_res = [new_res, new_res, new_res]
#print('pre resample', I.shape)
resolution = np.array(new_res)
new_size = (np.array(I.shape) * orig_res / resolution).astype(int)
factors = np.array(new_size) / np.array(I.shape)
delta = (1.0 - factors) / (2.0 * factors)
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]]
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]]
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]]
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij')
II = torch.tensor(II, dtype=torch.float, device=I.device)
JJ = torch.tensor(JJ, dtype=torch.float, device=I.device)
KK = torch.tensor(KK, dtype=torch.float, device=I.device)
I_resize = fast_3D_interp_torch(I, II, JJ, KK, 'linear')
I_new = utils.myzoom_torch(I_resize, 1 / factors)
#print('post resample', I_new.shape)
return I_new
def read_image(img_path, is_label = False, device = 'cpu'):
im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float')
im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device)
im = torch.nan_to_num(im)
return im
def prepare_image(img_path, win_size = None, zero_crop_first = False, spacing = None,
add_bf = False, is_CT = False, is_label = False, rescale = True,
hemis_mask = None, im_only = False, device = 'cpu'):
im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float')
im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device)
im = torch.nan_to_num(im)
if len(im.shape) > 3:
#print('shape', im.shape)
im = im.mean(dim = -1) # averaging the RGB
if is_CT and rescale: # for CT as input
im = torch.clamp(im, min = 0., max = 80.)
if not is_label and rescale:
im -= torch.min(im)
im /= torch.max(im)
im, aff = utils.torch_resize(im, aff, 1.)
orig = im
orig, aff_before_crop = utils.align_volume_to_ref(orig, aff, aff_ref=np.eye(4), return_aff=True, n_dims=3)
orig, crop_start, orig_shp, aff = center_crop(orig, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop)
if add_bf and not is_CT:
high_res, bf = add_bias_field(im, device = device)
bf, _ = utils.align_volume_to_ref(bf, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3)
bf, crop_start, orig_shp, _ = center_crop(bf, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop)
else:
high_res, bf = im, None
if spacing is not None:
final = resample(high_res, new_res = spacing)
else:
final = high_res
high_res, _ = utils.align_volume_to_ref(high_res, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3)
high_res, crop_start, orig_shp, _ = center_crop(high_res, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop)
final, _ = utils.align_volume_to_ref(final, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3)
final, crop_start, orig_shp, _ = center_crop(final, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop)
if hemis_mask is not None:
final[hemis_mask ==0] = 0
if im_only:
return final
return final, orig, high_res, bf, aff, crop_start, orig_shp
@torch.no_grad()
def evaluate_image(inputs, ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None):
# inputs: Torch.Tensor -- (batch_size, 1, s, r, c)
# ============ prepare ... ============
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir)
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir)
samples = [ { 'input': inputs } ]
# ============ testing ... ============
gen_args, train_args, feat_model, processors, criterion, postprocessor = build_model(gen_args, train_args, device)
load_checkpoint(ckp_path, [feat_model], model_keys = ['model'], to_print = False)
outputs, _ = feat_model(samples) # dict with features
for processor in processors:
outputs = processor(outputs, samples)
if postprocessor is not None:
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks)
if feature_only:
return outputs[0]['feat'][-1] # (batch_size, 64, s, r, c)
else:
return outputs[0]
@torch.no_grad()
def evaluate_image_twostage(inputs, pathol_ckp_path, task_ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None):
# inputs: Torch.Tensor -- (batch_size, 1, s, r, c)
# ============ prepare ... ============
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir)
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir)
samples = [ { 'input': inputs } ]
# ============ testing ... ============
gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor = build_inpaint_model(gen_args, train_args, device)
load_checkpoint(pathol_ckp_path, [pathol_model], model_keys = ['model'], to_print = False)
load_checkpoint(task_ckp_path, [task_model], model_keys = ['model'], to_print = False)
# stage-0: pathology segmentation prediction
outputs_pathol, _ = pathol_model(samples)
for processor in pathol_processors:
outputs_pathol = processor(outputs_pathol, samples)
# stage-1: pathology-mask-conditioned inpainting tasks prediction
for i in range(len(samples)): # mask using predicted anomaly
samples[i]['input_masked'] = samples[i]['input'] * (1 - outputs_pathol[i]['pathology'])
outputs_task, _ = task_model(samples, input_name = 'input_masked', cond = [o['pathology'] for o in outputs_pathol])
for processor in task_processors:
outputs_task = processor(outputs_task, samples)
outputs = utils.merge_list_of_dict(outputs_task, outputs_pathol)
if postprocessor is not None:
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks)
if feature_only:
return outputs[0]['feat_pathol'][-1], outputs[0]['feat_task'][-1] # (batch_size, 64, s, r, c)
else:
return outputs[0]
@torch.no_grad()
def evaluate_path(input_paths, save_dir, ckp_path, win_size = [220, 220, 220],
save_input = False, aux_paths = {}, save_aux = False, exclude_keys = [],
mask_output = False, ext = '.nii.gz', device = 'cpu',
gen_cfg = None, model_cfg = None):
gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir)
train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir)
# ============ loading ... ============
gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device)
load_checkpoint(ckp_path, [model], model_keys = ['model'], to_print = False)
for i, input_path in enumerate(input_paths):
print('Now testing: %s (%d/%d)' % (input_path, i+1, len(input_paths)))
print(' ckp:', ckp_path)
curr_save_dir = utils.make_dir(os.path.join(save_dir, os.path.basename(input_path).split('.nii')[0]))
# ============ prepare ... ============
mask = None
im, orig, high_res, bf, aff, crop_start, orig_shp = prepare_image(input_path, win_size, device = device)
if save_input:
print(' Input: saved in - %s' % (os.path.join(curr_save_dir, 'input' + ext)))
utils.viewVolume(im, aff, names = ['input'], ext = ext, save_dir = curr_save_dir)
for k in aux_paths.keys():
im_k, _, _, _, _, _, _ = prepare_image(aux_paths[k][i], win_size, is_label = 'label' in k, device = device)
if save_aux:
print(' Aux input: %s - saved in - %s' % (k, os.path.join(curr_save_dir, k + ext)))
utils.viewVolume(im_k, aff, names = [k], ext = ext, save_dir = curr_save_dir)
if mask_output and 'mask' in k:
mask = im_k.clone()
mask[im_k != 0.] = 1.
samples = [ { 'input': im } ]
# ============ testing ... ============
outputs, _ = model(samples) # dict with features
for processor in processors:
outputs = processor(outputs, samples)
if postprocessor is not None:
outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks)
out = outputs[0]
if mask_output and mask is None:
mask = torch.zeros_like(im)
mask[im != 0.] = 1.
for key in out.keys():
if key not in exclude_keys and isinstance(out[key], torch.Tensor):
print(' Output: %s - saved in - %s' % (key, os.path.join(curr_save_dir, 'out_' + key + ext)))
out[key][out[key] < 0.] = 0.
utils.viewVolume(out[key] * mask if mask_output else out[key], aff, names = ['out_'+key], ext = ext, save_dir = curr_save_dir)