yagizdevre commited on
Commit
4d7d25c
·
1 Parent(s): 13bfa0f
Files changed (5) hide show
  1. attn.py +206 -0
  2. causal_conv1d_compilable.py +1 -2
  3. model.py +129 -191
  4. norms.py +1 -1
  5. ssm_compilable.py +1 -2
attn.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ try:
9
+ from flash_attn import flash_attn_func
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
13
+ )
14
+
15
+
16
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
17
+ return (
18
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
19
+ )
20
+
21
+ def _generate_slopes(self, n: int):
22
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
23
+ return [start * (start**i) for i in range(n)]
24
+
25
+ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
26
+ # If n_heads is a power of 2, generate slopes directly
27
+ if math.log2(n_heads).is_integer():
28
+ slopes = self._generate_slopes(n_heads)
29
+ else:
30
+ # Get slopes for the nearest power of two
31
+ n = nearest_power_of_two(n_heads, round_up=False)
32
+ slopes_power_of_two = self._generate_slopes(n)
33
+
34
+ # Generate extra slopes
35
+ extra_slopes = self._generate_slopes(2 * n)
36
+ extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
37
+ slopes = slopes_power_of_two + extra_slopes_trunc
38
+ slopes = torch.tensor(slopes, device=self.device)
39
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
40
+ return slopes
41
+
42
+
43
+ def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
44
+ # For half the dimensions, build the scale factor:
45
+ freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
46
+ freqs = 1.0 / (theta ** freq_seq)
47
+
48
+ # Outer product with positions
49
+ t = torch.arange(max_seq_len, dtype=torch.float32)
50
+ angles = torch.outer(t, freqs)
51
+
52
+ # Build a complex exponential e^{i * theta}
53
+ freqs_cis = torch.polar(
54
+ torch.ones_like(angles),
55
+ angles
56
+ )
57
+ return freqs_cis
58
+
59
+
60
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
61
+ """
62
+ x is [B, n_heads, seq_len, head_dim_as_complex],
63
+ so we want to broadcast freqs_cis from [max_seq_len, half_dim]
64
+ to [1, 1, seq_len, half_dim].
65
+ """
66
+ seq_len = x.shape[2]
67
+ freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
68
+ return freqs_cis.view(1, 1, seq_len, -1)
69
+
70
+
71
+ def apply_rotary_emb(
72
+ xq: torch.Tensor,
73
+ xk: torch.Tensor,
74
+ freqs_cis: torch.Tensor,
75
+ ) -> tuple[torch.Tensor, torch.Tensor]:
76
+ # Convert real -> complex by grouping last dim in pairs
77
+ # shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
78
+ xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
79
+ xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
80
+
81
+ # Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
82
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
83
+
84
+ # Multiply => apply rotation
85
+ xq_complex = xq_complex * freqs_cis
86
+ xk_complex = xk_complex * freqs_cis
87
+
88
+ # Convert back to real => shape [B, n_heads, seq_len, head_dim]
89
+ xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
90
+ xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
91
+ return xq_out.type_as(xq), xk_out.type_as(xk)
92
+
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, config):
96
+ super(Attention, self).__init__()
97
+ self.dim, self.num_heads = config.dim, config.num_heads
98
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
99
+ self.head_dim = config.dim // config.num_heads
100
+
101
+ self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
102
+ self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
103
+ self.c_proj.SCALE_INIT = 1
104
+
105
+ self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
106
+ self.window_size = config.window_size
107
+ self.softcap = config.softcap
108
+
109
+ self.dropout = config.dropout
110
+ self.resid_dropout = nn.Dropout(self.dropout)
111
+
112
+ def _generate_slopes(self, n: int):
113
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
114
+ return [start * (start**i) for i in range(n)]
115
+
116
+ def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
117
+ # If n_heads is a power of 2, generate slopes directly
118
+ if math.log2(num_heads).is_integer():
119
+ slopes = self._generate_slopes(num_heads)
120
+ else:
121
+ # Get slopes for the nearest power of two
122
+ n = nearest_power_of_two(num_heads, round_up=False)
123
+ slopes_power_of_two = self._generate_slopes(n)
124
+
125
+ # Generate extra slopes
126
+ extra_slopes = self._generate_slopes(2 * n)
127
+ extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
128
+ slopes = slopes_power_of_two + extra_slopes_trunc
129
+ slopes = torch.tensor(slopes, device=torch.device("cuda"))
130
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
131
+ return slopes
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor = None,
136
+ q: torch.Tensor = None,
137
+ k: torch.Tensor = None,
138
+ v: torch.Tensor = None,
139
+ freqs_cis: torch.Tensor = None,
140
+ ) -> torch.Tensor:
141
+ if x is not None:
142
+ q = k = v = x
143
+ if any(t is None for t in [q, k, v]):
144
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
145
+
146
+ bsz, q_len, dim = q.shape
147
+ _, k_len, _ = k.shape
148
+ _, v_len, _ = v.shape
149
+
150
+ qkv = self.c_attn(x)
151
+ q, k, v = torch.chunk(qkv, 3, dim=2)
152
+
153
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim)
154
+ k = k.view(bsz, k_len, self.num_heads, self.head_dim)
155
+ v = v.view(bsz, v_len, self.num_heads, self.head_dim)
156
+
157
+ if self.alibi_slopes is None: # Use either ALiBi or RoPE
158
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
159
+
160
+ y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
161
+ q=q, k=k, v=v,
162
+ dropout_p=self.dropout if self.training else 0.0,
163
+ causal=True,
164
+ window_size=(self.window_size, 0), # Set to config.seq_len if full attention
165
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
166
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
167
+ )
168
+
169
+ y = y.contiguous().view(bsz, q_len, -1)
170
+ y = self.resid_dropout(self.c_proj(y))
171
+ return y
172
+
173
+
174
+ class MLP(nn.Module):
175
+ def __init__(self, config):
176
+ # https://arxiv.org/pdf/2002.05202
177
+ super().__init__()
178
+ self.hidden_size = config.dim
179
+ self.intermediate_size = config.dim * config.mlp_scale
180
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
181
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
182
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
183
+ self.dropout = nn.Dropout(config.dropout)
184
+
185
+ def forward(self, x):
186
+ gate = self.gate_proj(x)
187
+ gate = F.gelu(gate, approximate="tanh")
188
+ up = self.up_proj(x)
189
+ fuse = gate * up
190
+ outputs = self.down_proj(fuse)
191
+ outputs = self.dropout(outputs)
192
+ return outputs
193
+
194
+
195
+ class AttentionLayer(nn.Module):
196
+ def __init__(self, config) -> None:
197
+ super(AttentionLayer, self).__init__()
198
+ self.attn_norm = nn.RMSNorm(config.dim)
199
+ self.attn = Attention(config=config)
200
+ self.mlp_norm = nn.RMSNorm(config.dim)
201
+ self.mlp = MLP(config)
202
+
203
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
204
+ x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
205
+ x = x + self.mlp(self.mlp_norm(x))
206
+ return x
causal_conv1d_compilable.py CHANGED
@@ -211,5 +211,4 @@ if __name__ == "__main__":
211
 
