File size: 2,181 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Persistence operations for the WrinkleBrane memory tensor.

This module provides a small helper implementing a leaky integrator update
with an optional energy clamp.  The function mirrors the philosophy of the
rest of the code base: operate purely on tensors, avoid side effects and
refuse silent device/dtype conversions.  The implementation is intentionally
minimal so that unit tests can reason about its behaviour without depending on
hidden state.
"""

from __future__ import annotations

import torch

from .write_ops import energy_clamp

__all__ = ["leaky_update"]


def _check_device_dtype(reference: torch.Tensor, other: torch.Tensor) -> None:
    """Raise ``ValueError`` if ``other`` differs in device or dtype."""

    if other.device != reference.device:
        raise ValueError("all tensors must reside on the same device")
    if other.dtype != reference.dtype:
        raise ValueError("all tensors must share the same dtype")


def leaky_update(
    M: torch.Tensor,
    delta_M: torch.Tensor,
    lam: float = 0.99,
    max_energy: float | None = None,
) -> torch.Tensor:
    """Return ``lam * M + delta_M`` with optional energy clamping.

    Parameters
    ----------
    M:
        Current membrane tensor with shape ``[B, L, H, W]``.
    delta_M:
        Tensor of the same shape containing the update to apply.
    lam:
        Leak factor for the previous state.  ``lam`` is multiplied with ``M``
        before ``delta_M`` is added.
    max_energy:
        If provided and positive, the result is clamped so that the L2 energy
        of each ``[B, L]`` slice over the spatial dimensions does not exceed
        this value.  ``None`` disables clamping.

    Returns
    -------
    torch.Tensor
        The updated tensor.  Inputs are not modified in-place.
    """

    if M.ndim != 4 or delta_M.ndim != 4:
        raise ValueError("M and delta_M must have shape [B, L, H, W]")
    if M.shape != delta_M.shape:
        raise ValueError("M and delta_M must have matching shapes")

    _check_device_dtype(M, delta_M)

    updated = lam * M + delta_M
    if max_energy is not None:
        updated = energy_clamp(updated, max_energy)
    return updated