File size: 5,193 Bytes
3050f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007).

Direct Python translation of the MATLAB reference implementation at
https://www.ponomarenko.info/psnrhvsm.m

Returns (p_hvs_m, p_hvs) as a tuple.
Uses CUDA if available, otherwise falls back to CPU.
"""

import math

import numpy as np
import torch

_N = 8


def _make_dct_matrix() -> torch.Tensor:
    """8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N))."""
    k = torch.arange(_N, dtype=torch.float64).unsqueeze(1)
    n = torch.arange(_N, dtype=torch.float64).unsqueeze(0)
    D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N))
    D[0] = D[0] / math.sqrt(_N)
    D[1:] = D[1:] * math.sqrt(2.0 / _N)
    return D


_DCT8 = _make_dct_matrix()  # (8, 8), CPU float64

_CSF = torch.tensor(
    [
        [1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887],
        [2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911],
        [1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555],
        [1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082],
        [1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222],
        [1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729],
        [0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803],
        [0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950],
    ],
    dtype=torch.float64,
)
_MASKCOF = torch.tensor(
    [
        [0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874],
        [0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058],
        [0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888],
        [0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015],
        [0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866],
        [0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815],
        [0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803],
        [0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203],
    ],
    dtype=torch.float64,
)

# True everywhere except the DC coefficient at (0, 0)
_AC_MASK = torch.ones((_N, _N), dtype=torch.bool)
_AC_MASK[0, 0] = False


def _vari_batch(blocks: torch.Tensor) -> torch.Tensor:
    """Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)"""
    flat = blocks.reshape(blocks.shape[0], -1)
    return flat.var(dim=-1, correction=1) * flat.shape[-1]


def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor:
    """Perceptual masking strength for a batch of 8x8 blocks. Returns (B,)."""
    dev = blocks.device
    ac = _AC_MASK.to(dev)
    mc = _MASKCOF.to(dev)

    m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1)  # (B,)

    pop = _vari_batch(blocks)
    quad = (
        _vari_batch(blocks[:, :4, :4])
        + _vari_batch(blocks[:, :4, 4:])
        + _vari_batch(blocks[:, 4:, :4])
        + _vari_batch(blocks[:, 4:, 4:])
    )
    pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop))
    return torch.sqrt(m * pop_ratio) / 32.0


def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]:
    """Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays.

    Direct translation of the MATLAB reference (Ponomarenko et al.).
    Partial edge blocks are skipped (truncate to nearest multiple of 8).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    D = _DCT8.to(device)
    csf = _CSF.to(device)
    maskcof = _MASKCOF.to(device)
    ac_mask = _AC_MASK.to(device)

    a = torch.from_numpy(img1.astype(np.float64)).to(device)
    b = torch.from_numpy(img2.astype(np.float64)).to(device)

    h, w = a.shape
    h = (h // 8) * 8
    w = (w // 8) * 8
    a = a[:h, :w]
    b = b[:h, :w]

    num_blocks = (h // 8) * (w // 8)
    if num_blocks == 0:
        return 100000.0, 100000.0

    # Extract all non-overlapping 8x8 blocks: (B, 8, 8)
    ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
    bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)

    # 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T
    da = D @ ba @ D.t()
    db = D @ bb @ D.t()

    mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db))  # (B,)

    diff = torch.abs(da - db)  # (B, 8, 8)

    # PSNR-HVS: CSF-weighted squared error (no masking)
    S2 = float(((diff * csf) ** 2).sum())

    # PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is
    thresh = mask[:, None, None] / maskcof[None, :, :]
    u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff)
    S1 = float(((u * csf) ** 2).sum())

    denom = num_blocks * 64
    S1 /= denom
    S2 /= denom
    p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1))
    p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2))
    return p_hvs_m, p_hvs