import torch import torch.nn as nn import torch.nn.functional as F def get_model_device(model): return next(iter(model.parameters())).device class CausalConv1d(nn.Module): def __init__(self, hidden_size, kernel_size): super().__init__() self.hidden_size = hidden_size self.kernel_size = kernel_size self.conv = nn.Conv1d( hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True ) def init_state(self, batch_size: int, device: torch.device | None = None): if device is None: device = get_model_device(self) return torch.zeros( batch_size, self.hidden_size, self.kernel_size - 1, device=device ) def forward(self, x: torch.Tensor, state: torch.Tensor): x_with_state = torch.concat([state, x[:, :, None]], dim=-1) out = self.conv(x_with_state) new_state = x_with_state[:, :, 1:] return out.squeeze(-1), new_state class Mamba2(nn.Module): def __init__( self, hidden_size: int, inner_size: int | None = None, head_size: int = 64, bc_head_size: int = 128, conv_kernel_size: int = 4, ): super().__init__() self.head_size = head_size self.bc_head_size = bc_head_size if inner_size is None: inner_size = 2 * hidden_size assert inner_size % head_size == 0 self.inner_size = inner_size self.num_heads = inner_size // head_size # Projections self.input_proj = nn.Linear(hidden_size, inner_size, bias=False) self.z_proj = nn.Linear(hidden_size, inner_size, bias=False) self.b_proj = nn.Linear(hidden_size, bc_head_size, bias=False) self.c_proj = nn.Linear(hidden_size, bc_head_size, bias=False) self.dt_proj = nn.Linear(hidden_size, self.num_heads, bias=True) # Convs self.input_conv = CausalConv1d(inner_size, conv_kernel_size) self.b_conv = CausalConv1d(bc_head_size, conv_kernel_size) self.c_conv = CausalConv1d(bc_head_size, conv_kernel_size) # Other parameters self.a = nn.Parameter(-torch.empty(self.num_heads).uniform_(1, 16)) self.d = nn.Parameter(torch.ones(self.num_heads)) # Output self.norm = nn.RMSNorm(inner_size, eps=1e-5) self.out_proj = nn.Linear(inner_size, hidden_size, bias=False) def init_state(self, batch_size: int, device: torch.device | None = None): if device is None: device = get_model_device(self) conv_states = [ conv.init_state(batch_size, device) for conv in [self.input_conv, self.b_conv, self.c_conv] ] ssm_state = torch.zeros( batch_size, self.num_heads, self.head_size, self.bc_head_size, device=device ) return conv_states + [ssm_state] def forward(self, t, state): batch_size = t.shape[0] x = self.input_proj(t) z = self.z_proj(t) b = self.b_proj(t) c = self.c_proj(t) dt = self.dt_proj(t) x_conv_state, b_conv_state, c_conv_state, ssm_state = state x, x_conv_state = self.input_conv(x, x_conv_state) b, b_conv_state = self.b_conv(b, b_conv_state) c, c_conv_state = self.c_conv(c, c_conv_state) x = F.silu(x) b = F.silu(b) c = F.silu(c) x = x.view(batch_size, self.num_heads, self.head_size) dt = F.softplus(dt) # new_state computation: h[t] = exp(A*dt) * h[t-1] + dt * B * x[t] # [batch_size, num_heads] decay = torch.exp(self.a[None] * dt) # dt is [batch_size, num_heads] # b is [batch_size, bc_head_size] # x is [batch_size, head_size] new_state_contrib = dt[:, :, None, None] * b[:, None, None] * x[:, :, :, None] ssm_state = decay[:, :, None, None] * ssm_state + new_state_contrib # output computation: y[t] = C @ h[t] + D * x[t] # The accumulation in the product of C and h[t] is on the bc_head_size dimension state_contrib = torch.einsum("bc,bnhc->bnh", c, ssm_state) # d has shape [num_heads], broadcasting it to the shape of x. y = state_contrib + self.d[None, :, None] * x # Combine heads y = y.view(batch_size, self.inner_size) # Gate, normalization and out y = y * F.silu(z) y = self.norm(y) output = self.out_proj(y) new_state = [x_conv_state, b_conv_state, c_conv_state, ssm_state] return output, new_state