| """Persistence operations for the WrinkleBrane memory tensor. | |
| This module provides a small helper implementing a leaky integrator update | |
| with an optional energy clamp. The function mirrors the philosophy of the | |
| rest of the code base: operate purely on tensors, avoid side effects and | |
| refuse silent device/dtype conversions. The implementation is intentionally | |
| minimal so that unit tests can reason about its behaviour without depending on | |
| hidden state. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| from .write_ops import energy_clamp | |
| __all__ = ["leaky_update"] | |
| def _check_device_dtype(reference: torch.Tensor, other: torch.Tensor) -> None: | |
| """Raise ``ValueError`` if ``other`` differs in device or dtype.""" | |
| if other.device != reference.device: | |
| raise ValueError("all tensors must reside on the same device") | |
| if other.dtype != reference.dtype: | |
| raise ValueError("all tensors must share the same dtype") | |
| def leaky_update( | |
| M: torch.Tensor, | |
| delta_M: torch.Tensor, | |
| lam: float = 0.99, | |
| max_energy: float | None = None, | |
| ) -> torch.Tensor: | |
| """Return ``lam * M + delta_M`` with optional energy clamping. | |
| Parameters | |
| ---------- | |
| M: | |
| Current membrane tensor with shape ``[B, L, H, W]``. | |
| delta_M: | |
| Tensor of the same shape containing the update to apply. | |
| lam: | |
| Leak factor for the previous state. ``lam`` is multiplied with ``M`` | |
| before ``delta_M`` is added. | |
| max_energy: | |
| If provided and positive, the result is clamped so that the L2 energy | |
| of each ``[B, L]`` slice over the spatial dimensions does not exceed | |
| this value. ``None`` disables clamping. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The updated tensor. Inputs are not modified in-place. | |
| """ | |
| if M.ndim != 4 or delta_M.ndim != 4: | |
| raise ValueError("M and delta_M must have shape [B, L, H, W]") | |
| if M.shape != delta_M.shape: | |
| raise ValueError("M and delta_M must have matching shapes") | |
| _check_device_dtype(M, delta_M) | |
| updated = lam * M + delta_M | |
| if max_energy is not None: | |
| updated = energy_clamp(updated, max_energy) | |
| return updated | |