File size: 10,038 Bytes
c0f89d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""
Mamba-2 block based on the Structured State Space Duality (SSD) formulation.

Reference: "Transformers are SSMs: Generalized Models and Efficient Algorithms
Through Structured State Space Duality" (Dao & Gu, 2024).

This implements a pure-PyTorch sequential scan for correctness and generality.
A chunked SSD kernel can be swapped in later for speed.
"""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import RMSNorm


# ---------------------------------------------------------------------------
# Selective Scan (sequential, numerically stable in float32)
# ---------------------------------------------------------------------------

def selective_scan(
    x: torch.Tensor,
    dt: torch.Tensor,
    A_log: torch.Tensor,
    B: torch.Tensor,
    C: torch.Tensor,
    D: torch.Tensor,
    n_groups: int,
) -> torch.Tensor:
    """Run the SSM recurrence sequentially over the time axis.

    Args:
        x:     (B, L, n_heads, head_dim) β€” input after conv + activation.
        dt:    (B, L, n_heads)           β€” discretisation time-steps (after softplus).
        A_log: (n_heads,)                β€” log(-A), learnable diagonal decay.
        B:     (B, L, n_groups, d_state) β€” input-to-state projection per step.
        C:     (B, L, n_groups, d_state) β€” state-to-output projection per step.
        D:     (n_heads,)                β€” skip/residual connection per head.
        n_groups: int                    β€” number of B/C groups (heads per group share B/C).

    Returns:
        y: (B, L, n_heads, head_dim) β€” SSM output.
    """
    batch, seq_len, n_heads, head_dim = x.shape
    d_state = B.shape[-1]
    heads_per_group = n_heads // n_groups

    # Compute decay: dA = exp(-exp(A_log) * dt)  β€” shape (B, L, n_heads)
    neg_A = A_log.exp()                           # (n_heads,)
    dA = torch.exp(-neg_A.unsqueeze(0).unsqueeze(0) * dt)  # (B, L, n_heads)

    # Scale input by dt: dBx will be accumulated into state
    # dt: (B, L, n_heads) -> (B, L, n_heads, 1)
    dt_x = dt.unsqueeze(-1) * x  # (B, L, n_heads, head_dim)

    # Allocate output
    y = torch.zeros_like(x)

    # State: (B, n_heads, head_dim, d_state) β€” accumulated in float32
    h = torch.zeros(
        batch, n_heads, head_dim, d_state,
        dtype=torch.float32, device=x.device,
    )

    # Expand B/C from groups to heads: (B, L, n_groups, d_state) -> indexing
    # For efficiency we index into the group dimension during the loop.
    # group_idx[head] -> which group this head belongs to
    group_idx = torch.arange(n_heads, device=x.device) // heads_per_group  # (n_heads,)

    for t in range(seq_len):
        # --- Decay state ---
        # dA_t: (B, n_heads) -> (B, n_heads, 1, 1)
        dA_t = dA[:, t, :].float().unsqueeze(-1).unsqueeze(-1)
        h = h * dA_t  # (B, n_heads, head_dim, d_state)

        # --- Input contribution ---
        # B_t: (B, n_groups, d_state) -> (B, n_heads, d_state) via group expansion
        B_t = B[:, t, :, :][:, group_idx, :]  # (B, n_heads, d_state)
        # dt_x_t: (B, n_heads, head_dim)
        dt_x_t = dt_x[:, t, :, :].float()     # (B, n_heads, head_dim)
        # Outer product: (B, n_heads, head_dim, 1) * (B, n_heads, 1, d_state)
        h = h + dt_x_t.unsqueeze(-1) * B_t.float().unsqueeze(-2)

        # --- Output ---
        # C_t: (B, n_groups, d_state) -> (B, n_heads, d_state)
        C_t = C[:, t, :, :][:, group_idx, :]  # (B, n_heads, d_state)
        # y_t = sum_over_d_state( h * C_t ) -> (B, n_heads, head_dim)
        y_t = torch.einsum("bnhd,bnd->bnh", h, C_t.float())
        y[:, t, :, :] = y_t.to(x.dtype)

    # Skip connection: D * x
    y = y + D.view(1, 1, n_heads, 1) * x

    return y


# ---------------------------------------------------------------------------
# Mamba-2 Block
# ---------------------------------------------------------------------------

class Mamba2Block(nn.Module):
    """Mamba-2 block with pre-norm residual connection.

    Implements:
        1. RMSNorm (pre-norm)
        2. Input projection -> (z, x, B, C, dt)
        3. Causal depth-wise Conv1d on x
        4. SiLU activation on x
        5. Selective scan (SSM recurrence)
        6. Gated output: y * SiLU(z)
        7. Output projection + residual

    Args:
        d_model:     Model hidden dimension.
        d_state:     SSM state dimension N (default 128).
        head_dim:    Per-head dimension for SSD (default 64).
        expand:      Expansion factor for inner dimension (default 2).
        conv_kernel: Causal 1D convolution kernel size (default 4).
        n_groups:    Number of groups for B/C projections (default 1).
        chunk_size:  Chunk size for SSD algorithm β€” reserved for future use (default 256).
    """

    def __init__(
        self,
        d_model: int,
        d_state: int = 128,
        head_dim: int = 64,
        expand: int = 2,
        conv_kernel: int = 4,
        n_groups: int = 1,
        chunk_size: int = 256,
    ) -> None:
        super().__init__()

        self.d_model = d_model
        self.d_state = d_state
        self.head_dim = head_dim
        self.expand = expand
        self.n_groups = n_groups
        self.chunk_size = chunk_size

        # Derived dimensions
        self.d_inner = expand * d_model
        self.n_heads = self.d_inner // head_dim
        assert self.d_inner % head_dim == 0, (
            f"d_inner ({self.d_inner}) must be divisible by head_dim ({head_dim})"
        )
        assert self.n_heads % n_groups == 0, (
            f"n_heads ({self.n_heads}) must be divisible by n_groups ({n_groups})"
        )

        # Pre-norm
        self.norm = RMSNorm(d_model)

        # Input projection: d_model -> z + x + B + C + dt
        self.d_proj = (
            self.d_inner          # z (gate)
            + self.d_inner        # x (input to conv + SSM)
            + n_groups * d_state  # B
            + n_groups * d_state  # C
            + self.n_heads        # dt (one per head)
        )
        self.in_proj = nn.Linear(d_model, self.d_proj, bias=False)

        # Causal depth-wise conv1d over x
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=conv_kernel,
            groups=self.d_inner,
            padding=conv_kernel - 1,  # causal: trim trailing values
        )

        # SSM parameters
        # A_log: log(-A) where A is the diagonal decay β€” init from log(uniform(1, 16))
        A_init = torch.log(torch.rand(self.n_heads) * 15.0 + 1.0)  # log(U(1,16))
        self.A_log = nn.Parameter(A_init)

        # D: skip connection per head β€” init to ones
        self.D = nn.Parameter(torch.ones(self.n_heads))

        # dt_bias: added before softplus β€” init from log(uniform(0.001, 0.1))
        dt_bias_init = torch.log(torch.rand(self.n_heads) * 0.099 + 0.001)
        self.dt_bias = nn.Parameter(dt_bias_init)

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _split_projection(
        self, proj: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Split the fused input projection into (z, x, B, C, dt).

        Args:
            proj: (B, L, d_proj)

        Returns:
            z:  (B, L, d_inner)
            x:  (B, L, d_inner)
            B:  (B, L, n_groups, d_state)
            C:  (B, L, n_groups, d_state)
            dt: (B, L, n_heads)
        """
        batch, seq_len, _ = proj.shape
        i = 0

        z = proj[:, :, i : i + self.d_inner]
        i += self.d_inner

        x = proj[:, :, i : i + self.d_inner]
        i += self.d_inner

        bc_dim = self.n_groups * self.d_state
        B = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
        i += bc_dim

        C = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
        i += bc_dim

        dt = proj[:, :, i : i + self.n_heads]
        return z, x, B, C, dt

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, L, d_model) β€” input hidden states.

        Returns:
            (B, L, d_model) β€” output with residual connection applied.
        """
        residual = x
        x = self.norm(x)

        # --- Input projection ---
        proj = self.in_proj(x)                         # (B, L, d_proj)
        z, x_ssm, B, C, dt_raw = self._split_projection(proj)

        # --- Causal conv1d on x ---
        # Conv1d expects (B, C, L)
        x_conv = x_ssm.transpose(1, 2)                # (B, d_inner, L)
        x_conv = self.conv1d(x_conv)
        # Trim to causal: remove the (kernel-1) trailing padding
        x_conv = x_conv[:, :, :x_ssm.shape[1]]        # (B, d_inner, L)
        x_conv = x_conv.transpose(1, 2)               # (B, L, d_inner)
        x_conv = F.silu(x_conv)

        # --- Discretise dt ---
        dt = F.softplus(dt_raw + self.dt_bias)         # (B, L, n_heads)

        # --- Reshape x for multi-head scan ---
        batch, seq_len, _ = x_conv.shape
        x_heads = x_conv.reshape(batch, seq_len, self.n_heads, self.head_dim)

        # --- Selective scan (SSM recurrence) ---
        y = selective_scan(
            x_heads, dt, self.A_log, B, C, self.D,
            n_groups=self.n_groups,
        )  # (B, L, n_heads, head_dim)

        # --- Flatten heads back ---
        y = y.reshape(batch, seq_len, self.d_inner)    # (B, L, d_inner)

        # --- Gated output ---
        y = y * F.silu(z)

        # --- Output projection + residual ---
        return residual + self.out_proj(y)