""" Visualization modules """ import os import numpy as np from math import ceil import torch import torch.nn.functional as F from PIL import Image from collections import defaultdict from utils.misc import make_dir def match_shape(array, shape): # array: (channel_dim, *orig_shape) array = array[None] if list(array.shape[2:]) != list(shape): array = F.interpolate(array, size=shape) return array[0] def pad_shape(array_list): max_shape = [0] * len(array_list[0].shape) for array in array_list: max_shape = [max(max_shape[dim], array.shape[dim]) for dim in range(len(max_shape))] pad_array_list = [] for array in array_list: start = [(max_shape[dim] - array.shape[dim]) // 2 for dim in range(len(max_shape))] if len(start) == 2: pad_array = np.zeros((max_shape[0], max_shape[1])) pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1]] = array elif len(start) == 3: pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2])) pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2]] = array elif len(start) == 4: pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2], max_shape[3])) pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2], start[3] : start[3] + array.shape[3]] = array pad_array_list.append(pad_array) return pad_array_list def even_sample(orig_len, num): idx = [] length = float(orig_len) for i in range(num): idx.append(int(ceil(i * length / num))) return idx def normalize(nda, channel = None): if channel is not None: nda_max = np.max(nda, axis = channel, keepdims = True) nda_min = np.min(nda, axis = channel, keepdims = True) else: nda_max = np.max(nda) nda_min = np.min(nda) return (nda - nda_min) / (nda_max - nda_min + 1e-7) ############################################## class BaseVisualizer(object): def __init__(self, gen_args, train_args, draw_border=False): self.tasks = [key for (key, value) in vars(gen_args.task).items() if value] self.args = train_args self.draw_border = draw_border self.vis_spacing = self.args.visualizer.spacing def create_image_row(self, images): if self.draw_border: images = np.copy(images) images[:, :, [0, -1]] = (1, 1, 1) images[:, :, [0, -1]] = (1, 1, 1) return np.concatenate(list(images), axis=1) def create_image_grid(self, *args): out = [] for arg in args: out.append(normalize(self.create_image_row(arg))) return np.concatenate(out, axis=0) def prepare_for_itk(self, array): # (s, r, c, *) return array[:, ::-1, :] def prepare_for_png(self, array, normalize = False): # (s, r, c, *) slc = array[::self.vis_spacing[0]] # (s', r, c *) row = array[:, ::self.vis_spacing[1]].transpose((1, 0, 2, 3))[:, ::-1] # (s, r', c, *) -> (r', s, c, *) col = array[:, :, ::self.vis_spacing[2]].transpose((2, 0, 1, 3))[:, ::-1] # (s, r, c', *) -> (c', s, r, *) if normalize: slc = (slc - np.min(slc)) / (np.max(slc) - np.min(slc)) row = (slc - np.min(slc)) / (np.max(slc) - np.min(row)) col = (slc - np.min(slc)) / (np.max(slc) - np.min(col)) return slc, row, col class FeatVisualizer(BaseVisualizer): def __init__(self, gen_args, train_args, draw_border=False): BaseVisualizer.__init__(self, gen_args, train_args, draw_border) self.feat_vis_num = train_args.visualizer.feat_vis_num def visualize_all_multi(self, subjects, multi_inputs, multi_outputs, out_dir): """ For med-id student input samples: n_samples * [ (batch_size, channel_dim, *img_shp) ] For med-id student output features: n_samples * [ n_levels * (batch_size, channel_dim, *img_shp) ] """ names = [name.split('.nii')[0] for name in subjects['name']] multi_inputs = [x['input'] for x in multi_inputs] # n_samples * (b, d, s, r, c) for k in multi_outputs[0].keys(): if 'feat' in k: multi_features = [x[k] for x in multi_outputs] self.visualize_all_multi_features(names , multi_features, multi_inputs, out_dir, prefix = k) def visualize_all_multi_features(self, names, multi_features, multi_inputs, out_dir, prefix = 'feat'): n_samples = len(multi_inputs) n_levels = len(multi_features[0]) multi_inputs_reorg = [] # batch_size * [ n_samples * (channel_dim, *img_shp) ] multi_features_reorg = [] # batch_size * [ n_samples * [ n_levels * (channel_dim, *img_shp) ] ] for i_name, _ in enumerate(names): multi_features_reorg.append([[multi_features[i_sample][i_level][i_name] for i_level in range(n_levels)] for i_sample in range(n_samples)]) multi_inputs_reorg.append([multi_inputs[i_sample][i_name] for i_sample in range(n_samples)]) for i_name, name in enumerate(names): inputs = multi_inputs_reorg[i_name] features = multi_features_reorg[i_name] all_sample_results = defaultdict(list) for i_sample in range(n_samples): curr_input = inputs[i_sample].data.cpu().numpy() # ( d=1, s, r, c) curr_input = self.prepare_for_itk(curr_input.transpose(3, 2, 1, 0)) # (d, x, y, z) -> (z, y, x, d) curr_feat = features[i_sample] # n_levels * (channel_dim, s, r, c) curr_level_feats = [] for l in range(n_levels): curr_level_feat = curr_feat[l] # (channel_dim, s, r, c) sub_idx = even_sample(curr_level_feat.shape[0], self.feat_vis_num) curr_level_feat = torch.stack([curr_level_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c) curr_level_feat = match_shape(curr_level_feat, list(curr_input.shape[:-1])) curr_level_feats.append(self.prepare_for_itk((curr_level_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) all_results = self.gather(curr_input, curr_level_feats) for l, result in enumerate(all_results): # n_level * (r, c) gap = np.zeros_like(result[:, :int( result.shape[1] / (curr_input.shape[0] / self.vis_spacing[0]) )]) all_sample_results[l] += [result] + [gap] for l in all_sample_results.keys(): curr_level_all_sample_feats = np.concatenate(list(all_sample_results[l][:-1]), axis=1) # (s, n_samples * c) Image.fromarray(curr_level_all_sample_feats).save(os.path.join(make_dir(os.path.join(out_dir, name)), name + '_%s_l%s.png' % (prefix, str(l)))) def visualize_all(self, names, inputs, features): """ For general (single-sample) inputs: (batch_size, channel_dim, *img_shp) For general (single-sample) output features: n_levels * (batch_size, channel_dim, *img_shp) """ inputs = inputs.data.cpu().numpy() # (b, d=1, s, r, c) n_levels = len(features) # n_levels * (b, channel_dim, s, r, c) for i_name, name in enumerate(names): curr_input = self.prepare_for_itk(inputs[i_name].transpose((3, 2, 1, 0))) # (d, x, y, z) -> (z, y, x, d) curr_level_feats = [] for l in range(n_levels): curr_feat = features[l][i_name] # (channel_dim, s, r, c) sub_idx = even_sample(curr_feat.shape[0], self.feat_vis_num) curr_feat = torch.stack([curr_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c) curr_feat = match_shape(curr_feat, list(curr_input.shape[:-1])) curr_level_feats.append(self.prepare_for_itk((curr_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) self.gather(curr_input, curr_level_feats) def gather(self, input, feats): input_slc = self.prepare_for_png(input, normalize = False)[0][..., 0] # (sub_s, r, c) all_images = [] for l, feat in enumerate(feats): slc_images = [input_slc] # only plot along axial slc_feat = normalize(feat[::self.vis_spacing[0]].transpose(3, 0, 1, 2), channel = 1) # (sub_s, r, c, sub_channel_dim) -> (sub_channel_dim, sub_s, r, c) slc_images = [input_slc, np.zeros_like(input_slc)] + list(slc_feat) # (1 + 1 + s', r, c *) slc_images = pad_shape(slc_images) slc_image = self.create_image_grid(*slc_images) slc_image = (255 * slc_image).astype(np.uint8) all_images.append(slc_image) return all_images class TaskVisualizer(BaseVisualizer): def __init__(self, gen_args, train_args, draw_border=False): BaseVisualizer.__init__(self, gen_args, train_args, draw_border) def visualize_all(self, subjects, samples, outputs, out_dir, output_names = ['image'], target_names = ['image']): if len(output_names) == 0: return n_samples = len(samples) names = [name.split('.nii')[0] for name in subjects['name']] inputs = [x['input'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) if 'input_flip' in samples[0].keys(): inputs_flip = [x['input_flip'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) out_images = {} for output_name in output_names: if output_name in outputs[0].keys(): out_images[output_name] = [x[output_name].data.cpu().numpy() for x in outputs] # n_samples * (b, d, s, r, c) for i, name in enumerate(names): #case_out_dir = make_dir(os.path.join(out_dir, name)) curr_inputs = [self.prepare_for_itk(inputs[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) if 'input_flip' in samples[0].keys(): curr_inputs_flip = [self.prepare_for_itk(inputs_flip[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) # Plot all inputs #self.visualize_sample(name, curr_inputs, out_dir, postfix = '_input') if len(out_images) > 0: curr_target = {} if 'bias_field' in samples[0]: curr_target['bias_field'] = [self.prepare_for_itk(samples[i_sample]['bias_field'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] if 'high_res' in samples[0]: curr_target['high_res'] = [self.prepare_for_itk(samples[i_sample]['high_res'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] for target_name in target_names: if target_name in subjects and target_name not in curr_target.keys(): try: curr_target[target_name] = self.prepare_for_itk(subjects[target_name][i].data.cpu().numpy().transpose((3, 2, 1, 0))) # (d=1, s, r, c) -> (z, y, x, d) except: pass #print(target_name, 'failed in visualization') curr_outputs = {} for output_name in output_names: if output_name in outputs[0].keys(): #print('output name', output_name) curr_outputs[output_name] = [self.prepare_for_itk(out_images[output_name][i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) all_images = [] for i_sample, curr_input in enumerate(curr_inputs): target_list = [curr_input] if 'input_flip' in samples[0].keys(): target_list.append(curr_inputs_flip[i_sample]) for target_name in target_names: if target_name in curr_target: #print('target name', target_name) if 'bias_field' in target_name or 'high_res' in target_name: target_list.append(curr_target[target_name][i_sample]) else: target_list.append(curr_target[target_name]) output_list = [] for ouput_name in output_names: if ouput_name in curr_outputs.keys(): output_list.append(curr_outputs[ouput_name][i_sample]) all_image = self.gather(target_list, output_list) # (row, col) all_images.append(all_image) # n_sample * (row, col) all_images = np.concatenate(all_images, axis=1).astype(np.uint8) # (row, n_sample * col) Image.fromarray(all_images).save(os.path.join(out_dir, name + '_all_outputs.png')) def visualize_sample(self, name, input, out_dir, postfix = '_input'): n_samples = len(input) slc_images, row_images, col_images = [], [], [] for i_sample in range(n_samples): input_slc, input_row, input_col = self.prepare_for_png(input[i_sample], normalize = False) slc_images.append(input_slc) row_images.append(input_row) col_images.append(input_col) # add row gap gap = [np.zeros_like(slc_images[0])] all_images = slc_images + gap + row_images + gap + col_images all_images = pad_shape(all_images) all_image = self.create_image_grid(*all_images) all_image = (255 * all_image).astype(np.uint8) Image.fromarray(all_image[:, :, 0]).save(os.path.join(out_dir, name + '_all' + postfix + '.png')) # grey scale image last channel == 1 return def gather(self, target_list = [], output_list = []): slc_images, row_images, col_images = [], [], [] for add_target in target_list: add_target_slc, add_target_row, add_target_col = self.prepare_for_png(add_target, normalize = False) slc_images += [add_target_slc] row_images += [add_target_row] col_images += [add_target_col] for add_output in output_list: add_output_slc, add_output_row, add_output_col = self.prepare_for_png(add_output, normalize = False) slc_images += [add_output_slc] row_images += [add_output_row] col_images += [add_output_col] # add row gap gap = [np.zeros_like(add_target_slc)] all_images = slc_images + gap + row_images + gap + col_images all_images = pad_shape(all_images) all_image = self.create_image_grid(*all_images) all_image = (255 * all_image).astype(np.uint8) return all_image[:, :, 0] # shrink last channel dimension (d=1)