artigen / ssm_block.py
krystv's picture
Upload ssm_block.py
95534c2 verified
"""
Simplified selective SSM block for image tokens.
O(N) complexity, O(1) memory per token.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimplifiedMambaBlock(nn.Module):
"""Minimal selective SSM block without cuda-specific selective_scan."""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(expand * d_model)
self.dt_rank = math.ceil(d_model / 16)
self.d_conv = d_conv
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
self.d_inner,
self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner,
bias=True,
)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.norm = nn.LayerNorm(d_model)
def _selective_scan(self, x, dt, A, B, C, D):
Bb, L, d_in = x.shape
d_state = A.shape[1]
dtA = dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)
A_bar = torch.exp(dtA)
dtB = dt.unsqueeze(-1) * B.unsqueeze(2)
h = torch.zeros(Bb, d_in, d_state, device=x.device, dtype=x.dtype)
ys = []
for t in range(L):
h = A_bar[:, t] * h + dtB[:, t] * x[:, t].unsqueeze(-1)
y = torch.sum(h * C[:, t].unsqueeze(1), dim=-1)
ys.append(y)
y = torch.stack(ys, dim=1)
y = y + D.unsqueeze(0).unsqueeze(0) * x
return y
def forward(self, x: torch.Tensor):
x_norm = self.norm(x)
xz = self.in_proj(x_norm)
x_gate, z_gate = xz.chunk(2, dim=-1)
x_conv = self.conv1d(x_gate.transpose(1, 2))[:, :, :x_gate.shape[1]].transpose(1, 2)
x_conv = F.silu(x_conv)
xbc = self.x_proj(x_conv)
dt_un, B_un, C_un = torch.split(xbc, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.softplus(self.dt_proj(dt_un))
A = -torch.exp(self.A_log.float())
B = B_un
C = C_un
y = self._selective_scan(x_conv, dt, A, B, C, self.D)
y = y * F.silu(z_gate)
out = self.out_proj(y)
return out + x