| | import os |
| | import sys |
| | import torch |
| | import numpy as np |
| |
|
| | sys.path.append(os.path.dirname(os.getcwd())) |
| | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) |
| | from util import pytorch_ssim |
| |
|
| | class Metric(object): |
| | """Base class for all metrics. |
| | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py |
| | """ |
| |
|
| | def reset(self): pass |
| | def add(self): pass |
| | def value(self): pass |
| |
|
| |
|
| | def img_metrics(target, pred, var=None, pixelwise=True): |
| | rmse = torch.sqrt(torch.mean(torch.square(target - pred))) |
| | psnr = 20 * torch.log10(1 / rmse) |
| | mae = torch.mean(torch.abs(target - pred)) |
| | |
| | |
| | mat = target * pred |
| | mat = torch.sum(mat, 1) |
| | mat = torch.div(mat, torch.sqrt(torch.sum(target * target, 1))) |
| | mat = torch.div(mat, torch.sqrt(torch.sum(pred * pred, 1))) |
| | sam = torch.mean(torch.acos(torch.clamp(mat, -1, 1))*torch.tensor(180)/torch.pi) |
| |
|
| | ssim = pytorch_ssim.ssim(target, pred) |
| |
|
| | metric_dict = {'RMSE': rmse.cpu().numpy().item(), |
| | 'MAE': mae.cpu().numpy().item(), |
| | 'PSNR': psnr.cpu().numpy().item(), |
| | 'SAM': sam.cpu().numpy().item(), |
| | 'SSIM': ssim.cpu().numpy().item()} |
| |
|
| | |
| | if var is not None: |
| | error = target - pred |
| | |
| | se = torch.square(error) |
| | ae = torch.abs(error) |
| |
|
| | |
| | |
| | |
| | errvar_samplewise = {'error': error.nanmean().cpu().numpy().item(), |
| | 'mean ae': ae.nanmean().cpu().numpy().item(), |
| | 'mean se': se.nanmean().cpu().numpy().item(), |
| | 'mean var': var.nanmean().cpu().numpy().item()} |
| | if pixelwise: |
| | |
| | errvar_samplewise = {**errvar_samplewise, **{'pixelwise error': error.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
| | 'pixelwise ae': ae.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
| | 'pixelwise se': se.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
| | 'pixelwise var': var.nanmean(0).nanmean(0).flatten().cpu().numpy()}} |
| |
|
| | metric_dict = {**metric_dict, **errvar_samplewise} |
| |
|
| | return metric_dict |
| |
|
| | class avg_img_metrics(Metric): |
| | def __init__(self): |
| | super().__init__() |
| | self.n_samples = 0 |
| | self.metrics = ['RMSE', 'MAE', 'PSNR','SAM','SSIM'] |
| | self.metrics += ['error', 'mean se', 'mean ae', 'mean var'] |
| |
|
| | self.running_img_metrics = {} |
| | self.running_nonan_count = {} |
| | self.reset() |
| |
|
| | def reset(self): |
| | for metric in self.metrics: |
| | self.running_nonan_count[metric] = 0 |
| | self.running_img_metrics[metric] = np.nan |
| |
|
| | def add(self, metrics_dict): |
| | for key, val in metrics_dict.items(): |
| | |
| | if key not in self.metrics: continue |
| | |
| | if torch.is_tensor(val): continue |
| | if isinstance(val, tuple): val=val[0] |
| |
|
| | |
| | if np.isnan(val): continue |
| |
|
| | if not self.running_nonan_count[key]: |
| | self.running_nonan_count[key] = 1 |
| | self.running_img_metrics[key] = val |
| | else: |
| | self.running_nonan_count[key]+= 1 |
| | self.running_img_metrics[key] = (self.running_nonan_count[key]-1)/self.running_nonan_count[key] * self.running_img_metrics[key] \ |
| | + 1/self.running_nonan_count[key] * val |
| |
|
| | def value(self): |
| | return self.running_img_metrics |