File size: 4,105 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations

"""Write operations for the WrinkleBrane memory tensor.

This module implements two small helper functions used by the tests and
example scripts.  The functions are intentionally written in a fully
vectorised style so that they are easy to reason about and do not hide any
state.  Both functions expect all tensors to share the same device and dtype
(except for ``keys`` which must be ``torch.long``) – any mismatch results in a
``ValueError`` rather than silently converting the inputs.
"""

from typing import Iterable

import torch

__all__ = ["store_pairs", "energy_clamp"]


def _check_device_dtype(reference: torch.Tensor, tensors: Iterable[torch.Tensor]) -> None:
    """Raise ``ValueError`` if any tensor differs in device or dtype."""

    for t in tensors:
        if t.device != reference.device:
            raise ValueError("all tensors must reside on the same device")
        if t.dtype != reference.dtype:
            raise ValueError("all tensors must share the same dtype")


# ---------------------------------------------------------------------------
# write operations

def store_pairs(
    M: torch.Tensor,
    C: torch.Tensor,
    keys: torch.Tensor,
    values: torch.Tensor,
    alphas: torch.Tensor,
) -> torch.Tensor:
    """Return ``M`` with key–value pairs written to it.

    Parameters
    ----------
    M:
        Current membranes with shape ``[B, L, H, W]``.
    C:
        Codebook tensor of shape ``[L, K]``.  Column ``k`` contains the code
        used when storing a pair whose key is ``k``.
    keys:
        Long tensor of shape ``[T]`` with key indices in ``[0, K)``.
    values:
        Real-valued tensor of shape ``[T, H, W]`` containing the maps to be
        written.
    alphas:
        Tensor of shape ``[T]`` specifying a gain for each pair.

    Returns
    -------
    torch.Tensor
        Updated membrane tensor.  The update is performed without mutating
        ``M`` – a new tensor containing the result is returned.
    """

    if M.ndim != 4:
        raise ValueError("M must have shape [B, L, H, W]")
    B, L, H, W = M.shape

    if C.shape[0] != L:
        raise ValueError("codebook C must have shape [L, K]")
    K = C.shape[1]

    if keys.ndim != 1:
        raise ValueError("keys must be one-dimensional")
    T = keys.shape[0]

    if values.shape != (T, H, W):
        raise ValueError("values must have shape [T, H, W]")
    if alphas.shape != (T,):
        raise ValueError("alphas must have shape [T]")

    if keys.dtype != torch.long:
        raise ValueError("keys must be of dtype torch.long")

    _check_device_dtype(M, (C, values, alphas))

    if torch.any((keys < 0) | (keys >= K)):
        raise ValueError("keys contain indices outside the valid range")

    # Select the relevant columns from the codebook and scale by alphas
    codes = C[:, keys] * alphas.unsqueeze(0)  # [L, T]

    # Compute the sum over outer products in a vectorised fashion:
    #   ΔM = Σ_t codes[:, t] ⊗ values[t]
    delta = torch.einsum("lt,thw->lhw", codes, values)

    # Broadcast the update across the batch dimension and return the result
    return M + delta.unsqueeze(0)


def energy_clamp(M: torch.Tensor, max_per_layer_energy: float) -> torch.Tensor:
    """Clamp the L2 energy of each layer to ``max_per_layer_energy``.

    ``energy`` refers to the L2 norm over the spatial dimensions ``H`` and
    ``W`` for each ``[B, L]`` slice.  If a layer's norm exceeds the supplied
    maximum it is scaled down so that its energy equals the threshold.  Layers
    below the threshold remain unchanged.  The function returns a new tensor
    and does not modify ``M`` in-place.
    """

    if M.ndim != 4:
        raise ValueError("M must have shape [B, L, H, W]")
    if max_per_layer_energy <= 0:
        return M

    B, L, H, W = M.shape
    flat = M.view(B, L, -1)
    norms = torch.linalg.norm(flat, dim=2)  # [B, L]

    eps = torch.finfo(M.dtype).eps
    scales = (max_per_layer_energy / norms.clamp_min(eps)).clamp(max=1.0)
    scales = scales.view(B, L, 1, 1)

    return M * scales