File size: 5,193 Bytes
3050f1b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007).
Direct Python translation of the MATLAB reference implementation at
https://www.ponomarenko.info/psnrhvsm.m
Returns (p_hvs_m, p_hvs) as a tuple.
Uses CUDA if available, otherwise falls back to CPU.
"""
import math
import numpy as np
import torch
_N = 8
def _make_dct_matrix() -> torch.Tensor:
"""8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N))."""
k = torch.arange(_N, dtype=torch.float64).unsqueeze(1)
n = torch.arange(_N, dtype=torch.float64).unsqueeze(0)
D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N))
D[0] = D[0] / math.sqrt(_N)
D[1:] = D[1:] * math.sqrt(2.0 / _N)
return D
_DCT8 = _make_dct_matrix() # (8, 8), CPU float64
_CSF = torch.tensor(
[
[1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887],
[2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911],
[1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555],
[1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082],
[1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222],
[1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729],
[0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803],
[0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950],
],
dtype=torch.float64,
)
_MASKCOF = torch.tensor(
[
[0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874],
[0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058],
[0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888],
[0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015],
[0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866],
[0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815],
[0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803],
[0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203],
],
dtype=torch.float64,
)
# True everywhere except the DC coefficient at (0, 0)
_AC_MASK = torch.ones((_N, _N), dtype=torch.bool)
_AC_MASK[0, 0] = False
def _vari_batch(blocks: torch.Tensor) -> torch.Tensor:
"""Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)"""
flat = blocks.reshape(blocks.shape[0], -1)
return flat.var(dim=-1, correction=1) * flat.shape[-1]
def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor:
"""Perceptual masking strength for a batch of 8x8 blocks. Returns (B,)."""
dev = blocks.device
ac = _AC_MASK.to(dev)
mc = _MASKCOF.to(dev)
m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1) # (B,)
pop = _vari_batch(blocks)
quad = (
_vari_batch(blocks[:, :4, :4])
+ _vari_batch(blocks[:, :4, 4:])
+ _vari_batch(blocks[:, 4:, :4])
+ _vari_batch(blocks[:, 4:, 4:])
)
pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop))
return torch.sqrt(m * pop_ratio) / 32.0
def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]:
"""Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays.
Direct translation of the MATLAB reference (Ponomarenko et al.).
Partial edge blocks are skipped (truncate to nearest multiple of 8).
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
D = _DCT8.to(device)
csf = _CSF.to(device)
maskcof = _MASKCOF.to(device)
ac_mask = _AC_MASK.to(device)
a = torch.from_numpy(img1.astype(np.float64)).to(device)
b = torch.from_numpy(img2.astype(np.float64)).to(device)
h, w = a.shape
h = (h // 8) * 8
w = (w // 8) * 8
a = a[:h, :w]
b = b[:h, :w]
num_blocks = (h // 8) * (w // 8)
if num_blocks == 0:
return 100000.0, 100000.0
# Extract all non-overlapping 8x8 blocks: (B, 8, 8)
ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
# 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T
da = D @ ba @ D.t()
db = D @ bb @ D.t()
mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db)) # (B,)
diff = torch.abs(da - db) # (B, 8, 8)
# PSNR-HVS: CSF-weighted squared error (no masking)
S2 = float(((diff * csf) ** 2).sum())
# PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is
thresh = mask[:, None, None] / maskcof[None, :, :]
u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff)
S1 = float(((u * csf) ** 2).sum())
denom = num_blocks * 64
S1 /= denom
S2 /= denom
p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1))
p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2))
return p_hvs_m, p_hvs
|