| | import torch |
| | from skimage.metrics import structural_similarity |
| | import numpy as np |
| |
|
| | @torch.no_grad() |
| | def compute_ssim(ground_truth, predicted, full=True): |
| | |
| | |
| | ssim = [ |
| | structural_similarity( |
| | gt.detach().cpu().numpy(), |
| | hat.detach().cpu().numpy(), |
| | win_size=11, |
| | gaussian_weights=True, |
| | channel_axis=0, |
| | data_range=1.0, |
| | full=full, |
| | ) |
| | for gt, hat in zip(ground_truth, predicted) |
| | ] |
| | if full: |
| | ssim = [spatial for _, spatial in ssim] |
| | ssim = np.array(ssim) |
| | ssim = torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) |
| | assert not torch.isnan(ssim).any(), "SSIM has NaNs" |
| | return ssim |
| |
|