File size: 5,140 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from functools import cache

import torch
from einops import reduce
from jaxtyping import Float
# from lpips import LPIPS
# from skimage.metrics import structural_similarity
from torch import Tensor
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm


@torch.no_grad()
def compute_psnr(
        ground_truth: Float[Tensor, "batch channel height width"],
        predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, ""]:
    ground_truth = ground_truth.clip(min=0, max=1)
    predicted = predicted.clip(min=0, max=1)
    # Use native torch ops instead of einops reduce for speed
    mse = ((ground_truth - predicted) ** 2).mean(dim=(1, 2, 3))  # [b]
    return -10 * mse.log10().mean()


@cache
def get_alex_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity:
    return LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True, reduction='none').to(device)


@cache
def get_vgg_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity:
    return LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True, reduction='none').to(device)


@torch.no_grad()
def compute_lpips(
        ground_truth: Float[Tensor, "batch channel height width"],
        predicted: Float[Tensor, "batch channel height width"],
) -> tuple[Float[Tensor, ""], Float[Tensor, ""]]:
    predicted = torch.clamp(predicted, 0.0, 1.0)
    ground_truth = torch.clamp(ground_truth, 0.0, 1.0)
    vgg_value = get_vgg_lpips(predicted.device)(ground_truth, predicted).mean()
    # Note: skipping alex lpips for efficiency, always return 0.
    # alex_value = get_alex_lpips(predicted.device)(ground_truth, predicted)
    alex_value = torch.zeros_like(vgg_value).mean()
    return alex_value, vgg_value


@cache
def get_ssim(device: torch.device) -> StructuralSimilarityIndexMeasure:
    return StructuralSimilarityIndexMeasure(data_range=1.0, reduction='none').to(device)


@torch.no_grad()
def compute_ssim(
        ground_truth: Float[Tensor, "batch channel height width"],
        predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, ""]:
    predicted = torch.clamp(predicted, 0.0, 1.0)
    ground_truth = torch.clamp(ground_truth, 0.0, 1.0)
    ssim_value = get_ssim(predicted.device)(predicted, ground_truth).mean()
    return ssim_value


metric_fn_dict = {
    "psnr": compute_psnr,
    "ssim": compute_ssim,
    "lpips": compute_lpips,
}


def compute_rgb_metrics(rgb, rgb_gt, metrics: list[str], iter_batch_size: int = -1) -> dict:
    metric_scores = {}
    for m in metrics:

        # check if metric is recognized
        if m not in metric_fn_dict:
            raise ValueError(f"Metric {m} not recognized. Available metrics: {list(metric_fn_dict.keys())}")

        # compute metric score
        if iter_batch_size == -1:
            # compute all at once
            # move back to device
            rgb = rgb.to("cuda")
            rgb_gt = rgb_gt.to("cuda")
            score = metric_fn_dict[m](rgb_gt, rgb)
            # can be tuple (for lpips) or single tensor

        else:
            # batchify to save memory
            all_batches_scores = []
            batch_sizes = []

            batch_num = rgb.shape[0] // iter_batch_size + int(rgb.shape[0] % iter_batch_size != 0)
            for i in tqdm(range(0, rgb.shape[0], iter_batch_size), disable=batch_num < 20,
                          desc=f"Computing {m} in batches"):
                bs = min(iter_batch_size, rgb.shape[0] - i)
                rgb_batch = rgb[i:i + bs].to("cuda")
                rgb_gt_batch = rgb_gt[i:i + bs].to("cuda")
                batch_scores = metric_fn_dict[m](rgb_gt_batch, rgb_batch)
                # can be tuple (for lpips) or single tensor

                all_batches_scores.append(batch_scores)
                batch_sizes.append(bs)

            assert len(all_batches_scores) > 0, "No batch scores computed."

            # Use weighted mean to avoid bias when the last batch is smaller than iter_batch_size.
            # Each batch score is the mean over `bs` images, so we weight by bs to recover the
            # true per-image mean across all N images.
            weights = torch.tensor(batch_sizes, dtype=torch.float32)
            total = weights.sum()
            first = all_batches_scores[0]

            # Case 1: scalar tensors
            if isinstance(first, torch.Tensor):
                vals = torch.stack(all_batches_scores).cpu()
                score = (vals * weights).sum() / total

            # Case 2: tuples of tensors
            elif isinstance(first, tuple):
                n = len(first)
                cols = [torch.stack([batch[i] for batch in all_batches_scores]).cpu() for i in range(n)]
                score = tuple((col * weights).sum() / total for col in cols)

            else:
                raise TypeError("Unexpected element type: must be torch.Tensor or tuple of torch.Tensors.")

        # append to scores list
        metric_scores[m] = score

    return metric_scores