| import numpy as np | |
| import torch | |
| import skimage | |
| import torchtask | |
| def task_func(): | |
| return HarmonizationFunc | |
| class HarmonizationFunc(torchtask.func_template.TaskFunc): | |
| def __init__(self, args): | |
| super(HarmonizationFunc, self).__init__(args) | |
| def metrics(self, pred_image, gt_image, mask, meters, id_str=''): | |
| n, c, h, w = pred_image.shape | |
| assert n == 1 | |
| total_pixels = h * w | |
| fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy()) | |
| pred_image = torch.clamp(pred_image * 255, 0, 255) | |
| gt_image = torch.clamp(gt_image * 255, 0, 255) | |
| pred_image = pred_image[0].permute(1, 2, 0).detach().cpu().numpy() | |
| gt_image = gt_image[0].permute(1, 2, 0).detach().cpu().numpy() | |
| mask = mask[0].permute(1, 2, 0).detach().cpu().numpy() | |
| batch_mse = skimage.metrics.mean_squared_error(pred_image, gt_image) | |
| meters.update('{0}_{1}_mse'.format(id_str, self.METRIC_STR), batch_mse) | |
| batch_fmse = skimage.metrics.mean_squared_error(pred_image * mask, gt_image * mask) * total_pixels / fg_pixels | |
| meters.update('{0}_{1}_fmse'.format(id_str, self.METRIC_STR), batch_fmse) | |
| batch_psnr = skimage.metrics.peak_signal_noise_ratio(pred_image, gt_image, data_range=pred_image.max() - pred_image.min()) | |
| meters.update('{0}_{1}_psnr'.format(id_str, self.METRIC_STR), batch_psnr) | |
| batch_ssim = skimage.metrics.structural_similarity(pred_image, gt_image, multichannel=True) | |
| meters.update('{0}_{1}_ssim'.format(id_str, self.METRIC_STR), batch_ssim) | |