Spaces:
Sleeping
Sleeping
| from functools import cache | |
| import torch | |
| from einops import reduce | |
| from jaxtyping import Float | |
| # from lpips import LPIPS | |
| # from skimage.metrics import structural_similarity | |
| from torch import Tensor | |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure | |
| from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity | |
| from tqdm import tqdm | |
| def compute_psnr( | |
| ground_truth: Float[Tensor, "batch channel height width"], | |
| predicted: Float[Tensor, "batch channel height width"], | |
| ) -> Float[Tensor, ""]: | |
| ground_truth = ground_truth.clip(min=0, max=1) | |
| predicted = predicted.clip(min=0, max=1) | |
| # Use native torch ops instead of einops reduce for speed | |
| mse = ((ground_truth - predicted) ** 2).mean(dim=(1, 2, 3)) # [b] | |
| return -10 * mse.log10().mean() | |
| def get_alex_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity: | |
| return LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True, reduction='none').to(device) | |
| def get_vgg_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity: | |
| return LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True, reduction='none').to(device) | |
| def compute_lpips( | |
| ground_truth: Float[Tensor, "batch channel height width"], | |
| predicted: Float[Tensor, "batch channel height width"], | |
| ) -> tuple[Float[Tensor, ""], Float[Tensor, ""]]: | |
| predicted = torch.clamp(predicted, 0.0, 1.0) | |
| ground_truth = torch.clamp(ground_truth, 0.0, 1.0) | |
| vgg_value = get_vgg_lpips(predicted.device)(ground_truth, predicted).mean() | |
| # Note: skipping alex lpips for efficiency, always return 0. | |
| # alex_value = get_alex_lpips(predicted.device)(ground_truth, predicted) | |
| alex_value = torch.zeros_like(vgg_value).mean() | |
| return alex_value, vgg_value | |
| def get_ssim(device: torch.device) -> StructuralSimilarityIndexMeasure: | |
| return StructuralSimilarityIndexMeasure(data_range=1.0, reduction='none').to(device) | |
| def compute_ssim( | |
| ground_truth: Float[Tensor, "batch channel height width"], | |
| predicted: Float[Tensor, "batch channel height width"], | |
| ) -> Float[Tensor, ""]: | |
| predicted = torch.clamp(predicted, 0.0, 1.0) | |
| ground_truth = torch.clamp(ground_truth, 0.0, 1.0) | |
| ssim_value = get_ssim(predicted.device)(predicted, ground_truth).mean() | |
| return ssim_value | |
| metric_fn_dict = { | |
| "psnr": compute_psnr, | |
| "ssim": compute_ssim, | |
| "lpips": compute_lpips, | |
| } | |
| def compute_rgb_metrics(rgb, rgb_gt, metrics: list[str], iter_batch_size: int = -1) -> dict: | |
| metric_scores = {} | |
| for m in metrics: | |
| # check if metric is recognized | |
| if m not in metric_fn_dict: | |
| raise ValueError(f"Metric {m} not recognized. Available metrics: {list(metric_fn_dict.keys())}") | |
| # compute metric score | |
| if iter_batch_size == -1: | |
| # compute all at once | |
| # move back to device | |
| rgb = rgb.to("cuda") | |
| rgb_gt = rgb_gt.to("cuda") | |
| score = metric_fn_dict[m](rgb_gt, rgb) | |
| # can be tuple (for lpips) or single tensor | |
| else: | |
| # batchify to save memory | |
| all_batches_scores = [] | |
| batch_sizes = [] | |
| batch_num = rgb.shape[0] // iter_batch_size + int(rgb.shape[0] % iter_batch_size != 0) | |
| for i in tqdm(range(0, rgb.shape[0], iter_batch_size), disable=batch_num < 20, | |
| desc=f"Computing {m} in batches"): | |
| bs = min(iter_batch_size, rgb.shape[0] - i) | |
| rgb_batch = rgb[i:i + bs].to("cuda") | |
| rgb_gt_batch = rgb_gt[i:i + bs].to("cuda") | |
| batch_scores = metric_fn_dict[m](rgb_gt_batch, rgb_batch) | |
| # can be tuple (for lpips) or single tensor | |
| all_batches_scores.append(batch_scores) | |
| batch_sizes.append(bs) | |
| assert len(all_batches_scores) > 0, "No batch scores computed." | |
| # Use weighted mean to avoid bias when the last batch is smaller than iter_batch_size. | |
| # Each batch score is the mean over `bs` images, so we weight by bs to recover the | |
| # true per-image mean across all N images. | |
| weights = torch.tensor(batch_sizes, dtype=torch.float32) | |
| total = weights.sum() | |
| first = all_batches_scores[0] | |
| # Case 1: scalar tensors | |
| if isinstance(first, torch.Tensor): | |
| vals = torch.stack(all_batches_scores).cpu() | |
| score = (vals * weights).sum() / total | |
| # Case 2: tuples of tensors | |
| elif isinstance(first, tuple): | |
| n = len(first) | |
| cols = [torch.stack([batch[i] for batch in all_batches_scores]).cpu() for i in range(n)] | |
| score = tuple((col * weights).sum() / total for col in cols) | |
| else: | |
| raise TypeError("Unexpected element type: must be torch.Tensor or tuple of torch.Tensors.") | |
| # append to scores list | |
| metric_scores[m] = score | |
| return metric_scores | |