File size: 2,229 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

"""Telemetry helpers.

The functions in this module provide small wrappers used by experiments to
compute a set of diagnostic metrics for a batch of data.  The return value
of :func:`batch_telemetry` is a flat dictionary that can directly be fed to
logging utilities such as :mod:`tensorboard` or :func:`print`.
"""

from typing import Dict, Sequence

import numpy as np
import torch

from .metrics import (
    gzip_ratio,
    interference_index,
    spectral_entropy_2d,
    symbiosis,
)


# ---------------------------------------------------------------------------
# public API
# ---------------------------------------------------------------------------

def batch_telemetry(
    Y: torch.Tensor,
    keys: torch.Tensor,
    values: torch.Tensor,
    fidelity_scores: Sequence[float],
    orthogonality_scores: Sequence[float],
    energy_scores: Sequence[float],
) -> Dict[str, float]:
    """Return telemetry metrics for a batch.

    Parameters
    ----------
    Y, keys, values:
        See :func:`wrinklebrane.metrics.interference_index` for a description
        of these tensors.
    fidelity_scores, orthogonality_scores, energy_scores:
        Per-item measurements that describe the batch.  ``fidelity_scores``
        is typically a list of PSNR/SSIM values.  These are combined with
        the internally computed ``K``/``C`` statistics to form the symbiosis
        score ``S``.

    Returns
    -------
    dict
        Dictionary with the mean ``K`` and ``C`` values, the symbiosis score
        ``S`` and the interference index ``I``.
    """

    if values.ndim != 3:
        raise ValueError("values must have shape [B,H,W]")

    # K (negentropy) and C (complexity) are computed per item and then
    # averaged to obtain a batch level statistic.
    K_scores = [spectral_entropy_2d(v) for v in values]
    C_scores = [gzip_ratio(v) for v in values]

    S = symbiosis(
        fidelity_scores,
        orthogonality_scores,
        energy_scores,
        K_scores,
        C_scores,
    )

    I = interference_index(Y, keys, values)

    return {
        "K": float(np.mean(K_scores)),
        "C": float(np.mean(C_scores)),
        "S": S,
        "I": I,
    }