"""PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007). Direct Python translation of the MATLAB reference implementation at https://www.ponomarenko.info/psnrhvsm.m Returns (p_hvs_m, p_hvs) as a tuple. Uses CUDA if available, otherwise falls back to CPU. """ import math import numpy as np import torch _N = 8 def _make_dct_matrix() -> torch.Tensor: """8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N)).""" k = torch.arange(_N, dtype=torch.float64).unsqueeze(1) n = torch.arange(_N, dtype=torch.float64).unsqueeze(0) D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N)) D[0] = D[0] / math.sqrt(_N) D[1:] = D[1:] * math.sqrt(2.0 / _N) return D _DCT8 = _make_dct_matrix() # (8, 8), CPU float64 _CSF = torch.tensor( [ [1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887], [2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911], [1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555], [1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082], [1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222], [1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729], [0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803], [0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950], ], dtype=torch.float64, ) _MASKCOF = torch.tensor( [ [0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874], [0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058], [0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888], [0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015], [0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866], [0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815], [0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803], [0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203], ], dtype=torch.float64, ) # True everywhere except the DC coefficient at (0, 0) _AC_MASK = torch.ones((_N, _N), dtype=torch.bool) _AC_MASK[0, 0] = False def _vari_batch(blocks: torch.Tensor) -> torch.Tensor: """Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)""" flat = blocks.reshape(blocks.shape[0], -1) return flat.var(dim=-1, correction=1) * flat.shape[-1] def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor: """Perceptual masking strength for a batch of 8x8 blocks. Returns (B,).""" dev = blocks.device ac = _AC_MASK.to(dev) mc = _MASKCOF.to(dev) m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1) # (B,) pop = _vari_batch(blocks) quad = ( _vari_batch(blocks[:, :4, :4]) + _vari_batch(blocks[:, :4, 4:]) + _vari_batch(blocks[:, 4:, :4]) + _vari_batch(blocks[:, 4:, 4:]) ) pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop)) return torch.sqrt(m * pop_ratio) / 32.0 def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]: """Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays. Direct translation of the MATLAB reference (Ponomarenko et al.). Partial edge blocks are skipped (truncate to nearest multiple of 8). """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") D = _DCT8.to(device) csf = _CSF.to(device) maskcof = _MASKCOF.to(device) ac_mask = _AC_MASK.to(device) a = torch.from_numpy(img1.astype(np.float64)).to(device) b = torch.from_numpy(img2.astype(np.float64)).to(device) h, w = a.shape h = (h // 8) * 8 w = (w // 8) * 8 a = a[:h, :w] b = b[:h, :w] num_blocks = (h // 8) * (w // 8) if num_blocks == 0: return 100000.0, 100000.0 # Extract all non-overlapping 8x8 blocks: (B, 8, 8) ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8) bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8) # 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T da = D @ ba @ D.t() db = D @ bb @ D.t() mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db)) # (B,) diff = torch.abs(da - db) # (B, 8, 8) # PSNR-HVS: CSF-weighted squared error (no masking) S2 = float(((diff * csf) ** 2).sum()) # PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is thresh = mask[:, None, None] / maskcof[None, :, :] u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff) S1 = float(((u * csf) ** 2).sum()) denom = num_blocks * 64 S1 /= denom S2 /= denom p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1)) p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2)) return p_hvs_m, p_hvs