Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | from functools import cache
import torch
from einops import reduce
from jaxtyping import Float
# from lpips import LPIPS
# from skimage.metrics import structural_similarity
from torch import Tensor
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm
@torch.no_grad()
def compute_psnr(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, ""]:
ground_truth = ground_truth.clip(min=0, max=1)
predicted = predicted.clip(min=0, max=1)
# Use native torch ops instead of einops reduce for speed
mse = ((ground_truth - predicted) ** 2).mean(dim=(1, 2, 3)) # [b]
return -10 * mse.log10().mean()
@cache
def get_alex_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity:
return LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True, reduction='none').to(device)
@cache
def get_vgg_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity:
return LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True, reduction='none').to(device)
@torch.no_grad()
def compute_lpips(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> tuple[Float[Tensor, ""], Float[Tensor, ""]]:
predicted = torch.clamp(predicted, 0.0, 1.0)
ground_truth = torch.clamp(ground_truth, 0.0, 1.0)
vgg_value = get_vgg_lpips(predicted.device)(ground_truth, predicted).mean()
# Note: skipping alex lpips for efficiency, always return 0.
# alex_value = get_alex_lpips(predicted.device)(ground_truth, predicted)
alex_value = torch.zeros_like(vgg_value).mean()
return alex_value, vgg_value
@cache
def get_ssim(device: torch.device) -> StructuralSimilarityIndexMeasure:
return StructuralSimilarityIndexMeasure(data_range=1.0, reduction='none').to(device)
@torch.no_grad()
def compute_ssim(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, ""]:
predicted = torch.clamp(predicted, 0.0, 1.0)
ground_truth = torch.clamp(ground_truth, 0.0, 1.0)
ssim_value = get_ssim(predicted.device)(predicted, ground_truth).mean()
return ssim_value
metric_fn_dict = {
"psnr": compute_psnr,
"ssim": compute_ssim,
"lpips": compute_lpips,
}
def compute_rgb_metrics(rgb, rgb_gt, metrics: list[str], iter_batch_size: int = -1) -> dict:
metric_scores = {}
for m in metrics:
# check if metric is recognized
if m not in metric_fn_dict:
raise ValueError(f"Metric {m} not recognized. Available metrics: {list(metric_fn_dict.keys())}")
# compute metric score
if iter_batch_size == -1:
# compute all at once
# move back to device
rgb = rgb.to("cuda")
rgb_gt = rgb_gt.to("cuda")
score = metric_fn_dict[m](rgb_gt, rgb)
# can be tuple (for lpips) or single tensor
else:
# batchify to save memory
all_batches_scores = []
batch_sizes = []
batch_num = rgb.shape[0] // iter_batch_size + int(rgb.shape[0] % iter_batch_size != 0)
for i in tqdm(range(0, rgb.shape[0], iter_batch_size), disable=batch_num < 20,
desc=f"Computing {m} in batches"):
bs = min(iter_batch_size, rgb.shape[0] - i)
rgb_batch = rgb[i:i + bs].to("cuda")
rgb_gt_batch = rgb_gt[i:i + bs].to("cuda")
batch_scores = metric_fn_dict[m](rgb_gt_batch, rgb_batch)
# can be tuple (for lpips) or single tensor
all_batches_scores.append(batch_scores)
batch_sizes.append(bs)
assert len(all_batches_scores) > 0, "No batch scores computed."
# Use weighted mean to avoid bias when the last batch is smaller than iter_batch_size.
# Each batch score is the mean over `bs` images, so we weight by bs to recover the
# true per-image mean across all N images.
weights = torch.tensor(batch_sizes, dtype=torch.float32)
total = weights.sum()
first = all_batches_scores[0]
# Case 1: scalar tensors
if isinstance(first, torch.Tensor):
vals = torch.stack(all_batches_scores).cpu()
score = (vals * weights).sum() / total
# Case 2: tuples of tensors
elif isinstance(first, tuple):
n = len(first)
cols = [torch.stack([batch[i] for batch in all_batches_scores]).cpu() for i in range(n)]
score = tuple((col * weights).sum() / total for col in cols)
else:
raise TypeError("Unexpected element type: must be torch.Tensor or tuple of torch.Tensors.")
# append to scores list
metric_scores[m] = score
return metric_scores
|