from __future__ import annotations """Telemetry helpers. The functions in this module provide small wrappers used by experiments to compute a set of diagnostic metrics for a batch of data. The return value of :func:`batch_telemetry` is a flat dictionary that can directly be fed to logging utilities such as :mod:`tensorboard` or :func:`print`. """ from typing import Dict, Sequence import numpy as np import torch from .metrics import ( gzip_ratio, interference_index, spectral_entropy_2d, symbiosis, ) # --------------------------------------------------------------------------- # public API # --------------------------------------------------------------------------- def batch_telemetry( Y: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, fidelity_scores: Sequence[float], orthogonality_scores: Sequence[float], energy_scores: Sequence[float], ) -> Dict[str, float]: """Return telemetry metrics for a batch. Parameters ---------- Y, keys, values: See :func:`wrinklebrane.metrics.interference_index` for a description of these tensors. fidelity_scores, orthogonality_scores, energy_scores: Per-item measurements that describe the batch. ``fidelity_scores`` is typically a list of PSNR/SSIM values. These are combined with the internally computed ``K``/``C`` statistics to form the symbiosis score ``S``. Returns ------- dict Dictionary with the mean ``K`` and ``C`` values, the symbiosis score ``S`` and the interference index ``I``. """ if values.ndim != 3: raise ValueError("values must have shape [B,H,W]") # K (negentropy) and C (complexity) are computed per item and then # averaged to obtain a batch level statistic. K_scores = [spectral_entropy_2d(v) for v in values] C_scores = [gzip_ratio(v) for v in values] S = symbiosis( fidelity_scores, orthogonality_scores, energy_scores, K_scores, C_scores, ) I = interference_index(Y, keys, values) return { "K": float(np.mean(K_scores)), "C": float(np.mean(C_scores)), "S": S, "I": I, }