from __future__ import annotations """Utility metrics for WrinkleBrane. This module collects small numerical helpers used throughout the project. The functions are intentionally lightweight wrappers around well known libraries so that they remain easy to test and reason about. """ from typing import Sequence import gzip import math import numpy as np import torch from skimage.metrics import ( mean_squared_error as sk_mse, peak_signal_noise_ratio as sk_psnr, structural_similarity as sk_ssim, ) # --------------------------------------------------------------------------- # Basic fidelity metrics # --------------------------------------------------------------------------- def mse(A: np.ndarray, B: np.ndarray) -> float: """Return the mean squared error between ``A`` and ``B``. This is a thin wrapper around :func:`skimage.metrics.mean_squared_error` so that the project has a single place from which to import it. """ return float(sk_mse(A, B)) def psnr(A: np.ndarray, B: np.ndarray, data_range: float = 1.0) -> float: """Return the peak signal to noise ratio between ``A`` and ``B``.""" return float(sk_psnr(A, B, data_range=data_range)) def ssim(A: np.ndarray, B: np.ndarray) -> float: """Return the structural similarity index between ``A`` and ``B``.""" return float(sk_ssim(A, B, data_range=float(np.max(A) - np.min(A) or 1))) # --------------------------------------------------------------------------- # Information theoretic helpers # --------------------------------------------------------------------------- def spectral_entropy_2d(img: torch.Tensor) -> float: """Return the spectral entropy of a 2‑D image. The entropy is computed over the power spectrum of the two dimensional FFT. The power is normalised to form a discrete probability distribution ``p`` and the Shannon entropy ``H(p)`` is returned. The result is further normalised by ``log(N)`` (``N`` = number of frequencies) so that the value lies in ``[0, 1]``. """ if img.ndim != 2: raise ValueError("expected a 2-D image tensor") F = torch.fft.fft2(img.to(torch.float32)) power = torch.abs(F) ** 2 flat = power.flatten() total = flat.sum() if total <= 0: return 0.0 p = flat / total eps = torch.finfo(p.dtype).eps entropy = -torch.sum(p * torch.log(p.clamp_min(eps))) entropy /= math.log(flat.numel()) return float(entropy) def gzip_ratio(tensor: torch.Tensor) -> float: """Return the gzip compression ratio of ``tensor``. The tensor is min–max normalised to ``[0, 255]`` and cast to ``uint8`` before being compressed with :func:`gzip.compress`. The ratio between compressed and raw byte lengths is returned. Lower values therefore indicate a more compressible (less complex) tensor. """ arr = tensor.detach().cpu().float() arr -= arr.min() maxv = arr.max() if maxv > 0: arr /= maxv arr = (arr * 255).round().clamp(0, 255).to(torch.uint8) raw = arr.numpy().tobytes() if len(raw) == 0: return 0.0 comp = gzip.compress(raw) return float(len(comp) / len(raw)) def interference_index( Y: torch.Tensor, keys: torch.Tensor, values: torch.Tensor ) -> float: """Return the RMS error at channels that do not match ``keys``. Parameters ---------- Y: Retrieved tensor of shape ``[B, K, H, W]``. keys: Index of the correct channel for each batch item ``[B]``. values: Ground truth values with shape ``[B, H, W]`` used to construct the expected output. """ if Y.ndim != 4: raise ValueError("Y must have shape [B,K,H,W]") B, K, H, W = Y.shape if keys.shape != (B,): raise ValueError("keys must have shape [B]") if values.shape != (B, H, W): raise ValueError("values must have shape [B,H,W]") target = torch.zeros_like(Y) target[torch.arange(B), keys] = values err = Y - target mask = torch.ones_like(Y, dtype=torch.bool) mask[torch.arange(B), keys] = False mse = (err[mask] ** 2).mean() return float(torch.sqrt(mse)) def symbiosis( fidelity_scores: Sequence[float], orthogonality_scores: Sequence[float], energy_scores: Sequence[float], K_scores: Sequence[float], C_scores: Sequence[float], ) -> float: """Return a simple composite of Pearson correlations. ``symbiosis`` evaluates how the fidelity of retrieval correlates with four other quantities: orthogonality, energy, ``K`` (spectral entropy) and ``C`` (gzip ratio). For arrays of equal length ``F``, ``O``, ``E``, ``K`` and ``C`` the metric is defined as:: S = mean([ corr(F, O), corr(F, E), corr(F, K), corr(F, C) ]) where ``corr`` denotes the sample Pearson correlation coefficient. If any of the inputs is constant the corresponding correlation is treated as zero. The final result is the arithmetic mean of the four coefficients. """ F = np.asarray(fidelity_scores, dtype=float) O = np.asarray(orthogonality_scores, dtype=float) E = np.asarray(energy_scores, dtype=float) K_ = np.asarray(K_scores, dtype=float) C_ = np.asarray(C_scores, dtype=float) n = len(F) if not (n and len(O) == n and len(E) == n and len(K_) == n and len(C_) == n): raise ValueError("all score sequences must have the same non-zero length") def _corr(a: np.ndarray, b: np.ndarray) -> float: if np.allclose(a, a[0]) or np.allclose(b, b[0]): return 0.0 return float(np.corrcoef(a, b)[0, 1]) corrs = [_corr(F, O), _corr(F, E), _corr(F, K_), _corr(F, C_)] return float(np.mean(corrs))