File size: 15,446 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

"""
Visualization modules
"""
import os
import numpy as np  
from math import ceil
import torch
import torch.nn.functional as F
from PIL import Image
from collections import defaultdict

from utils.misc import make_dir


def match_shape(array, shape):
    # array: (channel_dim, *orig_shape)
    array = array[None]
    if list(array.shape[2:]) != list(shape):
        array = F.interpolate(array, size=shape) 
    return array[0]

def pad_shape(array_list):
    max_shape = [0] * len(array_list[0].shape)

    for array in array_list:
        max_shape = [max(max_shape[dim], array.shape[dim]) for dim in range(len(max_shape))]  
    pad_array_list = []
    for array in array_list: 
        start = [(max_shape[dim] - array.shape[dim]) // 2 for dim in range(len(max_shape))] 
        if len(start) == 2:
            pad_array = np.zeros((max_shape[0], max_shape[1]))
            pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1]] = array
        elif len(start) == 3:
            pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2]))
            pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2]] = array
        elif len(start) == 4:
            pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2], max_shape[3]))
            pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2], start[3] : start[3] + array.shape[3]] = array
        
        pad_array_list.append(pad_array) 
    return pad_array_list


def even_sample(orig_len, num):
     idx = []
     length = float(orig_len)
     for i in range(num):
             idx.append(int(ceil(i * length / num)))
     return idx


def normalize(nda, channel = None):
    if channel is not None:
        nda_max = np.max(nda, axis = channel, keepdims = True)
        nda_min = np.min(nda, axis = channel, keepdims = True)
    else:
        nda_max = np.max(nda)
        nda_min = np.min(nda)
    return (nda - nda_min) / (nda_max - nda_min + 1e-7)


##############################################


class BaseVisualizer(object):
    
    def __init__(self, gen_args, train_args, draw_border=False): 

        self.tasks = [key for (key, value) in vars(gen_args.task).items() if value]

        self.args = train_args
        self.draw_border = draw_border 
        self.vis_spacing = self.args.visualizer.spacing 
        

    def create_image_row(self, images):
        if self.draw_border:
            images = np.copy(images)
            images[:, :, [0, -1]] = (1, 1, 1)
            images[:, :, [0, -1]] = (1, 1, 1)
        return np.concatenate(list(images), axis=1)

    def create_image_grid(self, *args):
        out = []
        for arg in args:
            out.append(normalize(self.create_image_row(arg))) 
        return np.concatenate(out, axis=0) 

    def prepare_for_itk(self, array): # (s, r, c, *)
        return array[:, ::-1, :]

    def prepare_for_png(self, array, normalize = False): # (s, r, c, *)
        slc = array[::self.vis_spacing[0]] # (s', r, c *)
        row = array[:, ::self.vis_spacing[1]].transpose((1, 0, 2, 3))[:, ::-1] # (s, r', c, *) -> (r', s, c, *)
        col = array[:, :, ::self.vis_spacing[2]].transpose((2, 0, 1, 3))[:, ::-1] # (s, r, c', *) -> (c', s, r, *)

        if normalize:
            slc = (slc - np.min(slc)) / (np.max(slc) - np.min(slc))
            row = (slc - np.min(slc)) / (np.max(slc) - np.min(row))
            col = (slc - np.min(slc)) / (np.max(slc) - np.min(col))
        return slc, row, col



