| """ |
| Mamba-2 block based on the Structured State Space Duality (SSD) formulation. |
| |
| Reference: "Transformers are SSMs: Generalized Models and Efficient Algorithms |
| Through Structured State Space Duality" (Dao & Gu, 2024). |
| |
| This implements a pure-PyTorch sequential scan for correctness and generality. |
| A chunked SSD kernel can be swapped in later for speed. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .layers import RMSNorm |
|
|
|
|
| |
| |
| |
|
|
| def selective_scan( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A_log: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| D: torch.Tensor, |
| n_groups: int, |
| ) -> torch.Tensor: |
| """Run the SSM recurrence sequentially over the time axis. |
| |
| Args: |
| x: (B, L, n_heads, head_dim) — input after conv + activation. |
| dt: (B, L, n_heads) — discretisation time-steps (after softplus). |
| A_log: (n_heads,) — log(-A), learnable diagonal decay. |
| B: (B, L, n_groups, d_state) — input-to-state projection per step. |
| C: (B, L, n_groups, d_state) — state-to-output projection per step. |
| D: (n_heads,) — skip/residual connection per head. |
| n_groups: int — number of B/C groups (heads per group share B/C). |
| |
| Returns: |
| y: (B, L, n_heads, head_dim) — SSM output. |
| """ |
| batch, seq_len, n_heads, head_dim = x.shape |
| d_state = B.shape[-1] |
| heads_per_group = n_heads // n_groups |
|
|
| |
| neg_A = A_log.exp() |
| dA = torch.exp(-neg_A.unsqueeze(0).unsqueeze(0) * dt) |
|
|
| |
| |
| dt_x = dt.unsqueeze(-1) * x |
|
|
| |
| y = torch.zeros_like(x) |
|
|
| |
| h = torch.zeros( |
| batch, n_heads, head_dim, d_state, |
| dtype=torch.float32, device=x.device, |
| ) |
|
|
| |
| |
| |
| group_idx = torch.arange(n_heads, device=x.device) // heads_per_group |
|
|
| for t in range(seq_len): |
| |
| |
| dA_t = dA[:, t, :].float().unsqueeze(-1).unsqueeze(-1) |
| h = h * dA_t |
|
|
| |
| |
| B_t = B[:, t, :, :][:, group_idx, :] |
| |
| dt_x_t = dt_x[:, t, :, :].float() |
| |
| h = h + dt_x_t.unsqueeze(-1) * B_t.float().unsqueeze(-2) |
|
|
| |
| |
| C_t = C[:, t, :, :][:, group_idx, :] |
| |
| y_t = torch.einsum("bnhd,bnd->bnh", h, C_t.float()) |
| y[:, t, :, :] = y_t.to(x.dtype) |
|
|
| |
| y = y + D.view(1, 1, n_heads, 1) * x |
|
|
| return y |
|
|
|
|
| |
| |
| |
|
|
| class Mamba2Block(nn.Module): |
| """Mamba-2 block with pre-norm residual connection. |
| |
| Implements: |
| 1. RMSNorm (pre-norm) |
| 2. Input projection -> (z, x, B, C, dt) |
| 3. Causal depth-wise Conv1d on x |
| 4. SiLU activation on x |
| 5. Selective scan (SSM recurrence) |
| 6. Gated output: y * SiLU(z) |
| 7. Output projection + residual |
| |
| Args: |
| d_model: Model hidden dimension. |
| d_state: SSM state dimension N (default 128). |
| head_dim: Per-head dimension for SSD (default 64). |
| expand: Expansion factor for inner dimension (default 2). |
| conv_kernel: Causal 1D convolution kernel size (default 4). |
| n_groups: Number of groups for B/C projections (default 1). |
| chunk_size: Chunk size for SSD algorithm — reserved for future use (default 256). |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| d_state: int = 128, |
| head_dim: int = 64, |
| expand: int = 2, |
| conv_kernel: int = 4, |
| n_groups: int = 1, |
| chunk_size: int = 256, |
| ) -> None: |
| super().__init__() |
|
|
| self.d_model = d_model |
| self.d_state = d_state |
| self.head_dim = head_dim |
| self.expand = expand |
| self.n_groups = n_groups |
| self.chunk_size = chunk_size |
|
|
| |
| self.d_inner = expand * d_model |
| self.n_heads = self.d_inner // head_dim |
| assert self.d_inner % head_dim == 0, ( |
| f"d_inner ({self.d_inner}) must be divisible by head_dim ({head_dim})" |
| ) |
| assert self.n_heads % n_groups == 0, ( |
| f"n_heads ({self.n_heads}) must be divisible by n_groups ({n_groups})" |
| ) |
|
|
| |
| self.norm = RMSNorm(d_model) |
|
|
| |
| self.d_proj = ( |
| self.d_inner |
| + self.d_inner |
| + n_groups * d_state |
| + n_groups * d_state |
| + self.n_heads |
| ) |
| self.in_proj = nn.Linear(d_model, self.d_proj, bias=False) |
|
|
| |
| self.conv1d = nn.Conv1d( |
| in_channels=self.d_inner, |
| out_channels=self.d_inner, |
| kernel_size=conv_kernel, |
| groups=self.d_inner, |
| padding=conv_kernel - 1, |
| ) |
|
|
| |
| |
| A_init = torch.log(torch.rand(self.n_heads) * 15.0 + 1.0) |
| self.A_log = nn.Parameter(A_init) |
|
|
| |
| self.D = nn.Parameter(torch.ones(self.n_heads)) |
|
|
| |
| dt_bias_init = torch.log(torch.rand(self.n_heads) * 0.099 + 0.001) |
| self.dt_bias = nn.Parameter(dt_bias_init) |
|
|
| |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) |
|
|
| |
| |
| |
|
|
| def _split_projection( |
| self, proj: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Split the fused input projection into (z, x, B, C, dt). |
| |
| Args: |
| proj: (B, L, d_proj) |
| |
| Returns: |
| z: (B, L, d_inner) |
| x: (B, L, d_inner) |
| B: (B, L, n_groups, d_state) |
| C: (B, L, n_groups, d_state) |
| dt: (B, L, n_heads) |
| """ |
| batch, seq_len, _ = proj.shape |
| i = 0 |
|
|
| z = proj[:, :, i : i + self.d_inner] |
| i += self.d_inner |
|
|
| x = proj[:, :, i : i + self.d_inner] |
| i += self.d_inner |
|
|
| bc_dim = self.n_groups * self.d_state |
| B = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state) |
| i += bc_dim |
|
|
| C = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state) |
| i += bc_dim |
|
|
| dt = proj[:, :, i : i + self.n_heads] |
| return z, x, B, C, dt |
|
|
| |
| |
| |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, L, d_model) — input hidden states. |
| |
| Returns: |
| (B, L, d_model) — output with residual connection applied. |
| """ |
| residual = x |
| x = self.norm(x) |
|
|
| |
| proj = self.in_proj(x) |
| z, x_ssm, B, C, dt_raw = self._split_projection(proj) |
|
|
| |
| |
| x_conv = x_ssm.transpose(1, 2) |
| x_conv = self.conv1d(x_conv) |
| |
| x_conv = x_conv[:, :, :x_ssm.shape[1]] |
| x_conv = x_conv.transpose(1, 2) |
| x_conv = F.silu(x_conv) |
|
|
| |
| dt = F.softplus(dt_raw + self.dt_bias) |
|
|
| |
| batch, seq_len, _ = x_conv.shape |
| x_heads = x_conv.reshape(batch, seq_len, self.n_heads, self.head_dim) |
|
|
| |
| y = selective_scan( |
| x_heads, dt, self.A_log, B, C, self.D, |
| n_groups=self.n_groups, |
| ) |
|
|
| |
| y = y.reshape(batch, seq_len, self.d_inner) |
|
|
| |
| y = y * F.silu(z) |
|
|
| |
| return residual + self.out_proj(y) |
|
|