WCNegentropy's picture
📚 Updated with scientifically rigorous documentation
dc2b9f3 verified
from __future__ import annotations
"""Write operations for the WrinkleBrane memory tensor.
This module implements two small helper functions used by the tests and
example scripts. The functions are intentionally written in a fully
vectorised style so that they are easy to reason about and do not hide any
state. Both functions expect all tensors to share the same device and dtype
(except for ``keys`` which must be ``torch.long``) – any mismatch results in a
``ValueError`` rather than silently converting the inputs.
"""
from typing import Iterable
import torch
__all__ = ["store_pairs", "energy_clamp"]
def _check_device_dtype(reference: torch.Tensor, tensors: Iterable[torch.Tensor]) -> None:
"""Raise ``ValueError`` if any tensor differs in device or dtype."""
for t in tensors:
if t.device != reference.device:
raise ValueError("all tensors must reside on the same device")
if t.dtype != reference.dtype:
raise ValueError("all tensors must share the same dtype")
# ---------------------------------------------------------------------------
# write operations
def store_pairs(
M: torch.Tensor,
C: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
alphas: torch.Tensor,
) -> torch.Tensor:
"""Return ``M`` with key–value pairs written to it.
Parameters
----------
M:
Current membranes with shape ``[B, L, H, W]``.
C:
Codebook tensor of shape ``[L, K]``. Column ``k`` contains the code
used when storing a pair whose key is ``k``.
keys:
Long tensor of shape ``[T]`` with key indices in ``[0, K)``.
values:
Real-valued tensor of shape ``[T, H, W]`` containing the maps to be
written.
alphas:
Tensor of shape ``[T]`` specifying a gain for each pair.
Returns
-------
torch.Tensor
Updated membrane tensor. The update is performed without mutating
``M`` – a new tensor containing the result is returned.
"""
if M.ndim != 4:
raise ValueError("M must have shape [B, L, H, W]")
B, L, H, W = M.shape
if C.shape[0] != L:
raise ValueError("codebook C must have shape [L, K]")
K = C.shape[1]
if keys.ndim != 1:
raise ValueError("keys must be one-dimensional")
T = keys.shape[0]
if values.shape != (T, H, W):
raise ValueError("values must have shape [T, H, W]")
if alphas.shape != (T,):
raise ValueError("alphas must have shape [T]")
if keys.dtype != torch.long:
raise ValueError("keys must be of dtype torch.long")
_check_device_dtype(M, (C, values, alphas))
if torch.any((keys < 0) | (keys >= K)):
raise ValueError("keys contain indices outside the valid range")
# Select the relevant columns from the codebook and scale by alphas
codes = C[:, keys] * alphas.unsqueeze(0) # [L, T]
# Compute the sum over outer products in a vectorised fashion:
# ΔM = Σ_t codes[:, t] ⊗ values[t]
delta = torch.einsum("lt,thw->lhw", codes, values)
# Broadcast the update across the batch dimension and return the result
return M + delta.unsqueeze(0)
def energy_clamp(M: torch.Tensor, max_per_layer_energy: float) -> torch.Tensor:
"""Clamp the L2 energy of each layer to ``max_per_layer_energy``.
``energy`` refers to the L2 norm over the spatial dimensions ``H`` and
``W`` for each ``[B, L]`` slice. If a layer's norm exceeds the supplied
maximum it is scaled down so that its energy equals the threshold. Layers
below the threshold remain unchanged. The function returns a new tensor
and does not modify ``M`` in-place.
"""
if M.ndim != 4:
raise ValueError("M must have shape [B, L, H, W]")
if max_per_layer_energy <= 0:
return M
B, L, H, W = M.shape
flat = M.view(B, L, -1)
norms = torch.linalg.norm(flat, dim=2) # [B, L]
eps = torch.finfo(M.dtype).eps
scales = (max_per_layer_energy / norms.clamp_min(eps)).clamp(max=1.0)
scales = scales.view(B, L, 1, 1)
return M * scales