File size: 5,341 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""tilelli.core.ternary_ssm — the State pathway of Tilelli.

From ARCHITECTURE.md:
    State path: small Mamba-style SSM. Long-range topic carry. O(n).

Day-0 scope: a **diagonal** state-space model — one independent scalar
recurrence per channel — which is the S4D / HiPPO-diag skeleton that
Mamba is built on. We skip Mamba's data-dependent selection for now;
that's a refinement on top of a working diagonal SSM, not the core idea.

The per-channel recurrence:

    h_t[c] = a[c] · h_{t-1}[c] + b[c] · x_t[c]
    y_t[c] = c[c] · h_t[c]

Three learnable per-channel scalars: decay `a`, input gain `b`, output
scale `c`. Stability demands |a| < 1; we enforce that with `tanh(a_raw)`.

Training uses the **convolutional mode** — because the recurrence is
linear and diagonal, y_t unrolls to a 1-D convolution with kernel

    K[c, i] = c[c] · a[c]^i · b[c]      for i = 0 … L-1

so a single depthwise `F.conv1d` gives us the whole output sequence in
one shot. This is the S4 trick. Inference uses the recurrent mode — a
simple per-step state update, O(L · C) sequential — which is what
Tilelli will actually run on CPU one token at a time.

A note on ternary weights here:
  The per-channel scalars are only O(C) parameters, vs O(C²) for the
  Linear layers. Ternarizing them saves almost nothing and makes the
  decay dynamics much harder to learn (decay must be in (0, 1), which
  ternary {-α, 0, +α} can't cleanly express). We keep these few
  parameters in FP32 and are honest about it: the SSM is the one place
  in Tilelli where a little floating point lives. The big consumers —
  Linear and Conv — remain pure ternary.
"""
from __future__ import annotations

import math

import torch
from torch import Tensor, nn
from torch.nn import functional as F


class DiagonalSSM(nn.Module):
    """Per-channel diagonal state-space model. Input/output shape (B, L, C).

    Parameters are three per-channel vectors:
      - ``a_raw``  : pre-tanh decay;  effective a = tanh(a_raw) ∈ (-1, 1)
      - ``b``      : input gain
      - ``c_out``  : output scale

    The state dimension equals the channel count (one scalar state per
    channel). For a wider state per channel, stack multiple DiagonalSSMs
    or move to a non-diagonal variant.
    """

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.channels = channels
        # Init decay near 0.9 so early training has long-ish memory.
        # tanh(1.5) ≈ 0.905.
        self.a_raw = nn.Parameter(torch.full((channels,), 1.5))
        self.b = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels)))
        self.c_out = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels)))

    # ------------------------------------------------------------------ #
    # Training forward — convolutional mode
    # ------------------------------------------------------------------ #

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() != 3:
            raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
        B, L, C = x.shape
        if C != self.channels:
            raise ValueError(f"channel mismatch: module has {self.channels}, input has {C}")

        a = torch.tanh(self.a_raw)              # (C,), in (-1, 1)
        b = self.b                              # (C,)
        c_out = self.c_out                      # (C,)

        # Build the per-channel causal kernel. We want
        #     y_t = sum_{d=0}^{L-1} (c_out * a^d * b) * x_{t-d}
        # torch.conv1d is cross-correlation: with left-pad L-1, the
        # LAST kernel element is delay 0, so the powers must run from
        # (L-1) down to 0 across the kernel's spatial axis.
        i = torch.arange(L - 1, -1, -1, device=x.device, dtype=x.dtype)  # (L,)
        powers = a.unsqueeze(-1) ** i.unsqueeze(0)                       # (C, L)
        kernel = (c_out * b).unsqueeze(-1) * powers                      # (C, L)
        kernel = kernel.unsqueeze(1)                                     # (C, 1, L)

        # Depthwise causal conv: left-pad L-1, groups=C
        x_ = x.transpose(1, 2)                                         # (B, C, L)
        x_ = F.pad(x_, (L - 1, 0))
        y = F.conv1d(x_, kernel, groups=C)
        return y.transpose(1, 2)                                       # (B, L, C)

    # ------------------------------------------------------------------ #
    # Inference — recurrent mode, O(L·C) sequential
    # ------------------------------------------------------------------ #

    @torch.no_grad()
    def infer(self, x: Tensor) -> Tensor:
        """Step-by-step recurrence. Agrees with `forward` numerically.

        This is the path Tilelli runs at CPU inference time — one
        token in, one token out, state of shape (B, C) carried across
        steps. No L² kernel to build.
        """
        if x.dim() != 3:
            raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
        B, L, C = x.shape
        a = torch.tanh(self.a_raw)
        b = self.b
        c_out = self.c_out
        h = torch.zeros(B, C, dtype=x.dtype, device=x.device)
        ys = []
        for t in range(L):
            h = a * h + b * x[:, t]
            ys.append(c_out * h)
        return torch.stack(ys, dim=1)                                  # (B, L, C)