File size: 14,170 Bytes
19c1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import SimpleITK
import numpy as np
from typing import Optional
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util.arraycrop import crop
from scipy.signal import fftconvolve
from scipy.ndimage import uniform_filter
import torch

class ImageMetrics():
    def __init__(self, debug=False):
        # Use fixed wide dynamic range
        self.dynamic_range = [-1024., 3000.]
        self.debug = debug
    def score_patient(self, gt_img, synthetic_ct, mask):        
        assert gt_img.shape == synthetic_ct.shape 
        if mask is not None:
            assert mask.shape == synthetic_ct.shape 

        # perform masking on the images
        ground_truth = gt_img if mask is None else np.where(mask == 0, -1024, gt_img)
        prediction = synthetic_ct if mask is None else np.where(mask == 0, -1024, synthetic_ct)
        
        # Compute image similarity within the mask
        mae_value = self.mae(ground_truth,
                             prediction,
                             mask)
        
        psnr_value = self.psnr(ground_truth,
                               prediction,
                               mask,
                               use_population_range=True)
        
        ms_ssim_value, ms_ssim_mask_value = self.ms_ssim(ground_truth,
                               prediction, 
                               mask)

        return {
            'mae': mae_value,
            'psnr': psnr_value,
            'ms_ssim': ms_ssim_mask_value,
        }
    
    def mae(self,

            gt: np.ndarray, 

            pred: np.ndarray,

            mask: Optional[np.ndarray] = None) -> float:
        """

        Compute Mean Absolute Error (MAE)

    

        Parameters

        ----------

        gt : np.ndarray

            Ground truth

        pred : np.ndarray

            Prediction

        mask : np.ndarray, optional

            Mask for voxels to include. The default is None (including all voxels).

    

        Returns

        -------

        mae : float

            mean absolute error.

    

        """
        if mask is None:
            mask = np.ones(gt.shape)
        else:
            #binarize mask
            mask = np.where(mask>0, 1., 0.)
            
        mae_value = np.sum(np.abs(gt*mask - pred*mask))/mask.sum() 
        return float(mae_value)
    
    
    def psnr(self,

             gt: np.ndarray, 

             pred: np.ndarray,

             mask: Optional[np.ndarray] = None,

             use_population_range: Optional[bool] = False) -> float:
        """

        Compute Peak Signal to Noise Ratio metric (PSNR)

    

        Parameters

        ----------

        gt : np.ndarray

            Ground truth

        pred : np.ndarray

            Prediction

        mask : np.ndarray, optional

            Mask for voxels to include. The default is None (including all voxels).

        use_population_range : bool, optional

            When a predefined population wide dynamic range should be used.

            gt and pred will also be clipped to these values.

    

        Returns

        -------

        psnr : float

            Peak signal to noise ratio..

    

        """
        if mask is None:
            mask = np.ones(gt.shape)
        else:
            #binarize mask
            mask = np.where(mask>0, 1., 0.)
            
        if use_population_range:            
            # Clip gt and pred to the dynamic range
            gt = np.clip(gt, a_min=self.dynamic_range[0], a_max=self.dynamic_range[1])
            pred = np.clip(pred, a_min=self.dynamic_range[0], a_max=self.dynamic_range[1])
            dynamic_range = self.dynamic_range[1]  - self.dynamic_range[0]
        else:
            dynamic_range = gt.max()-gt.min()
            pred = np.clip(pred, a_min=gt.min(), a_max=gt.max())
            
        # apply mask
        gt = gt[mask==1]
        pred = pred[mask==1]
        psnr_value = peak_signal_noise_ratio(gt, pred, data_range=dynamic_range)
        return float(psnr_value)

    # Compute the luminance, contrast and structure components of the SSIM between two images    
    def structural_similarity_at_scale(self, im1,

        im2,

        *,

        luminance_weight=1,

        contrast_weight=1,

        structure_weight=1,

        win_size=None,

        gradient=False,

        data_range=None,

        channel_axis=None,

        gaussian_weights=False,

        full=False,

        **kwargs,):
            K1 = kwargs.pop('K1', 0.01)
            K2 = kwargs.pop('K2', 0.03)
            sigma = kwargs.pop('sigma', 1.5)
            if K1 < 0:
                raise ValueError("K1 must be positive")
            if K2 < 0:
                raise ValueError("K2 must be positive")
            if sigma < 0:
                raise ValueError("sigma must be positive")
            use_sample_covariance = kwargs.pop('use_sample_covariance', True)
            if gaussian_weights:
                # Set to give an 11-tap filter with the default sigma of 1.5 to match
                # Wang et. al. 2004.
                truncate = 3.5

            if win_size is None:
                if gaussian_weights:
                    # set win_size used by crop to match the filter size
                    r = int(truncate * sigma + 0.5)  # radius as in ndimage
                    win_size = 2 * r + 1
                else:
                    win_size = 7  # backwards compatibility
            if gaussian_weights:
                filter_func = gaussian
                filter_args = {'sigma': sigma, 'truncate': truncate, 'mode': 'reflect'}
            else:
                filter_func = uniform_filter
                filter_args = {'size': win_size}

            ndim = im1.ndim
            NP = win_size**ndim

            # filter has already normalized by NP
            if use_sample_covariance:
                cov_norm = NP / (NP - 1)  # sample covariance
            else:
                cov_norm = 1.0  # population covariance to match Wang et. al. 2004
            # compute (weighted) means
            ux = filter_func(im1, **filter_args)
            uy = filter_func(im2, **filter_args)


            # compute (weighted) variances and covariances
            uxx = filter_func(im1 * im1, **filter_args)
            uyy = filter_func(im2 * im2, **filter_args)
            uxy = filter_func(im1 * im2, **filter_args)
            vx = cov_norm * (uxx - ux * ux)
            vxsqrt = np.clip(vx, a_min=0, a_max=None) ** 0.5 # TODO: this is very ugly
            vy = cov_norm * (uyy - uy * uy)
            vysqrt = np.clip(vy, a_min=0, a_max=None) ** 0.5 # TODO: this is very ugly
            vxy = cov_norm * (uxy - ux * uy)

            R = data_range
            C1 = (K1 * R) ** 2
            C2 = (K2 * R) ** 2
            C3 = C2 / 2

            L = np.clip((2 * ux * uy + C1) / (ux * ux + uy * uy + C1), a_min=0, a_max=None) # TODO is this clipping necessary or do we increase K1 and K2?

            C = np.clip((2 * vxsqrt * vysqrt + C2) / (vx + vy + C2), a_min=0, a_max=None)
            S = np.clip((vxy + C3) / (vxsqrt * vysqrt + C3), a_min=0, a_max=None)

            result = (L ** luminance_weight) * (C ** contrast_weight) * (S ** structure_weight)
            # to avoid edge effects will ignore filter radius strip around edges
            pad = (win_size - 1) // 2

            # compute (weighted) mean of ssim. Use float64 for accuracy.
            mssim = crop(result, pad).mean(dtype=np.float64)

            if full:
                return mssim, result
            return mssim


    # Compute the masked MS-SSIM by masking the SSIM at every resolution level
    def ms_ssim(self, gt: np.ndarray, pred: np.ndarray, mask: Optional[np.ndarray] = None, scale_weights: Optional[np.ndarray] = None) -> float:
        # Clip gt and pred to the dynamic range
        gt = np.clip(gt, min(self.dynamic_range), max(self.dynamic_range))
        pred = np.clip(pred, min(self.dynamic_range), max(self.dynamic_range))

        if mask is not None:
            #binarize mask
            mask = np.where(mask>0, 1., 0.)
            
            # Mask gt and pred
            gt = np.where(mask==0, min(self.dynamic_range), gt)
            pred = np.where(mask==0, min(self.dynamic_range), pred)

        # Make values non-negative
        if min(self.dynamic_range) < 0:
            gt = gt - min(self.dynamic_range)
            pred = pred - min(self.dynamic_range)

        # Set dynamic range for ssim calculation and calculate ssim_map per pixel
        dynamic_range = self.dynamic_range[1] - self.dynamic_range[0]


        # see Eq. 7 in https://live.ece.utexas.edu/publications/2003/zw_asil2003_msssim.pdf
        # Also, the final sentence of section 3.2 (Results)
        scale_weights = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) if scale_weights is None else scale_weights
        luminance_weights = np.array([0, 0, 0, 0, 0, 0.1333]) if scale_weights is None else scale_weights 
        levels = len(scale_weights)

        downsample_filter = np.ones((2, 2, 2)) / 8

        gtx, gty, gtz = gt.shape

        # Due to the downsampling in the MS-SSIM, the minimum matrix size must be 97 in every dimension
        target_size = 97

        pad_values = [
            (np.clip((target_size - dim)//2, a_min=0, a_max=None), 
            np.clip(target_size - dim - (target_size - dim)//2, a_min=0, a_max=None)) 
            for dim in [gtx, gty, gtz]]

        gt   = np.pad(gt,   pad_values, mode='edge')
        pred = np.pad(pred, pad_values, mode='edge')
        mask = np.pad(mask, pad_values, mode='edge')

        min_size = (downsample_filter.shape[-1] - 1) * 2 ** (levels - 1) + 1

        ms_ssim_vals, ms_ssim_maps = [], []
        for level in range(levels):
            ssim_value_full, ssim_map = self.structural_similarity_at_scale(gt, pred, 
                luminance_weight=luminance_weights[level], 
                contrast_weight=scale_weights[level],
                structure_weight=scale_weights[level],
                data_range=dynamic_range, full=True)
            pad = 3
            # at every level, we get the ssim_value_full, which is just mean SSIM at a level L, and the 
            # SSIM map. The masked SSIM is the mean SSIM within this mask
            ssim_value_masked  = (crop(ssim_map, pad)[crop(mask, pad).astype(bool)]).mean(dtype=np.float64)

            ms_ssim_vals.append(ssim_value_full)
            ms_ssim_maps.append(ssim_value_masked)

            # The images are cleverly downsampled using an uniform filter
            # the mask is just downsampled by selecting every other line in every dimension
            filtered = [fftconvolve(im, downsample_filter, mode='same')
                for im in [gt, pred]]
            gt, pred, mask = [x[::2, ::2, ::2] for x in [*filtered, mask]]

        ms_ssim_val = np.prod([np.clip(x, a_min=0, a_max=1) for x in ms_ssim_vals])
        ms_ssim_mask_val = np.prod([np.clip(x, a_min=0, a_max=1) for x in ms_ssim_maps])

        return float(ms_ssim_val), float(ms_ssim_mask_val)


# compute image metrics for the predition folders
class ImageMetricsCompute(ImageMetrics):
    def __init__(self):
        super().__init__()
        self.names = ["mae", "psnr", "ms_ssim"]
    
    def init_storage(self, names: list):
        self.storage = dict()
        self.storage_id = []
        self.names = names
        for name in names:
            self.storage[name] = []

    def add(self, res: dict, patient_id=None):
        for key, value in res.items():
            self.storage[key].append(value)
        if patient_id:
            self.storage_id.append(patient_id)

    def aggregate(self):
        res = dict()
        for name in self.names:
            res[name] = dict()

        for key, value in self.storage.items():
            res[key]['mean'] = np.mean(value)
            res[key]['std'] = np.std(value)
            res[key]['max'] = np.max(value)
            res[key]['min'] = np.min(value)
            res[key]['25pc'] = np.percentile(value, 25)
            res[key]['50pc'] = np.percentile(value, 50)
            res[key]['75pc'] = np.percentile(value, 75)
            res[key]['count'] = len(value)
        return res

    def reset(self):
        for key, value in self.storage.items():
            self.storage[key] = []

    def score_array(self, gt_array, pred_array, mask_array=None):
        if torch.is_tensor(gt_array):
            gt_array = gt_array.cpu().numpy().squeeze()
        if torch.is_tensor(pred_array):
            pred_array = pred_array.cpu().numpy().squeeze()
        if torch.is_tensor(mask_array):
            mask_array = mask_array.cpu().numpy().squeeze()

        # Calculate image metrics
        res = dict()
        if self.names and 'mae' in self.names:
            mae_value = self.mae(gt_array,
                                 pred_array,
                                 mask_array)
            res['mae'] = mae_value

        if self.names and 'psnr' in self.names:
            psnr_value = self.psnr(gt_array,
                                   pred_array,
                                   mask_array, use_population_range=True)
            res['psnr'] = psnr_value

        if self.names and 'ms_ssim' in self.names:
            ms_ssim_value, ms_ssim_mask_value = self.ms_ssim(gt_array,
                                   pred_array, 
                                   mask_array)
            res['ms_ssim'] = ms_ssim_mask_value


        return res


if __name__=='__main__':
    metrics = ImageMetrics()
    ground_truth_path = "path/to/ground_truth.mha"
    predicted_path = "path/to/prediction.mha"
    mask_path = "path/to/mask.mha"
    print(metrics.score_patient(ground_truth_path, predicted_path, mask_path))