Spaces:
Runtime error
Runtime error
| import skimage | |
| import torch | |
| import numpy as np | |
| from pytorch_msssim import ssim | |
| import math | |
| def calc_metrics(harmonized, real, mask_batch): | |
| n, c, h, w = harmonized.shape | |
| mse = [] | |
| fmse = [] | |
| psnr = [] | |
| ssim = [] | |
| for id in range(n): | |
| # fg = (mask_batch[id]).view(-1) | |
| # fg_pixels = int(torch.sum(fg).cpu().numpy()) | |
| # total_pixels = h * w | |
| # | |
| # pred = torch.clamp(harmonized[id] * 255, 0, 255) | |
| # gt = torch.clamp(real[id] * 255, 0, 255) | |
| # | |
| # pred = pred.permute(1, 2, 0).cpu().numpy() | |
| # gt = gt.permute(1, 2, 0).cpu().numpy() | |
| # mask = mask_batch[id].permute(1, 2, 0).cpu().numpy() | |
| # | |
| # mse.append(skimage.metrics.mean_squared_error(pred, gt)) | |
| # fmse.append(skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels) | |
| # psnr.append(skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min())) | |
| # ssim.append(skimage.metrics.structural_similarity(pred, gt, multichannel=True)) | |
| mse.append(MSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
| fmse.append(fMSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
| psnr.append(PSNR(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
| ssim.append(SSIM(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
| return mse, fmse, psnr, ssim | |
| def SSIM(pred, target_image, mask): | |
| pred = pred * mask + (target_image) * (1 - mask) | |
| return ssim(pred.unsqueeze(0), target_image.unsqueeze(0)) | |
| def MSE(pred, target_image, mask): | |
| return (mask * (pred - target_image) ** 2).mean().item() | |
| def fMSE(pred, target_image, mask): | |
| diff = mask * ((pred - target_image) ** 2) | |
| return (diff.sum() / (diff.size(0) * mask.sum() + 1e-6)).item() | |
| def PSNR(pred, target_image, mask): | |
| mse = (mask * (pred - target_image) ** 2).mean().item() | |
| squared_max = target_image.max().item() ** 2 | |
| return 10 * math.log10(squared_max / (mse + 1e-6)) |