Create core/mamba_block.py
Browse files- core/mamba_block.py +116 -0
core/mamba_block.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class MambaBlock(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Production-ready Mamba block for graph processing
|
| 10 |
+
Based on official Mamba implementation with graph optimizations
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.d_model = d_model
|
| 15 |
+
self.d_state = d_state
|
| 16 |
+
self.d_conv = d_conv
|
| 17 |
+
self.expand = expand
|
| 18 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 19 |
+
|
| 20 |
+
if dt_rank == "auto":
|
| 21 |
+
self.dt_rank = math.ceil(self.d_model / 16)
|
| 22 |
+
else:
|
| 23 |
+
self.dt_rank = dt_rank
|
| 24 |
+
|
| 25 |
+
# Linear projections
|
| 26 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
|
| 27 |
+
|
| 28 |
+
# Convolution for local patterns
|
| 29 |
+
self.conv1d = nn.Conv1d(
|
| 30 |
+
in_channels=self.d_inner,
|
| 31 |
+
out_channels=self.d_inner,
|
| 32 |
+
kernel_size=d_conv,
|
| 33 |
+
groups=self.d_inner,
|
| 34 |
+
padding=d_conv - 1,
|
| 35 |
+
bias=True,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# SSM parameters
|
| 39 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
|
| 40 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 41 |
+
|
| 42 |
+
# Initialize A (state evolution matrix)
|
| 43 |
+
A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
|
| 44 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 45 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 46 |
+
|
| 47 |
+
# Output projection
|
| 48 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
|
| 49 |
+
|
| 50 |
+
# Activation
|
| 51 |
+
self.act = nn.SiLU()
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
"""
|
| 55 |
+
x: (batch, length, d_model)
|
| 56 |
+
Returns: (batch, length, d_model)
|
| 57 |
+
"""
|
| 58 |
+
batch, length, _ = x.shape
|
| 59 |
+
|
| 60 |
+
# Input projection and split
|
| 61 |
+
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
| 62 |
+
x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
|
| 63 |
+
|
| 64 |
+
# Convolution
|
| 65 |
+
x = rearrange(x, 'b l d -> b d l')
|
| 66 |
+
x = self.conv1d(x)[:, :, :length]
|
| 67 |
+
x = rearrange(x, 'b d l -> b l d')
|
| 68 |
+
x = self.act(x)
|
| 69 |
+
|
| 70 |
+
# SSM
|
| 71 |
+
y = self.selective_scan(x)
|
| 72 |
+
|
| 73 |
+
# Gating
|
| 74 |
+
y = y * self.act(z)
|
| 75 |
+
|
| 76 |
+
# Output projection
|
| 77 |
+
out = self.out_proj(y)
|
| 78 |
+
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
def selective_scan(self, u):
|
| 82 |
+
"""Selective scan operation - core of Mamba"""
|
| 83 |
+
batch, length, d_inner = u.shape
|
| 84 |
+
|
| 85 |
+
# Compute ∆, B, C
|
| 86 |
+
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
| 87 |
+
delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 88 |
+
|
| 89 |
+
# Softplus ensures delta > 0
|
| 90 |
+
delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner)
|
| 91 |
+
|
| 92 |
+
return self._selective_scan_pytorch(u, delta, B, C)
|
| 93 |
+
|
| 94 |
+
def _selective_scan_pytorch(self, u, delta, B, C):
|
| 95 |
+
"""PyTorch implementation of selective scan"""
|
| 96 |
+
batch, length, d_inner = u.shape
|
| 97 |
+
|
| 98 |
+
# Discretize
|
| 99 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) # (batch, length, d_inner, d_state)
|
| 100 |
+
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
| 101 |
+
|
| 102 |
+
# Initialize state
|
| 103 |
+
x = torch.zeros((batch, d_inner, self.d_state), device=u.device, dtype=u.dtype)
|
| 104 |
+
ys = []
|
| 105 |
+
|
| 106 |
+
for i in range(length):
|
| 107 |
+
x = deltaA[:, i] * x + deltaB_u[:, i]
|
| 108 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, i])
|
| 109 |
+
ys.append(y)
|
| 110 |
+
|
| 111 |
+
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
| 112 |
+
|
| 113 |
+
# Add skip connection
|
| 114 |
+
y = y + u * self.D
|
| 115 |
+
|
| 116 |
+
return y
|