""" Simplified selective SSM block for image tokens. O(N) complexity, O(1) memory per token. """ import math import torch import torch.nn as nn import torch.nn.functional as F class SimplifiedMambaBlock(nn.Module): """Minimal selective SSM block without cuda-specific selective_scan.""" def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2): super().__init__() self.d_model = d_model self.d_state = d_state self.d_inner = int(expand * d_model) self.dt_rank = math.ceil(d_model / 16) self.d_conv = d_conv self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d( self.d_inner, self.d_inner, kernel_size=d_conv, padding=d_conv - 1, groups=self.d_inner, bias=True, ) self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) self.norm = nn.LayerNorm(d_model) def _selective_scan(self, x, dt, A, B, C, D): Bb, L, d_in = x.shape d_state = A.shape[1] dtA = dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0) A_bar = torch.exp(dtA) dtB = dt.unsqueeze(-1) * B.unsqueeze(2) h = torch.zeros(Bb, d_in, d_state, device=x.device, dtype=x.dtype) ys = [] for t in range(L): h = A_bar[:, t] * h + dtB[:, t] * x[:, t].unsqueeze(-1) y = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) ys.append(y) y = torch.stack(ys, dim=1) y = y + D.unsqueeze(0).unsqueeze(0) * x return y def forward(self, x: torch.Tensor): x_norm = self.norm(x) xz = self.in_proj(x_norm) x_gate, z_gate = xz.chunk(2, dim=-1) x_conv = self.conv1d(x_gate.transpose(1, 2))[:, :, :x_gate.shape[1]].transpose(1, 2) x_conv = F.silu(x_conv) xbc = self.x_proj(x_conv) dt_un, B_un, C_un = torch.split(xbc, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = F.softplus(self.dt_proj(dt_un)) A = -torch.exp(self.A_log.float()) B = B_un C = C_un y = self._selective_scan(x_conv, dt, A, B, C, self.D) y = y * F.silu(z_gate) out = self.out_proj(y) return out + x