Commit ·
4d7d25c
1
Parent(s): 13bfa0f
fixes
Browse files- attn.py +206 -0
- causal_conv1d_compilable.py +1 -2
- model.py +129 -191
- norms.py +1 -1
- ssm_compilable.py +1 -2
attn.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from flash_attn import flash_attn_func
|
| 10 |
+
except ImportError as e:
|
| 11 |
+
print(
|
| 12 |
+
f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def nearest_power_of_two(x: int, round_up: bool = False) -> int:
|
| 17 |
+
return (
|
| 18 |
+
1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def _generate_slopes(self, n: int):
|
| 22 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 23 |
+
return [start * (start**i) for i in range(n)]
|
| 24 |
+
|
| 25 |
+
def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
|
| 26 |
+
# If n_heads is a power of 2, generate slopes directly
|
| 27 |
+
if math.log2(n_heads).is_integer():
|
| 28 |
+
slopes = self._generate_slopes(n_heads)
|
| 29 |
+
else:
|
| 30 |
+
# Get slopes for the nearest power of two
|
| 31 |
+
n = nearest_power_of_two(n_heads, round_up=False)
|
| 32 |
+
slopes_power_of_two = self._generate_slopes(n)
|
| 33 |
+
|
| 34 |
+
# Generate extra slopes
|
| 35 |
+
extra_slopes = self._generate_slopes(2 * n)
|
| 36 |
+
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
| 37 |
+
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 38 |
+
slopes = torch.tensor(slopes, device=self.device)
|
| 39 |
+
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 40 |
+
return slopes
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
|
| 44 |
+
# For half the dimensions, build the scale factor:
|
| 45 |
+
freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
|
| 46 |
+
freqs = 1.0 / (theta ** freq_seq)
|
| 47 |
+
|
| 48 |
+
# Outer product with positions
|
| 49 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 50 |
+
angles = torch.outer(t, freqs)
|
| 51 |
+
|
| 52 |
+
# Build a complex exponential e^{i * theta}
|
| 53 |
+
freqs_cis = torch.polar(
|
| 54 |
+
torch.ones_like(angles),
|
| 55 |
+
angles
|
| 56 |
+
)
|
| 57 |
+
return freqs_cis
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 61 |
+
"""
|
| 62 |
+
x is [B, n_heads, seq_len, head_dim_as_complex],
|
| 63 |
+
so we want to broadcast freqs_cis from [max_seq_len, half_dim]
|
| 64 |
+
to [1, 1, seq_len, half_dim].
|
| 65 |
+
"""
|
| 66 |
+
seq_len = x.shape[2]
|
| 67 |
+
freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
|
| 68 |
+
return freqs_cis.view(1, 1, seq_len, -1)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def apply_rotary_emb(
|
| 72 |
+
xq: torch.Tensor,
|
| 73 |
+
xk: torch.Tensor,
|
| 74 |
+
freqs_cis: torch.Tensor,
|
| 75 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 76 |
+
# Convert real -> complex by grouping last dim in pairs
|
| 77 |
+
# shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
|
| 78 |
+
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 79 |
+
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 80 |
+
|
| 81 |
+
# Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
|
| 82 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
|
| 83 |
+
|
| 84 |
+
# Multiply => apply rotation
|
| 85 |
+
xq_complex = xq_complex * freqs_cis
|
| 86 |
+
xk_complex = xk_complex * freqs_cis
|
| 87 |
+
|
| 88 |
+
# Convert back to real => shape [B, n_heads, seq_len, head_dim]
|
| 89 |
+
xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
|
| 90 |
+
xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
|
| 91 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Attention(nn.Module):
|
| 95 |
+
def __init__(self, config):
|
| 96 |
+
super(Attention, self).__init__()
|
| 97 |
+
self.dim, self.num_heads = config.dim, config.num_heads
|
| 98 |
+
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
|
| 99 |
+
self.head_dim = config.dim // config.num_heads
|
| 100 |
+
|
| 101 |
+
self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
|
| 102 |
+
self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
|
| 103 |
+
self.c_proj.SCALE_INIT = 1
|
| 104 |
+
|
| 105 |
+
self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
|
| 106 |
+
self.window_size = config.window_size
|
| 107 |
+
self.softcap = config.softcap
|
| 108 |
+
|
| 109 |
+
self.dropout = config.dropout
|
| 110 |
+
self.resid_dropout = nn.Dropout(self.dropout)
|
| 111 |
+
|
| 112 |
+
def _generate_slopes(self, n: int):
|
| 113 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 114 |
+
return [start * (start**i) for i in range(n)]
|
| 115 |
+
|
| 116 |
+
def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
|
| 117 |
+
# If n_heads is a power of 2, generate slopes directly
|
| 118 |
+
if math.log2(num_heads).is_integer():
|
| 119 |
+
slopes = self._generate_slopes(num_heads)
|
| 120 |
+
else:
|
| 121 |
+
# Get slopes for the nearest power of two
|
| 122 |
+
n = nearest_power_of_two(num_heads, round_up=False)
|
| 123 |
+
slopes_power_of_two = self._generate_slopes(n)
|
| 124 |
+
|
| 125 |
+
# Generate extra slopes
|
| 126 |
+
extra_slopes = self._generate_slopes(2 * n)
|
| 127 |
+
extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
|
| 128 |
+
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 129 |
+
slopes = torch.tensor(slopes, device=torch.device("cuda"))
|
| 130 |
+
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 131 |
+
return slopes
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
x: torch.Tensor = None,
|
| 136 |
+
q: torch.Tensor = None,
|
| 137 |
+
k: torch.Tensor = None,
|
| 138 |
+
v: torch.Tensor = None,
|
| 139 |
+
freqs_cis: torch.Tensor = None,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
if x is not None:
|
| 142 |
+
q = k = v = x
|
| 143 |
+
if any(t is None for t in [q, k, v]):
|
| 144 |
+
raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
|
| 145 |
+
|
| 146 |
+
bsz, q_len, dim = q.shape
|
| 147 |
+
_, k_len, _ = k.shape
|
| 148 |
+
_, v_len, _ = v.shape
|
| 149 |
+
|
| 150 |
+
qkv = self.c_attn(x)
|
| 151 |
+
q, k, v = torch.chunk(qkv, 3, dim=2)
|
| 152 |
+
|
| 153 |
+
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
| 154 |
+
k = k.view(bsz, k_len, self.num_heads, self.head_dim)
|
| 155 |
+
v = v.view(bsz, v_len, self.num_heads, self.head_dim)
|
| 156 |
+
|
| 157 |
+
if self.alibi_slopes is None: # Use either ALiBi or RoPE
|
| 158 |
+
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
|
| 159 |
+
|
| 160 |
+
y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
|
| 161 |
+
q=q, k=k, v=v,
|
| 162 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 163 |
+
causal=True,
|
| 164 |
+
window_size=(self.window_size, 0), # Set to config.seq_len if full attention
|
| 165 |
+
alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
|
| 166 |
+
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
y = y.contiguous().view(bsz, q_len, -1)
|
| 170 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 171 |
+
return y
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class MLP(nn.Module):
|
| 175 |
+
def __init__(self, config):
|
| 176 |
+
# https://arxiv.org/pdf/2002.05202
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.hidden_size = config.dim
|
| 179 |
+
self.intermediate_size = config.dim * config.mlp_scale
|
| 180 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
|
| 181 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
|
| 182 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
|
| 183 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
gate = self.gate_proj(x)
|
| 187 |
+
gate = F.gelu(gate, approximate="tanh")
|
| 188 |
+
up = self.up_proj(x)
|
| 189 |
+
fuse = gate * up
|
| 190 |
+
outputs = self.down_proj(fuse)
|
| 191 |
+
outputs = self.dropout(outputs)
|
| 192 |
+
return outputs
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class AttentionLayer(nn.Module):
|
| 196 |
+
def __init__(self, config) -> None:
|
| 197 |
+
super(AttentionLayer, self).__init__()
|
| 198 |
+
self.attn_norm = nn.RMSNorm(config.dim)
|
| 199 |
+
self.attn = Attention(config=config)
|
| 200 |
+
self.mlp_norm = nn.RMSNorm(config.dim)
|
| 201 |
+
self.mlp = MLP(config)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
|
| 204 |
+
x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
|
| 205 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 206 |
+
return x
|
causal_conv1d_compilable.py
CHANGED
|
@@ -211,5 +211,4 @@ if __name__ == "__main__":
|
|
| 211 |
|
| 212 |
print(out.min(), out.max(), out.mean(), out.std())
|
| 213 |
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
|
| 214 |
-
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
|
| 215 |
-
|
|
|
|
| 211 |
|
| 212 |
print(out.min(), out.max(), out.mean(), out.std())
|
| 213 |
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
|
| 214 |
+
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
|
|
|
model.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import torch
|
|
@@ -6,13 +15,19 @@ import torch.nn.functional as F
|
|
| 6 |
|
| 7 |
from enum import Enum
|
| 8 |
from dataclasses import dataclass, field
|
|
|
|
| 9 |
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 11 |
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
| 12 |
|
| 13 |
-
from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
|
| 14 |
from .ssm_compilable import mamba_chunk_scan_combined
|
| 15 |
from .norms import build_norm
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class InitStdFactor(Enum):
|
|
@@ -154,9 +169,7 @@ class SSM(nn.Module):
|
|
| 154 |
if self.learnable_init_states:
|
| 155 |
self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
|
| 156 |
|
| 157 |
-
# Can also just use nn.RMSNorm
|
| 158 |
self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
|
| 159 |
-
|
| 160 |
self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
|
| 161 |
|
| 162 |
def _causal_conv(
|
|
@@ -320,7 +333,7 @@ class SSM(nn.Module):
|
|
| 320 |
),
|
| 321 |
dt_softplus=True,
|
| 322 |
).unsqueeze(0)
|
| 323 |
-
|
| 324 |
return y
|
| 325 |
|
| 326 |
def forward(
|
|
@@ -502,8 +515,16 @@ class BaseMamba(nn.Module):
|
|
| 502 |
self.init_std_factor = InitStdFactor(config.init_std_factor)
|
| 503 |
|
| 504 |
self.layers = nn.ModuleList()
|
| 505 |
-
for
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
def forward(
|
| 509 |
self,
|
|
@@ -536,6 +557,7 @@ class BaseMamba(nn.Module):
|
|
| 536 |
|
| 537 |
@dataclass
|
| 538 |
class Mamba2Config(BaseMambaConfig):
|
|
|
|
| 539 |
seed: int = 1337
|
| 540 |
|
| 541 |
vocab_size: int = -1 # Will error if unchanged, makes you double check!
|
|
@@ -573,10 +595,10 @@ class Mamba2(BaseMamba):
|
|
| 573 |
|
| 574 |
def _get_num_params(self):
|
| 575 |
n_params = sum(p.numel() for p in self.parameters())
|
|
|
|
| 576 |
if hasattr(self, "pos_emb") and self.pos_emb is not None:
|
| 577 |
n_params -= self.pos_emb.weight.numel()
|
| 578 |
-
|
| 579 |
-
n_params -= self.tok_emb.weight.numel()
|
| 580 |
return n_params
|
| 581 |
|
| 582 |
def forward(
|
|
@@ -657,192 +679,108 @@ class Mamba2(BaseMamba):
|
|
| 657 |
return cls(config)
|
| 658 |
|
| 659 |
|
| 660 |
-
def
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
""
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
return int(flops_forward * forward_backward_multiplier)
|
| 746 |
-
|
| 747 |
-
def get_mamba2_flops_per_token(
|
| 748 |
-
**kwargs
|
| 749 |
-
) -> float:
|
| 750 |
-
"""
|
| 751 |
-
Estimate FLOPs per token for a Mamba-2 style model.
|
| 752 |
-
|
| 753 |
-
This function extracts necessary parameters from kwargs and calculates the FLOPs per token.
|
| 754 |
-
|
| 755 |
-
Args:
|
| 756 |
-
**kwargs: Dictionary containing model configuration parameters.
|
| 757 |
-
|
| 758 |
-
Returns:
|
| 759 |
-
float: Approximate FLOPs per token.
|
| 760 |
-
"""
|
| 761 |
-
defaults = {
|
| 762 |
-
'ffn_dim_multiplier': 2.0,
|
| 763 |
-
'state_dim': 128,
|
| 764 |
-
'conv_size': 4,
|
| 765 |
-
'num_heads': 8,
|
| 766 |
-
'num_groups': 1,
|
| 767 |
-
'multiple_of': 256,
|
| 768 |
-
'include_input_embedding': True,
|
| 769 |
-
'include_output_logits': True,
|
| 770 |
-
'forward_backward_multiplier': 1.0,
|
| 771 |
-
}
|
| 772 |
-
# Merge defaults
|
| 773 |
-
for k, v in defaults.items():
|
| 774 |
-
kwargs.setdefault(k, v)
|
| 775 |
-
# Mandatory keys
|
| 776 |
-
for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']:
|
| 777 |
-
if required not in kwargs:
|
| 778 |
-
raise ValueError(f"Missing required parameter: {required}")
|
| 779 |
-
|
| 780 |
-
total_flops = get_mamba2_flops(
|
| 781 |
-
seq_len=kwargs['seq_len'],
|
| 782 |
-
dim=kwargs['dim'],
|
| 783 |
-
num_layers=kwargs['num_layers'],
|
| 784 |
-
vocab_size=kwargs['vocab_size'],
|
| 785 |
-
ffn_multiplier=kwargs['ffn_dim_multiplier'],
|
| 786 |
-
state_dim=kwargs['state_dim'],
|
| 787 |
-
conv_size=kwargs['conv_size'],
|
| 788 |
-
num_heads=kwargs['num_heads'],
|
| 789 |
-
num_groups=kwargs['num_groups'],
|
| 790 |
-
multiple_of=kwargs['multiple_of'],
|
| 791 |
-
include_input_embedding=kwargs['include_input_embedding'],
|
| 792 |
-
include_output_logits=kwargs['include_output_logits'],
|
| 793 |
-
forward_backward_multiplier=kwargs['forward_backward_multiplier'],
|
| 794 |
)
|
| 795 |
-
flops_per_token = total_flops / kwargs['seq_len']
|
| 796 |
-
|
| 797 |
-
return flops_per_token
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
|
| 801 |
-
def get_no_recompute_ops():
|
| 802 |
-
return {
|
| 803 |
-
torch.ops.aten.mm.default,
|
| 804 |
-
torch.ops.aten._scaled_mm.default,
|
| 805 |
-
torch.ops.c10d_functional.reduce_scatter_tensor.default,
|
| 806 |
-
torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,
|
| 807 |
-
|
| 808 |
-
# For low-precision training, it's useful to always save the result of max(abs(tensor))
|
| 809 |
-
torch.ops.aten.abs.default,
|
| 810 |
-
torch.ops.aten.max.default,
|
| 811 |
-
}
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
def main():
|
| 815 |
-
from mamba_ssm import Mamba2 as MambaRef
|
| 816 |
-
|
| 817 |
-
x = torch.randn(2, 64, 192).cuda()
|
| 818 |
-
|
| 819 |
-
# Create and run the first model
|
| 820 |
-
model = MambaRef(
|
| 821 |
-
d_model=192,
|
| 822 |
-
expand=2,
|
| 823 |
-
d_conv=4,
|
| 824 |
-
d_state=64,
|
| 825 |
-
headdim=48,
|
| 826 |
-
).cuda()
|
| 827 |
-
y = model(x)
|
| 828 |
-
print("Mamba reference output: ", y)
|
| 829 |
-
print("Mean of MambaRef output: ", y.mean().item())
|
| 830 |
-
print("Stddev of MambaRef output: ", y.std().item())
|
| 831 |
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
config=config,
|
| 836 |
-
).cuda()
|
| 837 |
|
| 838 |
-
|
| 839 |
-
x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
|
| 840 |
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
|
|
|
| 845 |
|
| 846 |
-
|
| 847 |
-
main()
|
| 848 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
Adapted from Meta's Lingua repository:
|
| 4 |
+
- https://github.com/facebookresearch/lingua/blob/main/apps/mamba/core_mamba.py
|
| 5 |
+
- https://github.com/facebookresearch/lingua/blob/main/apps/mamba/mamba.py
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
import math
|
| 11 |
|
| 12 |
import torch
|
|
|
|
| 15 |
|
| 16 |
from enum import Enum
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
+
|
| 19 |
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
| 20 |
+
from .causal_conv1d_compilable import (
|
| 21 |
+
causal_conv1d_fn as causal_conv1d_fn,
|
| 22 |
+
causal_conv1d_update as causal_conv1d_update
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 26 |
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
| 27 |
|
|
|
|
| 28 |
from .ssm_compilable import mamba_chunk_scan_combined
|
| 29 |
from .norms import build_norm
|
| 30 |
+
from .attn import AttentionLayer
|
| 31 |
|
| 32 |
|
| 33 |
class InitStdFactor(Enum):
|
|
|
|
| 169 |
if self.learnable_init_states:
|
| 170 |
self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
|
| 171 |
|
|
|
|
| 172 |
self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
|
|
|
|
| 173 |
self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
|
| 174 |
|
| 175 |
def _causal_conv(
|
|
|
|
| 333 |
),
|
| 334 |
dt_softplus=True,
|
| 335 |
).unsqueeze(0)
|
| 336 |
+
|
| 337 |
return y
|
| 338 |
|
| 339 |
def forward(
|
|
|
|
| 515 |
self.init_std_factor = InitStdFactor(config.init_std_factor)
|
| 516 |
|
| 517 |
self.layers = nn.ModuleList()
|
| 518 |
+
for layer_idx in range(config.num_layers):
|
| 519 |
+
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
|
| 520 |
+
if layer_idx % 2 == 0:
|
| 521 |
+
self.layers.append(MambaBlock(config))
|
| 522 |
+
else:
|
| 523 |
+
self.layers.append(
|
| 524 |
+
AttentionLayer(config)
|
| 525 |
+
if config.use_attn
|
| 526 |
+
else (MambaBlock(config))
|
| 527 |
+
)
|
| 528 |
|
| 529 |
def forward(
|
| 530 |
self,
|
|
|
|
| 557 |
|
| 558 |
@dataclass
|
| 559 |
class Mamba2Config(BaseMambaConfig):
|
| 560 |
+
bsz: int = 2
|
| 561 |
seed: int = 1337
|
| 562 |
|
| 563 |
vocab_size: int = -1 # Will error if unchanged, makes you double check!
|
|
|
|
| 595 |
|
| 596 |
def _get_num_params(self):
|
| 597 |
n_params = sum(p.numel() for p in self.parameters())
|
| 598 |
+
|
| 599 |
if hasattr(self, "pos_emb") and self.pos_emb is not None:
|
| 600 |
n_params -= self.pos_emb.weight.numel()
|
| 601 |
+
|
|
|
|
| 602 |
return n_params
|
| 603 |
|
| 604 |
def forward(
|
|
|
|
| 679 |
return cls(config)
|
| 680 |
|
| 681 |
|
| 682 |
+
# def main():
|
| 683 |
+
# x = torch.randn(2, 64, 192).cuda()
|
| 684 |
+
|
| 685 |
+
# config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
|
| 686 |
+
# model2 = Mamba2(
|
| 687 |
+
# config=config,
|
| 688 |
+
# ).cuda()
|
| 689 |
+
|
| 690 |
+
# x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
|
| 691 |
+
|
| 692 |
+
# y2 = model2(x_indices)
|
| 693 |
+
# print("Mamba output: ", y2)
|
| 694 |
+
# print("Mean of Mamba output: ", y2.mean().item())
|
| 695 |
+
# print("Stddev of Mamba output: ", y2.std().item())
|
| 696 |
+
|
| 697 |
+
# if __name__ == "__main__":
|
| 698 |
+
# main()
|
| 699 |
+
|
| 700 |
+
if __name__ == '__main__':
|
| 701 |
+
x = torch.randn(2, 64, 192).cuda() # Removing this produces NaNs lol
|
| 702 |
+
|
| 703 |
+
config_path = "/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/models/mamba/config.json"
|
| 704 |
+
|
| 705 |
+
with open(config_path, "r") as f:
|
| 706 |
+
config_data = json.load(f)
|
| 707 |
+
|
| 708 |
+
if torch.cuda.is_available():
|
| 709 |
+
device = torch.device("cuda")
|
| 710 |
+
elif torch.backends.mps.is_available():
|
| 711 |
+
device = torch.device("mps")
|
| 712 |
+
else:
|
| 713 |
+
device = torch.device("cpu")
|
| 714 |
+
print("Device:", device)
|
| 715 |
+
|
| 716 |
+
torch_dtype = getattr(torch, config_data["torch_dtype"])
|
| 717 |
+
print("Torch dtype:", torch_dtype)
|
| 718 |
+
|
| 719 |
+
dim = config_data["dim"]
|
| 720 |
+
num_heads = config_data["num_heads"]
|
| 721 |
+
num_layers = config_data["num_layers"]
|
| 722 |
+
vocab_size = config_data["vocab_size"]
|
| 723 |
+
bias = config_data["bias"]
|
| 724 |
+
state_dim = config_data["state_dim"]
|
| 725 |
+
num_groups = config_data["num_groups"]
|
| 726 |
+
conv_size = config_data.get("conv_size")
|
| 727 |
+
use_mem_eff_path = config_data.get("use_mem_eff_path")
|
| 728 |
+
dt_bias = config_data["dt_bias"]
|
| 729 |
+
D_has_head_dim = config_data["D_has_head_dim"]
|
| 730 |
+
learnable_init_states = config_data["learnable_init_states"]
|
| 731 |
+
ssm_chunk_size = config_data["ssm_chunk_size"]
|
| 732 |
+
weight_tying = config_data["weight_tying"]
|
| 733 |
+
ffn_dim_multiplier = config_data.get("ffn_dim_multiplier")
|
| 734 |
+
multiple_of = config_data["multiple_of"]
|
| 735 |
+
norm_eps = config_data["norm_eps"]
|
| 736 |
+
init_use_depth = config_data["init_use_depth"]
|
| 737 |
+
init_base_std = config_data.get("init_base_std")
|
| 738 |
+
init_std_factor = config_data["init_std_factor"]
|
| 739 |
+
use_attn = config_data["use_attn"]
|
| 740 |
+
softcap = config_data["softcap"]
|
| 741 |
+
torch_compile = config_data["torch_compile"]
|
| 742 |
+
|
| 743 |
+
configs = Mamba2Config(
|
| 744 |
+
dim=dim,
|
| 745 |
+
num_layers=num_layers,
|
| 746 |
+
num_heads=num_heads,
|
| 747 |
+
vocab_size=vocab_size,
|
| 748 |
+
bias=bias,
|
| 749 |
+
torch_dtype=torch_dtype,
|
| 750 |
+
state_dim=state_dim,
|
| 751 |
+
num_groups=num_groups,
|
| 752 |
+
conv_size=conv_size,
|
| 753 |
+
use_mem_eff_path=use_mem_eff_path,
|
| 754 |
+
dt_bias=dt_bias,
|
| 755 |
+
D_has_head_dim=D_has_head_dim,
|
| 756 |
+
learnable_init_states=learnable_init_states,
|
| 757 |
+
ssm_chunk_size=ssm_chunk_size,
|
| 758 |
+
weight_tying=weight_tying,
|
| 759 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 760 |
+
multiple_of=multiple_of,
|
| 761 |
+
norm_eps=norm_eps,
|
| 762 |
+
init_use_depth=init_use_depth,
|
| 763 |
+
init_base_std=init_base_std,
|
| 764 |
+
init_std_factor=init_std_factor,
|
| 765 |
+
use_attn=use_attn,
|
| 766 |
+
softcap=softcap,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
|
| 769 |
+
print("Configs:")
|
| 770 |
+
for key, value in vars(configs).items():
|
| 771 |
+
print(f" {key}: {value}")
|
|
|
|
|
|
|
| 772 |
|
| 773 |
+
model = Mamba2(configs).to(device=device)
|
|
|
|
| 774 |
|
| 775 |
+
x = torch.randint(
|
| 776 |
+
0, configs.vocab_size,
|
| 777 |
+
(config_data["bsz"], config_data["seq_len"]),
|
| 778 |
+
dtype=torch.long
|
| 779 |
+
).to(device)
|
| 780 |
|
| 781 |
+
outputs = model(x)
|
|
|
|
| 782 |
|
| 783 |
+
print("Output shape:", outputs.shape)
|
| 784 |
+
print("Sample output:", outputs[0, 0, :10])
|
| 785 |
+
print("Mean of Mamba output: ", outputs.mean().item())
|
| 786 |
+
print("Stddev of Mamba output: ", outputs.std().item())
|
norms.py
CHANGED
|
@@ -354,4 +354,4 @@ def fused_rms_norm_fn(
|
|
| 354 |
x,
|
| 355 |
weight,
|
| 356 |
eps,
|
| 357 |
-
)
|
|
|
|
| 354 |
x,
|
| 355 |
weight,
|
| 356 |
eps,
|
| 357 |
+
)
|
ssm_compilable.py
CHANGED
|
@@ -218,5 +218,4 @@ if __name__ == "__main__":
|
|
| 218 |
|
| 219 |
out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
|
| 220 |
|
| 221 |
-
print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())
|
| 222 |
-
|
|
|
|
| 218 |
|
| 219 |
out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
|
| 220 |
|
| 221 |
+
print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())
|
|
|