| from __future__ import annotations |
|
|
| """Write operations for the WrinkleBrane memory tensor. |
| |
| This module implements two small helper functions used by the tests and |
| example scripts. The functions are intentionally written in a fully |
| vectorised style so that they are easy to reason about and do not hide any |
| state. Both functions expect all tensors to share the same device and dtype |
| (except for ``keys`` which must be ``torch.long``) – any mismatch results in a |
| ``ValueError`` rather than silently converting the inputs. |
| """ |
|
|
| from typing import Iterable |
|
|
| import torch |
|
|
| __all__ = ["store_pairs", "energy_clamp"] |
|
|
|
|
| def _check_device_dtype(reference: torch.Tensor, tensors: Iterable[torch.Tensor]) -> None: |
| """Raise ``ValueError`` if any tensor differs in device or dtype.""" |
|
|
| for t in tensors: |
| if t.device != reference.device: |
| raise ValueError("all tensors must reside on the same device") |
| if t.dtype != reference.dtype: |
| raise ValueError("all tensors must share the same dtype") |
|
|
|
|
| |
| |
|
|
| def store_pairs( |
| M: torch.Tensor, |
| C: torch.Tensor, |
| keys: torch.Tensor, |
| values: torch.Tensor, |
| alphas: torch.Tensor, |
| ) -> torch.Tensor: |
| """Return ``M`` with key–value pairs written to it. |
| |
| Parameters |
| ---------- |
| M: |
| Current membranes with shape ``[B, L, H, W]``. |
| C: |
| Codebook tensor of shape ``[L, K]``. Column ``k`` contains the code |
| used when storing a pair whose key is ``k``. |
| keys: |
| Long tensor of shape ``[T]`` with key indices in ``[0, K)``. |
| values: |
| Real-valued tensor of shape ``[T, H, W]`` containing the maps to be |
| written. |
| alphas: |
| Tensor of shape ``[T]`` specifying a gain for each pair. |
| |
| Returns |
| ------- |
| torch.Tensor |
| Updated membrane tensor. The update is performed without mutating |
| ``M`` – a new tensor containing the result is returned. |
| """ |
|
|
| if M.ndim != 4: |
| raise ValueError("M must have shape [B, L, H, W]") |
| B, L, H, W = M.shape |
|
|
| if C.shape[0] != L: |
| raise ValueError("codebook C must have shape [L, K]") |
| K = C.shape[1] |
|
|
| if keys.ndim != 1: |
| raise ValueError("keys must be one-dimensional") |
| T = keys.shape[0] |
|
|
| if values.shape != (T, H, W): |
| raise ValueError("values must have shape [T, H, W]") |
| if alphas.shape != (T,): |
| raise ValueError("alphas must have shape [T]") |
|
|
| if keys.dtype != torch.long: |
| raise ValueError("keys must be of dtype torch.long") |
|
|
| _check_device_dtype(M, (C, values, alphas)) |
|
|
| if torch.any((keys < 0) | (keys >= K)): |
| raise ValueError("keys contain indices outside the valid range") |
|
|
| |
| codes = C[:, keys] * alphas.unsqueeze(0) |
|
|
| |
| |
| delta = torch.einsum("lt,thw->lhw", codes, values) |
|
|
| |
| return M + delta.unsqueeze(0) |
|
|
|
|
| def energy_clamp(M: torch.Tensor, max_per_layer_energy: float) -> torch.Tensor: |
| """Clamp the L2 energy of each layer to ``max_per_layer_energy``. |
| |
| ``energy`` refers to the L2 norm over the spatial dimensions ``H`` and |
| ``W`` for each ``[B, L]`` slice. If a layer's norm exceeds the supplied |
| maximum it is scaled down so that its energy equals the threshold. Layers |
| below the threshold remain unchanged. The function returns a new tensor |
| and does not modify ``M`` in-place. |
| """ |
|
|
| if M.ndim != 4: |
| raise ValueError("M must have shape [B, L, H, W]") |
| if max_per_layer_energy <= 0: |
| return M |
|
|
| B, L, H, W = M.shape |
| flat = M.view(B, L, -1) |
| norms = torch.linalg.norm(flat, dim=2) |
|
|
| eps = torch.finfo(M.dtype).eps |
| scales = (max_per_layer_energy / norms.clamp_min(eps)).clamp(max=1.0) |
| scales = scales.view(B, L, 1, 1) |
|
|
| return M * scales |
|
|
|
|