|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from einops import rearrange, repeat |
|
|
|
|
|
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm |
|
|
from mamba_ssm.ops.triton.layer_norm import RMSNorm |
|
|
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined |
|
|
|
|
|
class Mamba2SingleScan(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
d_state, |
|
|
d_conv, |
|
|
conv_init, |
|
|
expand, |
|
|
headdim, |
|
|
ngroups, |
|
|
A_init_range, |
|
|
dt_min, |
|
|
dt_max, |
|
|
dt_init_floor, |
|
|
dt_limit, |
|
|
learnable_init_states, |
|
|
activation, |
|
|
bias, |
|
|
conv_bias, |
|
|
|
|
|
chunk_size, |
|
|
device, |
|
|
dtype, |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.d_state = d_state |
|
|
self.d_conv = d_conv |
|
|
self.conv_init = conv_init |
|
|
self.expand = expand |
|
|
self.d_inner = self.expand * self.d_model |
|
|
self.headdim = headdim |
|
|
self.ngroups = ngroups |
|
|
assert self.d_inner % self.headdim == 0 |
|
|
self.nheads = self.d_inner // self.headdim |
|
|
self.dt_limit = dt_limit |
|
|
self.learnable_init_states = learnable_init_states |
|
|
self.activation = activation |
|
|
self.chunk_size = chunk_size |
|
|
|
|
|
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state |
|
|
self.conv1d = nn.Conv1d( |
|
|
in_channels=conv_dim, |
|
|
out_channels=conv_dim, |
|
|
bias=conv_bias, |
|
|
kernel_size=d_conv, |
|
|
groups=conv_dim, |
|
|
padding=d_conv - 1, |
|
|
**factory_kwargs, |
|
|
) |
|
|
if self.conv_init is not None: |
|
|
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) |
|
|
|
|
|
|
|
|
if self.learnable_init_states: |
|
|
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)) |
|
|
self.init_states._no_weight_decay = True |
|
|
|
|
|
self.act = nn.SiLU() |
|
|
|
|
|
|
|
|
dt = torch.exp( |
|
|
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) |
|
|
+ math.log(dt_min) |
|
|
) |
|
|
dt = torch.clamp(dt, min=dt_init_floor) |
|
|
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
|
self.dt_bias = nn.Parameter(inv_dt) |
|
|
|
|
|
|
|
|
self.dt_bias._no_weight_decay = True |
|
|
|
|
|
|
|
|
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] |
|
|
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) |
|
|
A_log = torch.log(A).to(dtype=dtype) |
|
|
self.A_log = nn.Parameter(A_log) |
|
|
|
|
|
self.A_log._no_weight_decay = True |
|
|
|
|
|
|
|
|
self.D = nn.Parameter(torch.ones(self.nheads, device=device)) |
|
|
self.D._no_weight_decay = True |
|
|
|
|
|
|
|
|
assert RMSNormGated is not None |
|
|
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs) |
|
|
|
|
|
def forward(self, zxbcdt): |
|
|
""" |
|
|
zxbcdt: (B, L, D) |
|
|
Returns: same shape as input |
|
|
""" |
|
|
A = -torch.exp(self.A_log) |
|
|
initial_states = None |
|
|
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) |
|
|
|
|
|
|
|
|
out = mamba_split_conv1d_scan_combined( |
|
|
zxbcdt, |
|
|
rearrange(self.conv1d.weight, "d 1 w -> d w"), |
|
|
self.conv1d.bias, |
|
|
self.dt_bias, |
|
|
A, |
|
|
D=self.D, |
|
|
chunk_size=self.chunk_size, |
|
|
activation=self.activation, |
|
|
rmsnorm_weight=self.norm.weight, |
|
|
rmsnorm_eps=self.norm.eps, |
|
|
headdim=self.headdim, |
|
|
ngroups=self.ngroups, |
|
|
norm_before_gate=False, |
|
|
initial_states=initial_states, |
|
|
**dt_limit_kwargs, |
|
|
) |
|
|
return out |
|
|
|
|
|
class Mamba2MultiScan(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
d_state, |
|
|
d_conv, |
|
|
conv_init, |
|
|
expand, |
|
|
headdim, |
|
|
ngroups, |
|
|
A_init_range, |
|
|
dt_min, |
|
|
dt_max, |
|
|
dt_init_floor, |
|
|
dt_limit, |
|
|
learnable_init_states, |
|
|
activation, |
|
|
bias, |
|
|
conv_bias, |
|
|
|
|
|
chunk_size, |
|
|
scan_type, |
|
|
device, |
|
|
dtype, |
|
|
if_divide_out, |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.d_state = d_state |
|
|
self.expand = expand |
|
|
self.d_inner = self.expand * self.d_model |
|
|
self.headdim = headdim |
|
|
self.ngroups = ngroups |
|
|
assert self.d_inner % self.headdim == 0 |
|
|
self.nheads = self.d_inner // self.headdim |
|
|
assert scan_type in ["single", "bi"] |
|
|
self.scan_type = scan_type |
|
|
self.if_divide_out = if_divide_out |
|
|
|
|
|
|
|
|
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads |
|
|
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) |
|
|
|
|
|
self.mamba_scans = nn.ModuleList() |
|
|
self.scan_num = 1 |
|
|
if scan_type == "bi": |
|
|
self.scan_num = 2 |
|
|
for _ in range(self.scan_num): |
|
|
self.mamba_scans.append( |
|
|
Mamba2SingleScan( |
|
|
d_model, |
|
|
d_state, |
|
|
d_conv, |
|
|
conv_init, |
|
|
expand, |
|
|
headdim, |
|
|
ngroups, |
|
|
A_init_range, |
|
|
dt_min, |
|
|
dt_max, |
|
|
dt_init_floor, |
|
|
dt_limit, |
|
|
learnable_init_states, |
|
|
activation, |
|
|
bias, |
|
|
conv_bias, |
|
|
chunk_size, |
|
|
device, |
|
|
dtype, |
|
|
) |
|
|
) |
|
|
|
|
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
""" |
|
|
hidden_states: (B, L, D) |
|
|
Returns: same shape as input |
|
|
""" |
|
|
batch, seqlen, dim = hidden_states.shape |
|
|
|
|
|
xz = self.in_proj(hidden_states) |
|
|
|
|
|
xzs = [xz] |
|
|
if self.scan_type == "bi": |
|
|
xzs.append(xz.flip([1])) |
|
|
|
|
|
outs = [] |
|
|
for i in range(self.scan_num): |
|
|
out = self.mamba_scans[i](xzs[i]) |
|
|
if i == 0: |
|
|
outs.append(out) |
|
|
elif i == 1: |
|
|
outs.append(out.flip([1])) |
|
|
|
|
|
out = sum(outs) |
|
|
if self.if_divide_out: |
|
|
out = out / self.scan_num |
|
|
|
|
|
out = self.out_proj(out) |
|
|
|
|
|
return out |
|
|
|
|
|
class Mamba2Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
d_state=256, |
|
|
d_conv=4, |
|
|
conv_init=None, |
|
|
expand=2, |
|
|
headdim=64, |
|
|
ngroups=1, |
|
|
A_init_range=(1, 16), |
|
|
dt_min=0.001, |
|
|
dt_max=0.1, |
|
|
dt_init_floor=1e-4, |
|
|
dt_limit=(0.0, float("inf")), |
|
|
learnable_init_states=False, |
|
|
activation="swish", |
|
|
bias=False, |
|
|
conv_bias=True, |
|
|
|
|
|
chunk_size=256, |
|
|
scan_type="bi", |
|
|
device=None, |
|
|
dtype=None, |
|
|
if_divide_out=False, |
|
|
norm_cls="rms_norm", |
|
|
): |
|
|
super().__init__() |
|
|
assert norm_cls in ["rms_norm", "layer_norm"] |
|
|
if norm_cls=="rms_norm": |
|
|
self.norm = RMSNorm(d_model) |
|
|
elif norm_cls=="layer_norm": |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.mamba = Mamba2MultiScan(d_model, d_state, d_conv, conv_init, expand, headdim, ngroups, A_init_range, dt_min, |
|
|
dt_max, dt_init_floor, dt_limit, learnable_init_states, activation, bias, conv_bias, |
|
|
chunk_size, scan_type, device, dtype, if_divide_out) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: (B, L, D) |
|
|
Returns: same shape as input |
|
|
""" |
|
|
x = x + self.mamba(self.norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
batch_size = 4 |
|
|
seq_len = 128 |
|
|
input_dim = 256 |
|
|
|
|
|
model = Mamba2Block(d_model=input_dim, device="cuda").to("cuda") |
|
|
input = torch.randn(batch_size, seq_len, input_dim).to("cuda") |
|
|
output = model(input) |
|
|
|
|
|
print("Input shape:", input.shape) |
|
|
print("Output shape:", output.shape) |
|
|
|