yagizdevre commited on
Commit
9991887
·
1 Parent(s): 4d1df97

added configs

Browse files
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_minimamba import MiniMambaConfig
2
+ from .modeling_minimamba import MiniMamba
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|endofprompt|>": 200018
3
+ }
causal_conv1d_compilable.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import causal_conv1d_cuda
4
+
5
+ # Causal Conv1D Forward Function
6
+ @torch.library.custom_op(
7
+ "mamba_causal_conv1d::causal_conv1d_fwd",
8
+ mutates_args=(),
9
+ device_types="cuda",
10
+ )
11
+ def causal_conv1d_fwd(
12
+ x: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ bias: Optional[torch.Tensor] = None,
15
+ seq_idx: Optional[torch.Tensor] = None,
16
+ activation: Optional[str] = None,
17
+ ) -> torch.Tensor:
18
+ # Ensure activation is valid
19
+ if activation not in [None, "silu", "swish"]:
20
+ raise NotImplementedError("activation must be None, silu, or swish")
21
+
22
+ # Ensure x is contiguous
23
+ if x.stride(2) != 1 and x.stride(1) != 1:
24
+ x = x.contiguous()
25
+
26
+ # Make bias and seq_idx contiguous if they exist
27
+ bias = bias.contiguous() if bias is not None else None
28
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
29
+
30
+ # Translate activation to bool for custom CUDA kernel
31
+ use_activation = activation in ["silu", "swish"]
32
+
33
+ # Call custom CUDA kernel for forward pass
34
+ out = causal_conv1d_cuda.causal_conv1d_fwd(
35
+ x, weight, bias, seq_idx, None, None, use_activation
36
+ )
37
+ return out
38
+
39
+ # Register a fake forward pass for tracing
40
+ @causal_conv1d_fwd.register_fake
41
+ def _causal_conv1d_fwd_fake(
42
+ x: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ bias: Optional[torch.Tensor] = None,
45
+ seq_idx: Optional[torch.Tensor] = None,
46
+ activation: Optional[str] = None,
47
+ ) -> torch.Tensor:
48
+ torch._check(x.shape[-2] == weight.shape[0])
49
+ return torch.empty_like(x)
50
+
51
+ # Causal Conv1D Backward Function
52
+ @torch.library.custom_op(
53
+ "mamba_causal_conv1d::causal_conv1d_bwd",
54
+ mutates_args=(),
55
+ device_types="cuda",
56
+ )
57
+ def causal_conv1d_bwd(
58
+ x: torch.Tensor,
59
+ weight: torch.Tensor,
60
+ bias: Optional[torch.Tensor],
61
+ dout: torch.Tensor,
62
+ seq_idx: Optional[torch.Tensor],
63
+ activation: bool,
64
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65
+ # Ensure dout is contiguous
66
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
67
+ dout = dout.contiguous()
68
+
69
+ # Call custom CUDA kernel for backward pass
70
+ dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd(
71
+ x, weight, bias, dout, seq_idx, None, None, None, False, activation
72
+ )
73
+
74
+ # Handle optional bias gradient
75
+ dbias = dbias if bias is not None else torch.empty((0,), device=dout.device)
76
+
77
+ return dx, dweight, dbias
78
+
79
+ # Register a fake backward pass for tracing
80
+ @causal_conv1d_bwd.register_fake
81
+ def _causal_conv1d_bwd_fake(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ bias: Optional[torch.Tensor],
85
+ dout: torch.Tensor,
86
+ seq_idx: Optional[torch.Tensor],
87
+ activation: bool,
88
+ ):
89
+ return (
90
+ torch.empty_like(x),
91
+ torch.empty_like(weight),
92
+ torch.empty_like(bias) if bias is not None else None,
93
+ )
94
+
95
+ # Setup context for autograd
96
+ def causal_conv1d_setup_context(ctx, inputs, output):
97
+ x, weight, bias, seq_idx, activation = inputs
98
+ ctx.activation = activation in ["silu", "swish"]
99
+ ctx.save_for_backward(x, weight, bias, seq_idx)
100
+
101
+ # Bridge for backward pass in autograd
102
+ def causal_conv1d_bwd_bridge(ctx, dout):
103
+ x, weight, bias, seq_idx = ctx.saved_tensors
104
+ dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation)
105
+
106
+ # Handle None return values
107
+ dbias = dbias if bias is not None else None
108
+ return dx, dweight, dbias, None, None
109
+
110
+ # Register custom autograd function
111
+ torch.library.register_autograd(
112
+ "mamba_causal_conv1d::causal_conv1d_fwd",
113
+ causal_conv1d_bwd_bridge,
114
+ setup_context=causal_conv1d_setup_context,
115
+ )
116
+
117
+ # Define a higher-level function to invoke the custom op
118
+ def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
119
+ return causal_conv1d_fwd(x, weight, bias, seq_idx, activation)
120
+
121
+
122
+ @torch.library.custom_op(
123
+ "mamba_causal_conv1d::causal_conv1d_update",
124
+ mutates_args=(),
125
+ device_types="cuda",
126
+ )
127
+ def causal_conv1d_update_fwd(
128
+ x: torch.Tensor,
129
+ conv_state: torch.Tensor,
130
+ weight: torch.Tensor,
131
+ bias: Optional[torch.Tensor] = None,
132
+ activation: Optional[str] = None,
133
+ cache_seqlens: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
+ """
136
+ x: (batch, dim) or (batch, dim, seqlen)
137
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
138
+ weight: (dim, width)
139
+ bias: (dim,)
140
+ cache_seqlens: (batch,), dtype int32.
141
+ If not None, the conv_state is treated as a circular buffer.
142
+ The conv_state will be updated by copying x to the conv_state starting at the index
143
+ @cache_seqlens % state_len.
144
+
145
+ out: (batch, dim) or (batch, dim, seqlen)
146
+ """
147
+ if activation not in [None, "silu", "swish"]:
148
+ raise NotImplementedError("activation must be None, silu, or swish")
149
+ activation = activation in ["silu", "swish"]
150
+ unsqueeze = x.dim() == 2
151
+ if unsqueeze:
152
+ x = x.unsqueeze(-1)
153
+ out = causal_conv1d_cuda.causal_conv1d_update(
154
+ x, conv_state, weight, bias, activation, cache_seqlens
155
+ )
156
+ if unsqueeze:
157
+ out = out.squeeze(-1)
158
+ return out
159
+
160
+ @causal_conv1d_update_fwd.register_fake
161
+ def _causal_conv1d_update_fwd(
162
+ x: torch.Tensor,
163
+ conv_state: torch.Tensor,
164
+ weight: torch.Tensor,
165
+ bias: Optional[torch.Tensor] = None,
166
+ activation: Optional[str] = None,
167
+ cache_seqlens: Optional[torch.Tensor] = None,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(x)
170
+
171
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
172
+ return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens)
173
+
174
+ # Test the implementation
175
+ if __name__ == "__main__":
176
+ from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref
177
+
178
+ torch.manual_seed(0)
179
+
180
+ x = torch.randn(8, 32, 16, device="cuda", requires_grad=True)
181
+ weight = torch.randn(32, 3, device="cuda", requires_grad=True)
182
+ bias = None#torch.randn(32, device="cuda", requires_grad=True)
183
+
184
+ # Test the forward and backward pass
185
+ print("Custom Implementation")
186
+ out = causal_conv1d_fn(x, weight, bias, activation="silu")
187
+ out.sum().backward()
188
+
189
+ print(out.min(), out.max(), out.mean(), out.std())
190
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
191
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
192
+
193
+ # Try compiling the function using torch.compile
194
+ x.grad.zero_(), weight.grad.zero_()
195
+ compiled_conv1d = torch.compile(causal_conv1d_fn)
196
+ print(compiled_conv1d)
197
+
198
+ # Run the compiled function
199
+ print("Compiled Implementation")
200
+ out = compiled_conv1d(x, weight, bias, activation="silu")
201
+ out.sum().backward()
202
+
203
+ print(out.min(), out.max(), out.mean(), out.std())
204
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
205
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
206
+
207
+ print("Reference Implementation")
208
+ x.grad.zero_(), weight.grad.zero_()
209
+ out = causal_conv1d_fn_ref(x, weight, bias, activation="silu")
210
+ out.sum().backward()
211
+
212
+ print(out.min(), out.max(), out.mean(), out.std())
213
+ print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
214
+ print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
215
+
config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "minimamba",
3
+ "_name_or_path": "Mamba_500M",
4
+ "architectures": ["MiniMamba"],
5
+ "dim": 1024,
6
+ "num_layers": 54,
7
+ "num_heads": 32,
8
+ "state_dim": 128,
9
+ "num_groups": 1,
10
+ "conv_size": 4,
11
+ "use_mem_eff_path": true,
12
+ "dt_bias": true,
13
+ "D_has_head_dim": true,
14
+ "learnable_init_states": false,
15
+ "ssm_chunk_size": 256,
16
+ "vocab_size": 200064,
17
+ "ffn_dim_multiplier": 2.0,
18
+ "multiple_of": 256,
19
+ "norm_eps": 1e-05,
20
+ "init_use_depth": false,
21
+ "init_base_std": null,
22
+ "init_std_factor": "disabled",
23
+ "hidden_act": "silu",
24
+ "bias": false,
25
+ "torch_dtype": "bfloat16",
26
+ "seed": 1337,
27
+ "init_args": {
28
+ "dt_max": 0.1,
29
+ "dt_min": 0.001,
30
+ "dt_init_floor": 0.0001,
31
+ "A_init_min": 0.01,
32
+ "A_init_max": 16
33
+ },
34
+ "seq_len": 8192,
35
+ "weight_tying": true,
36
+ "dropout": 0.0,
37
+ "num_epochs": 1,
38
+ "global_bsz": 524288,
39
+ "bsz": 1,
40
+ "warmup_steps": 1907,
41
+ "eval_period": 50,
42
+ "save_period": 500,
43
+ "max_lr": 0.0003,
44
+ "min_lr": 3e-05,
45
+ "max_norm": 1.0,
46
+ "dilation": 1,
47
+ "fsdp": true,
48
+ "ddp": false,
49
+ "mixed_precision": true,
50
+ "cpu_offload": false,
51
+ "sharding_strategy": "full_shard",
52
+ "state_dict_type": "full",
53
+ "auto_wrap_policy": "partial",
54
+ "backward_prefetch": "backward_pre",
55
+ "forward_prefetch": false,
56
+ "sync_module_states": true,
57
+ "use_orig_params": true,
58
+ "device_id": null,
59
+ "precision": {
60
+ "param": "bfloat16",
61
+ "reduce": "bfloat16",
62
+ "buffer": "bfloat16"
63
+ },
64
+ "fsdp_modules": [
65
+ "MambaBlock"
66
+ ],
67
+ "use_activation_checkpointing": true,
68
+ "use_attn": false,
69
+ "softcap": 50.0,
70
+ "torch_compile": true
71
+ }
72
+
configuration_minimamba.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MiniMambaConfig(PretrainedConfig):
4
+ """
5
+ Minimal or extended config class for MiniMamba.
6
+ Inherits from HF's PretrainedConfig so we can do:
7
+ model = MiniMamba.from_pretrained(...)
8
+ and it will load this config automatically.
9
+
10
+ This config includes all fields from the provided config.json.
11
+ """
12
+ model_type = "minimamba"
13
+
14
+ def __init__(
15
+ self,
16
+ # Standard HF fields:
17
+ model_type="minimamba",
18
+ _name_or_path="Mamba_500M",
19
+ architectures=["MiniMamba"],
20
+
21
+ # Key Mamba architecture hyperparameters:
22
+ dim=1024,
23
+ num_layers=54,
24
+ num_heads=32,
25
+ state_dim=128,
26
+ num_groups=1,
27
+ conv_size=4,
28
+ use_mem_eff_path=True,
29
+ dt_bias=True,
30
+ D_has_head_dim=True,
31
+ learnable_init_states=False,
32
+ ssm_chunk_size=256,
33
+ vocab_size=200064,
34
+ ffn_dim_multiplier=2.0,
35
+ multiple_of=256,
36
+ norm_eps=1e-5,
37
+ init_use_depth=False,
38
+ init_base_std=None,
39
+ init_std_factor="disabled",
40
+ hidden_act="silu",
41
+ bias=False,
42
+
43
+ # Torch / training:
44
+ torch_dtype="bfloat16",
45
+ seed=1337,
46
+
47
+ # The init_config block nested in JSON:
48
+ init_args=None, # e.g. dict with dt_max, dt_min, dt_init_floor, ...
49
+
50
+ # Additional Mamba or training fields:
51
+ seq_len=8192,
52
+ weight_tying=True,
53
+ dropout=0.0,
54
+ num_epochs=1,
55
+ global_bsz=524288,
56
+ bsz=1,
57
+ warmup_steps=1907,
58
+ eval_period=50,
59
+ save_period=500,
60
+ max_lr=0.0003,
61
+ min_lr=3e-5,
62
+ max_norm=1.0,
63
+ dilation=1,
64
+ fsdp=True,
65
+ ddp=False,
66
+ mixed_precision=True,
67
+ cpu_offload=False,
68
+ sharding_strategy="full_shard",
69
+ state_dict_type="full",
70
+ auto_wrap_policy="partial",
71
+ backward_prefetch="backward_pre",
72
+ forward_prefetch=False,
73
+ sync_module_states=True,
74
+ use_orig_params=True,
75
+ device_id=None,
76
+ precision=None, # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16"
77
+ fsdp_modules=None,# e.g. ["MambaBlock"]
78
+ use_activation_checkpointing=True,
79
+ use_attn=False,
80
+ softcap=50.0,
81
+ torch_compile=True,
82
+
83
+ # Now accept arbitrary additional kwargs, to remain flexible:
84
+ **kwargs
85
+ ):
86
+ super().__init__(
87
+ # In HF, these common keys are typically passed to the parent:
88
+ model_type=model_type,
89
+ _name_or_path=_name_or_path,
90
+ architectures=architectures,
91
+ **kwargs
92
+ )
93
+
94
+ self.dim = dim
95
+ self.num_layers = num_layers
96
+ self.num_heads = num_heads
97
+ self.state_dim = state_dim
98
+ self.num_groups = num_groups
99
+ self.conv_size = conv_size
100
+ self.use_mem_eff_path = use_mem_eff_path
101
+ self.dt_bias = dt_bias
102
+ self.D_has_head_dim = D_has_head_dim
103
+ self.learnable_init_states = learnable_init_states
104
+ self.ssm_chunk_size = ssm_chunk_size
105
+ self.vocab_size = vocab_size
106
+ self.ffn_dim_multiplier = ffn_dim_multiplier
107
+ self.multiple_of = multiple_of
108
+ self.norm_eps = norm_eps
109
+ self.init_use_depth = init_use_depth
110
+ self.init_base_std = init_base_std
111
+ self.init_std_factor = init_std_factor
112
+ self.hidden_act = hidden_act
113
+ self.bias = bias
114
+
115
+ self.torch_dtype = torch_dtype
116
+ self.seed = seed
117
+
118
+ # Nested init_args (dt_max, dt_min, etc.).
119
+ # Could store it as a dict, or parse out the fields individually:
120
+ self.init_args = init_args or {}
121
+
122
+ self.seq_len = seq_len
123
+ self.weight_tying = weight_tying
124
+ self.dropout = dropout
125
+ self.num_epochs = num_epochs
126
+ self.global_bsz = global_bsz
127
+ self.bsz = bsz
128
+ self.warmup_steps = warmup_steps
129
+ self.eval_period = eval_period
130
+ self.save_period = save_period
131
+ self.max_lr = max_lr
132
+ self.min_lr = min_lr
133
+ self.max_norm = max_norm
134
+ self.dilation = dilation
135
+ self.fsdp = fsdp
136
+ self.ddp = ddp
137
+ self.mixed_precision = mixed_precision
138
+ self.cpu_offload = cpu_offload
139
+ self.sharding_strategy = sharding_strategy
140
+ self.state_dict_type = state_dict_type
141
+ self.auto_wrap_policy = auto_wrap_policy
142
+ self.backward_prefetch = backward_prefetch
143
+ self.forward_prefetch = forward_prefetch
144
+ self.sync_module_states = sync_module_states
145
+ self.use_orig_params = use_orig_params
146
+ self.device_id = device_id
147
+ self.precision = precision
148
+ self.fsdp_modules = fsdp_modules
149
+ self.use_activation_checkpointing = use_activation_checkpointing
150
+ self.use_attn = use_attn
151
+ self.softcap = softcap
152
+ self.torch_compile = torch_compile
153
+
154
+ # If you want to store any leftover kwargs:
155
+ self.extra_args = kwargs
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from enum import Enum
8
+ from dataclasses import dataclass, field
9
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
10
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
11
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
12
+
13
+ from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
14
+ from .ssm_compilable import mamba_chunk_scan_combined
15
+ from .norms import build_norm
16
+
17
+
18
+ class InitStdFactor(Enum):
19
+ DISABLED = "disabled" # Init std is divided by 1.0
20
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*num_layers)
21
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
22
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
23
+
24
+
25
+ @dataclass
26
+ class InitConfig:
27
+ dt_max: float = 0.1
28
+ dt_min: float = 0.001
29
+
30
+ dt_init_floor: float = 1e-4
31
+
32
+ A_init_min: float = 1
33
+ A_init_max: float = 16
34
+
35
+
36
+ DEFAULT_INIT_CONFIG = InitConfig()
37
+
38
+
39
+ @dataclass
40
+ class BaseMambaConfig:
41
+ """
42
+ Configuration for the Mamba family of models.
43
+ """
44
+ dim: int = 512
45
+ num_layers: int = 8
46
+ num_heads: int = 8
47
+
48
+ state_dim: int = 128
49
+ num_groups: int = 1
50
+ conv_size: int | None = 4
51
+
52
+ bias: bool = False # Linear bias
53
+ conv_bias: bool = True # Convolutional bias
54
+ dt_bias: bool = False
55
+ D_has_head_dim: bool = False
56
+ learnable_init_states: bool = False
57
+
58
+ ffn_dim_multiplier: float = 2.0
59
+ multiple_of: int = 256 # Enforce that MLP hidden layer size is multiple of a large power of 2
60
+
61
+ norm_eps: float = 1e-6
62
+ norm_type: str = "rmsnorm"
63
+
64
+ # CUDA-related items
65
+ ssm_chunk_size: int = 256
66
+ use_mem_eff_path: bool = False
67
+
68
+ # Initialization-related items
69
+ init_use_depth: bool = False
70
+ init_base_std: float | None = None
71
+ init_std_factor: str = "disabled" # e.g. "global_depth"
72
+ init_config: InitConfig = field(default_factory=InitConfig)
73
+
74
+
75
+ class SSM(nn.Module):
76
+ """
77
+ State Space Model (SSM) implementation with selective state updates and convolution.
78
+
79
+ Implements the core SSM computation with support for both training and inference modes.
80
+ During inference, uses cached states for efficient token-by-token generation.
81
+ """
82
+ def __init__(self, config: BaseMambaConfig) -> None:
83
+ """Initialize SSM parameters and layers.
84
+ Args:
85
+ config: Configuration containing model hyperparameters
86
+ """
87
+ super().__init__()
88
+ self.config = config
89
+ vars(self).update(vars(config))
90
+
91
+ assert self.dim > 0, "Model dimension (config.dim) must be positive"
92
+ assert self.num_heads > 0, "Number of heads (config.num_heads) must be positive"
93
+ assert self.state_dim > 0, "State dimension (config.state_dim) must be positive"
94
+
95
+ if self.ffn_dim_multiplier is None:
96
+ raise ValueError(
97
+ "ffn_dim_multiplier must be set to a valid float (e.g. 2.0) "
98
+ "to determine hidden_dim in SSM."
99
+ )
100
+ assert self.ffn_dim_multiplier > 0, "ffn_dim_multiplier must be > 0"
101
+
102
+ self.hidden_dim = int(self.ffn_dim_multiplier * self.dim)
103
+ self.hidden_dim = config.multiple_of * ( # Round up to multiple_of
104
+ (self.hidden_dim + self.multiple_of - 1) // self.multiple_of
105
+ )
106
+
107
+ assert self.hidden_dim % self.num_heads == 0, (
108
+ f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}."
109
+ )
110
+
111
+ self.head_dim = self.hidden_dim // self.num_heads
112
+
113
+ self.dt_limit_kwargs = {}
114
+ dt_limit = (self.init_config.dt_min, self.init_config.dt_max)
115
+ if dt_limit != (0.0, float("inf")):
116
+ self.dt_limit_kwargs = dict(dt_limit=dt_limit)
117
+
118
+ # Order: [z, x, B, C, dt]
119
+ d_input = (
120
+ 2 * self.hidden_dim
121
+ + 2 * self.num_groups * self.state_dim
122
+ + self.num_heads
123
+ )
124
+
125
+ self.input = nn.Linear(self.dim, d_input, bias=self.bias)
126
+
127
+ # Only create Conv1d if self.conv_size is specified
128
+ if self.conv_size is not None:
129
+ conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim
130
+
131
+ # Depthwise-ish conv (groups = out_channels)
132
+ # TODO: Check that this is used if causal_conv1d_fn and causal_conv1d_update cannot be imported
133
+ self.conv1d = nn.Conv1d(
134
+ in_channels=conv_dim,
135
+ out_channels=conv_dim,
136
+ kernel_size=self.conv_size,
137
+ groups=conv_dim,
138
+ bias=self.conv_bias, # <- This is a boolean in your config, so pass that or True/False
139
+ padding=self.conv_size - 1 # for "causal" style
140
+ )
141
+
142
+ if config.dt_bias:
143
+ self.dt_bias = nn.Parameter(torch.empty(self.num_heads))
144
+ else:
145
+ self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False)
146
+
147
+ self.A_log = nn.Parameter(torch.empty(self.num_heads))
148
+
149
+ if config.D_has_head_dim:
150
+ self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim))
151
+ else:
152
+ self.D = nn.Parameter(torch.ones(self.num_heads))
153
+
154
+ if self.learnable_init_states:
155
+ self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
156
+
157
+ # Can also just use nn.RMSNorm
158
+ self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
159
+
160
+ self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
161
+
162
+ def _causal_conv(
163
+ self,
164
+ zxbcdt: torch.Tensor,
165
+ tok_idx: torch.Tensor | None = None,
166
+ cu_seqlens: torch.Tensor | None = None,
167
+ ssm_impl: str = "ssm"
168
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ # TODO: Make slightly less verbose
170
+ """Processes input through causal convolution path, handling both full sequence and incremental cases.
171
+
172
+ This function implements two processing modes:
173
+ 1. Full sequence ("ssm"): Used during training and initial prompt processing.
174
+ 2. Incremental ("ssm_update"): Used during token-by-token generation.
175
+
176
+ Args:
177
+ zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components
178
+ tok_idx: Token indices for sequence processing. Required for "ssm" mode.
179
+ Defaults to None.
180
+ cu_seqlens: Cumulative sequence lengths for variable length processing.
181
+ Used only in "ssm" mode with caching. Defaults to None.
182
+ ssm_impl: Implementation mode, either "ssm" for full sequence processing
183
+ or "ssm_update" for incremental generation. Defaults to "ssm".
184
+
185
+ Returns:
186
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
187
+ Tuple containing separated components (z, x, B, C, dt), where:
188
+ - z: Gating branch
189
+ - x: Main branch
190
+ - B, C: SSM state matrices (analogous to K, Q in attention)
191
+ - dt: Time delta values
192
+
193
+ Notes:
194
+ - When using "ssm" mode during inference, a cache should be pre-initialized
195
+ externally. This design allows for flexible caching strategies without
196
+ modifying model code.
197
+ - The "ssm_update" mode requires a cache to exist and will use it for
198
+ incremental state updates during generation.
199
+ - B, C components correspond to Key, Query in the SSM/attention duality.
200
+ """
201
+ # Split input into components
202
+ z, xBC, dt = torch.split(
203
+ zxbcdt,
204
+ [
205
+ self.hidden_dim,
206
+ self.hidden_dim + 2 * self.num_groups * self.state_dim,
207
+ self.num_heads,
208
+ ],
209
+ dim=-1,
210
+ )
211
+
212
+ if ssm_impl == "ssm":
213
+ if hasattr(self, "cache"):
214
+ conv_varlen_states = causal_conv1d_varlen_states(
215
+ xBC.squeeze(0),
216
+ cu_seqlens,
217
+ state_len=self.cache.conv_cache.shape[-1],
218
+ )
219
+ self.cache.conv_cache.copy_(conv_varlen_states)
220
+
221
+ xBC = causal_conv1d_fn(
222
+ x=xBC.transpose(1, 2),
223
+ weight=self.conv1d.weight.squeeze(1),
224
+ bias=self.conv1d.bias,
225
+ activation="silu",
226
+ seq_idx=tok_idx,
227
+ ).transpose(1, 2)
228
+ elif ssm_impl == "ssm_update":
229
+ xBC = causal_conv1d_update(
230
+ x=xBC.squeeze(0),
231
+ conv_state=self.cache.conv_cache,
232
+ weight=self.conv1d.weight.squeeze(1),
233
+ bias=self.conv1d.bias,
234
+ activation="silu",
235
+ ).unsqueeze(0)
236
+ else:
237
+ raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
238
+
239
+ # Split processed tensor into components
240
+ x, B, C = torch.split(
241
+ xBC,
242
+ [
243
+ self.hidden_dim,
244
+ self.num_groups * self.state_dim,
245
+ self.num_groups * self.state_dim,
246
+ ],
247
+ dim=-1,
248
+ )
249
+
250
+ return z, x, B, C, dt
251
+
252
+ def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ z, x, B, C, dt = torch.split(
254
+ zxbcdt,
255
+ [
256
+ self.hidden_dim,
257
+ self.hidden_dim,
258
+ self.num_groups * self.state_dim,
259
+ self.num_groups * self.state_dim,
260
+ self.num_heads,
261
+ ],
262
+ dim=-1,
263
+ )
264
+ return z, x, B, C, dt
265
+
266
+ def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states):
267
+ """
268
+ For training
269
+
270
+ Returns:
271
+ (bsz, seq_len, num_heads, head_dim)
272
+ """
273
+ y = mamba_chunk_scan_combined(
274
+ x,
275
+ dt,
276
+ A,
277
+ B,
278
+ C,
279
+ dt_bias=self.dt_bias,
280
+ dt_softplus=True,
281
+ chunk_size=self.ssm_chunk_size,
282
+ D=self.D,
283
+ z=None,
284
+ seq_idx=tok_idx,
285
+ cu_seqlens=cu_seqlens,
286
+ initial_states=initial_states,
287
+ **self.dt_limit_kwargs,
288
+ )
289
+
290
+ if hasattr(self, "cache"):
291
+ y, varlen_states = y
292
+ self.cache.state_cache.copy_(varlen_states)
293
+
294
+ return y
295
+
296
+ def _step(self, x, seq_len, dt, A, B, C):
297
+ """
298
+ For inference / generation.
299
+ """
300
+ x = x.squeeze(0)
301
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim)
302
+ dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim)
303
+ D = self.D
304
+ if D is not None and D.dim() == 1:
305
+ D = D.unsqueeze(1).expand(self.num_heads, self.head_dim)
306
+ B, C = B.squeeze(0), C.squeeze(0)
307
+ y = selective_state_update(
308
+ self.cache.state_cache,
309
+ x,
310
+ dt,
311
+ A,
312
+ B,
313
+ C,
314
+ D,
315
+ z=None,
316
+ dt_bias=(
317
+ torch.zeros(self.num_heads, self.head_dim).to(x)
318
+ if self.dt_bias is None
319
+ else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim)
320
+ ),
321
+ dt_softplus=True,
322
+ ).unsqueeze(0)
323
+
324
+ return y
325
+
326
+ def forward(
327
+ self,
328
+ x: torch.Tensor,
329
+ tok_idx: torch.Tensor | None = None,
330
+ cu_seqlens: torch.Tensor | None = None,
331
+ ssm_impl: str = "ssm",
332
+ ) -> torch.Tensor:
333
+ bsz, seq_len, _ = x.shape
334
+
335
+ zxbcdt = self.input(x)
336
+
337
+ A = -torch.exp(self.A_log.float())
338
+ initial_states = (
339
+ self.init_states.expand(bsz, -1, -1, -1)
340
+ if self.learnable_init_states else None
341
+ )
342
+
343
+ # Causal conv path
344
+ if self.conv_size is not None:
345
+
346
+ # Memory-efficient Triton kernel path
347
+ if self.use_mem_eff_path:
348
+ out = mamba_split_conv1d_scan_combined(
349
+ zxbcdt,
350
+ self.conv1d.weight.squeeze(1),
351
+ self.conv1d.bias,
352
+ self.dt_bias,
353
+ A,
354
+ D=self.D,
355
+ chunk_size=self.ssm_chunk_size,
356
+ seq_idx=tok_idx,
357
+ activation="silu",
358
+ rmsnorm_weight=self.norm.weight,
359
+ rmsnorm_eps=self.norm.eps,
360
+ outproj_weight=self.output.weight,
361
+ outproj_bias=self.output.bias,
362
+ headdim=self.head_dim,
363
+ ngroups=self.num_groups,
364
+ norm_before_gate=False, # Post-norm, y = self.norm(y * F.silu(z))
365
+ initial_states=initial_states,
366
+ **self.dt_limit_kwargs,
367
+ )
368
+ return out
369
+ else:
370
+ # CUDA kernel path
371
+ z, x, B, C, dt = self._causal_conv(zxbcdt)
372
+ else:
373
+ # Non-causal conv path
374
+ z, x, B, C, dt = self._non_causal_conv(zxbcdt)
375
+
376
+ x = x.view(bsz, seq_len, self.num_heads, self.head_dim)
377
+ B = B.view(bsz, seq_len, self.num_groups, self.state_dim)
378
+ C = C.view(bsz, seq_len, self.num_groups, self.state_dim)
379
+
380
+ # Chunked SSM scan
381
+ if ssm_impl == "ssm":
382
+ # (bsz, seq_len, num_heads, head_dim)
383
+ y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states)
384
+ elif ssm_impl == "ssm_update":
385
+ y = self._step(x, seq_len, dt, A, B, C)
386
+ else:
387
+ raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
388
+
389
+ y = y.view(bsz, seq_len, self.hidden_dim)
390
+
391
+ # Could be different activation function, including None.
392
+ # Mamba people post_norm here also (sometimes norm(z)*y or norm(z*y))
393
+ # y = self.norm(y) * F.silu(z)
394
+ y = self.norm(y * F.silu(z))
395
+ out = self.output(y)
396
+
397
+ return out
398
+
399
+ @torch.inference_mode()
400
+ def reset_parameters(self, init_std, factor) -> None:
401
+ config = self.config
402
+ init_config = config.init_config
403
+ if init_config is None:
404
+ init_config = DEFAULT_INIT_CONFIG
405
+
406
+ # Linear layers
407
+ in_init_std = init_std or (self.dim ** (-0.5))
408
+ out_init_std = init_std or (self.hidden_dim ** (-0.5))
409
+ out_init_std = out_init_std / factor
410
+
411
+ nn.init.trunc_normal_(
412
+ self.input.weight,
413
+ mean=0.0,
414
+ std=in_init_std,
415
+ a=-3 * in_init_std,
416
+ b=3 * in_init_std,
417
+ )
418
+
419
+ nn.init.trunc_normal_(
420
+ self.output.weight,
421
+ mean=0.0,
422
+ std=out_init_std,
423
+ a=-3 * out_init_std,
424
+ b=3 * out_init_std,
425
+ )
426
+
427
+ # SSM
428
+ if self.dt_bias is not None and self.dt_bias.requires_grad:
429
+ log_dt_min = math.log(init_config.dt_min)
430
+ log_dt_max = math.log(init_config.dt_max)
431
+
432
+ # Sample log_dt ~ Uniform[log_dt_min, log_dt_max]
433
+ log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min
434
+ dt = torch.exp(log_dt)
435
+ dt = torch.clamp(dt, min=init_config.dt_init_floor)
436
+
437
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
438
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
439
+ self.dt_bias.copy_(inv_dt)
440
+
441
+ elif self.dt_bias is not None:
442
+ # If dt_bias is not trainable, we can just keep it zero or set to any constant
443
+ self.dt_bias.fill_(0.0)
444
+
445
+ # Convolution
446
+ if self.conv_size is not None:
447
+ conv_std = init_std or (self.conv_size ** (-0.5))
448
+ nn.init.trunc_normal_(
449
+ self.conv1d.weight,
450
+ mean=0.0,
451
+ std=conv_std,
452
+ a=-3 * conv_std,
453
+ b=3 * conv_std,
454
+ )
455
+ if self.conv1d.bias is not None:
456
+ nn.init.zeros_(self.conv1d.bias)
457
+
458
+ # Learnable init states
459
+ if self.learnable_init_states:
460
+ self.init_states.zero_()
461
+
462
+ # Initialize A_log ~ log( Uniform(A_init_min, A_init_max) )
463
+ self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max)
464
+ self.A_log.log_()
465
+
466
+ if self.D is not None:
467
+ self.D.data.fill_(1.0)
468
+
469
+ # Reset norm parameters
470
+ self.norm.reset_parameters()
471
+
472
+
473
+ class MambaBlock(nn.Module):
474
+ def __init__(self, config: BaseMambaConfig):
475
+ super().__init__()
476
+ self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps)
477
+ self.ssm = SSM(config)
478
+
479
+ def forward(
480
+ self,
481
+ x: torch.Tensor,
482
+ tok_idx: torch.Tensor | None,
483
+ cu_seqlens: torch.Tensor | None,
484
+ ssm_impl: str = "ssm",
485
+ ) -> torch.Tensor:
486
+ x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
487
+ return x
488
+
489
+ @torch.inference_mode()
490
+ def init_weights(self, init_std=None, factor=1.0):
491
+ self.norm.reset_parameters()
492
+ self.ssm.reset_parameters(init_std, factor)
493
+
494
+
495
+ class BaseMamba(nn.Module):
496
+ def __init__(self, config: BaseMambaConfig):
497
+ super().__init__()
498
+ self.model_dim = config.dim
499
+ self.init_base_std = config.init_base_std
500
+
501
+ self.init_config = config.init_config
502
+ self.init_std_factor = InitStdFactor(config.init_std_factor)
503
+
504
+ self.layers = nn.ModuleList()
505
+ for _ in range(config.num_layers):
506
+ self.layers.append(MambaBlock(config))
507
+
508
+ def forward(
509
+ self,
510
+ h: torch.Tensor,
511
+ tok_idx: torch.Tensor | None,
512
+ cu_seqlens: torch.Tensor | None,
513
+ ssm_impl: str = "ssm",
514
+ ) -> torch.Tensor:
515
+ for layer in self.layers:
516
+ h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
517
+ return h
518
+
519
+ @torch.inference_mode()
520
+ def reset_parameters(self):
521
+ pass
522
+
523
+ @torch.inference_mode()
524
+ def init_weights(self):
525
+ self.reset_parameters()
526
+ for depth, layer in enumerate(self.layers):
527
+ factor = {
528
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
529
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
530
+ InitStdFactor.DIM_RATIO: self.model_dim / 4096,
531
+ InitStdFactor.DISABLED: 1.0,
532
+ }[self.init_std_factor]
533
+
534
+ layer.init_weights(self.init_base_std, factor)
535
+
536
+
537
+ @dataclass
538
+ class Mamba2Config(BaseMambaConfig):
539
+ seed: int = 1337
540
+
541
+ vocab_size: int = -1 # Will error if unchanged, makes you double check!
542
+ weight_tying: bool = False
543
+ torch_dtype: torch.dtype = torch.bfloat16
544
+
545
+ loss_reduction: str = "mean"
546
+
547
+ use_attn: bool = False
548
+ softcap: float = 50.0
549
+
550
+
551
+ class Mamba2(BaseMamba):
552
+ def __init__(self, config: Mamba2Config) -> None:
553
+ super().__init__(config)
554
+ self.weight_tying = config.weight_tying
555
+ self.loss_reduction = config.loss_reduction
556
+
557
+ assert config.vocab_size > 0, "vocab_size must be set and > 0"
558
+
559
+ self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim)
560
+
561
+ self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
562
+
563
+ self.output = nn.Linear(
564
+ config.dim,
565
+ config.vocab_size,
566
+ bias=False,
567
+ )
568
+
569
+ if config.weight_tying:
570
+ self.output.weight = self.tok_emb.weight
571
+
572
+ print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
573
+
574
+ def _get_num_params(self):
575
+ n_params = sum(p.numel() for p in self.parameters())
576
+ if hasattr(self, "pos_emb") and self.pos_emb is not None:
577
+ n_params -= self.pos_emb.weight.numel()
578
+ if self.tok_emb.weight is not self.output.weight:
579
+ n_params -= self.tok_emb.weight.numel()
580
+ return n_params
581
+
582
+ def forward(
583
+ self,
584
+ x: torch.Tensor,
585
+ target: torch.Tensor | None = None,
586
+ tok_idx: torch.Tensor | None = None,
587
+ cu_seqlens: torch.Tensor | None = None,
588
+ ssm_impl: str = "ssm",
589
+ ) -> torch.Tensor:
590
+ h = self.tok_emb(x)
591
+ h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
592
+ logits = self.output(self.norm(h))
593
+ return logits
594
+
595
+ @torch.inference_mode()
596
+ def reset_parameters(self, init_std=None):
597
+ # Either use fixed base std or sqrt model dim
598
+ super().reset_parameters()
599
+ init_std = init_std or (self.model_dim ** (-0.5))
600
+ self.norm.reset_parameters()
601
+ nn.init.trunc_normal_(
602
+ self.tok_emb.weight,
603
+ mean=0.0,
604
+ std=init_std,
605
+ a=-3 * init_std,
606
+ b=3 * init_std,
607
+ )
608
+ if not self.weight_tying:
609
+ nn.init.trunc_normal_(
610
+ self.output.weight,
611
+ mean=0.0,
612
+ std=init_std,
613
+ a=-3 * init_std,
614
+ b=3 * init_std,
615
+ )
616
+
617
+ @torch.inference_mode()
618
+ def init_weights(self, buffer_device: torch.device = None):
619
+ """
620
+ Initialize model parameters and optionally compute buffers on a specific device.
621
+
622
+ Args:
623
+ buffer_device (torch.device, optional): If provided, any large or precomputed
624
+ buffers (like RoPE frequency tensors) will be allocated or re-created on
625
+ this device during initialization. This can avoid overhead from transferring
626
+ buffers between CPU and GPU after creation. If None, buffers default to the
627
+ device of the first parameter or CPU.
628
+
629
+ Usage:
630
+ - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure
631
+ buffers are created directly on GPU, preventing extra transfers.
632
+ - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep
633
+ large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups).
634
+ - Leave it as ``None`` to rely on the model’s existing parameter device or
635
+ the default PyTorch device context.
636
+
637
+ When / Why:
638
+ - Useful in distributed or pipeline-parallel training where parameters may
639
+ initially live on CPU, but you still need certain buffers on GPU to avoid
640
+ overhead during forward passes.
641
+ - Prevents large re-allocations or re-copies when big buffers (like RoPE
642
+ frequency tables) are needed per rank.
643
+ """
644
+ super().init_weights()
645
+
646
+ @classmethod
647
+ def from_model_args(cls, config: Mamba2Config) -> "Mamba2":
648
+ """
649
+ Initialize a Mamba model from a MambaConfig object.
650
+
651
+ Args:
652
+ config (MambaConfig): Mamba configuration arguments.
653
+
654
+ Returns:
655
+ Mamba: Mamba-2 model.
656
+ """
657
+ return cls(config)
658
+
659
+
660
+ def get_mamba2_flops(
661
+ seq_len: int,
662
+ dim: int,
663
+ num_layers: int,
664
+ vocab_size: int,
665
+ ffn_multiplier: float = 2.0,
666
+ state_dim: int = 128,
667
+ conv_size: int = 4,
668
+ num_heads: int = 8,
669
+ num_groups: int = 1,
670
+ multiple_of: int = 256,
671
+ include_input_embedding: bool = True,
672
+ include_output_logits: bool = True,
673
+ forward_backward_multiplier: float = 1.0,
674
+ ) -> int:
675
+ """
676
+ Estimate the FLOPs for a Mamba-2 style model using a "Chinchilla-like" shape-based approach.
677
+
678
+ By default, this returns the forward-pass cost. If you want a rough
679
+ forward+backward estimate, set `forward_backward_multiplier=3.0` (common
680
+ rule-of-thumb for these models).
681
+
682
+ What gets counted:
683
+ • Hidden dimension is rounded up to 'multiple_of' = 256 (as in Mamba).
684
+ • Per-layer:
685
+ 1) Input Linear: [dim → 2*hidden_dim + 2*(groups*state_dim) + num_heads]
686
+ 2) Depthwise Conv1D: 2*(conv_dim * conv_size), where conv_dim=hidden_dim + 2*groups*state_dim
687
+ 3) SSM selective scan: ~9*(dim*state_dim) (from Mamba dev discussion)
688
+ 4) Output Linear: [hidden_dim → dim]
689
+ • Each layer’s cost is multiplied by (seq_len * num_layers).
690
+ • Optionally adds:
691
+ - The cost of the input embedding (treating it as a matmul: seq_len×vocab_size × vocab_size×dim).
692
+ - The cost of the final projection [dim → vocab_size].
693
+ • Finally scaled by `forward_backward_multiplier` if desired.
694
+
695
+ Args:
696
+ seq_len (int): Sequence length (number of tokens).
697
+ dim (int): Model (embedding) dimension.
698
+ num_layers (int): Number of Mamba layers.
699
+ vocab_size (int): Vocabulary size for final logits projection.
700
+ ffn_multiplier (float): FFN expansion ratio, e.g. 2.0 => hidden_dim=2×dim (rounded up).
701
+ state_dim (int): SSM state dimension (commonly 128).
702
+ conv_size (int): Kernel size for the depthwise conv1d (default=4).
703
+ num_heads (int): Number of heads (slightly affects input-lin out_dim).
704
+ num_groups (int): For "grouped" states in some Mamba variants (usually 1).
705
+ multiple_of (int): Round hidden_dim up to this multiple (commonly 256).
706
+ include_input_embedding (bool): If True, count the cost of an “embedding matmul”
707
+ for the input tokens => shape-based approach.
708
+ include_output_logits (bool): If True, count the cost of final [dim → vocab_size].
709
+ forward_backward_multiplier (float): E.g. 1.0 for forward only, 2.0 or 3.0 for forward+backward.
710
+
711
+ Returns:
712
+ int: Approximate total FLOPs (multiply-adds) for the selected pass(es),
713
+ as an integer.
714
+ """
715
+ # 0) Input embedding (optional)
716
+ flops_embedding = 0
717
+ if include_input_embedding:
718
+ flops_embedding = 2 * (seq_len * vocab_size * dim)
719
+
720
+ # 1) Round up hidden_dim
721
+ raw_hidden_dim = int(ffn_multiplier * dim)
722
+ hidden_dim = multiple_of * ((raw_hidden_dim + multiple_of - 1) // multiple_of)
723
+
724
+ # 2) Per-layer forward cost
725
+ out_dim_input = 2*hidden_dim + 2*(num_groups*state_dim) + num_heads
726
+ flops_input_linear = 2 * (dim * out_dim_input)
727
+ conv_dim = hidden_dim + 2*(num_groups*state_dim)
728
+ flops_conv = 2 * (conv_dim * conv_size)
729
+ flops_ssm = 9 * state_dim * dim
730
+ flops_output_linear = 2 * (hidden_dim * dim)
731
+ flops_layer = (flops_input_linear + flops_conv + flops_ssm + flops_output_linear)
732
+
733
+ # Multiply by #layers and sequence length
734
+ flops_layers = flops_layer * num_layers * seq_len
735
+
736
+ # 3) Final projection [dim → vocab_size] (optional)
737
+ flops_vocab = 0
738
+ if include_output_logits:
739
+ flops_vocab = 2 * (seq_len * dim * vocab_size)
740
+
741
+ # 4) Total forward FLOPs
742
+ flops_forward = flops_embedding + flops_layers + flops_vocab
743
+
744
+ # 5) Scale for forward+backward if desired
745
+ return int(flops_forward * forward_backward_multiplier)
746
+
747
+ def get_mamba2_flops_per_token(
748
+ **kwargs
749
+ ) -> float:
750
+ """
751
+ Estimate FLOPs per token for a Mamba-2 style model.
752
+
753
+ This function extracts necessary parameters from kwargs and calculates the FLOPs per token.
754
+
755
+ Args:
756
+ **kwargs: Dictionary containing model configuration parameters.
757
+
758
+ Returns:
759
+ float: Approximate FLOPs per token.
760
+ """
761
+ defaults = {
762
+ 'ffn_dim_multiplier': 2.0,
763
+ 'state_dim': 128,
764
+ 'conv_size': 4,
765
+ 'num_heads': 8,
766
+ 'num_groups': 1,
767
+ 'multiple_of': 256,
768
+ 'include_input_embedding': True,
769
+ 'include_output_logits': True,
770
+ 'forward_backward_multiplier': 1.0,
771
+ }
772
+ # Merge defaults
773
+ for k, v in defaults.items():
774
+ kwargs.setdefault(k, v)
775
+ # Mandatory keys
776
+ for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']:
777
+ if required not in kwargs:
778
+ raise ValueError(f"Missing required parameter: {required}")
779
+
780
+ total_flops = get_mamba2_flops(
781
+ seq_len=kwargs['seq_len'],
782
+ dim=kwargs['dim'],
783
+ num_layers=kwargs['num_layers'],
784
+ vocab_size=kwargs['vocab_size'],
785
+ ffn_multiplier=kwargs['ffn_dim_multiplier'],
786
+ state_dim=kwargs['state_dim'],
787
+ conv_size=kwargs['conv_size'],
788
+ num_heads=kwargs['num_heads'],
789
+ num_groups=kwargs['num_groups'],
790
+ multiple_of=kwargs['multiple_of'],
791
+ include_input_embedding=kwargs['include_input_embedding'],
792
+ include_output_logits=kwargs['include_output_logits'],
793
+ forward_backward_multiplier=kwargs['forward_backward_multiplier'],
794
+ )
795
+ flops_per_token = total_flops / kwargs['seq_len']
796
+
797
+ return flops_per_token
798
+
799
+
800
+ # Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
801
+ def get_no_recompute_ops():
802
+ return {
803
+ torch.ops.aten.mm.default,
804
+ torch.ops.aten._scaled_mm.default,
805
+ torch.ops.c10d_functional.reduce_scatter_tensor.default,
806
+ torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,
807
+
808
+ # For low-precision training, it's useful to always save the result of max(abs(tensor))
809
+ torch.ops.aten.abs.default,
810
+ torch.ops.aten.max.default,
811
+ }
812
+
813
+
814
+ def main():
815
+ from mamba_ssm import Mamba2 as MambaRef
816
+
817
+ x = torch.randn(2, 64, 192).cuda()
818
+
819
+ # Create and run the first model
820
+ model = MambaRef(
821
+ d_model=192,
822
+ expand=2,
823
+ d_conv=4,
824
+ d_state=64,
825
+ headdim=48,
826
+ ).cuda()
827
+ y = model(x)
828
+ print("Mamba reference output: ", y)
829
+ print("Mean of MambaRef output: ", y.mean().item())
830
+ print("Stddev of MambaRef output: ", y.std().item())
831
+
832
+ # Create and run the second model
833
+ config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
834
+ model2 = Mamba2(
835
+ config=config,
836
+ ).cuda()
837
+
838
+ # Fix: Convert x to torch.LongTensor
839
+ x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
840
+
841
+ y2 = model2(x_indices)
842
+ print("Mamba output: ", y2)
843
+ print("Mean of Mamba output: ", y2.mean().item())
844
+ print("Stddev of Mamba output: ", y2.std().item())
845
+
846
+ if __name__ == "__main__":
847
+ main()
848
+
modeling_minimamba.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+
9
+ from .configuration_minimamba import MiniMambaConfig
10
+ from .model import Mamba2, Mamba2Config
11
+
12
+
13
+
14
+ class MiniMamba(PreTrainedModel):
15
+ """
16
+ A Hugging Face–style wrapper around a Mamba2 model, providing:
17
+ • forward(...) returning a CausalLMOutput
18
+ • support for HF training loops
19
+ • a naive generate(...) method with top-k/top-p sampling
20
+ """
21
+ config_class = MiniMambaConfig # Tells HF which config class to use
22
+
23
+ def __init__(self, config: MiniMambaConfig) -> None:
24
+ """
25
+ Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel.
26
+ """
27
+ super().__init__(config)
28
+
29
+ # If your config includes Mamba2-like parameters, you can build a Mamba2Config from it:
30
+ mamba2_args = Mamba2Config(
31
+ dim=config.dim,
32
+ num_layers=config.num_layers,
33
+ num_heads=config.num_heads,
34
+ state_dim=config.state_dim,
35
+ num_groups=config.num_groups,
36
+ conv_size=config.conv_size,
37
+ use_mem_eff_path=config.use_mem_eff_path,
38
+ dt_bias=config.dt_bias,
39
+ D_has_head_dim=config.D_has_head_dim,
40
+ learnable_init_states=config.learnable_init_states,
41
+ ssm_chunk_size=config.ssm_chunk_size,
42
+ vocab_size=config.vocab_size,
43
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
44
+ multiple_of=config.multiple_of,
45
+ norm_eps=config.norm_eps,
46
+ init_use_depth=config.init_use_depth,
47
+ init_base_std=config.init_base_std,
48
+ init_std_factor=config.init_std_factor,
49
+ bias=config.bias,
50
+
51
+ # Torch / training:
52
+ seed=config.seed,
53
+
54
+ # The init_config block nested in JSON:
55
+
56
+
57
+ # Additional Mamba or training fields:
58
+
59
+
60
+
61
+
62
+ weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False,
63
+ torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype,
64
+ )
65
+
66
+ # Internally hold a Mamba2 model
67
+ self.mamba = Mamba2(config=mamba2_args)
68
+
69
+ # Because HF wants the final linear to be part of this top-level model,
70
+ # you *can* rely on Mamba2’s built-in embedding + output if you prefer.
71
+ # Mamba2 already has self.tok_emb and self.output.
72
+ # So we typically do NOT need a separate embedding or lm_head here.
73
+ #
74
+ # We only do so if we want the “HF standard” tie-weights approach:
75
+ # self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
76
+ # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
77
+ # self.lm_head.weight = self.tok_emb.weight
78
+ #
79
+ # But Mamba2 does that internally if config.weight_tying == True.
80
+
81
+ # This is optional: store any device or dtype you might want
82
+ self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu'
83
+ if isinstance(config.torch_dtype, str):
84
+ self.dtype_ = getattr(torch, config.torch_dtype)
85
+ else:
86
+ self.dtype_ = config.torch_dtype
87
+
88
+ # Parameter initialization (HF calls them with self._init_weights in some flows).
89
+ self.apply(self._init_weights)
90
+
91
+ print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
92
+
93
+ def forward(
94
+ self,
95
+ input_ids: torch.LongTensor,
96
+ labels: torch.LongTensor = None,
97
+ **kwargs
98
+ ) -> CausalLMOutput:
99
+ """
100
+ Forward pass for causal language modeling.
101
+ Returns a CausalLMOutput that includes loss (if labels is provided) and logits.
102
+ """
103
+ # Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...)
104
+ # but we only need the logits from the simple call:
105
+ logits = self.mamba(input_ids) # shape: [batch, seq_len, vocab_size]
106
+
107
+ loss = None
108
+ if labels is not None:
109
+ # By default, huggingface GPT-like models shift the logits by one
110
+ shift_logits = logits[..., :-1, :].contiguous()
111
+ shift_labels = labels[..., 1:].contiguous()
112
+ loss_fct = nn.CrossEntropyLoss()
113
+ loss = loss_fct(
114
+ shift_logits.view(-1, shift_logits.size(-1)),
115
+ shift_labels.view(-1)
116
+ )
117
+
118
+ return CausalLMOutput(
119
+ loss=loss,
120
+ logits=logits,
121
+ )
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ input_ids: torch.LongTensor,
127
+ max_new_tokens: int = 50,
128
+ temperature: float = 0.5,
129
+ top_k: int = 50,
130
+ top_p: float = 0.95,
131
+ eos_token_id: int = None,
132
+ pad_token_id: int = 0,
133
+ **kwargs
134
+ ):
135
+ """
136
+ A naive token-by-token generation loop (greedy + top-k/top-p + temperature).
137
+ """
138
+ # We'll accumulate new tokens in generated_ids
139
+ generated_ids = input_ids.clone()
140
+
141
+ for _ in range(max_new_tokens):
142
+ # Forward pass to get logits for the last token
143
+ outputs = self.forward(generated_ids)
144
+ logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
145
+
146
+ # Scale by temperature
147
+ if temperature != 1.0:
148
+ logits = logits / temperature
149
+
150
+ # Filter
151
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
152
+
153
+ # Sample next token
154
+ probs = F.softmax(logits, dim=-1)
155
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (batch, 1)
156
+
157
+ # Append
158
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
159
+
160
+ # If we have an EOS token, we can break early if all sequences have ended
161
+ if eos_token_id is not None and (next_token == eos_token_id).all():
162
+ break
163
+
164
+ return generated_ids
165
+
166
+ @staticmethod
167
+ def top_k_top_p_filtering(
168
+ logits: torch.Tensor,
169
+ top_k: int = 50,
170
+ top_p: float = 0.95,
171
+ filter_value: float = float("-inf"),
172
+ ):
173
+ """
174
+ Filters logits using top-k and/or nucleus (top-p) filtering.
175
+ """
176
+ # top_k
177
+ if top_k > 0:
178
+ top_k = min(top_k, logits.size(-1))
179
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
180
+ logits[indices_to_remove] = filter_value
181
+
182
+ # top_p (nucleus)
183
+ if 0 < top_p < 1.0:
184
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
185
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
186
+
187
+ # Remove tokens with cumulative probability above the threshold
188
+ sorted_indices_to_remove = cumulative_probs > top_p
189
+
190
+ # Shift right to keep also the first token above threshold
191
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
192
+ sorted_indices_to_remove[:, 0] = False
193
+
194
+ # Scatter to get back to original indexing
195
+ indices_to_remove = sorted_indices_to_remove.scatter(
196
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
197
+ )
198
+ logits[indices_to_remove] = filter_value
199
+
200
+ return logits
201
+
202
+ def _init_weights(self, module):
203
+ """
204
+ HF calls _init_weights to initialize parameters.
205
+ If you prefer Mamba’s own init approach, you can call model.mamba.init_weights().
206
+ """
207
+ # As an example, we just call Mamba2's init routine for the entire submodel,
208
+ # or do some standard PyTorch inits for linear layers, embeddings, etc.
209
+ if isinstance(module, Mamba2):
210
+ module.init_weights() # Mamba2’s internal init
211
+ elif isinstance(module, nn.Linear):
212
+ # e.g. standard xavier or normal init
213
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
214
+ if module.bias is not None:
215
+ nn.init.zeros_(module.bias)
216
+ elif isinstance(module, nn.Embedding):
217
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
218
+ # If needed, do your specialized inits for other modules
219
+
220
+ def _get_num_params(self):
221
+ # Count trainable params, subtract duplicates if tying weights, etc.
222
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
norms.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py"""
2
+
3
+ import math
4
+
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from torch.distributed._tensor import Partial, Replicate, Shard
14
+ from torch.distributed._tensor.experimental import local_map
15
+ from torch._utils import _get_available_device_type, _get_device_module
16
+
17
+
18
+ def get_device_info():
19
+ device_type = _get_available_device_type()
20
+
21
+ if device_type is None:
22
+ device_type = "cuda" # Default to CUDA
23
+
24
+ device_module = _get_device_module(device_type)
25
+ return device_type, device_module
26
+
27
+ device_type, device_module = get_device_info()
28
+
29
+ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
30
+ """
31
+ Builds the specified normalization layer based on the norm_type.
32
+
33
+ Args:
34
+ norm_type (str): The type of normalization layer to build.
35
+ Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
36
+ dim (int): The dimension of the normalization layer.
37
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
38
+
39
+ Returns:
40
+ The built normalization layer.
41
+
42
+ Raises:
43
+ NotImplementedError: If an unknown norm_type is provided.
44
+ """
45
+ norm_type = norm_type.lower() # Normalize to lowercase
46
+
47
+ if norm_type == "layernorm":
48
+ return nn.LayerNorm(dim, eps=eps, bias=False)
49
+ elif norm_type == "np_layernorm":
50
+ return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
51
+ elif norm_type == "rmsnorm":
52
+ return RMSNorm(dim, eps=eps)
53
+ elif norm_type == "fused_rmsnorm":
54
+ return FusedRMSNorm(dim, eps=eps)
55
+ else:
56
+ raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
57
+
58
+
59
+ class FusedRMSNorm(nn.Module):
60
+ """Fused RMS Norm, wraps a fused Triton Kernel"""
61
+
62
+ def __init__(
63
+ self,
64
+ dim: int,
65
+ eps: float = 1e-6,
66
+ ):
67
+ super().__init__()
68
+ self.eps = eps
69
+ self.weight = nn.Parameter(torch.ones(dim))
70
+ self.fused_rms_norm_fn = fused_rms_norm_fn
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ """leverages Triton Fused RMS Norm kernel"""
74
+ return self.fused_rms_norm_fn(
75
+ x,
76
+ self.weight,
77
+ eps=self.eps,
78
+ )
79
+
80
+ def reset_parameters(self):
81
+ torch.nn.init.ones_(self.weight) # type: ignore
82
+
83
+
84
+ class RMSNorm(torch.nn.Module):
85
+ def __init__(self, dim: int, eps: float = 1e-6):
86
+ """
87
+ Initialize the RMSNorm normalization layer.
88
+
89
+ Args:
90
+ dim (int): The dimension of the input tensor.
91
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
92
+
93
+ Attributes:
94
+ eps (float): A small value added to the denominator for numerical stability.
95
+ weight (nn.Parameter): Learnable scaling parameter.
96
+
97
+ """
98
+ super().__init__()
99
+ self.eps = eps
100
+ self.weight = nn.Parameter(torch.ones(dim))
101
+
102
+ def _norm(self, x):
103
+ """
104
+ Apply the RMSNorm normalization to the input tensor.
105
+
106
+ Args:
107
+ x (torch.Tensor): The input tensor.
108
+
109
+ Returns:
110
+ torch.Tensor: The normalized tensor.
111
+
112
+ """
113
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
114
+
115
+ def forward(self, x):
116
+ """
117
+ Forward pass through the RMSNorm layer.
118
+
119
+ Args:
120
+ x (torch.Tensor): The input tensor.
121
+
122
+ Returns:
123
+ torch.Tensor: The output tensor after applying RMSNorm.
124
+
125
+ """
126
+ output = self._norm(x.float()).type_as(x)
127
+ return output * self.weight
128
+
129
+ def reset_parameters(self):
130
+ torch.nn.init.ones_(self.weight) # type: ignore
131
+
132
+
133
+ # FusedRMSNorm in Triton
134
+
135
+ # Credit
136
+ # Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
137
+ # Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
138
+
139
+
140
+ @triton.autotune(
141
+ configs=[
142
+ triton.Config({}, num_warps=1),
143
+ triton.Config({}, num_warps=2),
144
+ triton.Config({}, num_warps=4),
145
+ triton.Config({}, num_warps=8),
146
+ triton.Config({}, num_warps=16),
147
+ triton.Config({}, num_warps=32),
148
+ ],
149
+ key=["N"],
150
+ )
151
+ @triton.jit
152
+ def _rms_norm_fwd_kernel(
153
+ X,
154
+ stride_x,
155
+ Y,
156
+ stride_y,
157
+ W,
158
+ Rstd,
159
+ eps,
160
+ M, # num rows
161
+ N, # num cols
162
+ block_N: tl.constexpr,
163
+ ):
164
+ row = tl.program_id(0)
165
+ cols = tl.arange(0, block_N)
166
+
167
+ # Load input data and weights
168
+ mask = cols < N
169
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
170
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
171
+
172
+ # Compute mean and variance
173
+ xbar = tl.where(cols < N, x, 0.0)
174
+ var = tl.sum(xbar * xbar, axis=0) / N
175
+ rstd = 1 / tl.sqrt(var + eps)
176
+
177
+ # Store the reciprocal standard deviation
178
+ tl.store(Rstd + row, rstd)
179
+
180
+ # Normalize and apply linear transformation
181
+ x_hat = x * rstd
182
+ y = x_hat * w
183
+
184
+ # Write output
185
+ tl.store(Y + row * stride_y + cols, y, mask=mask)
186
+
187
+
188
+ @triton.autotune(
189
+ configs=[
190
+ triton.Config({}, num_warps=1),
191
+ triton.Config({}, num_warps=2),
192
+ triton.Config({}, num_warps=4),
193
+ triton.Config({}, num_warps=8),
194
+ triton.Config({}, num_warps=16),
195
+ triton.Config({}, num_warps=32),
196
+ ],
197
+ key=["N"],
198
+ )
199
+ @triton.jit
200
+ def _rms_norm_bwd_kernel_sm(
201
+ X,
202
+ stride_x,
203
+ W,
204
+ DY,
205
+ stride_dy,
206
+ DX,
207
+ stride_dx,
208
+ Rstd,
209
+ DW,
210
+ eps,
211
+ M, # num rows
212
+ N, # num cols
213
+ rows_per_program,
214
+ block_N: tl.constexpr,
215
+ ):
216
+ row_block_id = tl.program_id(0)
217
+ row_start = row_block_id * rows_per_program
218
+ cols = tl.arange(0, block_N)
219
+ mask = cols < N
220
+
221
+ # Load weights
222
+ w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
223
+
224
+ # Accumulate gradients for weights
225
+ dw = tl.zeros((block_N,), dtype=tl.float32)
226
+
227
+ row_end = min(row_start + rows_per_program, M)
228
+ for row in range(row_start, row_end):
229
+ # Load input, output gradient, and reciprocal standard deviation
230
+ x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
231
+ dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
232
+ rstd = tl.load(Rstd + row)
233
+
234
+ # Compute normalized input and gradients
235
+ x_hat = x * rstd
236
+ wdy = w * dy
237
+ dw += dy * x_hat
238
+ c1 = tl.sum(x_hat * wdy, axis=0) / N
239
+ dx = (wdy - x_hat * c1) * rstd
240
+
241
+ # Store input gradient
242
+ tl.store(DX + row * stride_dx + cols, dx, mask=mask)
243
+
244
+ # Store weight gradients
245
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
246
+
247
+
248
+ class TritonFusedRMSNorm(torch.autograd.Function):
249
+ @partial(
250
+ local_map,
251
+ out_placements=[Shard(1)],
252
+ in_placements=(None, [Shard(1)], [Replicate()], None),
253
+ )
254
+ @staticmethod
255
+ def forward(ctx, x, weight, eps):
256
+ x_shape_start = x.shape
257
+
258
+ # Flatten input
259
+ x = x.view(-1, x.shape[-1])
260
+ if x.stride(-1) != 1:
261
+ x = x.contiguous()
262
+ if weight.stride(-1) != 1:
263
+ weight = weight.contiguous()
264
+
265
+ M, N = x.shape
266
+ y = torch.empty_like(x)
267
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
268
+
269
+ max_size = 65536 // x.element_size()
270
+ block_N = min(max_size, triton.next_power_of_2(N))
271
+
272
+ if N > block_N:
273
+ raise ValueError(f"N {N} must be <= {block_N=}")
274
+
275
+ grid = lambda meta: (M,)
276
+ _rms_norm_fwd_kernel[grid](
277
+ x,
278
+ x.stride(0),
279
+ y,
280
+ y.stride(0),
281
+ weight,
282
+ rstd,
283
+ eps,
284
+ M,
285
+ N,
286
+ block_N,
287
+ )
288
+
289
+ ctx.eps = eps
290
+ ctx.save_for_backward(x, weight, rstd)
291
+ ctx.x_shape_start = x_shape_start
292
+
293
+ y = y.reshape(x_shape_start)
294
+ return y
295
+
296
+ @partial(
297
+ local_map,
298
+ out_placements=([Shard(1)], [Partial()], None),
299
+ in_placements=(None, [Shard(1)]),
300
+ )
301
+ @staticmethod
302
+ def backward(ctx, dy):
303
+ x, weight, rstd = ctx.saved_tensors
304
+ eps = ctx.eps
305
+ x_shape_start = ctx.x_shape_start
306
+
307
+ # Flatten input and output gradients
308
+ dy = dy.view(-1, dy.shape[-1])
309
+ if dy.stride(-1) != 1:
310
+ dy = dy.contiguous()
311
+
312
+ M, N = dy.shape
313
+ dx = torch.empty_like(x)
314
+
315
+ sm_count = device_module.get_device_properties(x.device).multi_processor_count
316
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
317
+
318
+ max_size = 65536 // x.element_size()
319
+ block_N = min(max_size, triton.next_power_of_2(N))
320
+ rows_per_sm = math.ceil(M / sm_count)
321
+
322
+ if N > block_N:
323
+ raise ValueError(f"N {N} must be <= {block_N=}")
324
+
325
+ grid = lambda meta: (sm_count,)
326
+ _rms_norm_bwd_kernel_sm[grid](
327
+ x,
328
+ x.stride(0),
329
+ weight,
330
+ dy,
331
+ dy.stride(0),
332
+ dx,
333
+ dx.stride(0),
334
+ rstd,
335
+ _dw,
336
+ eps,
337
+ M,
338
+ N,
339
+ rows_per_sm,
340
+ block_N,
341
+ )
342
+ dw = _dw.sum(0).to(weight.dtype)
343
+ dx = dx.view(x_shape_start)
344
+ return dx, dw, None
345
+
346
+
347
+ # expose fusedRMSNorm as a function
348
+ def fused_rms_norm_fn(
349
+ x,
350
+ weight,
351
+ eps=1e-6,
352
+ ):
353
+ return TritonFusedRMSNorm.apply(
354
+ x,
355
+ weight,
356
+ eps,
357
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
ssm_compilable.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import torch
3
+
4
+ from mamba_ssm.ops.triton.ssd_combined import _mamba_chunk_scan_combined_fwd, _mamba_chunk_scan_combined_bwd
5
+
6
+
7
+ @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
8
+ def _compiled_mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=None):
9
+ return _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
10
+
11
+ @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
12
+ def _compiled_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=None):
13
+ return _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
14
+
15
+
16
+ @torch.library.custom_op(
17
+ "mamba_ssm::ssm_chunk_scan_combined_fwd",
18
+ mutates_args=(),
19
+ device_types="cuda",
20
+ )
21
+ def ssm_chunk_scan_combined_fwd(
22
+ x: torch.Tensor,
23
+ dt: torch.Tensor,
24
+ A: torch.Tensor,
25
+ B: torch.Tensor,
26
+ C: torch.Tensor,
27
+ chunk_size: int,
28
+ D: Optional[torch.Tensor] = None,
29
+ z: Optional[torch.Tensor] = None,
30
+ dt_bias: Optional[torch.Tensor] = None,
31
+ initial_states: Optional[torch.Tensor] = None,
32
+ seq_idx: Optional[torch.Tensor] = None,
33
+ cu_seqlens: Optional[torch.Tensor] = None,
34
+ dt_softplus: bool = False,
35
+ dt_limit: Optional[List[float]] = None
36
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37
+ out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
38
+
39
+ return out, out_x if out_x is not None else out.new_empty(0), rest[0] if cu_seqlens is not None else out.new_empty(0)
40
+
41
+ @ssm_chunk_scan_combined_fwd.register_fake
42
+ def _ssm_chunk_scan_combined_fwd_fake(
43
+ x: torch.Tensor,
44
+ dt: torch.Tensor,
45
+ A: torch.Tensor,
46
+ B: torch.Tensor,
47
+ C: torch.Tensor,
48
+ chunk_size: int,
49
+ D: Optional[torch.Tensor] = None,
50
+ z: Optional[torch.Tensor] = None,
51
+ dt_bias: Optional[torch.Tensor] = None,
52
+ initial_states: Optional[torch.Tensor] = None,
53
+ seq_idx: Optional[torch.Tensor] = None,
54
+ cu_seqlens: Optional[torch.Tensor] = None,
55
+ dt_softplus: bool = False,
56
+ dt_limit: Optional[List[float]] = None
57
+ ):
58
+ _, _, n_heads, head_dim = x.shape
59
+ return (
60
+ torch.empty_like(x),
61
+ torch.empty_like(x) if z is not None else None,
62
+ x.new_empty((cu_seqlens.size(0)-1, n_heads, head_dim, B.size(0))) if cu_seqlens is not None else None,
63
+ )
64
+
65
+ @torch.library.custom_op(
66
+ "mamba_ssm::ssm_chunk_scan_combined_bwd",
67
+ mutates_args=(),
68
+ device_types="cuda",
69
+ )
70
+ def ssm_chunk_scan_combined_bwd(
71
+ dout: torch.Tensor,
72
+ x: torch.Tensor,
73
+ dt: torch.Tensor,
74
+ A: torch.Tensor,
75
+ B: torch.Tensor,
76
+ C: torch.Tensor,
77
+ out: torch.Tensor,
78
+ chunk_size: int,
79
+ D: Optional[torch.Tensor] = None,
80
+ z: Optional[torch.Tensor] = None,
81
+ dt_bias: Optional[torch.Tensor] = None,
82
+ initial_states: Optional[torch.Tensor] = None,
83
+ seq_idx: Optional[torch.Tensor] = None,
84
+ dt_softplus: bool = False,
85
+ dt_limit: Optional[List[float]] = None
86
+ )-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
87
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=None, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
88
+ return (
89
+ dx,
90
+ ddt,
91
+ dA,
92
+ dB,
93
+ dC,
94
+ dD if dD is not None else dx.new_empty(0),
95
+ dz if dz is not None else dx.new_empty(0),
96
+ ddt_bias if ddt_bias is not None else dx.new_empty(0),
97
+ dinitial_states if dinitial_states is not None else dx.new_empty(0)
98
+ )
99
+
100
+ @ssm_chunk_scan_combined_bwd.register_fake
101
+ def _ssm_chunk_scan_combined_bwd_fake(
102
+ dout: torch.Tensor,
103
+ x: torch.Tensor,
104
+ dt: torch.Tensor,
105
+ A: torch.Tensor,
106
+ B: torch.Tensor,
107
+ C: torch.Tensor,
108
+ out: torch.Tensor,
109
+ chunk_size: int,
110
+ D: Optional[torch.Tensor] = None,
111
+ z: Optional[torch.Tensor] = None,
112
+ dt_bias: Optional[torch.Tensor] = None,
113
+ initial_states: Optional[torch.Tensor] = None,
114
+ seq_idx: Optional[torch.Tensor] = None,
115
+ dt_softplus: bool = False,
116
+ dt_limit: Optional[List[float]] = None
117
+ ):
118
+ return (
119
+ torch.empty_like(x),
120
+ torch.empty_like(dt),
121
+ torch.empty_like(A),
122
+ torch.empty_like(B),
123
+ torch.empty_like(C),
124
+ torch.empty_like(D) if D is not None else None,
125
+ torch.empty_like(z) if z is not None else None,
126
+ torch.empty_like(dt_bias) if dt_bias is not None else None,
127
+ torch.empty_like(initial_states) if initial_states is not None else None,
128
+ )
129
+
130
+
131
+ def ssm_chunk_scan_combined_setup_context(ctx, inputs, output):
132
+ x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit = inputs
133
+ out, out_x, state_varlen = output
134
+
135
+ ctx.save_for_backward(out if z is None else out_x, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx)
136
+ ctx.dt_softplus = dt_softplus
137
+ ctx.chunk_size = chunk_size
138
+ ctx.dt_limit = dt_limit
139
+
140
+ def ssm_chunk_scan_combined_bridge(ctx, dout, dout_x, dout_state_varlen):
141
+ out, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
142
+
143
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = ssm_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
144
+
145
+ return (
146
+ dx,
147
+ ddt,
148
+ dA,
149
+ dB,
150
+ dC,
151
+ None,
152
+ dD if D is not None else None,
153
+ dz if z is not None else None,
154
+ ddt_bias if dt_bias is not None else None,
155
+ dinitial_states if initial_states is not None else None,
156
+ None,
157
+ None,
158
+ None,
159
+ None,
160
+ )
161
+
162
+ # Register custom autograd function
163
+ torch.library.register_autograd(
164
+ "mamba_ssm::ssm_chunk_scan_combined_fwd",
165
+ ssm_chunk_scan_combined_bridge,
166
+ setup_context=ssm_chunk_scan_combined_setup_context,
167
+ )
168
+
169
+ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
170
+ """
171
+ Argument:
172
+ x: (batch, seqlen, nheads, headdim)
173
+ dt: (batch, seqlen, nheads)
174
+ A: (nheads)
175
+ B: (batch, seqlen, ngroups, dstate)
176
+ C: (batch, seqlen, ngroups, dstate)
177
+ chunk_size: int
178
+ D: (nheads, headdim) or (nheads,)
179
+ z: (batch, seqlen, nheads, headdim)
180
+ dt_bias: (nheads,)
181
+ initial_states: (batch, nheads, headdim, dstate)
182
+ seq_idx: (batch, seqlen)
183
+ cu_seqlens: (num_sequences + 1) or None
184
+ dt_softplus: Whether to apply softplus to dt
185
+ Return:
186
+ out: (batch, seqlen, nheads, headdim)
187
+ """
188
+
189
+ out, _, varlen_states = ssm_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
190
+ if cu_seqlens is not None:
191
+ return out, varlen_states
192
+ return out
193
+
194
+ if __name__ == "__main__":
195
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as mamba_chunk_scan_combined_ref
196
+
197
+ torch.manual_seed(0)
198
+ torch.cuda.manual_seed(0)
199
+
200
+ x = torch.randn(2, 3, 4, 5).cuda()
201
+ dt = torch.randn(2, 3, 4).cuda()
202
+ A = torch.randn(4).cuda()
203
+ B = torch.randn(2, 3, 4, 5).cuda()
204
+ C = torch.randn(2, 3, 4, 5).cuda()
205
+ chunk_size = 2
206
+ D = torch.randn(4, 5).cuda()
207
+ z = torch.randn(2, 3, 4, 5).cuda()
208
+ dt_bias = torch.randn(4).cuda()
209
+
210
+ out = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
211
+
212
+ print(out.min(), out.max(), out.mean(), out.std())
213
+
214
+ compiled_mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined)
215
+ out = compiled_mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
216
+
217
+ print(out.min(), out.max(), out.mean(), out.std())
218
+
219
+ out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias)
220
+
221
+ print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std())
222
+
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "199999": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200018": {
13
+ "content": "<|endofprompt|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|endoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "eos_token": "<|endoftext|>",
24
+ "model_max_length": 128000,
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": "<|endoftext|>"
27
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff