NisabaRelief / dev_scripts /util /psnr_hvsm.py
boatbomber's picture
Initial release
3050f1b
"""PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007).
Direct Python translation of the MATLAB reference implementation at
https://www.ponomarenko.info/psnrhvsm.m
Returns (p_hvs_m, p_hvs) as a tuple.
Uses CUDA if available, otherwise falls back to CPU.
"""
import math
import numpy as np
import torch
_N = 8
def _make_dct_matrix() -> torch.Tensor:
"""8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N))."""
k = torch.arange(_N, dtype=torch.float64).unsqueeze(1)
n = torch.arange(_N, dtype=torch.float64).unsqueeze(0)
D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N))
D[0] = D[0] / math.sqrt(_N)
D[1:] = D[1:] * math.sqrt(2.0 / _N)
return D
_DCT8 = _make_dct_matrix() # (8, 8), CPU float64
_CSF = torch.tensor(
[
[1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887],
[2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911],
[1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555],
[1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082],
[1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222],
[1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729],
[0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803],
[0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950],
],
dtype=torch.float64,
)
_MASKCOF = torch.tensor(
[
[0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874],
[0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058],
[0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888],
[0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015],
[0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866],
[0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815],
[0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803],
[0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203],
],
dtype=torch.float64,
)
# True everywhere except the DC coefficient at (0, 0)
_AC_MASK = torch.ones((_N, _N), dtype=torch.bool)
_AC_MASK[0, 0] = False
def _vari_batch(blocks: torch.Tensor) -> torch.Tensor:
"""Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)"""
flat = blocks.reshape(blocks.shape[0], -1)
return flat.var(dim=-1, correction=1) * flat.shape[-1]
def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor:
"""Perceptual masking strength for a batch of 8x8 blocks. Returns (B,)."""
dev = blocks.device
ac = _AC_MASK.to(dev)
mc = _MASKCOF.to(dev)
m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1) # (B,)
pop = _vari_batch(blocks)
quad = (
_vari_batch(blocks[:, :4, :4])
+ _vari_batch(blocks[:, :4, 4:])
+ _vari_batch(blocks[:, 4:, :4])
+ _vari_batch(blocks[:, 4:, 4:])
)
pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop))
return torch.sqrt(m * pop_ratio) / 32.0
def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]:
"""Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays.
Direct translation of the MATLAB reference (Ponomarenko et al.).
Partial edge blocks are skipped (truncate to nearest multiple of 8).
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
D = _DCT8.to(device)
csf = _CSF.to(device)
maskcof = _MASKCOF.to(device)
ac_mask = _AC_MASK.to(device)
a = torch.from_numpy(img1.astype(np.float64)).to(device)
b = torch.from_numpy(img2.astype(np.float64)).to(device)
h, w = a.shape
h = (h // 8) * 8
w = (w // 8) * 8
a = a[:h, :w]
b = b[:h, :w]
num_blocks = (h // 8) * (w // 8)
if num_blocks == 0:
return 100000.0, 100000.0
# Extract all non-overlapping 8x8 blocks: (B, 8, 8)
ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
# 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T
da = D @ ba @ D.t()
db = D @ bb @ D.t()
mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db)) # (B,)
diff = torch.abs(da - db) # (B, 8, 8)
# PSNR-HVS: CSF-weighted squared error (no masking)
S2 = float(((diff * csf) ** 2).sum())
# PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is
thresh = mask[:, None, None] / maskcof[None, :, :]
u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff)
S1 = float(((u * csf) ** 2).sum())
denom = num_blocks * 64
S1 /= denom
S2 /= denom
p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1))
p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2))
return p_hvs_m, p_hvs