File size: 1,591 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch

import skimage

import torchtask

def task_func():
    return HarmonizationFunc


class HarmonizationFunc(torchtask.func_template.TaskFunc):
    def __init__(self, args):
        super(HarmonizationFunc, self).__init__(args)

    def metrics(self, pred_image, gt_image, mask, meters, id_str=''):
        n, c, h, w = pred_image.shape
        
        assert n == 1

        total_pixels = h * w
        fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy())

        pred_image = torch.clamp(pred_image * 255, 0, 255)
        gt_image = torch.clamp(gt_image * 255, 0, 255)

        pred_image = pred_image[0].permute(1, 2, 0).detach().cpu().numpy()
        gt_image = gt_image[0].permute(1, 2, 0).detach().cpu().numpy()
        mask = mask[0].permute(1, 2, 0).detach().cpu().numpy()

        batch_mse = skimage.metrics.mean_squared_error(pred_image, gt_image)
        meters.update('{0}_{1}_mse'.format(id_str, self.METRIC_STR), batch_mse)

        batch_fmse = skimage.metrics.mean_squared_error(pred_image * mask, gt_image * mask) * total_pixels / fg_pixels
        meters.update('{0}_{1}_fmse'.format(id_str, self.METRIC_STR), batch_fmse)

        batch_psnr = skimage.metrics.peak_signal_noise_ratio(pred_image, gt_image, data_range=pred_image.max() - pred_image.min())
        meters.update('{0}_{1}_psnr'.format(id_str, self.METRIC_STR), batch_psnr)
        
        batch_ssim = skimage.metrics.structural_similarity(pred_image, gt_image, multichannel=True)
        meters.update('{0}_{1}_ssim'.format(id_str, self.METRIC_STR), batch_ssim)