from __future__ import annotations """Write operations for the WrinkleBrane memory tensor. This module implements two small helper functions used by the tests and example scripts. The functions are intentionally written in a fully vectorised style so that they are easy to reason about and do not hide any state. Both functions expect all tensors to share the same device and dtype (except for ``keys`` which must be ``torch.long``) – any mismatch results in a ``ValueError`` rather than silently converting the inputs. """ from typing import Iterable import torch __all__ = ["store_pairs", "energy_clamp"] def _check_device_dtype(reference: torch.Tensor, tensors: Iterable[torch.Tensor]) -> None: """Raise ``ValueError`` if any tensor differs in device or dtype.""" for t in tensors: if t.device != reference.device: raise ValueError("all tensors must reside on the same device") if t.dtype != reference.dtype: raise ValueError("all tensors must share the same dtype") # --------------------------------------------------------------------------- # write operations def store_pairs( M: torch.Tensor, C: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, alphas: torch.Tensor, ) -> torch.Tensor: """Return ``M`` with key–value pairs written to it. Parameters ---------- M: Current membranes with shape ``[B, L, H, W]``. C: Codebook tensor of shape ``[L, K]``. Column ``k`` contains the code used when storing a pair whose key is ``k``. keys: Long tensor of shape ``[T]`` with key indices in ``[0, K)``. values: Real-valued tensor of shape ``[T, H, W]`` containing the maps to be written. alphas: Tensor of shape ``[T]`` specifying a gain for each pair. Returns ------- torch.Tensor Updated membrane tensor. The update is performed without mutating ``M`` – a new tensor containing the result is returned. """ if M.ndim != 4: raise ValueError("M must have shape [B, L, H, W]") B, L, H, W = M.shape if C.shape[0] != L: raise ValueError("codebook C must have shape [L, K]") K = C.shape[1] if keys.ndim != 1: raise ValueError("keys must be one-dimensional") T = keys.shape[0] if values.shape != (T, H, W): raise ValueError("values must have shape [T, H, W]") if alphas.shape != (T,): raise ValueError("alphas must have shape [T]") if keys.dtype != torch.long: raise ValueError("keys must be of dtype torch.long") _check_device_dtype(M, (C, values, alphas)) if torch.any((keys < 0) | (keys >= K)): raise ValueError("keys contain indices outside the valid range") # Select the relevant columns from the codebook and scale by alphas codes = C[:, keys] * alphas.unsqueeze(0) # [L, T] # Compute the sum over outer products in a vectorised fashion: # ΔM = Σ_t codes[:, t] ⊗ values[t] delta = torch.einsum("lt,thw->lhw", codes, values) # Broadcast the update across the batch dimension and return the result return M + delta.unsqueeze(0) def energy_clamp(M: torch.Tensor, max_per_layer_energy: float) -> torch.Tensor: """Clamp the L2 energy of each layer to ``max_per_layer_energy``. ``energy`` refers to the L2 norm over the spatial dimensions ``H`` and ``W`` for each ``[B, L]`` slice. If a layer's norm exceeds the supplied maximum it is scaled down so that its energy equals the threshold. Layers below the threshold remain unchanged. The function returns a new tensor and does not modify ``M`` in-place. """ if M.ndim != 4: raise ValueError("M must have shape [B, L, H, W]") if max_per_layer_energy <= 0: return M B, L, H, W = M.shape flat = M.view(B, L, -1) norms = torch.linalg.norm(flat, dim=2) # [B, L] eps = torch.finfo(M.dtype).eps scales = (max_per_layer_energy / norms.clamp_min(eps)).clamp(max=1.0) scales = scales.view(B, L, 1, 1) return M * scales