212
  print(out.min(), out.max(), out.mean(), out.std())
213
  print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
214
- print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
215
-
 
211
 
212
  print(out.min(), out.max(), out.mean(), out.std())
213
  print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
214
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
 
model.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import math
2
 
3
  import torch
@@ -6,13 +15,19 @@ import torch.nn.functional as F
6
 
7
  from enum import Enum
8
  from dataclasses import dataclass, field
 
9
  from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
 
 
 
 
 
10
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
11
  from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
12
 
13
- from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
14
  from .ssm_compilable import mamba_chunk_scan_combined
15
  from .norms import build_norm
 
16
 
17
 
18
  class InitStdFactor(Enum):
@@ -154,9 +169,7 @@ class SSM(nn.Module):
154
  if self.learnable_init_states:
155
  self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
156
 
157
- # Can also just use nn.RMSNorm
158
  self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
159
-
160
  self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
161
 
162
  def _causal_conv(
@@ -320,7 +333,7 @@ class SSM(nn.Module):
320
  ),
321
  dt_softplus=True,
322
  ).unsqueeze(0)
323
-
324
  return y
325
 
326
  def forward(
@@ -502,8 +515,16 @@ class BaseMamba(nn.Module):
502
  self.init_std_factor = InitStdFactor(config.init_std_factor)
503
 
504
  self.layers = nn.ModuleList()
505
- for _ in range(config.num_layers):
506
- self.layers.append(MambaBlock(config))
 
 
 
 
 
 
 
 
507
 
508
  def forward(
509
  self,
@@ -536,6 +557,7 @@ class BaseMamba(nn.Module):
536
 
537
  @dataclass
538
  class Mamba2Config(BaseMambaConfig):
 
539
  seed: int = 1337
540
 
541
  vocab_size: int = -1 # Will error if unchanged, makes you double check!
@@ -573,10 +595,10 @@ class Mamba2(BaseMamba):
573
 
574
  def _get_num_params(self):
575
  n_params = sum(p.numel() for p in self.parameters())
 
576
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
577
  n_params -= self.pos_emb.weight.numel()
578
- if self.tok_emb.weight is not self.output.weight:
579
- n_params -= self.tok_emb.weight.numel()
580
  return n_params
581
 
582
  def forward(
@@ -657,192 +679,108 @@ class Mamba2(BaseMamba):
657
  return cls(config)
658
 
659
 
660
- def get_mamba2_flops(
661
- seq_len: int,
662
- dim: int,
663
- num_layers: int,
664
- vocab_size: int,
665
- ffn_multiplier: float = 2.0,
666
- state_dim: int = 128,
667
- conv_size: int = 4,
668
- num_heads: int = 8,
669
- num_groups: int = 1,
670
- multiple_of: int = 256,
671
- include_input_embedding: bool = True,
672
- include_output_logits: bool = True,
673
- forward_backward_multiplier: float = 1.0,
674
- ) -> int:
675
- """
676
- Estimate the FLOPs for a Mamba-2 style model using a "Chinchilla-like" shape-based approach.
677
-
678
- By default, this returns the forward-pass cost. If you want a rough
679
- forward+backward estimate, set `forward_backward_multiplier=3.0` (common
680
- rule-of-thumb for these models).
681
-
682
- What gets counted:
683
- Hidden dimension is rounded up to 'multiple_of' = 256 (as in Mamba).
684
- Per-layer:
685
- 1) Input Linear: [dim → 2*hidden_dim + 2*(groups*state_dim) + num_heads]
686
- 2) Depthwise Conv1D: 2*(conv_dim * conv_size), where conv_dim=hidden_dim + 2*groups*state_dim
687
- 3) SSM selective scan: ~9*(dim*state_dim) (from Mamba dev discussion)
688
- 4) Output Linear: [hidden_dim → dim]
689
- Each layer’s cost is multiplied by (seq_len * num_layers).
690
- • Optionally adds:
691
- - The cost of the input embedding (treating it as a matmul: seq_len×vocab_size × vocab_size×dim).
692
- - The cost of the final projection [dim → vocab_size].
693
- • Finally scaled by `forward_backward_multiplier` if desired.
694
-
695
- Args:
696
- seq_len (int): Sequence length (number of tokens).
697
- dim (int): Model (embedding) dimension.
698
- num_layers (int): Number of Mamba layers.
699
- vocab_size (int): Vocabulary size for final logits projection.
700
- ffn_multiplier (float): FFN expansion ratio, e.g. 2.0 => hidden_dim=2×dim (rounded up).
701
- state_dim (int): SSM state dimension (commonly 128).
702
- conv_size (int): Kernel size for the depthwise conv1d (default=4).
703
- num_heads (int): Number of heads (slightly affects input-lin out_dim).
704
- num_groups (int): For "grouped" states in some Mamba variants (usually 1).
705
- multiple_of (int): Round hidden_dim up to this multiple (commonly 256).
706
- include_input_embedding (bool): If True, count the cost of an “embedding matmul”
707
- for the input tokens => shape-based approach.
708
- include_output_logits (bool): If True, count the cost of final [dim → vocab_size].
709
- forward_backward_multiplier (float): E.g. 1.0 for forward only, 2.0 or 3.0 for forward+backward.
710
-
711
- Returns:
712
- int: Approximate total FLOPs (multiply-adds) for the selected pass(es),
713
- as an integer.
714
- """
715
- # 0) Input embedding (optional)
716
- flops_embedding = 0
717
- if include_input_embedding:
718
- flops_embedding = 2 * (seq_len * vocab_size * dim)
719
-
720
- # 1) Round up hidden_dim
721
- raw_hidden_dim = int(ffn_multiplier * dim)
722
- hidden_dim = multiple_of * ((raw_hidden_dim + multiple_of - 1) // multiple_of)
723
-
724
- # 2) Per-layer forward cost
725
- out_dim_input = 2*hidden_dim + 2*(num_groups*state_dim) + num_heads
726
- flops_input_linear = 2 * (dim * out_dim_input)
727
- conv_dim = hidden_dim + 2*(num_groups*state_dim)
728
- flops_conv = 2 * (conv_dim * conv_size)
729
- flops_ssm = 9 * state_dim * dim
730
- flops_output_linear = 2 * (hidden_dim * dim)
731
- flops_layer = (flops_input_linear + flops_conv + flops_ssm + flops_output_linear)
732
-
733
- # Multiply by #layers and sequence length
734
- flops_layers = flops_layer * num_layers * seq_len
735
-
736
- # 3) Final projection [dim → vocab_size] (optional)
737
- flops_vocab = 0
738
- if include_output_logits:
739
- flops_vocab = 2 * (seq_len * dim * vocab_size)
740
-
741
- # 4) Total forward FLOPs
742
- flops_forward = flops_embedding + flops_layers + flops_vocab
743
-
744
- # 5) Scale for forward+backward if desired
745
- return int(flops_forward * forward_backward_multiplier)
746
-
747
- def get_mamba2_flops_per_token(
748
- **kwargs
749
- ) -> float:
750
- """
751
- Estimate FLOPs per token for a Mamba-2 style model.
752
-
753
- This function extracts necessary parameters from kwargs and calculates the FLOPs per token.
754
-
755
- Args:
756
- **kwargs: Dictionary containing model configuration parameters.
757
-
758
- Returns:
759
- float: Approximate FLOPs per token.
760
- """
761
- defaults = {
762
- 'ffn_dim_multiplier': 2.0,
763
- 'state_dim': 128,
764
- 'conv_size': 4,
765
- 'num_heads': 8,
766
- 'num_groups': 1,
767
- 'multiple_of': 256,
768
- 'include_input_embedding': True,
769
- 'include_output_logits': True,
770
- 'forward_backward_multiplier': 1.0,
771
- }
772
- # Merge defaults
773
- for k, v in defaults.items():
774
- kwargs.setdefault(k, v)
775
- # Mandatory keys
776
- for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']:
777
- if required not in kwargs:
778
- raise ValueError(f"Missing required parameter: {required}")
779
-
780
- total_flops = get_mamba2_flops(
781
- seq_len=kwargs['seq_len'],
782
- dim=kwargs['dim'],
783
- num_layers=kwargs['num_layers'],
784
- vocab_size=kwargs['vocab_size'],
785
- ffn_multiplier=kwargs['ffn_dim_multiplier'],
786
- state_dim=kwargs['state_dim'],
787
- conv_size=kwargs['conv_size'],
788
- num_heads=kwargs['num_heads'],
789
- num_groups=kwargs['num_groups'],
790
- multiple_of=kwargs['multiple_of'],
791
- include_input_embedding=kwargs['include_input_embedding'],
792
- include_output_logits=kwargs['include_output_logits'],
793
- forward_backward_multiplier=kwargs['forward_backward_multiplier'],
794
  )
795
- flops_per_token = total_flops / kwargs['seq_len']
796
-
797
- return flops_per_token
798
-
799
-
800
- # Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
801
- def get_no_recompute_ops():
802
- return {
803
- torch.ops.aten.mm.default,
804
- torch.ops.aten._scaled_mm.default,
805
- torch.ops.c10d_functional.reduce_scatter_tensor.default,
806
- torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,
807
-
808
- # For low-precision training, it's useful to always save the result of max(abs(tensor))
809
- torch.ops.aten.abs.default,
810
- torch.ops.aten.max.default,
811
- }
812
-
813
-
814
- def main():
815
- from mamba_ssm import Mamba2 as MambaRef
816
-
817
- x = torch.randn(2, 64, 192).cuda()
818
-
819
- # Create and run the first model
820
- model = MambaRef(
821
- d_model=192,
822
- expand=2,
823
- d_conv=4,
824
- d_state=64,
825
- headdim=48,
826
- ).cuda()
827
- y = model(x)
828
- print("Mamba reference output: ", y)
829
- print("Mean of MambaRef output: ", y.mean().item())
830
- print("Stddev of MambaRef output: ", y.std().item())
831
 
832
- # Create and run the second model
833
- config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
834
- model2 = Mamba2(
835
- config=config,
836
- ).cuda()
837
 
838
- # Fix: Convert x to torch.LongTensor
839
- x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
840
 
841
- y2 = model2(x_indices)
842
- print("Mamba output: ", y2)
843
- print("Mean of Mamba output: ", y2.mean().item())
844
- print("Stddev of Mamba output: ", y2.std().item())
 
845
 
846
- if __name__ == "__main__":
847
- main()
848
 
 
 
 
 
 
1
+ """
2
+
3
+ Adapted from Meta's Lingua repository:
4
+ - https://github.com/facebookresearch/lingua/blob/main/apps/mamba/core_mamba.py
5
+ - https://github.com/facebookresearch/lingua/blob/main/apps/mamba/mamba.py
6
+
7
+ """
8
+
9
+ import json
10
  import math
11
 
12
  import torch
 
15
 
16
  from enum import Enum
17
  from dataclasses import dataclass, field
18
+
19
  from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
20
+ from .causal_conv1d_compilable import (
21
+ causal_conv1d_fn as causal_conv1d_fn,
22
+ causal_conv1d_update as causal_conv1d_update
23
+ )
24
+
25
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
26
  from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
27
 
 
28
  from .ssm_compilable import mamba_chunk_scan_combined
29
  from .norms import build_norm
30
+ from .attn import AttentionLayer
31
 
32
 
33
  class InitStdFactor(Enum):
 
169
  if self.learnable_init_states:
170
  self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
171
 
 
172
  self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
 
173
  self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
174
 
175
  def _causal_conv(
 
333
  ),
334
  dt_softplus=True,
335
  ).unsqueeze(0)
336
+
337
  return y
338
 
339
  def forward(
 
515
  self.init_std_factor = InitStdFactor(config.init_std_factor)
516
 
517
  self.layers = nn.ModuleList()
518
+ for layer_idx in range(config.num_layers):
519
+ # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
520
+ if layer_idx % 2 == 0:
521
+ self.layers.append(MambaBlock(config))
522
+ else:
523
+ self.layers.append(
524
+ AttentionLayer(config)
525
+ if config.use_attn
526
+ else (MambaBlock(config))
527
+ )
528
 
529
  def forward(
530
  self,
 
557
 
558
  @dataclass
559
  class Mamba2Config(BaseMambaConfig):
560
+ bsz: int = 2
561
  seed: int = 1337
562
 
563
  vocab_size: int = -1 # Will error if unchanged, makes you double check!
 
595
 
596
  def _get_num_params(self):
597
  n_params = sum(p.numel() for p in self.parameters())
598
+
599
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
600
  n_params -= self.pos_emb.weight.numel()
601
+
 
602
  return n_params
603
 
604
  def forward(
 
679
  return cls(config)
680
 
681
 
682
+ # def main():
683
+ # x = torch.randn(2, 64, 192).cuda()
684
+
685
+ # config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
686
+ # model2 = Mamba2(
687
+ # config=config,
688
+ # ).cuda()
689
+
690
+ # x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
691
+
692
+ # y2 = model2(x_indices)
693
+ # print("Mamba output: ", y2)
694
+ # print("Mean of Mamba output: ", y2.mean().item())
695
+ # print("Stddev of Mamba output: ", y2.std().item())
696
+
697
+ # if __name__ == "__main__":
698
+ # main()
699
+
700
+ if __name__ == '__main__':
701
+ x = torch.randn(2, 64, 192).cuda() # Removing this produces NaNs lol
702
+
703
+ config_path = "/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/models/mamba/config.json"
704
+
705
+ with open(config_path, "r") as f:
706
+ config_data = json.load(f)
707
+
708
+ if torch.cuda.is_available():
709
+ device = torch.device("cuda")
710
+ elif torch.backends.mps.is_available():
711
+ device = torch.device("mps")
712
+ else:
713
+ device = torch.device("cpu")
714
+ print("Device:", device)
715
+
716
+ torch_dtype = getattr(torch, config_data["torch_dtype"])
717
+ print("Torch dtype:", torch_dtype)
718
+
719
+ dim = config_data["dim"]
720
+ num_heads = config_data["num_heads"]
721
+ num_layers = config_data["num_layers"]
722
+ vocab_size = config_data["vocab_size"]
723
+ bias = config_data["bias"]
724
+ state_dim = config_data["state_dim"]
725
+ num_groups = config_data["num_groups"]
726
+ conv_size = config_data.get("conv_size")
727
+ use_mem_eff_path = config_data.get("use_mem_eff_path")
728
+ dt_bias = config_data["dt_bias"]
729
+ D_has_head_dim = config_data["D_has_head_dim"]
730
+ learnable_init_states = config_data["learnable_init_states"]
731
+ ssm_chunk_size = config_data["ssm_chunk_size"]
732
+ weight_tying = config_data["weight_tying"]
733
+ ffn_dim_multiplier = config_data.get("ffn_dim_multiplier")
734
+ multiple_of = config_data["multiple_of"]
735
+ norm_eps = config_data["norm_eps"]
736
+ init_use_depth = config_data["init_use_depth"]
737
+ init_base_std = config_data.get("init_base_std")
738
+ init_std_factor = config_data["init_std_factor"]
739
+ use_attn = config_data["use_attn"]
740
+ softcap = config_data["softcap"]
741
+ torch_compile = config_data["torch_compile"]
742
+
743
+ configs = Mamba2Config(
744
+ dim=dim,
745
+ num_layers=num_layers,
746
+ num_heads=num_heads,
747
+ vocab_size=vocab_size,
748
+ bias=bias,
749
+ torch_dtype=torch_dtype,
750
+ state_dim=state_dim,
751
+ num_groups=num_groups,
752
+ conv_size=conv_size,
753
+ use_mem_eff_path=use_mem_eff_path,
754
+ dt_bias=dt_bias,
755
+ D_has_head_dim=D_has_head_dim,
756
+ learnable_init_states=learnable_init_states,
757
+ ssm_chunk_size=ssm_chunk_size,
758
+ weight_tying=weight_tying,
759
+ ffn_dim_multiplier=ffn_dim_multiplier,
760
+ multiple_of=multiple_of,
761
+ norm_eps=norm_eps,
762
+ init_use_depth=init_use_depth,
763
+ init_base_std=init_base_std,
764
+ init_std_factor=init_std_factor,
765
+ use_attn=use_attn,
766
+ softcap=softcap,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
+ print("Configs:")
770
+ for key, value in vars(configs).items():
771
+ print(f" {key}: {value}")
 
 
772
 
773
+ model = Mamba2(configs).to(device=device)
 
774
 
775
+ x = torch.randint(
776
+ 0, configs.vocab_size,
777
+ (config_data["bsz"], config_data["seq_len"]),
778
+ dtype=torch.long
779
+ ).to(device)
780
 
781
+ outputs = model(x)
 
782
 
783
+ print("Output shape:", outputs.shape)
784
+ print("Sample output:", outputs[0, 0, :10])
785
+ print("Mean of Mamba output: ", outputs.mean().item())
786
+ print("Stddev of Mamba output: ", outputs.std().item())
norms.py CHANGED
@@ -354,4 +354,4 @@ def fused_rms_norm_fn(
354
  x,
355
  weight,
356
  eps,
357
- )
 
354
  x,
355
  weight,
356
  eps,
357
+ )
ssm_compilable.py CHANGED
@@ -218,5 +218,4 @@ if __name__ == "__main__":
218
 
219
  out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
220
 
221
- print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())
222
-
 
218
 
219
  out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
220
 
221
+ print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())