File size: 1,591 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import numpy as np
import torch
import skimage
import torchtask
def task_func():
return HarmonizationFunc
class HarmonizationFunc(torchtask.func_template.TaskFunc):
def __init__(self, args):
super(HarmonizationFunc, self).__init__(args)
def metrics(self, pred_image, gt_image, mask, meters, id_str=''):
n, c, h, w = pred_image.shape
assert n == 1
total_pixels = h * w
fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy())
pred_image = torch.clamp(pred_image * 255, 0, 255)
gt_image = torch.clamp(gt_image * 255, 0, 255)
pred_image = pred_image[0].permute(1, 2, 0).detach().cpu().numpy()
gt_image = gt_image[0].permute(1, 2, 0).detach().cpu().numpy()
mask = mask[0].permute(1, 2, 0).detach().cpu().numpy()
batch_mse = skimage.metrics.mean_squared_error(pred_image, gt_image)
meters.update('{0}_{1}_mse'.format(id_str, self.METRIC_STR), batch_mse)
batch_fmse = skimage.metrics.mean_squared_error(pred_image * mask, gt_image * mask) * total_pixels / fg_pixels
meters.update('{0}_{1}_fmse'.format(id_str, self.METRIC_STR), batch_fmse)
batch_psnr = skimage.metrics.peak_signal_noise_ratio(pred_image, gt_image, data_range=pred_image.max() - pred_image.min())
meters.update('{0}_{1}_psnr'.format(id_str, self.METRIC_STR), batch_psnr)
batch_ssim = skimage.metrics.structural_similarity(pred_image, gt_image, multichannel=True)
meters.update('{0}_{1}_ssim'.format(id_str, self.METRIC_STR), batch_ssim)
|