WCNegentropy's picture
๐Ÿ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
from __future__ import annotations
"""Telemetry helpers.
The functions in this module provide small wrappers used by experiments to
compute a set of diagnostic metrics for a batch of data. The return value
of :func:`batch_telemetry` is a flat dictionary that can directly be fed to
logging utilities such as :mod:`tensorboard` or :func:`print`.
"""
from typing import Dict, Sequence
import numpy as np
import torch
from .metrics import (
gzip_ratio,
interference_index,
spectral_entropy_2d,
symbiosis,
)
# ---------------------------------------------------------------------------
# public API
# ---------------------------------------------------------------------------
def batch_telemetry(
Y: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
fidelity_scores: Sequence[float],
orthogonality_scores: Sequence[float],
energy_scores: Sequence[float],
) -> Dict[str, float]:
"""Return telemetry metrics for a batch.
Parameters
----------
Y, keys, values:
See :func:`wrinklebrane.metrics.interference_index` for a description
of these tensors.
fidelity_scores, orthogonality_scores, energy_scores:
Per-item measurements that describe the batch. ``fidelity_scores``
is typically a list of PSNR/SSIM values. These are combined with
the internally computed ``K``/``C`` statistics to form the symbiosis
score ``S``.
Returns
-------
dict
Dictionary with the mean ``K`` and ``C`` values, the symbiosis score
``S`` and the interference index ``I``.
"""
if values.ndim != 3:
raise ValueError("values must have shape [B,H,W]")
# K (negentropy) and C (complexity) are computed per item and then
# averaged to obtain a batch level statistic.
K_scores = [spectral_entropy_2d(v) for v in values]
C_scores = [gzip_ratio(v) for v in values]
S = symbiosis(
fidelity_scores,
orthogonality_scores,
energy_scores,
K_scores,
C_scores,
)
I = interference_index(Y, keys, values)
return {
"K": float(np.mean(K_scores)),
"C": float(np.mean(C_scores)),
"S": S,
"I": I,
}