Lyra / src /models /utils /mamba2.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted for Long-LRM by Ziwen 2024
# from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
from mamba_ssm.ops.triton.layer_norm import RMSNorm
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
class Mamba2SingleScan(nn.Module):
def __init__(
self,
d_model,
d_state,
d_conv,
conv_init,
expand,
headdim,
ngroups,
A_init_range,
dt_min,
dt_max,
dt_init_floor,
dt_limit,
learnable_init_states,
activation,
bias,
conv_bias,
# Fused kernel and sharding options
chunk_size,
device,
dtype,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = self.expand * self.d_model
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.dt_limit = dt_limit
self.learnable_init_states = learnable_init_states
self.activation = activation
self.chunk_size = chunk_size
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
**factory_kwargs,
)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
# self.conv1d.weight._no_weight_decay = True
if self.learnable_init_states:
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
self.init_states._no_weight_decay = True
self.act = nn.SiLU()
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
self.D._no_weight_decay = True
# Extra normalization layer right before output projection
assert RMSNormGated is not None
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
def forward(self, zxbcdt):
"""
zxbcdt: (B, L, D)
Returns: same shape as input
"""
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states = None
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
# Fully fused path
out = mamba_split_conv1d_scan_combined(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.eps,
headdim=self.headdim,
ngroups=self.ngroups,
norm_before_gate=False,
initial_states=initial_states,
**dt_limit_kwargs,
)
return out
class Mamba2MultiScan(nn.Module):
def __init__(
self,
d_model,
d_state,
d_conv,
conv_init,
expand,
headdim,
ngroups,
A_init_range,
dt_min,
dt_max,
dt_init_floor,
dt_limit,
learnable_init_states,
activation,
bias,
conv_bias,
# Fused kernel and sharding options
chunk_size,
scan_type, # single, bi
device,
dtype,
if_divide_out,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.expand = expand
self.d_inner = self.expand * self.d_model
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
assert scan_type in ["single", "bi"]
self.scan_type = scan_type
self.if_divide_out = if_divide_out
# Order: [z, x, B, C, dt]
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
self.mamba_scans = nn.ModuleList()
self.scan_num = 1
if scan_type == "bi":
self.scan_num = 2
for _ in range(self.scan_num):
self.mamba_scans.append(
Mamba2SingleScan(
d_model,
d_state,
d_conv,
conv_init,
expand,
headdim,
ngroups,
A_init_range,
dt_min,
dt_max,
dt_init_floor,
dt_limit,
learnable_init_states,
activation,
bias,
conv_bias,
chunk_size,
device,
dtype,
)
)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: same shape as input
"""
batch, seqlen, dim = hidden_states.shape
xz = self.in_proj(hidden_states) # (B, L, d_in_proj), [z,x,B,C,dt]
xzs = [xz]
if self.scan_type == "bi":
xzs.append(xz.flip([1]))
outs = []
for i in range(self.scan_num):
out = self.mamba_scans[i](xzs[i])
if i == 0:
outs.append(out)
elif i == 1:
outs.append(out.flip([1]))
out = sum(outs)
if self.if_divide_out:
out = out / self.scan_num
out = self.out_proj(out)
return out
class Mamba2Block(nn.Module):
def __init__(
self,
d_model,
d_state=256,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=1,
A_init_range=(1, 16),
dt_min=0.001,
dt_max=0.1,
dt_init_floor=1e-4,
dt_limit=(0.0, float("inf")),
learnable_init_states=False,
activation="swish",
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=256,
scan_type="bi", # single, bi
device=None,
dtype=None,
if_divide_out=False,
norm_cls="rms_norm",
):
super().__init__()
assert norm_cls in ["rms_norm", "layer_norm"]
if norm_cls=="rms_norm":
self.norm = RMSNorm(d_model)
elif norm_cls=="layer_norm":
self.norm = nn.LayerNorm(d_model)
self.mamba = Mamba2MultiScan(d_model, d_state, d_conv, conv_init, expand, headdim, ngroups, A_init_range, dt_min,
dt_max, dt_init_floor, dt_limit, learnable_init_states, activation, bias, conv_bias,
chunk_size, scan_type, device, dtype, if_divide_out)
def forward(self, x):
"""
x: (B, L, D)
Returns: same shape as input
"""
x = x + self.mamba(self.norm(x))
return x
if __name__ == "__main__":
# Test Mamba2Block
batch_size = 4
seq_len = 128
input_dim = 256
model = Mamba2Block(d_model=input_dim, device="cuda").to("cuda")
input = torch.randn(batch_size, seq_len, input_dim).to("cuda")
output = model(input)
print("Input shape:", input.shape)
print("Output shape:", output.shape)