class FeatVisualizer(BaseVisualizer):
    
    def __init__(self, gen_args, train_args, draw_border=False):
        BaseVisualizer.__init__(self, gen_args, train_args, draw_border)
        self.feat_vis_num = train_args.visualizer.feat_vis_num

    def visualize_all_multi(self, subjects, multi_inputs, multi_outputs, out_dir):
        """
        For med-id student input samples: n_samples * [ (batch_size, channel_dim, *img_shp) ]
        For med-id student output features: n_samples * [ n_levels * (batch_size, channel_dim, *img_shp) ]
        """

        names = [name.split('.nii')[0] for name in subjects['name']] 
        multi_inputs = [x['input'] for x in multi_inputs] # n_samples * (b, d, s, r, c)
        for k in multi_outputs[0].keys():
            if 'feat' in k:
                multi_features = [x[k] for x in multi_outputs]
                self.visualize_all_multi_features(names , multi_features, multi_inputs, out_dir, prefix = k)
    
    def visualize_all_multi_features(self, names, multi_features, multi_inputs, out_dir, prefix = 'feat'):

        n_samples = len(multi_inputs)
        n_levels = len(multi_features[0])
        
        multi_inputs_reorg = [] # batch_size * [ n_samples * (channel_dim, *img_shp) ]
        multi_features_reorg = [] # batch_size * [ n_samples * [ n_levels * (channel_dim, *img_shp) ] ]
        for i_name, _ in enumerate(names):
            multi_features_reorg.append([[multi_features[i_sample][i_level][i_name] for i_level in range(n_levels)] for i_sample in range(n_samples)]) 
            multi_inputs_reorg.append([multi_inputs[i_sample][i_name] for i_sample in range(n_samples)])

        for i_name, name in enumerate(names): 

            inputs = multi_inputs_reorg[i_name]
            features = multi_features_reorg[i_name]

            all_sample_results = defaultdict(list)
            for i_sample in range(n_samples):

                curr_input = inputs[i_sample].data.cpu().numpy() # ( d=1, s, r, c)  
                curr_input = self.prepare_for_itk(curr_input.transpose(3, 2, 1, 0)) # (d, x, y, z) -> (z, y, x, d)

                curr_feat = features[i_sample] # n_levels * (channel_dim, s, r, c)
                curr_level_feats = []

                for l in range(n_levels):
                    curr_level_feat = curr_feat[l] # (channel_dim, s, r, c)

                    sub_idx = even_sample(curr_level_feat.shape[0], self.feat_vis_num)
                    curr_level_feat = torch.stack([curr_level_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c)
    
                    curr_level_feat = match_shape(curr_level_feat, list(curr_input.shape[:-1]))
                    curr_level_feats.append(self.prepare_for_itk((curr_level_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) 
                
                all_results = self.gather(curr_input, curr_level_feats) 
                
                for l, result in enumerate(all_results): # n_level * (r, c)
                    gap = np.zeros_like(result[:, :int( result.shape[1] / (curr_input.shape[0] / self.vis_spacing[0]) )]) 
                    all_sample_results[l] += [result] + [gap] 

            for l in all_sample_results.keys():
                curr_level_all_sample_feats = np.concatenate(list(all_sample_results[l][:-1]), axis=1) # (s, n_samples * c)
                Image.fromarray(curr_level_all_sample_feats).save(os.path.join(make_dir(os.path.join(out_dir, name)), name + '_%s_l%s.png' % (prefix, str(l))))


    def visualize_all(self, names, inputs, features): 
        """
        For general (single-sample) inputs: (batch_size, channel_dim, *img_shp)
        For general (single-sample) output features: n_levels * (batch_size, channel_dim, *img_shp)
        """
 
        inputs = inputs.data.cpu().numpy() # (b, d=1, s, r, c)
        n_levels = len(features) # n_levels * (b, channel_dim, s, r, c)
        
        for i_name, name in enumerate(names): 
            curr_input = self.prepare_for_itk(inputs[i_name].transpose((3, 2, 1, 0))) # (d, x, y, z) -> (z, y, x, d) 
            curr_level_feats = []
            for l in range(n_levels):
                curr_feat = features[l][i_name] # (channel_dim, s, r, c)

                sub_idx = even_sample(curr_feat.shape[0], self.feat_vis_num)
                curr_feat = torch.stack([curr_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c)
 
                curr_feat = match_shape(curr_feat, list(curr_input.shape[:-1]))
                curr_level_feats.append(self.prepare_for_itk((curr_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) 
            
            self.gather(curr_input, curr_level_feats) 
        

    def gather(self, input, feats):

        input_slc = self.prepare_for_png(input, normalize = False)[0][..., 0] # (sub_s, r, c)
        all_images = []
        for l, feat in enumerate(feats):
            slc_images = [input_slc] # only plot along axial  
            slc_feat = normalize(feat[::self.vis_spacing[0]].transpose(3, 0, 1, 2), channel = 1) # (sub_s, r, c, sub_channel_dim) -> (sub_channel_dim, sub_s, r, c)
            slc_images = [input_slc, np.zeros_like(input_slc)] + list(slc_feat) # (1 + 1 + s', r, c *)
            slc_images = pad_shape(slc_images)

            slc_image = self.create_image_grid(*slc_images)
            slc_image = (255 * slc_image).astype(np.uint8)  
            all_images.append(slc_image)
            
        return all_images



class TaskVisualizer(BaseVisualizer):

    def __init__(self, gen_args, train_args, draw_border=False):
        BaseVisualizer.__init__(self, gen_args, train_args, draw_border)

    def visualize_all(self, subjects, samples, outputs, out_dir, output_names = ['image'], target_names = ['image']):

        if len(output_names) == 0:
            return
        
        n_samples = len(samples)

        names = [name.split('.nii')[0] for name in subjects['name']] 

        inputs = [x['input'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) 
        if 'input_flip' in samples[0].keys():
            inputs_flip = [x['input_flip'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) 

        out_images = {}
        for output_name in output_names: 
            if output_name in outputs[0].keys(): 
                out_images[output_name] = [x[output_name].data.cpu().numpy() for x in outputs] # n_samples * (b, d, s, r, c)
        
        for i, name in enumerate(names): 
            #case_out_dir = make_dir(os.path.join(out_dir, name))  
            curr_inputs = [self.prepare_for_itk(inputs[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d)
            if 'input_flip' in samples[0].keys():
                curr_inputs_flip = [self.prepare_for_itk(inputs_flip[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d)
            
            # Plot all inputs
            #self.visualize_sample(name, curr_inputs, out_dir, postfix = '_input')
 
            if len(out_images) > 0:
                curr_target = {}  
                if 'bias_field' in samples[0]: 
                    curr_target['bias_field'] = [self.prepare_for_itk(samples[i_sample]['bias_field'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)]
                if 'high_res' in samples[0]: 
                    curr_target['high_res'] = [self.prepare_for_itk(samples[i_sample]['high_res'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)]
                
                for target_name in target_names: 
                    if target_name in subjects and target_name not in curr_target.keys(): 
                        try: 
                            curr_target[target_name] = self.prepare_for_itk(subjects[target_name][i].data.cpu().numpy().transpose((3, 2, 1, 0))) # (d=1, s, r, c) -> (z, y, x, d)  
                        except:
                            pass
                            #print(target_name, 'failed in visualization')

                curr_outputs = {}
                for output_name in output_names:
                    if output_name in outputs[0].keys(): 
                        #print('output name', output_name)
                        curr_outputs[output_name] = [self.prepare_for_itk(out_images[output_name][i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) 

                all_images = []

                for i_sample, curr_input in enumerate(curr_inputs):
                    target_list = [curr_input]
                    if 'input_flip' in samples[0].keys():
                        target_list.append(curr_inputs_flip[i_sample])
                    for target_name in target_names: 
                        if target_name in curr_target: 
                            #print('target name', target_name)
                            if 'bias_field' in target_name or 'high_res' in target_name: 
                                target_list.append(curr_target[target_name][i_sample]) 
                            else:
                                target_list.append(curr_target[target_name]) 

                    output_list = []
                    for ouput_name in output_names: 
                        if ouput_name in curr_outputs.keys(): 
                            output_list.append(curr_outputs[ouput_name][i_sample])

                    all_image = self.gather(target_list, output_list) # (row, col)
                    all_images.append(all_image) # n_sample * (row, col)
                all_images = np.concatenate(all_images, axis=1).astype(np.uint8) # (row, n_sample * col)
                Image.fromarray(all_images).save(os.path.join(out_dir, name + '_all_outputs.png'))

    def visualize_sample(self, name, input, out_dir, postfix = '_input'):
        
        n_samples = len(input)

        slc_images, row_images, col_images = [], [], []
        for i_sample in range(n_samples):
            input_slc, input_row, input_col = self.prepare_for_png(input[i_sample], normalize = False)

            slc_images.append(input_slc)
            row_images.append(input_row)
            col_images.append(input_col)

        # add row gap 
        gap = [np.zeros_like(slc_images[0])]
        all_images = slc_images + gap + row_images + gap + col_images
        all_images = pad_shape(all_images)
        all_image = self.create_image_grid(*all_images)
        all_image = (255 * all_image).astype(np.uint8)  
        Image.fromarray(all_image[:, :, 0]).save(os.path.join(out_dir, name + '_all' + postfix + '.png')) # grey scale image last channel == 1
        return 

    def gather(self, target_list = [], output_list = []):

        slc_images, row_images, col_images = [], [], []

        for add_target in target_list:
            add_target_slc, add_target_row, add_target_col = self.prepare_for_png(add_target, normalize = False)
            slc_images += [add_target_slc]
            row_images += [add_target_row]
            col_images += [add_target_col]

        for add_output in output_list:
            add_output_slc, add_output_row, add_output_col = self.prepare_for_png(add_output, normalize = False)
            slc_images += [add_output_slc]
            row_images += [add_output_row]
            col_images += [add_output_col]

        # add row gap 
        gap = [np.zeros_like(add_target_slc)]
        all_images = slc_images + gap + row_images + gap + col_images
        all_images = pad_shape(all_images)
        all_image = self.create_image_grid(*all_images)

        all_image = (255 * all_image).astype(np.uint8) 
        return all_image[:, :, 0] # shrink last channel dimension (d=1)