"""Persistence operations for the WrinkleBrane memory tensor. This module provides a small helper implementing a leaky integrator update with an optional energy clamp. The function mirrors the philosophy of the rest of the code base: operate purely on tensors, avoid side effects and refuse silent device/dtype conversions. The implementation is intentionally minimal so that unit tests can reason about its behaviour without depending on hidden state. """ from __future__ import annotations import torch from .write_ops import energy_clamp __all__ = ["leaky_update"] def _check_device_dtype(reference: torch.Tensor, other: torch.Tensor) -> None: """Raise ``ValueError`` if ``other`` differs in device or dtype.""" if other.device != reference.device: raise ValueError("all tensors must reside on the same device") if other.dtype != reference.dtype: raise ValueError("all tensors must share the same dtype") def leaky_update( M: torch.Tensor, delta_M: torch.Tensor, lam: float = 0.99, max_energy: float | None = None, ) -> torch.Tensor: """Return ``lam * M + delta_M`` with optional energy clamping. Parameters ---------- M: Current membrane tensor with shape ``[B, L, H, W]``. delta_M: Tensor of the same shape containing the update to apply. lam: Leak factor for the previous state. ``lam`` is multiplied with ``M`` before ``delta_M`` is added. max_energy: If provided and positive, the result is clamped so that the L2 energy of each ``[B, L]`` slice over the spatial dimensions does not exceed this value. ``None`` disables clamping. Returns ------- torch.Tensor The updated tensor. Inputs are not modified in-place. """ if M.ndim != 4 or delta_M.ndim != 4: raise ValueError("M and delta_M must have shape [B, L, H, W]") if M.shape != delta_M.shape: raise ValueError("M and delta_M must have matching shapes") _check_device_dtype(M, delta_M) updated = lam * M + delta_M if max_energy is not None: updated = energy_clamp(updated, max_energy) return updated