WrinkleBrane / src /wrinklebrane /persistence.py
WCNegentropy's picture
๐Ÿ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
"""Persistence operations for the WrinkleBrane memory tensor.
This module provides a small helper implementing a leaky integrator update
with an optional energy clamp. The function mirrors the philosophy of the
rest of the code base: operate purely on tensors, avoid side effects and
refuse silent device/dtype conversions. The implementation is intentionally
minimal so that unit tests can reason about its behaviour without depending on
hidden state.
"""
from __future__ import annotations
import torch
from .write_ops import energy_clamp
__all__ = ["leaky_update"]
def _check_device_dtype(reference: torch.Tensor, other: torch.Tensor) -> None:
"""Raise ``ValueError`` if ``other`` differs in device or dtype."""
if other.device != reference.device:
raise ValueError("all tensors must reside on the same device")
if other.dtype != reference.dtype:
raise ValueError("all tensors must share the same dtype")
def leaky_update(
M: torch.Tensor,
delta_M: torch.Tensor,
lam: float = 0.99,
max_energy: float | None = None,
) -> torch.Tensor:
"""Return ``lam * M + delta_M`` with optional energy clamping.
Parameters
----------
M:
Current membrane tensor with shape ``[B, L, H, W]``.
delta_M:
Tensor of the same shape containing the update to apply.
lam:
Leak factor for the previous state. ``lam`` is multiplied with ``M``
before ``delta_M`` is added.
max_energy:
If provided and positive, the result is clamped so that the L2 energy
of each ``[B, L]`` slice over the spatial dimensions does not exceed
this value. ``None`` disables clamping.
Returns
-------
torch.Tensor
The updated tensor. Inputs are not modified in-place.
"""
if M.ndim != 4 or delta_M.ndim != 4:
raise ValueError("M and delta_M must have shape [B, L, H, W]")
if M.shape != delta_M.shape:
raise ValueError("M and delta_M must have matching shapes")
_check_device_dtype(M, delta_M)
updated = lam * M + delta_M
if max_energy is not None:
updated = energy_clamp(updated, max_energy)
return updated