|
|
|
|
|
""" |
|
|
Visualization modules |
|
|
""" |
|
|
import os |
|
|
import numpy as np |
|
|
from math import ceil |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from collections import defaultdict |
|
|
|
|
|
from utils.misc import make_dir |
|
|
|
|
|
|
|
|
def match_shape(array, shape): |
|
|
|
|
|
array = array[None] |
|
|
if list(array.shape[2:]) != list(shape): |
|
|
array = F.interpolate(array, size=shape) |
|
|
return array[0] |
|
|
|
|
|
def pad_shape(array_list): |
|
|
max_shape = [0] * len(array_list[0].shape) |
|
|
|
|
|
for array in array_list: |
|
|
max_shape = [max(max_shape[dim], array.shape[dim]) for dim in range(len(max_shape))] |
|
|
pad_array_list = [] |
|
|
for array in array_list: |
|
|
start = [(max_shape[dim] - array.shape[dim]) // 2 for dim in range(len(max_shape))] |
|
|
if len(start) == 2: |
|
|
pad_array = np.zeros((max_shape[0], max_shape[1])) |
|
|
pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1]] = array |
|
|
elif len(start) == 3: |
|
|
pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2])) |
|
|
pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2]] = array |
|
|
elif len(start) == 4: |
|
|
pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2], max_shape[3])) |
|
|
pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2], start[3] : start[3] + array.shape[3]] = array |
|
|
|
|
|
pad_array_list.append(pad_array) |
|
|
return pad_array_list |
|
|
|
|
|
|
|
|
def even_sample(orig_len, num): |
|
|
idx = [] |
|
|
length = float(orig_len) |
|
|
for i in range(num): |
|
|
idx.append(int(ceil(i * length / num))) |
|
|
return idx |
|
|
|
|
|
|
|
|
def normalize(nda, channel = None): |
|
|
if channel is not None: |
|
|
nda_max = np.max(nda, axis = channel, keepdims = True) |
|
|
nda_min = np.min(nda, axis = channel, keepdims = True) |
|
|
else: |
|
|
nda_max = np.max(nda) |
|
|
nda_min = np.min(nda) |
|
|
return (nda - nda_min) / (nda_max - nda_min + 1e-7) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseVisualizer(object): |
|
|
|
|
|
def __init__(self, gen_args, train_args, draw_border=False): |
|
|
|
|
|
self.tasks = [key for (key, value) in vars(gen_args.task).items() if value] |
|
|
|
|
|
self.args = train_args |
|
|
self.draw_border = draw_border |
|
|
self.vis_spacing = self.args.visualizer.spacing |
|
|
|
|
|
|
|
|
def create_image_row(self, images): |
|
|
if self.draw_border: |
|
|
images = np.copy(images) |
|
|
images[:, :, [0, -1]] = (1, 1, 1) |
|
|
images[:, :, [0, -1]] = (1, 1, 1) |
|
|
return np.concatenate(list(images), axis=1) |
|
|
|
|
|
def create_image_grid(self, *args): |
|
|
out = [] |
|
|
for arg in args: |
|
|
out.append(normalize(self.create_image_row(arg))) |
|
|
return np.concatenate(out, axis=0) |
|
|
|
|
|
def prepare_for_itk(self, array): |
|
|
return array[:, ::-1, :] |
|
|
|
|
|
def prepare_for_png(self, array, normalize = False): |
|
|
slc = array[::self.vis_spacing[0]] |
|
|
row = array[:, ::self.vis_spacing[1]].transpose((1, 0, 2, 3))[:, ::-1] |
|
|
col = array[:, :, ::self.vis_spacing[2]].transpose((2, 0, 1, 3))[:, ::-1] |
|
|
|
|
|
if normalize: |
|
|
slc = (slc - np.min(slc)) / (np.max(slc) - np.min(slc)) |
|
|
row = (slc - np.min(slc)) / (np.max(slc) - np.min(row)) |
|
|
col = (slc - np.min(slc)) / (np.max(slc) - np.min(col)) |
|
|
return slc, row, col |
|
|
|
|
|
|
|
|
|
|
|
class FeatVisualizer(BaseVisualizer): |
|
|
|
|
|
def __init__(self, gen_args, train_args, draw_border=False): |
|
|
BaseVisualizer.__init__(self, gen_args, train_args, draw_border) |
|
|
self.feat_vis_num = train_args.visualizer.feat_vis_num |
|
|
|
|
|
def visualize_all_multi(self, subjects, multi_inputs, multi_outputs, out_dir): |
|
|
""" |
|
|
For med-id student input samples: n_samples * [ (batch_size, channel_dim, *img_shp) ] |
|
|
For med-id student output features: n_samples * [ n_levels * (batch_size, channel_dim, *img_shp) ] |
|
|
""" |
|
|
|
|
|
names = [name.split('.nii')[0] for name in subjects['name']] |
|
|
multi_inputs = [x['input'] for x in multi_inputs] |
|
|
for k in multi_outputs[0].keys(): |
|
|
if 'feat' in k: |
|
|
multi_features = [x[k] for x in multi_outputs] |
|
|
self.visualize_all_multi_features(names , multi_features, multi_inputs, out_dir, prefix = k) |
|
|
|
|
|
def visualize_all_multi_features(self, names, multi_features, multi_inputs, out_dir, prefix = 'feat'): |
|
|
|
|
|
n_samples = len(multi_inputs) |
|
|
n_levels = len(multi_features[0]) |
|
|
|
|
|
multi_inputs_reorg = [] |
|
|
multi_features_reorg = [] |
|
|
for i_name, _ in enumerate(names): |
|
|
multi_features_reorg.append([[multi_features[i_sample][i_level][i_name] for i_level in range(n_levels)] for i_sample in range(n_samples)]) |
|
|
multi_inputs_reorg.append([multi_inputs[i_sample][i_name] for i_sample in range(n_samples)]) |
|
|
|
|
|
for i_name, name in enumerate(names): |
|
|
|
|
|
inputs = multi_inputs_reorg[i_name] |
|
|
features = multi_features_reorg[i_name] |
|
|
|
|
|
all_sample_results = defaultdict(list) |
|
|
for i_sample in range(n_samples): |
|
|
|
|
|
curr_input = inputs[i_sample].data.cpu().numpy() |
|
|
curr_input = self.prepare_for_itk(curr_input.transpose(3, 2, 1, 0)) |
|
|
|
|
|
curr_feat = features[i_sample] |
|
|
curr_level_feats = [] |
|
|
|
|
|
for l in range(n_levels): |
|
|
curr_level_feat = curr_feat[l] |
|
|
|
|
|
sub_idx = even_sample(curr_level_feat.shape[0], self.feat_vis_num) |
|
|
curr_level_feat = torch.stack([curr_level_feat[idx] for idx in sub_idx], dim = 0) |
|
|
|
|
|
curr_level_feat = match_shape(curr_level_feat, list(curr_input.shape[:-1])) |
|
|
curr_level_feats.append(self.prepare_for_itk((curr_level_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) |
|
|
|
|
|
all_results = self.gather(curr_input, curr_level_feats) |
|
|
|
|
|
for l, result in enumerate(all_results): |
|
|
gap = np.zeros_like(result[:, :int( result.shape[1] / (curr_input.shape[0] / self.vis_spacing[0]) )]) |
|
|
all_sample_results[l] += [result] + [gap] |
|
|
|
|
|
for l in all_sample_results.keys(): |
|
|
curr_level_all_sample_feats = np.concatenate(list(all_sample_results[l][:-1]), axis=1) |
|
|
Image.fromarray(curr_level_all_sample_feats).save(os.path.join(make_dir(os.path.join(out_dir, name)), name + '_%s_l%s.png' % (prefix, str(l)))) |
|
|
|
|
|
|
|
|
def visualize_all(self, names, inputs, features): |
|
|
""" |
|
|
For general (single-sample) inputs: (batch_size, channel_dim, *img_shp) |
|
|
For general (single-sample) output features: n_levels * (batch_size, channel_dim, *img_shp) |
|
|
""" |
|
|
|
|
|
inputs = inputs.data.cpu().numpy() |
|
|
n_levels = len(features) |
|
|
|
|
|
for i_name, name in enumerate(names): |
|
|
curr_input = self.prepare_for_itk(inputs[i_name].transpose((3, 2, 1, 0))) |
|
|
curr_level_feats = [] |
|
|
for l in range(n_levels): |
|
|
curr_feat = features[l][i_name] |
|
|
|
|
|
sub_idx = even_sample(curr_feat.shape[0], self.feat_vis_num) |
|
|
curr_feat = torch.stack([curr_feat[idx] for idx in sub_idx], dim = 0) |
|
|
|
|
|
curr_feat = match_shape(curr_feat, list(curr_input.shape[:-1])) |
|
|
curr_level_feats.append(self.prepare_for_itk((curr_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) |
|
|
|
|
|
self.gather(curr_input, curr_level_feats) |
|
|
|
|
|
|
|
|
def gather(self, input, feats): |
|
|
|
|
|
input_slc = self.prepare_for_png(input, normalize = False)[0][..., 0] |
|
|
all_images = [] |
|
|
for l, feat in enumerate(feats): |
|
|
slc_images = [input_slc] |
|
|
slc_feat = normalize(feat[::self.vis_spacing[0]].transpose(3, 0, 1, 2), channel = 1) |
|
|
slc_images = [input_slc, np.zeros_like(input_slc)] + list(slc_feat) |
|
|
slc_images = pad_shape(slc_images) |
|
|
|
|
|
slc_image = self.create_image_grid(*slc_images) |
|
|
slc_image = (255 * slc_image).astype(np.uint8) |
|
|
all_images.append(slc_image) |
|
|
|
|
|
return all_images |
|
|
|
|
|
|
|
|
|
|
|
class TaskVisualizer(BaseVisualizer): |
|
|
|
|
|
def __init__(self, gen_args, train_args, draw_border=False): |
|
|
BaseVisualizer.__init__(self, gen_args, train_args, draw_border) |
|
|
|
|
|
def visualize_all(self, subjects, samples, outputs, out_dir, output_names = ['image'], target_names = ['image']): |
|
|
|
|
|
if len(output_names) == 0: |
|
|
return |
|
|
|
|
|
n_samples = len(samples) |
|
|
|
|
|
names = [name.split('.nii')[0] for name in subjects['name']] |
|
|
|
|
|
inputs = [x['input'].data.cpu().numpy() for x in samples] |
|
|
if 'input_flip' in samples[0].keys(): |
|
|
inputs_flip = [x['input_flip'].data.cpu().numpy() for x in samples] |
|
|
|
|
|
out_images = {} |
|
|
for output_name in output_names: |
|
|
if output_name in outputs[0].keys(): |
|
|
out_images[output_name] = [x[output_name].data.cpu().numpy() for x in outputs] |
|
|
|
|
|
for i, name in enumerate(names): |
|
|
|
|
|
curr_inputs = [self.prepare_for_itk(inputs[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] |
|
|
if 'input_flip' in samples[0].keys(): |
|
|
curr_inputs_flip = [self.prepare_for_itk(inputs_flip[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(out_images) > 0: |
|
|
curr_target = {} |
|
|
if 'bias_field' in samples[0]: |
|
|
curr_target['bias_field'] = [self.prepare_for_itk(samples[i_sample]['bias_field'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] |
|
|
if 'high_res' in samples[0]: |
|
|
curr_target['high_res'] = [self.prepare_for_itk(samples[i_sample]['high_res'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] |
|
|
|
|
|
for target_name in target_names: |
|
|
if target_name in subjects and target_name not in curr_target.keys(): |
|
|
try: |
|
|
curr_target[target_name] = self.prepare_for_itk(subjects[target_name][i].data.cpu().numpy().transpose((3, 2, 1, 0))) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
curr_outputs = {} |
|
|
for output_name in output_names: |
|
|
if output_name in outputs[0].keys(): |
|
|
|
|
|
curr_outputs[output_name] = [self.prepare_for_itk(out_images[output_name][i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] |
|
|
|
|
|
all_images = [] |
|
|
|
|
|
for i_sample, curr_input in enumerate(curr_inputs): |
|
|
target_list = [curr_input] |
|
|
if 'input_flip' in samples[0].keys(): |
|
|
target_list.append(curr_inputs_flip[i_sample]) |
|
|
for target_name in target_names: |
|
|
if target_name in curr_target: |
|
|
|
|
|
if 'bias_field' in target_name or 'high_res' in target_name: |
|
|
target_list.append(curr_target[target_name][i_sample]) |
|
|
else: |
|
|
target_list.append(curr_target[target_name]) |
|
|
|
|
|
output_list = [] |
|
|
for ouput_name in output_names: |
|
|
if ouput_name in curr_outputs.keys(): |
|
|
output_list.append(curr_outputs[ouput_name][i_sample]) |
|
|
|
|
|
all_image = self.gather(target_list, output_list) |
|
|
all_images.append(all_image) |
|
|
all_images = np.concatenate(all_images, axis=1).astype(np.uint8) |
|
|
Image.fromarray(all_images).save(os.path.join(out_dir, name + '_all_outputs.png')) |
|
|
|
|
|
def visualize_sample(self, name, input, out_dir, postfix = '_input'): |
|
|
|
|
|
n_samples = len(input) |
|
|
|
|
|
slc_images, row_images, col_images = [], [], [] |
|
|
for i_sample in range(n_samples): |
|
|
input_slc, input_row, input_col = self.prepare_for_png(input[i_sample], normalize = False) |
|
|
|
|
|
slc_images.append(input_slc) |
|
|
row_images.append(input_row) |
|
|
col_images.append(input_col) |
|
|
|
|
|
|
|
|
gap = [np.zeros_like(slc_images[0])] |
|
|
all_images = slc_images + gap + row_images + gap + col_images |
|
|
all_images = pad_shape(all_images) |
|
|
all_image = self.create_image_grid(*all_images) |
|
|
all_image = (255 * all_image).astype(np.uint8) |
|
|
Image.fromarray(all_image[:, :, 0]).save(os.path.join(out_dir, name + '_all' + postfix + '.png')) |
|
|
return |
|
|
|
|
|
def gather(self, target_list = [], output_list = []): |
|
|
|
|
|
slc_images, row_images, col_images = [], [], [] |
|
|
|
|
|
for add_target in target_list: |
|
|
add_target_slc, add_target_row, add_target_col = self.prepare_for_png(add_target, normalize = False) |
|
|
slc_images += [add_target_slc] |
|
|
row_images += [add_target_row] |
|
|
col_images += [add_target_col] |
|
|
|
|
|
for add_output in output_list: |
|
|
add_output_slc, add_output_row, add_output_col = self.prepare_for_png(add_output, normalize = False) |
|
|
slc_images += [add_output_slc] |
|
|
row_images += [add_output_row] |
|
|
col_images += [add_output_col] |
|
|
|
|
|
|
|
|
gap = [np.zeros_like(add_target_slc)] |
|
|
all_images = slc_images + gap + row_images + gap + col_images |
|
|
all_images = pad_shape(all_images) |
|
|
all_image = self.create_image_grid(*all_images) |
|
|
|
|
|
all_image = (255 * all_image).astype(np.uint8) |
|
|
return all_image[:, :, 0] |
|
|
|