Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from rstor.learning.metrics import compute_psnr, compute_ssim, compute_metrics, compute_lpips | |
| from rstor.properties import REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM, DEVICE | |
| from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS | |
| def test_compute_psnr(): | |
| # Test case 1: Identical values | |
| predic = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) | |
| target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) | |
| assert torch.isinf(compute_psnr(predic, target, clamp_mse=0)), "Test case 1 failed" | |
| # Test case 2: Predic and target have different values | |
| predic = torch.tensor([[[[0., 0.], [0., 0.]]]]) | |
| target = torch.tensor([[[[0.25, 0.25], [0.25, 0.25]]]]) | |
| assert compute_psnr(predic, target).item() == (10. * torch.log10(torch.Tensor([4.**2]))).item() # 12db | |
| print("All tests passed.") | |
| def test_compute_ssim(): | |
| x = torch.rand(8, 3, 256, 256) | |
| y = torch.rand(8, 3, 256, 256) | |
| ssim = compute_ssim(x, y, reduction=REDUCTION_AVERAGE) | |
| ssim_per_unit = compute_ssim(x, y, reduction=REDUCTION_SKIP) | |
| assert ssim_per_unit.shape == (8,), "SSIM Test case 1 failed" | |
| assert ssim_per_unit.mean() == ssim, "SSIM Test case 2 failed" | |
| def test_compute_lpips(): | |
| for i in range(2): | |
| x = torch.rand(8, 3, 256, 256).to(DEVICE) | |
| y = torch.rand(8, 3, 256, 256).to(DEVICE) | |
| lpips = compute_lpips(x, y, reduction=REDUCTION_AVERAGE) | |
| lpips_per_unit = compute_lpips(x, y, reduction=REDUCTION_SKIP) | |
| assert lpips_per_unit.shape == (8,), "LPIPS Test case 1 failed" | |
| assert torch.isclose(lpips_per_unit.mean(), lpips), "LPIPS Test case 2 failed" | |
| def test_compute_metrics(): | |
| x = torch.rand(8, 3, 256, 256) # negative value ensures that we check clamping for LPIPS | |
| y = x.clone() + torch.randn(8, 3, 256, 256) * 0.01 | |
| metrics = compute_metrics(x, y) | |
| print(metrics) | |
| metric_per_image = compute_metrics(x, y, reduction=REDUCTION_SKIP) | |
| metric_sum_reduction = compute_metrics(x, y, reduction=REDUCTION_SUM) | |
| assert metric_per_image[METRIC_PSNR].shape == (8,), "Metrics Test case 1 failed" | |
| assert metric_per_image[METRIC_SSIM].shape == (8,), "Metrics Test case 2 failed" | |
| assert metric_per_image[METRIC_LPIPS].shape == (8,), "Metrics Test case 3 failed" | |
| assert np.isclose(metric_per_image[METRIC_PSNR].mean().item(), metrics[METRIC_PSNR]), "Metrics Test case 4 failed" | |
| assert np.isclose(metric_per_image[METRIC_PSNR].sum().item(), | |
| metric_sum_reduction[METRIC_PSNR]), "Metrics Test case 5 failed" | |
| assert np.isclose(metrics[METRIC_PSNR], | |
| metric_sum_reduction[METRIC_PSNR]/8.), "Metrics Test case 6 failed" | |