| """ |
| Simplified selective SSM block for image tokens. |
| O(N) complexity, O(1) memory per token. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SimplifiedMambaBlock(nn.Module): |
| """Minimal selective SSM block without cuda-specific selective_scan.""" |
| def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2): |
| super().__init__() |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_inner = int(expand * d_model) |
| self.dt_rank = math.ceil(d_model / 16) |
| self.d_conv = d_conv |
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) |
| self.conv1d = nn.Conv1d( |
| self.d_inner, |
| self.d_inner, |
| kernel_size=d_conv, |
| padding=d_conv - 1, |
| groups=self.d_inner, |
| bias=True, |
| ) |
| self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False) |
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) |
| A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) |
| self.A_log = nn.Parameter(torch.log(A)) |
| self.D = nn.Parameter(torch.ones(self.d_inner)) |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def _selective_scan(self, x, dt, A, B, C, D): |
| Bb, L, d_in = x.shape |
| d_state = A.shape[1] |
| dtA = dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0) |
| A_bar = torch.exp(dtA) |
| dtB = dt.unsqueeze(-1) * B.unsqueeze(2) |
| h = torch.zeros(Bb, d_in, d_state, device=x.device, dtype=x.dtype) |
| ys = [] |
| for t in range(L): |
| h = A_bar[:, t] * h + dtB[:, t] * x[:, t].unsqueeze(-1) |
| y = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) |
| ys.append(y) |
| y = torch.stack(ys, dim=1) |
| y = y + D.unsqueeze(0).unsqueeze(0) * x |
| return y |
|
|
| def forward(self, x: torch.Tensor): |
| x_norm = self.norm(x) |
| xz = self.in_proj(x_norm) |
| x_gate, z_gate = xz.chunk(2, dim=-1) |
| x_conv = self.conv1d(x_gate.transpose(1, 2))[:, :, :x_gate.shape[1]].transpose(1, 2) |
| x_conv = F.silu(x_conv) |
| xbc = self.x_proj(x_conv) |
| dt_un, B_un, C_un = torch.split(xbc, [self.dt_rank, self.d_state, self.d_state], dim=-1) |
| dt = F.softplus(self.dt_proj(dt_un)) |
| A = -torch.exp(self.A_log.float()) |
| B = B_un |
| C = C_un |
| y = self._selective_scan(x_conv, dt, A, B, C, self.D) |
| y = y * F.silu(z_gate) |
| out = self.out_proj(y) |
| return out + x |
|
|