""" Mamba-2 SSD — OPTIMIZED: intra-chunk parallelism via matrix multiply. The key Mamba-2 insight (State Space Duality): Within each chunk of size T, the SSM can be computed as a MATRIX MULTIPLY: Y_chunk = (L ⊙ (C B^T)) @ (Δ ⊙ X) Where L is a lower-triangular mask with cumulative A products. This replaces the T sequential steps with a single matmul of size T×T. For L=256, T=16, num_chunks=16: - Within chunk: parallel matmul (T×T = 16×16) - Across chunks: 16 sequential state carries (unavoidable, but trivial) Total: 16 sequential state carries + 16 parallel matmuls = FAST. NO in-place ops. Fully autograd safe. Works on CPU and GPU. """ import torch import torch.nn as nn import torch.nn.functional as F import math class Mamba2SSD(nn.Module): """ Mamba-2 SSD with intra-chunk matrix-multiply parallelism. Args: dim: Input/output dimension d_state: SSM state dimension (default 16) d_conv: Conv1d kernel size (default 4) expand: Inner dimension expansion (default 2) chunk_size: Chunk size for scan (default 64 — larger = more parallel) """ def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64): super().__init__() self.dim = dim self.d_state = d_state self.chunk_size = chunk_size self.inner_dim = dim * expand # Input projection: x and gate self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False) # Short causal conv for local context self.conv1d = nn.Conv1d( self.inner_dim, self.inner_dim, kernel_size=d_conv, padding=d_conv - 1, groups=self.inner_dim, bias=True ) # SSM parameter projections self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True) self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False) self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False) # A: fixed decay rates (log-space, negative for stability) A = torch.arange(1, d_state + 1, dtype=torch.float32) self.A_log = nn.Parameter(torch.log(A)) # D: residual skip self.D = nn.Parameter(torch.ones(self.inner_dim)) # Output self.norm = nn.LayerNorm(self.inner_dim) self.out_proj = nn.Linear(self.inner_dim, dim, bias=False) self._init_weights() def _init_weights(self): nn.init.constant_(self.dt_proj.bias, -4.0) # softplus(-4) ≈ 0.018 nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1) nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1) def forward(self, x): """x: [B, L, dim] → [B, L, dim]""" return self._process(x) def _process(self, x): B, L, D = x.shape # Input projection xz = self.in_proj(x) x_inner, z = xz.chunk(2, dim=-1) # Causal conv x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) x_conv = F.silu(x_conv) # SSM params dt = F.softplus(self.dt_proj(x_conv)) # [B, L, inner_dim], positive B_mat = self.B_proj(x_conv) # [B, L, d_state] C_mat = self.C_proj(x_conv) # [B, L, d_state] A = -torch.exp(self.A_log) # [d_state], negative # Chunk-parallel scan y = self._chunk_ssm(x_conv, dt, A, B_mat, C_mat) # Skip + norm + gate y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0) y = self.norm(y) * F.silu(z) return self.out_proj(y) def _chunk_ssm(self, u, dt, A, B, C): """ Chunk-parallel SSM computation. Within each chunk: compute via cumulative decay matrix (parallel). Across chunks: propagate final state (sequential, only num_chunks steps). The intra-chunk computation uses the identity: h_t = sum_{s=0}^{t} (prod_{k=s+1}^{t} dA_k) * dB_s * u_s This is a lower-triangular matrix-vector product, computable in parallel. """ batch, L, d_inner = u.shape d_state = A.shape[0] T = min(self.chunk_size, L) # Pad to multiple of T pad = (T - L % T) % T if pad > 0: u = F.pad(u, (0, 0, 0, pad)) dt = F.pad(dt, (0, 0, 0, pad)) B = F.pad(B, (0, 0, 0, pad)) C = F.pad(C, (0, 0, 0, pad)) L_pad = u.shape[1] n_chunks = L_pad // T # Reshape: [B, n_chunks, T, ...] u_c = u.reshape(batch, n_chunks, T, d_inner) dt_c = dt.reshape(batch, n_chunks, T, d_inner) B_c = B.reshape(batch, n_chunks, T, d_state) C_c = C.reshape(batch, n_chunks, T, d_state) # Mean dt per position for state decay (simplification for scalar-A) dt_mean = dt_c.mean(dim=-1) # [B, n_chunks, T] # Compute log(dA) per position: log_dA = dt_mean * A # A is [d_state], dt_mean is [B, nc, T] log_dA = dt_mean.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0) # log_dA: [B, nc, T, d_state] # Cumulative sum for decay within chunk: cumsum along T dimension # For position t, decay from position s is: exp(sum_{k=s+1}^{t} log_dA_k) log_dA_cumsum = torch.cumsum(log_dA, dim=2) # [B, nc, T, d_state] # Lower-triangular decay matrix: L[t,s] = exp(cumsum[t] - cumsum[s]) # L[t,s,n] = exp(sum_{k=s+1}^{t} log_dA_k_n) for t >= s, else 0 # Shape: [B, nc, T, T, d_state] decay_matrix = log_dA_cumsum.unsqueeze(3) - log_dA_cumsum.unsqueeze(2) # decay_matrix[..., t, s, :] = cumsum[t] - cumsum[s] # Apply causal mask (t >= s only) causal_mask = torch.tril(torch.ones(T, T, device=u.device)) # [T, T] decay_matrix = decay_matrix * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1) decay_matrix = torch.exp(decay_matrix) * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1) # [B, nc, T, T, d_state] # Compute dBu: dt * B * u → state input at each position # dt_c: [B, nc, T, d_inner], B_c: [B, nc, T, d_state], u_c: [B, nc, T, d_inner] # We need [B, nc, T, d_state, d_inner] dBu = dt_c.unsqueeze(-2) * B_c.unsqueeze(-1) * u_c.unsqueeze(-2) # dBu: [B, nc, T, d_state, d_inner] # Intra-chunk SSM via matrix multiply: # h[t] = sum_s decay[t,s] * dBu[s] # h: [B, nc, T, d_state, d_inner] # decay_matrix: [B, nc, T, T, d_state] # dBu: [B, nc, T, d_state, d_inner] # Einsum: h[b,c,t,n,d] = sum_s decay[b,c,t,s,n] * dBu[b,c,s,n,d] h_intra = torch.einsum('bctsn,bcsnd->bctnd', decay_matrix, dBu) # h_intra: [B, nc, T, d_state, d_inner] # Inter-chunk state propagation # Decay of previous chunk's final state into current chunk # Total decay for a full chunk: exp(sum of all T log_dA values) chunk_decay = torch.exp(log_dA_cumsum[:, :, -1, :]) # [B, nc, d_state] # Decay from chunk start to each position within chunk: # position_decay[t] = exp(cumsum[t]) (from position 0) position_decay = torch.exp(log_dA_cumsum) # [B, nc, T, d_state] # Propagate states across chunks h_carry = torch.zeros(batch, d_state, d_inner, device=u.device) h_chunks = [] for c_idx in range(n_chunks): # Decay carry state to each position in this chunk # h_from_prev[t] = position_decay[t] * h_carry h_from_prev = position_decay[:, c_idx, :, :].unsqueeze(-1) * h_carry.unsqueeze(1) # h_from_prev: [B, T, d_state, d_inner] # Total hidden state h_total = h_intra[:, c_idx] + h_from_prev # [B, T, d_state, d_inner] h_chunks.append(h_total) # Update carry: final state of this chunk h_carry = h_total[:, -1, :, :] # [B, d_state, d_inner] # Stack chunks: [B, nc, T, d_state, d_inner] h_all = torch.stack(h_chunks, dim=1) # Output: y[t] = C[t]^T @ h[t] # C_c: [B, nc, T, d_state], h_all: [B, nc, T, d_state, d_inner] y = torch.einsum('bctn,bctnd->bctd', C_c, h_all) # y: [B, nc, T, d_inner] # Reshape back y = y.reshape(batch, L_pad, d_inner) return y[:, :L, :] class Mamba2Block(nn.Module): """ Mamba-2 block with bidirectional scanning for 2D images. Forward + backward raster scan, merged via learned projection. """ def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand) self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand) self.merge = nn.Linear(dim * 2, dim, bias=False) ff_dim = dim * expand self.ff = nn.Sequential( nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim), nn.Dropout(dropout), ) def forward(self, x): """x: [B, C, H, W] or [B, L, C]""" is_2d = x.dim() == 4 if is_2d: B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) residual = x x_norm = self.norm1(x) fwd = self.ssd_fwd(x_norm) bwd = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1]) merged = self.merge(torch.cat([fwd, bwd], dim=-1)) x = residual + merged x = x + self.ff(self.norm2(x)) if is_2d: x = x.transpose(1, 2).reshape(B, C, H, W) return x