|
|
from __future__ import annotations |
|
|
|
|
|
"""Write operations for the WrinkleBrane memory tensor. |
|
|
|
|
|
This module implements two small helper functions used by the tests and |
|
|
example scripts. The functions are intentionally written in a fully |
|
|
vectorised style so that they are easy to reason about and do not hide any |
|
|
state. Both functions expect all tensors to share the same device and dtype |
|
|
(except for ``keys`` which must be ``torch.long``) – any mismatch results in a |
|
|
``ValueError`` rather than silently converting the inputs. |
|
|
""" |
|
|
|
|
|
from typing import Iterable |
|
|
|
|
|
import torch |
|
|
|
|
|
__all__ = ["store_pairs", "energy_clamp"] |
|
|
|
|
|
|
|
|
def _check_device_dtype(reference: torch.Tensor, tensors: Iterable[torch.Tensor]) -> None: |
|
|
"""Raise ``ValueError`` if any tensor differs in device or dtype.""" |
|
|
|
|
|
for t in tensors: |
|
|
if t.device != reference.device: |
|
|
raise ValueError("all tensors must reside on the same device") |
|
|
if t.dtype != reference.dtype: |
|
|
raise ValueError("all tensors must share the same dtype") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def store_pairs( |
|
|
M: torch.Tensor, |
|
|
C: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
alphas: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Return ``M`` with key–value pairs written to it. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
M: |
|
|
Current membranes with shape ``[B, L, H, W]``. |
|
|
C: |
|
|
Codebook tensor of shape ``[L, K]``. Column ``k`` contains the code |
|
|
used when storing a pair whose key is ``k``. |
|
|
keys: |
|
|
Long tensor of shape ``[T]`` with key indices in ``[0, K)``. |
|
|
values: |
|
|
Real-valued tensor of shape ``[T, H, W]`` containing the maps to be |
|
|
written. |
|
|
alphas: |
|
|
Tensor of shape ``[T]`` specifying a gain for each pair. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Updated membrane tensor. The update is performed without mutating |
|
|
``M`` – a new tensor containing the result is returned. |
|
|
""" |
|
|
|
|
|
if M.ndim != 4: |
|
|
raise ValueError("M must have shape [B, L, H, W]") |
|
|
B, L, H, W = M.shape |
|
|
|
|
|
if C.shape[0] != L: |
|
|
raise ValueError("codebook C must have shape [L, K]") |
|
|
K = C.shape[1] |
|
|
|
|
|
if keys.ndim != 1: |
|
|
raise ValueError("keys must be one-dimensional") |
|
|
T = keys.shape[0] |
|
|
|
|
|
if values.shape != (T, H, W): |
|
|
raise ValueError("values must have shape [T, H, W]") |
|
|
if alphas.shape != (T,): |
|
|
raise ValueError("alphas must have shape [T]") |
|
|
|
|
|
if keys.dtype != torch.long: |
|
|
raise ValueError("keys must be of dtype torch.long") |
|
|
|
|
|
_check_device_dtype(M, (C, values, alphas)) |
|
|
|
|
|
if torch.any((keys < 0) | (keys >= K)): |
|
|
raise ValueError("keys contain indices outside the valid range") |
|
|
|
|
|
|
|
|
codes = C[:, keys] * alphas.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
delta = torch.einsum("lt,thw->lhw", codes, values) |
|
|
|
|
|
|
|
|
return M + delta.unsqueeze(0) |
|
|
|
|
|
|
|
|
def energy_clamp(M: torch.Tensor, max_per_layer_energy: float) -> torch.Tensor: |
|
|
"""Clamp the L2 energy of each layer to ``max_per_layer_energy``. |
|
|
|
|
|
``energy`` refers to the L2 norm over the spatial dimensions ``H`` and |
|
|
``W`` for each ``[B, L]`` slice. If a layer's norm exceeds the supplied |
|
|
maximum it is scaled down so that its energy equals the threshold. Layers |
|
|
below the threshold remain unchanged. The function returns a new tensor |
|
|
and does not modify ``M`` in-place. |
|
|
""" |
|
|
|
|
|
if M.ndim != 4: |
|
|
raise ValueError("M must have shape [B, L, H, W]") |
|
|
if max_per_layer_energy <= 0: |
|
|
return M |
|
|
|
|
|
B, L, H, W = M.shape |
|
|
flat = M.view(B, L, -1) |
|
|
norms = torch.linalg.norm(flat, dim=2) |
|
|
|
|
|
eps = torch.finfo(M.dtype).eps |
|
|
scales = (max_per_layer_energy / norms.clamp_min(eps)).clamp(max=1.0) |
|
|
scales = scales.view(B, L, 1, 1) |
|
|
|
|
|
return M * scales |
|
|
|
|
|
|