WCNegentropy's picture
📚 Updated with scientifically rigorous documentation
dc2b9f3 verified
from __future__ import annotations
"""Utility metrics for WrinkleBrane.
This module collects small numerical helpers used throughout the
project. The functions are intentionally lightweight wrappers around
well known libraries so that they remain easy to test and reason about.
"""
from typing import Sequence
import gzip
import math
import numpy as np
import torch
from skimage.metrics import (
mean_squared_error as sk_mse,
peak_signal_noise_ratio as sk_psnr,
structural_similarity as sk_ssim,
)
# ---------------------------------------------------------------------------
# Basic fidelity metrics
# ---------------------------------------------------------------------------
def mse(A: np.ndarray, B: np.ndarray) -> float:
"""Return the mean squared error between ``A`` and ``B``.
This is a thin wrapper around :func:`skimage.metrics.mean_squared_error`
so that the project has a single place from which to import it.
"""
return float(sk_mse(A, B))
def psnr(A: np.ndarray, B: np.ndarray, data_range: float = 1.0) -> float:
"""Return the peak signal to noise ratio between ``A`` and ``B``."""
return float(sk_psnr(A, B, data_range=data_range))
def ssim(A: np.ndarray, B: np.ndarray) -> float:
"""Return the structural similarity index between ``A`` and ``B``."""
return float(sk_ssim(A, B, data_range=float(np.max(A) - np.min(A) or 1)))
# ---------------------------------------------------------------------------
# Information theoretic helpers
# ---------------------------------------------------------------------------
def spectral_entropy_2d(img: torch.Tensor) -> float:
"""Return the spectral entropy of a 2‑D image.
The entropy is computed over the power spectrum of the two dimensional
FFT. The power is normalised to form a discrete probability
distribution ``p`` and the Shannon entropy ``H(p)`` is returned. The
result is further normalised by ``log(N)`` (``N`` = number of
frequencies) so that the value lies in ``[0, 1]``.
"""
if img.ndim != 2:
raise ValueError("expected a 2-D image tensor")
F = torch.fft.fft2(img.to(torch.float32))
power = torch.abs(F) ** 2
flat = power.flatten()
total = flat.sum()
if total <= 0:
return 0.0
p = flat / total
eps = torch.finfo(p.dtype).eps
entropy = -torch.sum(p * torch.log(p.clamp_min(eps)))
entropy /= math.log(flat.numel())
return float(entropy)
def gzip_ratio(tensor: torch.Tensor) -> float:
"""Return the gzip compression ratio of ``tensor``.
The tensor is min–max normalised to ``[0, 255]`` and cast to ``uint8``
before being compressed with :func:`gzip.compress`. The ratio between
compressed and raw byte lengths is returned. Lower values therefore
indicate a more compressible (less complex) tensor.
"""
arr = tensor.detach().cpu().float()
arr -= arr.min()
maxv = arr.max()
if maxv > 0:
arr /= maxv
arr = (arr * 255).round().clamp(0, 255).to(torch.uint8)
raw = arr.numpy().tobytes()
if len(raw) == 0:
return 0.0
comp = gzip.compress(raw)
return float(len(comp) / len(raw))
def interference_index(
Y: torch.Tensor, keys: torch.Tensor, values: torch.Tensor
) -> float:
"""Return the RMS error at channels that do not match ``keys``.
Parameters
----------
Y:
Retrieved tensor of shape ``[B, K, H, W]``.
keys:
Index of the correct channel for each batch item ``[B]``.
values:
Ground truth values with shape ``[B, H, W]`` used to construct the
expected output.
"""
if Y.ndim != 4:
raise ValueError("Y must have shape [B,K,H,W]")
B, K, H, W = Y.shape
if keys.shape != (B,):
raise ValueError("keys must have shape [B]")
if values.shape != (B, H, W):
raise ValueError("values must have shape [B,H,W]")
target = torch.zeros_like(Y)
target[torch.arange(B), keys] = values
err = Y - target
mask = torch.ones_like(Y, dtype=torch.bool)
mask[torch.arange(B), keys] = False
mse = (err[mask] ** 2).mean()
return float(torch.sqrt(mse))
def symbiosis(
fidelity_scores: Sequence[float],
orthogonality_scores: Sequence[float],
energy_scores: Sequence[float],
K_scores: Sequence[float],
C_scores: Sequence[float],
) -> float:
"""Return a simple composite of Pearson correlations.
``symbiosis`` evaluates how the fidelity of retrieval correlates with
four other quantities: orthogonality, energy, ``K`` (spectral entropy)
and ``C`` (gzip ratio). For arrays of equal length ``F``, ``O``, ``E``,
``K`` and ``C`` the metric is defined as::
S = mean([
corr(F, O),
corr(F, E),
corr(F, K),
corr(F, C)
])
where ``corr`` denotes the sample Pearson correlation coefficient. If
any of the inputs is constant the corresponding correlation is treated
as zero. The final result is the arithmetic mean of the four
coefficients.
"""
F = np.asarray(fidelity_scores, dtype=float)
O = np.asarray(orthogonality_scores, dtype=float)
E = np.asarray(energy_scores, dtype=float)
K_ = np.asarray(K_scores, dtype=float)
C_ = np.asarray(C_scores, dtype=float)
n = len(F)
if not (n and len(O) == n and len(E) == n and len(K_) == n and len(C_) == n):
raise ValueError("all score sequences must have the same non-zero length")
def _corr(a: np.ndarray, b: np.ndarray) -> float:
if np.allclose(a, a[0]) or np.allclose(b, b[0]):
return 0.0
return float(np.corrcoef(a, b)[0, 1])
corrs = [_corr(F, O), _corr(F, E), _corr(F, K_), _corr(F, C_)]
return float(np.mean(corrs))