YongganFu commited on
Commit
5b6cbb5
·
verified ·
1 Parent(s): f6866e2

Upload JambaForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "JambaForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_hidden_size": -1,
7
+ "attn_implementation": "flash_attention_2",
8
+ "attn_implementation_new": "flash_attention_2",
9
+ "attn_layer_offset": 4,
10
+ "attn_layer_period": 8,
11
+ "attn_reuse_every_i_layer": -1,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_jamba.JambaConfig",
14
+ "AutoModelForCausalLM": "modeling_jamba.JambaForCausalLM"
15
+ },
16
+ "bos_token_id": 1,
17
+ "calc_logits_for_entire_prompt": false,
18
+ "compact_gating": false,
19
+ "compute_attn_mat": false,
20
+ "d_conv": 4,
21
+ "dense_public_ffn_structure": false,
22
+ "double_v_dim": false,
23
+ "enable_mod": false,
24
+ "eos_token_id": 2,
25
+ "expert_layer_offset": 1,
26
+ "expert_layer_period": 2,
27
+ "ffn_expand_ratio": 3,
28
+ "ffn_reuse_every_i_layer": -1,
29
+ "ffn_sharing_config": null,
30
+ "fully_parallel_jamba": false,
31
+ "fused_multihead_config": null,
32
+ "global_attn_idx": [],
33
+ "gradient_checkpoint_layer": null,
34
+ "hash_grid_config": null,
35
+ "hash_grid_config_mlp": null,
36
+ "hidden_act": "silu",
37
+ "hidden_size": 3072,
38
+ "hybrid_block_indices": [],
39
+ "hybrid_decoder_layer": "mamba",
40
+ "initializer_range": 0.02,
41
+ "intermediate_size": 0,
42
+ "kq_head_dim": -1,
43
+ "kq_norm": "none",
44
+ "kv_reuse_every_i_layer": -1,
45
+ "kv_reuse_group": null,
46
+ "kv_weight_reuse": false,
47
+ "layer_type": [
48
+ "m",
49
+ "a",
50
+ "m",
51
+ "a",
52
+ "a",
53
+ "a",
54
+ "m",
55
+ "a",
56
+ "m",
57
+ "a",
58
+ "m",
59
+ "a",
60
+ "a",
61
+ "a",
62
+ "m",
63
+ "a",
64
+ "m",
65
+ "a",
66
+ "m",
67
+ "a",
68
+ "a",
69
+ "a",
70
+ "m",
71
+ "a",
72
+ "m",
73
+ "a",
74
+ "m",
75
+ "a",
76
+ "m",
77
+ "a",
78
+ "m",
79
+ "a",
80
+ "m",
81
+ "a",
82
+ "m",
83
+ "a"
84
+ ],
85
+ "layer_types": [
86
+ "deltanet",
87
+ "f",
88
+ "m2",
89
+ "f",
90
+ "a",
91
+ "f",
92
+ "m2",
93
+ "f",
94
+ "deltanet",
95
+ "f",
96
+ "m2",
97
+ "f",
98
+ "a",
99
+ "f",
100
+ "m2",
101
+ "f",
102
+ "deltanet",
103
+ "f",
104
+ "m2",
105
+ "f",
106
+ "a",
107
+ "f",
108
+ "m2",
109
+ "f",
110
+ "deltanet",
111
+ "f",
112
+ "m2",
113
+ "f",
114
+ "deltanet",
115
+ "f",
116
+ "m2",
117
+ "f",
118
+ "deltanet",
119
+ "f",
120
+ "m2",
121
+ "f"
122
+ ],
123
+ "layerwise_memory_token": false,
124
+ "local_expand_ratio": 1,
125
+ "local_global_dual_branch": false,
126
+ "local_global_dual_branch_merge_op": "mean",
127
+ "lookback_mode": "",
128
+ "macro_arch": "",
129
+ "mamba2_headdim": 64,
130
+ "mamba_attnaug_config": null,
131
+ "mamba_conv_bias": true,
132
+ "mamba_d_conv": 4,
133
+ "mamba_d_state": 16,
134
+ "mamba_dt_rank": 192,
135
+ "mamba_expand": 2,
136
+ "mamba_inner_layernorms": true,
137
+ "mamba_latent_size": null,
138
+ "mamba_multihead_config": null,
139
+ "mamba_proj_bias": false,
140
+ "mamba_reuse_every_i_layer": -1,
141
+ "max_position_embeddings": 2048,
142
+ "memory_tokens_interspersed_every": 0,
143
+ "mlp_hidden_act": "silu",
144
+ "mod_topk": 2,
145
+ "model_type": "jamba",
146
+ "moe_config": null,
147
+ "nGPT_config": {
148
+ "extra_grad": false,
149
+ "gate_scaling": false,
150
+ "init_norm": false,
151
+ "learned_scaling": false,
152
+ "norm_bc": false,
153
+ "norm_gating": false,
154
+ "norm_ssm_input": false,
155
+ "post_norm": false,
156
+ "qk_norm": false,
157
+ "weight_norm": true
158
+ },
159
+ "nGPT_mode": null,
160
+ "new_seq_length": 2048,
161
+ "no_dt_bias": false,
162
+ "num_attention_heads": 24,
163
+ "num_attn_per_ffn": 3,
164
+ "num_experts": 1,
165
+ "num_experts_per_tok": 1,
166
+ "num_ffn": 1,
167
+ "num_hidden_layers": 36,
168
+ "num_key_value_heads": 6,
169
+ "num_mamba": 1,
170
+ "num_memory_tokens": 256,
171
+ "orig_max_position_embeddings": 2048,
172
+ "other_args": null,
173
+ "output_router_logits": false,
174
+ "pad_token_id": 0,
175
+ "public_ffn_structure": false,
176
+ "pure_linear_attn": false,
177
+ "reduce_attn_ratio": 0.5,
178
+ "reduce_method": "mean",
179
+ "repeat_ffn": null,
180
+ "rms_norm_eps": 1e-06,
181
+ "rope": true,
182
+ "rope_theta": 10000.0,
183
+ "rope_type": null,
184
+ "router_aux_loss_coef": 0.001,
185
+ "save_input_output": false,
186
+ "self_attn_type": null,
187
+ "seq_length": 2048,
188
+ "sequential_jamba": false,
189
+ "share_kv": false,
190
+ "shared_module_attn": "",
191
+ "shared_module_mamba": "",
192
+ "sliding_window": null,
193
+ "sliding_window_size": null,
194
+ "supernet_config": null,
195
+ "swa_full_head": false,
196
+ "tie_word_embeddings": true,
197
+ "torch_dtype": "bfloat16",
198
+ "transformers_version": "4.48.2",
199
+ "use_cache": false,
200
+ "use_mamba2": false,
201
+ "use_mamba_kernels": true,
202
+ "use_nGPT": true,
203
+ "use_nemotron5": false,
204
+ "v_head_dim": -1,
205
+ "visual_attn": false,
206
+ "visual_entropy": false,
207
+ "vocab_size": 131072
208
+ }
configuration_jamba.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Jamba model configuration"""
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class JambaConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
+
31
+ [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 65536):
39
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`JambaModel`]
41
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
+ model has a output word embedding layer.
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 14336):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*, defaults to 8):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
+ last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
+ the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
+ will reduce memory footprint significantly.
73
+ Note: some generation features may not be available if this is set to `False`.
74
+ output_router_logits (`bool`, *optional*, defaults to `False`):
75
+ Whether or not the router logits should be returned by the model. Enabling this will also
76
+ allow the model to output the auxiliary loss. See [here]() for more details
77
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
+ The aux loss factor for the total loss.
79
+ pad_token_id (`int`, *optional*, defaults to 0):
80
+ The id of the padding token.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ The id of the "beginning-of-sequence" token.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the "end-of-sequence" token.
85
+ sliding_window (`int`, *optional*):
86
+ Sliding window attention window size. If not specified, will default to `None`.
87
+ n_ctx (`int`, *optional*, defaults to 262144):
88
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
+ used with. It can be used with longer sequences, but performance may degrade.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
93
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
+ parameter
95
+ num_experts (`int`, *optional*, defaults to 16):
96
+ Number of experts per Sparse MLP layer.
97
+ expert_layer_period (`int`, *optional*, defaults to 2):
98
+ Once in this many layers, we will have an expert layer
99
+ expert_layer_offset (`int`, *optional*, defaults to 1):
100
+ The first layer index that contains an expert mlp layer
101
+ attn_layer_period (`int`, *optional*, defaults to 8):
102
+ Once in this many layers, we will have a vanilla attention layer
103
+ attn_layer_offset (`int`, *optional*, defaults to 4):
104
+ The first layer index that contains a vanilla attention mlp layer
105
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
106
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
107
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
108
+ `True` and kernels are not available
109
+ mamba_d_state (`int`, *optional*, defaults to 16):
110
+ The dimension the mamba state space latents
111
+ mamba_d_conv (`int`, *optional*, defaults to 4):
112
+ The size of the mamba convolution kernel
113
+ mamba_expand (`int`, *optional*, defaults to 2):
114
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
115
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
116
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
117
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
118
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
121
+ mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
122
+ Flag indicating whether or not to apply layernorms to internal mamba activations
123
+
124
+ """
125
+
126
+ model_type = "jamba"
127
+ keys_to_ignore_at_inference = ["past_key_values"]
128
+
129
+ def __init__(
130
+ self,
131
+ vocab_size=65536,
132
+ tie_word_embeddings=False,
133
+ hidden_size=4096,
134
+ intermediate_size=14336,
135
+ num_hidden_layers=32,
136
+ num_attention_heads=32,
137
+ num_key_value_heads=8,
138
+ hidden_act="silu",
139
+ initializer_range=0.02,
140
+ rms_norm_eps=1e-6,
141
+ use_cache=True,
142
+ calc_logits_for_entire_prompt=False,
143
+ output_router_logits=False,
144
+ router_aux_loss_coef=0.001,
145
+ pad_token_id=0,
146
+ bos_token_id=1,
147
+ eos_token_id=2,
148
+ sliding_window=None,
149
+ max_position_embeddings=262144,
150
+ orig_max_position_embeddings=None,
151
+ attention_dropout=0.0,
152
+ num_experts_per_tok=2,
153
+ num_experts=16,
154
+ expert_layer_period=2,
155
+ expert_layer_offset=1,
156
+ attn_layer_period=8,
157
+ attn_layer_offset=4,
158
+ use_mamba_kernels=True,
159
+ mamba_d_state=16,
160
+ mamba_d_conv=4,
161
+ mamba_expand=2,
162
+ mamba_dt_rank="auto",
163
+ mamba_conv_bias=True,
164
+ mamba_proj_bias=False,
165
+ mamba_inner_layernorms=True,
166
+
167
+ hybrid_decoder_layer='mamba',
168
+ share_kv=False,
169
+ double_v_dim=False,
170
+ compact_gating=False,
171
+ kv_reuse_every_i_layer=-1,
172
+ kv_reuse_group=None,
173
+ kv_weight_reuse=False,
174
+
175
+ num_ffn=1,
176
+ ffn_reuse_every_i_layer=-1,
177
+ attn_reuse_every_i_layer=-1,
178
+ mamba_reuse_every_i_layer=-1,
179
+
180
+ macro_arch='',
181
+
182
+ lookback_mode='',
183
+
184
+ shared_module_attn='',
185
+ shared_module_mamba='',
186
+
187
+ ffn_sharing_config=None,
188
+
189
+ sliding_window_size=None,
190
+ global_attn_idx=None,
191
+
192
+ num_mamba=1,
193
+ mamba_latent_size=None,
194
+
195
+ public_ffn_structure=False,
196
+ num_attn_per_ffn=3,
197
+ dense_public_ffn_structure=False,
198
+
199
+ local_global_dual_branch=False,
200
+ local_expand_ratio=1,
201
+ local_global_dual_branch_merge_op='mean',
202
+
203
+ mamba_multihead_config=None,
204
+
205
+ moe_config=None,
206
+
207
+ enable_mod=False,
208
+ mod_topk=2,
209
+
210
+ sequential_jamba=False,
211
+ fully_parallel_jamba=False,
212
+
213
+ attn_implementation_new='sdpa',
214
+
215
+ fused_multihead_config=None,
216
+
217
+ compute_attn_mat=False,
218
+ visual_attn=False,
219
+ save_input_output=False,
220
+
221
+ use_mamba2=False,
222
+ mamba2_headdim=64,
223
+
224
+ swa_full_head=False,
225
+
226
+ gradient_checkpoint_layer=None,
227
+
228
+ rope_type=None,
229
+
230
+ visual_entropy=False,
231
+
232
+ use_nemotron5=False,
233
+
234
+ use_nGPT=False,
235
+ nGPT_mode=None,
236
+ nGPT_config=None,
237
+
238
+ mamba_attnaug_config=None,
239
+
240
+ no_dt_bias=False,
241
+
242
+ hash_grid_config=None,
243
+
244
+ hash_grid_config_mlp=None,
245
+
246
+ repeat_ffn=None,
247
+
248
+ layer_types=None,
249
+
250
+ supernet_config=None,
251
+
252
+ pure_linear_attn=False,
253
+
254
+ self_attn_type=None,
255
+
256
+ other_args=None,
257
+
258
+ ffn_expand_ratio=None,
259
+
260
+ d_conv=4,
261
+
262
+ layerwise_memory_token=False,
263
+
264
+ **kwargs,
265
+ ):
266
+ self.vocab_size = vocab_size
267
+ self.tie_word_embeddings = tie_word_embeddings
268
+ self.hidden_size = hidden_size
269
+ self.intermediate_size = intermediate_size
270
+ self.num_hidden_layers = num_hidden_layers
271
+ self.num_attention_heads = num_attention_heads
272
+ self.sliding_window = sliding_window
273
+ self.max_position_embeddings = max_position_embeddings
274
+ self.orig_max_position_embeddings = orig_max_position_embeddings
275
+ self.attention_dropout = attention_dropout
276
+
277
+ # for backward compatibility
278
+ if num_key_value_heads is None:
279
+ num_key_value_heads = num_attention_heads
280
+
281
+ self.num_key_value_heads = num_key_value_heads
282
+ self.hidden_act = hidden_act
283
+ self.initializer_range = initializer_range
284
+ self.rms_norm_eps = rms_norm_eps
285
+
286
+ self.use_cache = use_cache
287
+ self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
288
+ self.output_router_logits = output_router_logits
289
+ self.router_aux_loss_coef = router_aux_loss_coef
290
+
291
+ self.num_experts_per_tok = num_experts_per_tok
292
+ self.num_experts = num_experts
293
+ self.expert_layer_period = expert_layer_period
294
+ self.expert_layer_offset = expert_layer_offset
295
+ self.attn_layer_period = attn_layer_period
296
+ self.attn_layer_offset = attn_layer_offset
297
+
298
+ self.use_mamba_kernels = use_mamba_kernels
299
+ self.mamba_d_state = mamba_d_state
300
+ self.mamba_d_conv = mamba_d_conv
301
+ self.mamba_expand = mamba_expand
302
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
303
+ self.mamba_conv_bias = mamba_conv_bias
304
+ self.mamba_proj_bias = mamba_proj_bias
305
+ self.mamba_inner_layernorms = mamba_inner_layernorms
306
+
307
+ # added by Xin
308
+ self.reduce_method = kwargs.pop("reduce_method", "mean")
309
+ self.hybrid_block_indices = kwargs.pop("hybrid_block_indices", [])
310
+ self.reduce_attn_ratio = kwargs.pop("reduce_attn_ratio", 0.5)
311
+ self.attn_hidden_size = kwargs.pop("attn_hidden_size", -1)
312
+ self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
313
+ self.v_head_dim = kwargs.pop("v_head_dim", -1)
314
+ self.kq_norm = kwargs.pop("kq_norm", None)
315
+ self.rope = kwargs.pop("rope", False)
316
+ self.rope_theta = kwargs.pop("rope_theta", 10000.0)
317
+ self.num_memory_tokens = kwargs.pop("num_memory_tokens", 0)
318
+ self.memory_tokens_interspersed_every = kwargs.pop("memory_tokens_interspersed_every", 0)
319
+
320
+ #! adhoc change
321
+ self.new_seq_length = 2048
322
+ self.visual_entropy = kwargs.pop("visual_entropy", False)
323
+
324
+ self.hybrid_decoder_layer = hybrid_decoder_layer
325
+ self.share_kv = share_kv
326
+ self.double_v_dim = double_v_dim
327
+ self.compact_gating = compact_gating
328
+ self.kv_reuse_every_i_layer = kv_reuse_every_i_layer
329
+ self.kv_reuse_group = kv_reuse_group
330
+ self.kv_weight_reuse = kv_weight_reuse
331
+
332
+ self.num_ffn = num_ffn
333
+ self.ffn_reuse_every_i_layer = ffn_reuse_every_i_layer
334
+ self.attn_reuse_every_i_layer = attn_reuse_every_i_layer
335
+ self.mamba_reuse_every_i_layer = mamba_reuse_every_i_layer
336
+
337
+ self.macro_arch = macro_arch
338
+
339
+ self.lookback_mode = lookback_mode
340
+
341
+ self.shared_module_attn = shared_module_attn
342
+ self.shared_module_mamba = shared_module_mamba
343
+
344
+ self.ffn_sharing_config = ffn_sharing_config
345
+
346
+ self.sliding_window_size = sliding_window_size
347
+ self.global_attn_idx = global_attn_idx
348
+
349
+ self.num_mamba = num_mamba
350
+
351
+ self.mamba_latent_size = mamba_latent_size
352
+
353
+ self.public_ffn_structure = public_ffn_structure
354
+ self.num_attn_per_ffn = num_attn_per_ffn
355
+ self.dense_public_ffn_structure = dense_public_ffn_structure
356
+
357
+ self.local_global_dual_branch = local_global_dual_branch
358
+ self.local_expand_ratio = local_expand_ratio
359
+ self.local_global_dual_branch_merge_op = local_global_dual_branch_merge_op
360
+
361
+ self.mamba_multihead_config = mamba_multihead_config
362
+
363
+ self.moe_config = moe_config
364
+
365
+ self.enable_mod = enable_mod
366
+ self.mod_topk = mod_topk
367
+
368
+ self.sequential_jamba = sequential_jamba
369
+ self.fully_parallel_jamba = fully_parallel_jamba
370
+
371
+ self.attn_implementation_new = attn_implementation_new
372
+
373
+ self.fused_multihead_config = fused_multihead_config
374
+
375
+ self.compute_attn_mat = compute_attn_mat
376
+ self.visual_attn = visual_attn
377
+ self.save_input_output = save_input_output
378
+
379
+ self.use_mamba2 = use_mamba2
380
+ self.mamba2_headdim = mamba2_headdim
381
+
382
+ self.swa_full_head = swa_full_head
383
+
384
+ self.gradient_checkpoint_layer = gradient_checkpoint_layer
385
+
386
+ self.rope_type = rope_type
387
+
388
+ self.visual_entropy = visual_entropy
389
+
390
+ self.use_nemotron5 = use_nemotron5
391
+
392
+ self.use_nGPT = use_nGPT
393
+ self.nGPT_mode = nGPT_mode
394
+
395
+ self.mamba_attnaug_config = mamba_attnaug_config
396
+
397
+ self.no_dt_bias = no_dt_bias
398
+
399
+ self.nGPT_config = nGPT_config
400
+
401
+ self.hash_grid_config = hash_grid_config
402
+
403
+ self.hash_grid_config_mlp = hash_grid_config_mlp
404
+
405
+ self.repeat_ffn = repeat_ffn
406
+
407
+ self.layer_types = layer_types
408
+
409
+ self.supernet_config = supernet_config
410
+
411
+ self.pure_linear_attn = pure_linear_attn
412
+
413
+ self.self_attn_type = self_attn_type
414
+
415
+ self.other_args = other_args
416
+
417
+ self.ffn_expand_ratio = ffn_expand_ratio
418
+
419
+ self.d_conv = d_conv
420
+
421
+ self.layerwise_memory_token = layerwise_memory_token
422
+
423
+ super().__init__(
424
+ pad_token_id=pad_token_id,
425
+ bos_token_id=bos_token_id,
426
+ eos_token_id=eos_token_id,
427
+ tie_word_embeddings=tie_word_embeddings,
428
+ **kwargs,
429
+ )
delta_net.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ config = None,
90
+ **kwargs
91
+ ) -> DeltaNet:
92
+ super().__init__()
93
+
94
+ self.mode = mode
95
+ self.qk_activation = qk_activation
96
+ self.qk_norm = qk_norm
97
+
98
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
99
+ assert self.qk_norm in ['l2', 'sum']
100
+
101
+ self.config = config
102
+ if self.config is not None and self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
103
+ self.weight_norm = True
104
+ else:
105
+ self.weight_norm = False
106
+
107
+ if d_model is not None:
108
+ hidden_size = d_model
109
+ self.hidden_size = hidden_size
110
+ self.expand_k = expand_k
111
+ self.expand_v = expand_v
112
+ self.num_heads = num_heads
113
+ self.use_gate = use_gate
114
+ self.use_short_conv = use_short_conv
115
+ self.conv_size = conv_size
116
+ self.conv_bias = conv_bias
117
+ self.allow_neg_eigval = allow_neg_eigval
118
+
119
+ self.key_dim = int(hidden_size * expand_k)
120
+ self.value_dim = int(hidden_size * expand_v)
121
+ self.head_k_dim = self.key_dim // num_heads
122
+ self.head_v_dim = self.value_dim // num_heads
123
+ self.layer_idx = layer_idx
124
+
125
+ self.silu = nn.SiLU()
126
+ if mode == 'fused_chunk':
127
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
130
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
131
+
132
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
133
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
134
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
135
+
136
+ self.use_beta = use_beta
137
+ if self.use_beta:
138
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
139
+ if use_short_conv:
140
+ self.conv_size = conv_size
141
+ self.q_conv1d = ShortConvolution(
142
+ hidden_size=self.key_dim,
143
+ kernel_size=conv_size,
144
+ activation='silu' if qk_activation == 'silu' else None
145
+ )
146
+ self.k_conv1d = ShortConvolution(
147
+ hidden_size=self.key_dim,
148
+ kernel_size=conv_size,
149
+ activation='silu' if qk_activation == 'silu' else None
150
+ )
151
+ self.v_conv1d = ShortConvolution(
152
+ hidden_size=self.value_dim,
153
+ kernel_size=conv_size,
154
+ activation='silu'
155
+ )
156
+ else:
157
+ raise UserWarning(
158
+ "ShortConvolution is crucial to the performance. "
159
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
160
+ )
161
+ if use_gate:
162
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
163
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
164
+ else:
165
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
166
+
167
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
168
+
169
+ self.apply(self._initialize_weights)
170
+
171
+ def _initialize_weights(self, module: nn.Module):
172
+ if getattr(module, "_is_hf_initialized", False):
173
+ return
174
+ if isinstance(module, nn.Linear):
175
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
176
+ if module.bias is not None:
177
+ nn.init.zeros_(module.bias)
178
+ module._is_hf_initialized = True
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ attention_mask: Optional[torch.Tensor] = None,
184
+ past_key_values: Optional[Cache] = None,
185
+ use_cache: Optional[bool] = False,
186
+ output_attentions: Optional[bool] = False,
187
+ **kwargs: Unpack[Dict]
188
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
189
+ if attention_mask is not None:
190
+ assert len(attention_mask.shape) == 2, (
191
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
192
+ "for padding purposes (0 indicating padding). "
193
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
194
+ )
195
+
196
+ # change to inference mode.
197
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
198
+
199
+ last_state = None
200
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
201
+ last_state = past_key_values[self.layer_idx]
202
+
203
+ if self.use_short_conv:
204
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
205
+ if last_state is not None:
206
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
207
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
208
+ position_ids = kwargs.get('position_ids', None)
209
+
210
+ q = self.q_proj(hidden_states)
211
+ if self.weight_norm:
212
+ q = q / self.q_proj.weight.norm(p=2, dim=1)
213
+
214
+ q, conv_state_q = self.q_conv1d(x=q,
215
+ mask=conv_mask,
216
+ cache=conv_state_q,
217
+ output_final_state=use_cache,
218
+ seq_idx=position_ids)
219
+
220
+ k = self.k_proj(hidden_states)
221
+ if self.weight_norm:
222
+ k = k / self.k_proj.weight.norm(p=2, dim=1)
223
+ k, conv_state_k = self.k_conv1d(x=k,
224
+ mask=conv_mask,
225
+ cache=conv_state_k,
226
+ output_final_state=use_cache,
227
+ seq_idx=position_ids)
228
+
229
+ v = self.v_proj(hidden_states)
230
+ if self.weight_norm:
231
+ v = v / self.v_proj.weight.norm(p=2, dim=1)
232
+ v, conv_state_v = self.v_conv1d(x=v,
233
+ mask=conv_mask,
234
+ cache=conv_state_v,
235
+ output_final_state=use_cache,
236
+ seq_idx=position_ids)
237
+ else:
238
+ q = self.q_proj(hidden_states)
239
+ k = self.k_proj(hidden_states)
240
+ v = self.v_proj(hidden_states)
241
+
242
+ if self.weight_norm:
243
+ q = q / self.q_proj.weight.norm(p=2, dim=1)
244
+ k = k / self.k_proj.weight.norm(p=2, dim=1)
245
+ v = v / self.v_proj.weight.norm(p=2, dim=1)
246
+
247
+ if self.qk_activation == 'silu':
248
+ q, k = self.silu(q), self.silu(k)
249
+
250
+ v = self.silu(v)
251
+
252
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
253
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
254
+ if self.qk_activation != 'silu':
255
+ if self.qk_activation == 'relu':
256
+ q, k = q.relu(), k.relu()
257
+ elif self.qk_activation == 'elu':
258
+ q, k = elu_p1(q), elu_p1(k)
259
+ elif self.qk_activation == 'identity':
260
+ pass
261
+ else:
262
+ raise NotImplementedError
263
+
264
+ if self.qk_norm == 'sum':
265
+ q = sum_norm(q).to(q)
266
+ k = sum_norm(k).to(k)
267
+
268
+ if self.use_beta:
269
+ beta = self.b_proj(hidden_states)
270
+
271
+ if self.weight_norm:
272
+ beta = beta / self.b_proj.weight.norm(p=2, dim=1)
273
+
274
+ beta = beta.sigmoid()
275
+ else:
276
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
277
+
278
+ if self.allow_neg_eigval:
279
+ beta = beta * 2.
280
+
281
+ # dealing with padding
282
+ if attention_mask is not None:
283
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
284
+
285
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
286
+ cu_seqlens = kwargs.get('cu_seqlens', None)
287
+ if mode == 'fused_recurrent':
288
+ o, recurrent_state = fused_recurrent_delta_rule(
289
+ q=q,
290
+ k=k,
291
+ v=v,
292
+ beta=beta,
293
+ initial_state=recurrent_state,
294
+ output_final_state=use_cache,
295
+ cu_seqlens=cu_seqlens,
296
+ head_first=False,
297
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
298
+ )
299
+ elif mode == 'chunk':
300
+ o, recurrent_state = chunk_delta_rule(
301
+ q=q,
302
+ k=k,
303
+ v=v,
304
+ beta=beta,
305
+ initial_state=recurrent_state,
306
+ output_final_state=use_cache,
307
+ cu_seqlens=cu_seqlens,
308
+ head_first=False,
309
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
310
+ )
311
+ else:
312
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
313
+
314
+ if past_key_values is not None:
315
+ past_key_values.update(
316
+ recurrent_state=recurrent_state,
317
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
318
+ layer_idx=self.layer_idx,
319
+ offset=q.shape[1]
320
+ )
321
+
322
+ if self.use_gate:
323
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
324
+ o = self.o_norm(o, g)
325
+ else:
326
+ o = self.o_norm(o)
327
+ o = rearrange(o, 'b t h d -> b t (h d)')
328
+ o = self.o_proj(o)
329
+
330
+ if self.weight_norm:
331
+ o = o / self.o_proj.weight.norm(p=2, dim=0)
332
+
333
+ return o, None, past_key_values
gated_deltanet.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import (chunk_gated_delta_rule,
16
+ fused_recurrent_gated_delta_rule)
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ def sum_norm(x):
29
+ return (x / x.sum(-1, keepdim=True)).to(x)
30
+
31
+ # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
32
+
33
+
34
+ class GatedDeltaNet(nn.Module):
35
+ """
36
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
37
+
38
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ config = None,
95
+ **kwargs
96
+ ) -> GatedDeltaNet:
97
+ super().__init__()
98
+
99
+ self.config = config
100
+ if self.config is not None and self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
101
+ self.weight_norm = True
102
+ else:
103
+ self.weight_norm = False
104
+
105
+ self.mode = mode
106
+
107
+ self.hidden_size = hidden_size
108
+ self.expand_v = expand_v
109
+
110
+ self.use_gate = use_gate
111
+ self.use_short_conv = use_short_conv
112
+ self.conv_size = conv_size
113
+ self.conv_bias = conv_bias
114
+
115
+ self.head_dim = head_dim
116
+ self.num_heads = num_heads
117
+
118
+ self.key_dim = self.num_heads * self.head_dim
119
+ self.value_dim = self.key_dim * self.expand_v
120
+ self.head_k_dim = head_dim
121
+ self.head_v_dim = head_dim * self.expand_v
122
+ self.layer_idx = layer_idx
123
+ self.silu = nn.SiLU()
124
+
125
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
126
+
127
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
128
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
129
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
130
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
131
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
133
+ A_log = torch.log(A)
134
+ self.A_log = nn.Parameter(A_log)
135
+ self.A_log._no_weight_decay = True
136
+ self.D = nn.Parameter(torch.ones(self.num_heads))
137
+ self.D._no_weight_decay = True
138
+ # hard coded for now
139
+ dt_min = 0.001
140
+ dt_max = 0.1
141
+ dt_init_floor = 1e-4
142
+ dt = torch.exp(
143
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
144
+ + math.log(dt_min)
145
+ )
146
+ dt = torch.clamp(dt, min=dt_init_floor)
147
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
148
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
149
+ self.dt_bias = nn.Parameter(inv_dt)
150
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
151
+ # name.endswith("bias") in param_grouping.py
152
+ self.dt_bias._no_weight_decay = True
153
+
154
+ if use_short_conv:
155
+ self.conv_size = conv_size
156
+ self.q_conv1d = ShortConvolution(
157
+ hidden_size=self.key_dim,
158
+ kernel_size=conv_size,
159
+ activation='silu'
160
+ )
161
+ self.k_conv1d = ShortConvolution(
162
+ hidden_size=self.key_dim,
163
+ kernel_size=conv_size,
164
+ activation='silu'
165
+ )
166
+ self.v_conv1d = ShortConvolution(
167
+ hidden_size=self.value_dim,
168
+ kernel_size=conv_size,
169
+ activation='silu'
170
+ )
171
+ else:
172
+ raise UserWarning(
173
+ "ShortConvolution is crucial to the performance. "
174
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
175
+ )
176
+ if use_gate:
177
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
178
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
179
+ else:
180
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
181
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
182
+ self.apply(self._initialize_weights)
183
+
184
+ def _initialize_weights(self, module: nn.Module):
185
+ if getattr(module, "_is_hf_initialized", False):
186
+ return
187
+ if isinstance(module, nn.Linear):
188
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
189
+ if module.bias is not None:
190
+ nn.init.zeros_(module.bias)
191
+ module._is_hf_initialized = True
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ past_key_values: Optional[Cache] = None,
198
+ use_cache: Optional[bool] = False,
199
+ output_attentions: Optional[bool] = False,
200
+ **kwargs: Unpack[Dict]
201
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
202
+ if attention_mask is not None:
203
+ assert len(attention_mask.shape) == 2, (
204
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
205
+ "for padding purposes (0 indicating padding). "
206
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
207
+ )
208
+
209
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
210
+ if self.training:
211
+ assert mode == 'chunk', "Only chunk mode is supported in training."
212
+
213
+ last_state = None
214
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
215
+ last_state = past_key_values[self.layer_idx]
216
+
217
+ if self.use_short_conv:
218
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
219
+ if last_state is not None:
220
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
221
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
222
+ position_ids = kwargs.get('position_ids', None)
223
+
224
+ q = self.q_proj(hidden_states)
225
+ if self.weight_norm:
226
+ q = q / self.q_proj.weight.norm(p=2, dim=1)
227
+ q, conv_state_q = self.q_conv1d(x=q,
228
+ mask=conv_mask,
229
+ cache=conv_state_q,
230
+ output_final_state=use_cache,
231
+ seq_idx=position_ids)
232
+
233
+ k = self.k_proj(hidden_states)
234
+ if self.weight_norm:
235
+ k = k / self.k_proj.weight.norm(p=2, dim=1)
236
+ k, conv_state_k = self.k_conv1d(x=k,
237
+ mask=conv_mask,
238
+ cache=conv_state_k,
239
+ output_final_state=use_cache,
240
+ seq_idx=position_ids)
241
+
242
+ v = self.v_proj(hidden_states)
243
+ if self.weight_norm:
244
+ v = v / self.v_proj.weight.norm(p=2, dim=1)
245
+ v, conv_state_v = self.v_conv1d(x=v,
246
+ mask=conv_mask,
247
+ cache=conv_state_v,
248
+ output_final_state=use_cache,
249
+ seq_idx=position_ids)
250
+
251
+ else:
252
+ q = self.q_proj(hidden_states)
253
+ k = self.k_proj(hidden_states)
254
+ v = self.v_proj(hidden_states)
255
+
256
+ if self.weight_norm:
257
+ q = q / self.q_proj.weight.norm(p=2, dim=1)
258
+ k = k / self.k_proj.weight.norm(p=2, dim=1)
259
+ v = v / self.v_proj.weight.norm(p=2, dim=1)
260
+
261
+ q, k, v = self.silu(q), self.silu(k), self.silu(v)
262
+
263
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
264
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
265
+
266
+ beta = self.b_proj(hidden_states)
267
+ if self.weight_norm:
268
+ beta = beta / self.b_proj.weight.norm(p=2, dim=1)
269
+ beta = beta.sigmoid()
270
+
271
+ a_val = self.a_proj(hidden_states)
272
+ if self.weight_norm:
273
+ a_val = a_val / self.a_proj.weight.norm(p=2, dim=1)
274
+ g = -self.A_log.float().exp() * F.softplus(a_val.float() + self.dt_bias)
275
+
276
+ # dealing with padding
277
+ if attention_mask is not None:
278
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
279
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
280
+
281
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
282
+ cu_seqlens = kwargs.get('cu_seqlens', None)
283
+ if mode == 'chunk':
284
+ o, recurrent_state = chunk_gated_delta_rule(
285
+ q=q,
286
+ k=k,
287
+ v=v,
288
+ g=g,
289
+ beta=beta,
290
+ initial_state=recurrent_state,
291
+ output_final_state=use_cache,
292
+ cu_seqlens=cu_seqlens,
293
+ head_first=False,
294
+ use_qk_l2norm_in_kernel=True
295
+ )
296
+ elif mode == 'fused_recurrent':
297
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
298
+ q=q,
299
+ k=k,
300
+ v=v,
301
+ g=g,
302
+ beta=beta,
303
+ initial_state=recurrent_state,
304
+ output_final_state=use_cache,
305
+ cu_seqlens=cu_seqlens,
306
+ head_first=False,
307
+ use_qk_l2norm_in_kernel=True
308
+ )
309
+ if past_key_values is not None:
310
+ past_key_values.update(
311
+ recurrent_state=recurrent_state,
312
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
313
+ layer_idx=self.layer_idx,
314
+ offset=q.shape[1]
315
+ )
316
+
317
+ if self.use_gate:
318
+ gate_val = self.g_proj(hidden_states)
319
+
320
+ # if self.weight_norm:
321
+ # gate_val = gate_val / self.g_proj.weight.norm(p=2, dim=1)
322
+
323
+ g = rearrange(gate_val, '... (h d) -> ... h d', d=self.head_v_dim)
324
+ o = self.o_norm(o, g)
325
+ else:
326
+ o = self.o_norm(o)
327
+ o = rearrange(o, 'b t h d -> b t (h d)')
328
+ o = self.o_proj(o)
329
+
330
+ if self.weight_norm:
331
+ o = o / self.o_proj.weight.norm(p=2, dim=0)
332
+
333
+ return o, None, past_key_values
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.48.2",
7
+ "use_cache": false
8
+ }
mamba2.py ADDED
@@ -0,0 +1,1427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat, pack, unpack
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ import sys
22
+ # sys.path.insert(0, '/lustre/fsw/portfolios/nvr/users/yongganf/TLM/')
23
+
24
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
25
+ from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
26
+
27
+
28
+ from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+
35
+ class Mamba2(nn.Module):
36
+ def __init__(
37
+ self,
38
+ config,
39
+ conv_init=None,
40
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
41
+ ngroups=1,
42
+ A_init_range=(1, 16),
43
+ D_has_hdim=False,
44
+ rmsnorm=True,
45
+ norm_before_gate=False,
46
+ dt_min=0.001,
47
+ dt_max=0.1,
48
+ dt_init_floor=1e-4,
49
+ dt_limit=(0.0, float("inf")),
50
+ bias=False,
51
+ conv_bias=True,
52
+ # Fused kernel and sharding options
53
+ chunk_size=256,
54
+ use_mem_eff_path=False, # True,
55
+ layer_idx=None, # Absorb kwarg for general module
56
+ process_group=None,
57
+ sequence_parallel=True,
58
+ device=None,
59
+ dtype=None,
60
+ ):
61
+ factory_kwargs = {"device": device, "dtype": dtype}
62
+ super().__init__()
63
+
64
+ self.config = config
65
+ self.d_model = config.hidden_size
66
+ self.d_state = config.mamba_d_state
67
+ self.d_conv = config.mamba_d_conv
68
+
69
+ self.conv_init = conv_init
70
+ self.expand = config.mamba_expand
71
+ self.process_group = process_group
72
+ self.sequence_parallel = sequence_parallel
73
+ self.world_size = 1 if process_group is None else process_group.size()
74
+ self.local_rank = 0 if process_group is None else process_group.rank()
75
+ self.d_inner = (self.expand * self.d_model) // self.world_size
76
+ assert self.d_inner * self.world_size == self.expand * self.d_model
77
+ self.headdim = config.mamba2_headdim
78
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
79
+ assert ngroups % self.world_size == 0
80
+ self.ngroups = ngroups // self.world_size
81
+ assert self.d_ssm % self.headdim == 0
82
+ self.nheads = self.d_ssm // self.headdim
83
+ self.D_has_hdim = D_has_hdim
84
+ self.rmsnorm = rmsnorm
85
+ self.norm_before_gate = norm_before_gate
86
+ self.dt_limit = dt_limit
87
+ self.activation = "silu"
88
+ self.chunk_size = chunk_size
89
+ self.use_mem_eff_path = use_mem_eff_path
90
+ self.layer_idx = layer_idx
91
+
92
+ assert (self.d_model * self.expand / self.headdim) % 8 == 0
93
+
94
+ # Order: [z, x, B, C, dt]
95
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
96
+ if self.process_group is None:
97
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
98
+ else:
99
+ self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
100
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
101
+ **factory_kwargs)
102
+
103
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
104
+ self.conv1d = nn.Conv1d(
105
+ in_channels=conv_dim,
106
+ out_channels=conv_dim,
107
+ bias=conv_bias,
108
+ kernel_size=self.d_conv,
109
+ groups=conv_dim,
110
+ padding=self.d_conv - 1,
111
+ **factory_kwargs,
112
+ )
113
+ if self.conv_init is not None:
114
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
115
+
116
+ self.act = nn.SiLU()
117
+
118
+ # Initialize log dt bias
119
+ dt = torch.exp(
120
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
121
+ + math.log(dt_min)
122
+ )
123
+ dt = torch.clamp(dt, min=dt_init_floor)
124
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
125
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
126
+
127
+ if config.no_dt_bias:
128
+ self.dt_bias = None
129
+ else:
130
+ self.dt_bias = nn.Parameter(inv_dt)
131
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
132
+ # name.endswith("bias") in param_grouping.py
133
+ self.dt_bias._no_weight_decay = True
134
+
135
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
136
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
137
+ A_log = torch.log(A).to(dtype=dtype)
138
+ self.A_log = nn.Parameter(A_log)
139
+ self.A_log._no_weight_decay = True
140
+
141
+ # D "skip" parameter
142
+ self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
143
+ self.D._no_weight_decay = True
144
+
145
+ if self.rmsnorm:
146
+ assert RMSNormGated is not None
147
+ self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
148
+ group_size=self.d_ssm // ngroups, **factory_kwargs)
149
+
150
+ if self.process_group is None:
151
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
152
+ else:
153
+ self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
154
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
155
+ **factory_kwargs)
156
+
157
+ self.mamba_multihead_config = config.mamba_multihead_config
158
+ if self.mamba_multihead_config is not None:
159
+ assert self.mamba_multihead_config['alpha_mode'] == 'sparsity' or self.mamba_multihead_config['alpha_mode'] == 'cummax'
160
+
161
+ if self.mamba_multihead_config['alpha_mode'] == 'cummax':
162
+ self.learned_dt_scale = nn.Parameter(torch.ones(1, device=device))
163
+
164
+ if self.mamba_multihead_config['alpha_mode'] == 'sparsity':
165
+ if 'use_learned_thres' in self.mamba_multihead_config and self.mamba_multihead_config['use_learned_thres']:
166
+ self.learned_thres = nn.Parameter(torch.zeros(self.nheads, device=device))
167
+ self.smooth_factor = self.mamba_multihead_config['smooth_factor']
168
+ self.detach_dt = self.mamba_multihead_config['detach_dt']
169
+
170
+ if 'use_cummax' in self.mamba_multihead_config and self.mamba_multihead_config['use_cummax']:
171
+ self.use_cummax = True
172
+ self.cummax_lower_bound = self.mamba_multihead_config['cummax_lower_bound']
173
+ else:
174
+ self.use_cummax = False
175
+
176
+ else:
177
+ self.learned_thres = None
178
+ self.smooth_factor = None
179
+ self.detach_dt = None
180
+
181
+ self.sparsity_split = self.mamba_multihead_config['sparsity_split']
182
+ self.sparsity_ratio = self.mamba_multihead_config['sparsity_ratio']
183
+
184
+ if self.config.layerwise_memory_token:
185
+ assert self.config.num_memory_tokens > 0
186
+ self.memory_tokens = nn.Parameter(torch.randn(self.config.num_memory_tokens, self.config.hidden_size))
187
+ else:
188
+ self.memory_tokens = None
189
+
190
+
191
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
192
+ """
193
+ hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
194
+ If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
195
+ split hidden_states during sequence parallel, we split the batch * seqlen dimension
196
+ (in case batch is small).
197
+ Returns: same shape as u
198
+ """
199
+ # assert past_key_value is None, "Not implemented yet!!!"
200
+
201
+ if self.memory_tokens is not None:
202
+ hidden_states = hidden_states[:,self.config.num_memory_tokens:,...]
203
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = hidden_states.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
204
+ hidden_states, mem_packed_shape = pack((mem, hidden_states), 'b * d')
205
+
206
+ seqlen_og = seqlen
207
+ if seqlen is None:
208
+ batch, seqlen, dim = hidden_states.shape
209
+ else:
210
+ batch_seqlen, dim = hidden_states.shape
211
+ batch = batch_seqlen // seqlen
212
+
213
+ conv_state, ssm_state = None, None
214
+ if inference_params is not None:
215
+ inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
216
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
217
+ if inference_params.seqlen_offset > 0:
218
+ # The states are updated inplace
219
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
220
+ return out
221
+
222
+ zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
223
+
224
+ if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
225
+ zxbcdt = zxbcdt / self.in_proj.weight.norm(p=2, dim=1)
226
+
227
+ if seqlen_og is not None:
228
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
229
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
230
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
231
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
232
+ if self.use_mem_eff_path and inference_params is None:
233
+ out = mamba_split_conv1d_scan_combined(
234
+ zxbcdt,
235
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
236
+ self.conv1d.bias,
237
+ self.dt_bias,
238
+ A,
239
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
240
+ chunk_size=self.chunk_size,
241
+ seq_idx=seq_idx,
242
+ activation=self.activation,
243
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
244
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
245
+ outproj_weight=self.out_proj.weight,
246
+ outproj_bias=self.out_proj.bias,
247
+ headdim=None if self.D_has_hdim else self.headdim,
248
+ ngroups=self.ngroups,
249
+ norm_before_gate=self.norm_before_gate,
250
+ **dt_limit_kwargs,
251
+ )
252
+ if seqlen_og is not None:
253
+ out = rearrange(out, "b l d -> (b l) d")
254
+ if self.process_group is not None:
255
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
256
+ out = reduce_fn(out, self.process_group)
257
+ else:
258
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
259
+ z0, x0, z, xBC, dt = torch.split(
260
+ zxbcdt,
261
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
262
+ dim=-1
263
+ )
264
+ if conv_state is not None:
265
+ if cu_seqlens is None:
266
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
267
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
268
+ xBC_t = rearrange(xBC, "b l d -> b d l")
269
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
270
+ else:
271
+ assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
272
+ assert batch == 1, "varlen inference only supports batch dimension 1"
273
+ conv_varlen_states = causal_conv1d_varlen_states(
274
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
275
+ )
276
+ conv_state.copy_(conv_varlen_states)
277
+ assert self.activation in ["silu", "swish"]
278
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
279
+ assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
280
+ xBC = self.act(
281
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
282
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
283
+ else:
284
+ xBC = causal_conv1d_fn(
285
+ xBC.transpose(1, 2),
286
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
287
+ bias=self.conv1d.bias,
288
+ activation=self.activation,
289
+ # seq_idx=seq_idx,
290
+ ).transpose(1, 2)
291
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
292
+
293
+ no_dt_bias = False
294
+ if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'cummax': ### todo: implement this in the fused kernel
295
+ dt = dt + self.dt_bias
296
+ dt = torch.nn.functional.softmax(dt, dim=-1)
297
+ dt = torch.cumsum(dt, dim=-1)
298
+ dt = dt * self.learned_dt_scale
299
+
300
+ no_dt_bias = True
301
+
302
+ if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'sparsity':
303
+ dt = dt + self.dt_bias
304
+
305
+ if self.learned_thres is not None:
306
+ dt = self.sparsify_learned_thres(dt)
307
+ else:
308
+ dt = self.split_and_sparsify(dt, self.sparsity_split, self.sparsity_ratio)
309
+
310
+ no_dt_bias = True
311
+
312
+
313
+ y = mamba_chunk_scan_combined(
314
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
315
+ dt,
316
+ A,
317
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
318
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
319
+ chunk_size=self.chunk_size,
320
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
321
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
322
+ dt_bias=self.dt_bias if not no_dt_bias else None,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None and inference_params is not None,
329
+ )
330
+ if ssm_state is not None:
331
+ y, last_state, *rest = y
332
+ if cu_seqlens is None:
333
+ ssm_state.copy_(last_state)
334
+ else:
335
+ varlen_states = rest[0]
336
+ ssm_state.copy_(varlen_states)
337
+ y = rearrange(y, "b l h p -> b l (h p)")
338
+ if self.rmsnorm:
339
+ y = self.norm(y, z)
340
+ if d_mlp > 0:
341
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
342
+ if seqlen_og is not None:
343
+ y = rearrange(y, "b l d -> (b l) d")
344
+
345
+ if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']:
346
+ y = y / self.out_proj.weight.norm(p=2, dim=0)
347
+
348
+ out = self.out_proj(y)
349
+
350
+ return out, past_key_value
351
+
352
+
353
+ def sparsify_learned_thres(self, dt):
354
+ """
355
+ Args:
356
+ dt: Tensor of shape [bs, seq_len, nheads]
357
+ Returns:
358
+ pruned_dt: Pruned tensor with the same shape as dt
359
+ """
360
+ # Compute sigmoid scores
361
+
362
+ if self.use_cummax:
363
+ learned_thres = torch.nn.functional.softmax(self.learned_thres, dim=-1)
364
+ learned_thres = torch.cumsum(learned_thres, dim=-1) - self.cummax_lower_bound ## keep the dt_normalized larger than 1 - self.cummax_lower_bound
365
+
366
+ dt_normalized = (dt - dt.min(dim=-1, keepdim=True)[0]) / (dt.max(dim=-1, keepdim=True)[0] - dt.min(dim=-1, keepdim=True)[0])
367
+
368
+ scores = torch.sigmoid((dt_normalized.detach() - self.learned_thres) / self.smooth_factor)
369
+
370
+ else:
371
+ if self.detach_dt:
372
+ scores = torch.sigmoid((dt.detach() - self.learned_thres) / self.smooth_factor)
373
+ else:
374
+ scores = torch.sigmoid((dt - self.learned_thres) / self.smooth_factor)
375
+
376
+ # Generate binary mask for pruning (forward pass)
377
+ mask = (scores >= 0.5).float()
378
+
379
+ # Apply mask in the forward pass and backward using sigmoid
380
+ pruned_dt = (dt * mask - dt * scores).detach() + dt * scores
381
+
382
+ # print(pruned_dt.mean())
383
+
384
+ return pruned_dt
385
+
386
+
387
+ def split_and_sparsify(self, dt, sparsity_split, sparsity_ratio):
388
+ """
389
+ dt: a torch.Tensor of shape [bs, seq_len, dim]
390
+ sparsity_split: list of ratios (e.g., [0.4, 0.3, 0.3]) that sum to 1
391
+ and define how to split dt along the last dimension
392
+ sparsity_ratio: list of ratios (e.g., [0.2, 0.5, 0.3]) that sum to 1
393
+ and define how many time steps (along seq_len) to keep
394
+ """
395
+ bs, seq_len, dim = dt.shape
396
+
397
+ assert sum(sparsity_split) == 1
398
+
399
+ # Compute the exact split sizes (watching out for integer rounding)
400
+ split_sizes = [int(r * dim) for r in sparsity_split]
401
+ # Fix potential off-by-one rounding in the last split
402
+ split_sizes[-1] = dim - sum(split_sizes[:-1])
403
+
404
+ # Split the original tensor along the last dimension
405
+ splitted_tensors = torch.split(dt, split_sizes, dim=-1)
406
+
407
+ results = []
408
+ for i, sub_tensor in enumerate(splitted_tensors):
409
+ # sub_tensor has shape [bs, seq_len, split_dim_i]
410
+ k = int(sparsity_ratio[i] * seq_len)
411
+
412
+ ### Strategy 1: keep at least one token
413
+ k = max(k, 1)
414
+
415
+ ### Strategy 2: the #tokens is the same as training
416
+ # if self.config.orig_max_position_embeddings is not None:
417
+ # k = int(self.config.orig_max_position_embeddings * self.sparsity_ratio[i])
418
+ # else:
419
+ # assert self.config.max_position_embeddings is not None
420
+ # k = int(self.config.max_position_embeddings * self.sparsity_ratio[i])
421
+
422
+ # k = min(seq_len, k)
423
+
424
+ # print(self.config.max_position_embeddings, sparsity_ratio[i], seq_len, k)
425
+
426
+ # 1) Average over the feature dimension (the last dim),
427
+ # resulting in shape [bs, seq_len]
428
+ averaged_values = sub_tensor.mean(dim=-1)
429
+
430
+ # 2) Get top-k indices (along seq_len = dim=1)
431
+ topk_values, _ = torch.topk(averaged_values, k=k, dim=1)
432
+ # The smallest value among the top-k per batch element
433
+ threshold = topk_values[:, -1].unsqueeze(-1) # shape [bs, 1]
434
+
435
+ # 3) Create a mask of shape [bs, seq_len] => True if >= threshold
436
+ averaged_mask = (averaged_values >= threshold)
437
+
438
+ # 4) Expand that mask back to [bs, seq_len, split_dim_i]
439
+ mask_3d = averaged_mask.unsqueeze(-1).expand_as(sub_tensor)
440
+
441
+ # 5) Zero out everything that is not in top-k
442
+ sparsified_sub = sub_tensor * mask_3d
443
+
444
+ # print((sparsified_sub == 0).float().mean().item())
445
+ # input()
446
+
447
+ results.append(sparsified_sub)
448
+
449
+ # Concatenate the results back along the last dimension
450
+ output = torch.cat(results, dim=-1)
451
+ return output
452
+
453
+ def step(self, hidden_states, conv_state, ssm_state):
454
+ dtype = hidden_states.dtype
455
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
456
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
457
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
458
+ z0, x0, z, xBC, dt = torch.split(
459
+ zxbcdt,
460
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
461
+ dim=-1
462
+ )
463
+
464
+ # Conv step
465
+ if causal_conv1d_update is None:
466
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
467
+ conv_state[:, :, -1] = xBC
468
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
469
+ if self.conv1d.bias is not None:
470
+ xBC = xBC + self.conv1d.bias
471
+ xBC = self.act(xBC).to(dtype=dtype)
472
+ else:
473
+ xBC = causal_conv1d_update(
474
+ xBC,
475
+ conv_state,
476
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
477
+ self.conv1d.bias,
478
+ self.activation,
479
+ )
480
+
481
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
482
+ A = -torch.exp(self.A_log.float()) # (nheads,)
483
+
484
+ # SSM step
485
+ if selective_state_update is None:
486
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
487
+ # Discretize A and B
488
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
489
+ dA = torch.exp(dt * A) # (batch, nheads)
490
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
491
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
492
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
493
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
494
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
495
+ y = rearrange(y, "b h p -> b (h p)")
496
+ if not self.rmsnorm:
497
+ y = y * self.act(z) # (B D)
498
+ else:
499
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
500
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
501
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
502
+ D = repeat(self.D, "h -> h p", p=self.headdim)
503
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
504
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
505
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
506
+ if not self.rmsnorm:
507
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
508
+ y = selective_state_update(
509
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
510
+ dt_bias=dt_bias, dt_softplus=True
511
+ )
512
+ y = rearrange(y, "b h p -> b (h p)")
513
+ if self.rmsnorm:
514
+ y = self.norm(y, z)
515
+ if d_mlp > 0:
516
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
517
+ out = self.out_proj(y)
518
+ return out.unsqueeze(1), conv_state, ssm_state
519
+
520
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
521
+ device = self.out_proj.weight.device
522
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
523
+ conv_state = torch.zeros(
524
+ batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
525
+ ).transpose(1, 2)
526
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
527
+ ssm_state = torch.zeros(
528
+ batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
529
+ )
530
+ return conv_state, ssm_state
531
+
532
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
533
+ assert self.layer_idx is not None
534
+ if self.layer_idx not in inference_params.key_value_memory_dict:
535
+ batch_shape = (batch_size,)
536
+ conv_state = torch.zeros(
537
+ batch_size,
538
+ self.d_conv,
539
+ self.conv1d.weight.shape[0],
540
+ device=self.conv1d.weight.device,
541
+ dtype=self.conv1d.weight.dtype,
542
+ ).transpose(1, 2)
543
+ ssm_state = torch.zeros(
544
+ batch_size,
545
+ self.nheads,
546
+ self.headdim,
547
+ self.d_state,
548
+ device=self.in_proj.weight.device,
549
+ dtype=self.in_proj.weight.dtype,
550
+ )
551
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
552
+ else:
553
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
554
+ # TODO: What if batch size changes between generation, and we reuse the same states?
555
+ if initialize_states:
556
+ conv_state.zero_()
557
+ ssm_state.zero_()
558
+ return conv_state, ssm_state
559
+
560
+
561
+ class Mamba2_Fused(nn.Module):
562
+ def __init__(
563
+ self,
564
+ config,
565
+ layer_idx=None, # Absorb kwarg for general module
566
+ reuse_kv=False,
567
+ conv_init=None,
568
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
569
+ ngroups=1,
570
+ A_init_range=(1, 16),
571
+ D_has_hdim=False,
572
+ rmsnorm=True,
573
+ norm_before_gate=False,
574
+ dt_min=0.001,
575
+ dt_max=0.1,
576
+ dt_init_floor=1e-4,
577
+ dt_limit=(0.0, float("inf")),
578
+ bias=False,
579
+ conv_bias=True,
580
+ # Fused kernel and sharding options
581
+ chunk_size=256,
582
+ use_mem_eff_path=False, # True,
583
+ process_group=None,
584
+ sequence_parallel=True,
585
+ device=None,
586
+ dtype=None,
587
+ ):
588
+ factory_kwargs = {"device": device, "dtype": dtype}
589
+ super().__init__()
590
+
591
+ self.config = config
592
+ self.d_model = config.hidden_size
593
+ self.d_state = config.mamba_d_state
594
+ self.d_conv = config.mamba_d_conv
595
+
596
+ self.conv_init = conv_init
597
+ self.expand = config.mamba_expand
598
+ self.process_group = process_group
599
+ self.sequence_parallel = sequence_parallel
600
+ self.world_size = 1 if process_group is None else process_group.size()
601
+ self.local_rank = 0 if process_group is None else process_group.rank()
602
+ self.d_inner = (self.expand * self.d_model) // self.world_size
603
+ assert self.d_inner * self.world_size == self.expand * self.d_model
604
+ self.headdim = config.mamba2_headdim
605
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
606
+ assert ngroups % self.world_size == 0
607
+ self.ngroups = ngroups // self.world_size
608
+ assert self.d_ssm % self.headdim == 0
609
+ self.nheads = self.d_ssm // self.headdim
610
+ self.D_has_hdim = D_has_hdim
611
+ self.rmsnorm = rmsnorm
612
+ self.norm_before_gate = norm_before_gate
613
+ self.dt_limit = dt_limit
614
+ self.activation = "silu"
615
+ self.chunk_size = chunk_size
616
+ self.use_mem_eff_path = use_mem_eff_path
617
+ self.layer_idx = layer_idx
618
+
619
+ assert (self.d_model * self.expand / self.headdim) % 8 == 0
620
+
621
+ self.fused_multihead_config = config.fused_multihead_config
622
+ assert self.fused_multihead_config['expand_v'], "Only implemented Hymba for Mamba"
623
+
624
+ self.reuse_kv = reuse_kv
625
+
626
+ self.hidden_size = config.hidden_size
627
+ self.attn_hidden_size = config.hidden_size
628
+ self.num_attention_heads = config.num_attention_heads
629
+ self.num_key_value_heads = config.num_key_value_heads
630
+
631
+ self.k_hidden_size = int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size)
632
+ self.v_hidden_size = int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size * self.expand) if self.fused_multihead_config['expand_v'] else int(self.num_key_value_heads/self.num_attention_heads * self.attn_hidden_size)
633
+
634
+ if self.fused_multihead_config['expand_v']:
635
+ config.v_head_dim = self.d_inner // self.num_attention_heads
636
+
637
+ self.self_attn = config.attn_op(config, layer_idx, attn_only_wo_proj=True, reuse_kv=reuse_kv)
638
+
639
+ if self.reuse_kv: # Order: [q, z, x, B, C, dt]
640
+ d_in_proj = self.attn_hidden_size + 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
641
+ else: # Order: [q, k, v, z, x, B, C, dt]
642
+ d_in_proj = self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size + 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
643
+
644
+ if self.process_group is None:
645
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
646
+ else:
647
+ self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
648
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
649
+ **factory_kwargs)
650
+
651
+ self.pre_avg_layernorm1 = JambaRMSNorm(self.d_inner, eps=config.rms_norm_eps)
652
+ self.pre_avg_layernorm2 = JambaRMSNorm(self.d_inner, eps=config.rms_norm_eps)
653
+
654
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
655
+ self.conv1d = nn.Conv1d(
656
+ in_channels=conv_dim,
657
+ out_channels=conv_dim,
658
+ bias=conv_bias,
659
+ kernel_size=self.d_conv,
660
+ groups=conv_dim,
661
+ padding=self.d_conv - 1,
662
+ **factory_kwargs,
663
+ )
664
+ if self.conv_init is not None:
665
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
666
+
667
+ self.act = nn.SiLU()
668
+
669
+ # Initialize log dt bias
670
+ dt = torch.exp(
671
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
672
+ + math.log(dt_min)
673
+ )
674
+ dt = torch.clamp(dt, min=dt_init_floor)
675
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
676
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
677
+ self.dt_bias = nn.Parameter(inv_dt)
678
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
679
+ # name.endswith("bias") in param_grouping.py
680
+ self.dt_bias._no_weight_decay = True
681
+
682
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
683
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
684
+ A_log = torch.log(A).to(dtype=dtype)
685
+ self.A_log = nn.Parameter(A_log)
686
+ self.A_log._no_weight_decay = True
687
+
688
+ # D "skip" parameter
689
+ self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
690
+ self.D._no_weight_decay = True
691
+
692
+ if self.rmsnorm:
693
+ assert RMSNormGated is not None
694
+ self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
695
+ group_size=self.d_ssm // ngroups, **factory_kwargs)
696
+
697
+ if self.process_group is None:
698
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
699
+ else:
700
+ self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
701
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
702
+ **factory_kwargs)
703
+
704
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
705
+ """
706
+ hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
707
+ If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
708
+ split hidden_states during sequence parallel, we split the batch * seqlen dimension
709
+ (in case batch is small).
710
+ Returns: same shape as u
711
+ """
712
+ # assert past_key_value is None, "Not implemented yet!!!"
713
+
714
+ seqlen_og = seqlen
715
+ if seqlen is None:
716
+ batch, seqlen, dim = hidden_states.shape
717
+ else:
718
+ batch_seqlen, dim = hidden_states.shape
719
+ batch = batch_seqlen // seqlen
720
+
721
+ conv_state, ssm_state = None, None
722
+ if inference_params is not None:
723
+ inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
724
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
725
+ if inference_params.seqlen_offset > 0:
726
+ # The states are updated inplace
727
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
728
+ return out
729
+
730
+ zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
731
+
732
+ if self.reuse_kv:
733
+ query_states, zxbcdt = zxbcdt.tensor_split((self.attn_hidden_size,), dim=-1)
734
+ # query_states = query_states.transpose(1,2)
735
+ else:
736
+ query_states, key_states, value_states, zxbcdt = zxbcdt.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=-1)
737
+
738
+ # query_states = query_states.transpose(1,2)
739
+ # key_states = key_states.transpose(1,2)
740
+ # value_states = value_states.transpose(1,2)
741
+
742
+ if self.reuse_kv:
743
+ assert kv_last_layer is not None
744
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=past_key_value)
745
+ else:
746
+ if 'use_linear_attn' in self.fused_multihead_config and self.fused_multihead_config['use_linear_attn'] and self.linear_attn_op == 'gla':
747
+ attn_outputs, _, attn_key_value = self.self_attn(hidden_states=value_states, position_ids=position_ids, attention_mask=attention_mask, Q=query_states, K=key_states, V=value_states, past_key_value=past_key_value)
748
+ else:
749
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=past_key_value)
750
+
751
+
752
+ if seqlen_og is not None:
753
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
754
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
755
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
756
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
757
+ if self.use_mem_eff_path and inference_params is None:
758
+ out = mamba_split_conv1d_scan_combined(
759
+ zxbcdt,
760
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
761
+ self.conv1d.bias,
762
+ self.dt_bias,
763
+ A,
764
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
765
+ chunk_size=self.chunk_size,
766
+ seq_idx=seq_idx,
767
+ activation=self.activation,
768
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
769
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
770
+ outproj_weight=self.out_proj.weight,
771
+ outproj_bias=self.out_proj.bias,
772
+ headdim=None if self.D_has_hdim else self.headdim,
773
+ ngroups=self.ngroups,
774
+ norm_before_gate=self.norm_before_gate,
775
+ **dt_limit_kwargs,
776
+ )
777
+ if seqlen_og is not None:
778
+ out = rearrange(out, "b l d -> (b l) d")
779
+ if self.process_group is not None:
780
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
781
+ out = reduce_fn(out, self.process_group)
782
+ else:
783
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
784
+
785
+ z0, x0, z, xBC, dt = torch.split(
786
+ zxbcdt,
787
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
788
+ dim=-1
789
+ )
790
+ if conv_state is not None:
791
+ if cu_seqlens is None:
792
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
793
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
794
+ xBC_t = rearrange(xBC, "b l d -> b d l")
795
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
796
+ else:
797
+ assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
798
+ assert batch == 1, "varlen inference only supports batch dimension 1"
799
+ conv_varlen_states = causal_conv1d_varlen_states(
800
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
801
+ )
802
+ conv_state.copy_(conv_varlen_states)
803
+ assert self.activation in ["silu", "swish"]
804
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
805
+ assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
806
+ xBC = self.act(
807
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
808
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
809
+ else:
810
+ xBC = causal_conv1d_fn(
811
+ xBC.transpose(1, 2),
812
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
813
+ bias=self.conv1d.bias,
814
+ activation=self.activation,
815
+ # seq_idx=seq_idx,
816
+ ).transpose(1, 2)
817
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
818
+
819
+ y = mamba_chunk_scan_combined(
820
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
821
+ dt,
822
+ A,
823
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
824
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
825
+ chunk_size=self.chunk_size,
826
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
827
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
828
+ dt_bias=self.dt_bias,
829
+ dt_softplus=True,
830
+ seq_idx=seq_idx,
831
+ cu_seqlens=cu_seqlens,
832
+ **dt_limit_kwargs,
833
+ return_final_states=ssm_state is not None,
834
+ return_varlen_states=cu_seqlens is not None and inference_params is not None,
835
+ )
836
+ if ssm_state is not None:
837
+ y, last_state, *rest = y
838
+ if cu_seqlens is None:
839
+ ssm_state.copy_(last_state)
840
+ else:
841
+ varlen_states = rest[0]
842
+ ssm_state.copy_(varlen_states)
843
+ y = rearrange(y, "b l h p -> b l (h p)")
844
+ if self.rmsnorm:
845
+ y = self.norm(y, z)
846
+ if d_mlp > 0:
847
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
848
+ if seqlen_og is not None:
849
+ y = rearrange(y, "b l d -> (b l) d")
850
+
851
+ scan_outputs = y
852
+ if 'repeat_v' in self.fused_multihead_config and self.fused_multihead_config['repeat_v']:
853
+ num_repeat = scan_outputs.shape[-1] // attn_outputs.shape[-1]
854
+ attn_outputs = attn_outputs.repeat(1, 1, num_repeat)
855
+
856
+ hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
857
+ out = self.out_proj(hidden_states)
858
+
859
+ return out, attn_key_value, past_key_value
860
+
861
+
862
+ def step(self, hidden_states, conv_state, ssm_state):
863
+ dtype = hidden_states.dtype
864
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
865
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
866
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
867
+ z0, x0, z, xBC, dt = torch.split(
868
+ zxbcdt,
869
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
870
+ dim=-1
871
+ )
872
+
873
+ # Conv step
874
+ if causal_conv1d_update is None:
875
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
876
+ conv_state[:, :, -1] = xBC
877
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
878
+ if self.conv1d.bias is not None:
879
+ xBC = xBC + self.conv1d.bias
880
+ xBC = self.act(xBC).to(dtype=dtype)
881
+ else:
882
+ xBC = causal_conv1d_update(
883
+ xBC,
884
+ conv_state,
885
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
886
+ self.conv1d.bias,
887
+ self.activation,
888
+ )
889
+
890
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
891
+ A = -torch.exp(self.A_log.float()) # (nheads,)
892
+
893
+ # SSM step
894
+ if selective_state_update is None:
895
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
896
+ # Discretize A and B
897
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
898
+ dA = torch.exp(dt * A) # (batch, nheads)
899
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
900
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
901
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
902
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
903
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
904
+ y = rearrange(y, "b h p -> b (h p)")
905
+ if not self.rmsnorm:
906
+ y = y * self.act(z) # (B D)
907
+ else:
908
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
909
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
910
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
911
+ D = repeat(self.D, "h -> h p", p=self.headdim)
912
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
913
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
914
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
915
+ if not self.rmsnorm:
916
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
917
+ y = selective_state_update(
918
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
919
+ dt_bias=dt_bias, dt_softplus=True
920
+ )
921
+ y = rearrange(y, "b h p -> b (h p)")
922
+ if self.rmsnorm:
923
+ y = self.norm(y, z)
924
+ if d_mlp > 0:
925
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
926
+ out = self.out_proj(y)
927
+
928
+ print(out)
929
+ input()
930
+ return out.unsqueeze(1), conv_state, ssm_state
931
+
932
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
933
+ device = self.out_proj.weight.device
934
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
935
+ conv_state = torch.zeros(
936
+ batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
937
+ ).transpose(1, 2)
938
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
939
+ ssm_state = torch.zeros(
940
+ batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
941
+ )
942
+ return conv_state, ssm_state
943
+
944
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
945
+ assert self.layer_idx is not None
946
+ if self.layer_idx not in inference_params.key_value_memory_dict:
947
+ batch_shape = (batch_size,)
948
+ conv_state = torch.zeros(
949
+ batch_size,
950
+ self.d_conv,
951
+ self.conv1d.weight.shape[0],
952
+ device=self.conv1d.weight.device,
953
+ dtype=self.conv1d.weight.dtype,
954
+ ).transpose(1, 2)
955
+ ssm_state = torch.zeros(
956
+ batch_size,
957
+ self.nheads,
958
+ self.headdim,
959
+ self.d_state,
960
+ device=self.in_proj.weight.device,
961
+ dtype=self.in_proj.weight.dtype,
962
+ )
963
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
964
+ else:
965
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
966
+ # TODO: What if batch size changes between generation, and we reuse the same states?
967
+ if initialize_states:
968
+ conv_state.zero_()
969
+ ssm_state.zero_()
970
+ return conv_state, ssm_state
971
+
972
+
973
+ class Mamba2_Multihead(nn.Module):
974
+ def __init__(
975
+ self,
976
+ config,
977
+ conv_init=None,
978
+ headdim=64,
979
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
980
+ ngroups=1,
981
+ A_init_range=(1, 16),
982
+ D_has_hdim=False,
983
+ rmsnorm=True,
984
+ norm_before_gate=False,
985
+ dt_min=0.001,
986
+ dt_max=0.1,
987
+ dt_init_floor=1e-4,
988
+ dt_limit=(0.0, float("inf")),
989
+ bias=False,
990
+ conv_bias=True,
991
+ # Fused kernel and sharding options
992
+ chunk_size=256,
993
+ use_mem_eff_path=False, # True,
994
+ layer_idx=None, # Absorb kwarg for general module
995
+ process_group=None,
996
+ sequence_parallel=True,
997
+ device=None,
998
+ dtype=None,
999
+ ):
1000
+ factory_kwargs = {"device": device, "dtype": dtype}
1001
+ super().__init__()
1002
+
1003
+ self.config = config
1004
+ self.d_model = config.hidden_size
1005
+ self.d_state = config.mamba_d_state
1006
+ self.d_conv = config.mamba_d_conv
1007
+
1008
+ self.conv_init = conv_init
1009
+ self.expand = config.mamba_expand
1010
+ self.process_group = process_group
1011
+ self.sequence_parallel = sequence_parallel
1012
+ self.world_size = 1 if process_group is None else process_group.size()
1013
+ self.local_rank = 0 if process_group is None else process_group.rank()
1014
+ self.d_inner = (self.expand * self.d_model) // self.world_size
1015
+ assert self.d_inner * self.world_size == self.expand * self.d_model
1016
+ self.headdim = config.mamba2_headdim
1017
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
1018
+ assert ngroups % self.world_size == 0
1019
+ self.ngroups = ngroups // self.world_size
1020
+ assert self.d_ssm % self.headdim == 0
1021
+ self.nheads = self.d_ssm // self.headdim
1022
+ self.D_has_hdim = D_has_hdim
1023
+ self.rmsnorm = rmsnorm
1024
+ self.norm_before_gate = norm_before_gate
1025
+ self.dt_limit = dt_limit
1026
+ self.activation = "silu"
1027
+ self.chunk_size = chunk_size
1028
+ self.use_mem_eff_path = use_mem_eff_path
1029
+ self.layer_idx = layer_idx
1030
+
1031
+ assert (self.d_model * self.expand / self.headdim) % 8 == 0
1032
+
1033
+ self.mamba_multihead_config = config.mamba_multihead_config
1034
+ self.share_ratio = self.mamba_multihead_config['share_ratio']
1035
+
1036
+ self.reuse_ssm = self.mamba_multihead_config['reuse_ssm']
1037
+ self.num_ssm_param = 1 if self.reuse_ssm else self.share_ratio
1038
+
1039
+ if self.reuse_ssm:
1040
+ if self.mamba_multihead_config['alpha_mode'] == 'learnable':
1041
+ self.alpha = nn.Parameter(torch.ones(self.share_ratio))
1042
+ elif self.mamba_multihead_config['alpha_mode'] == 'manual':
1043
+ manual_alpha_base = self.mamba_multihead_config['manual_alpha_base']
1044
+ self.alpha = [1 / manual_alpha_base ** k for k in range(self.share_ratio)]
1045
+ else:
1046
+ raise ValueError(f"No such alpha_mode: {self.mamba_multihead_config['alpha_mode']}")
1047
+
1048
+ # Order: [z, x, B, C, dt]
1049
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads * self.num_ssm_param
1050
+ if self.process_group is None:
1051
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
1052
+ else:
1053
+ self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
1054
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
1055
+ **factory_kwargs)
1056
+
1057
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
1058
+ self.conv1d = nn.Conv1d(
1059
+ in_channels=conv_dim,
1060
+ out_channels=conv_dim,
1061
+ bias=conv_bias,
1062
+ kernel_size=self.d_conv,
1063
+ groups=conv_dim,
1064
+ padding=self.d_conv - 1,
1065
+ **factory_kwargs,
1066
+ )
1067
+ if self.conv_init is not None:
1068
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
1069
+
1070
+ self.act = nn.SiLU()
1071
+
1072
+ # Initialize log dt bias
1073
+ dt = torch.exp(
1074
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
1075
+ + math.log(dt_min)
1076
+ )
1077
+ dt = torch.clamp(dt, min=dt_init_floor)
1078
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1079
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
1080
+ self.dt_bias = nn.ParameterList([nn.Parameter(inv_dt) for _ in range(self.num_ssm_param)])
1081
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
1082
+ # name.endswith("bias") in param_grouping.py
1083
+ self.dt_bias._no_weight_decay = True
1084
+
1085
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
1086
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
1087
+ A_log = torch.log(A).to(dtype=dtype)
1088
+ self.A_log = nn.ParameterList([nn.Parameter(A_log) for _ in range(self.num_ssm_param)])
1089
+ self.A_log._no_weight_decay = True
1090
+
1091
+ # D "skip" parameter
1092
+ self.D = nn.ParameterList([nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) for _ in range(self.num_ssm_param)])
1093
+ self.D._no_weight_decay = True
1094
+
1095
+ if self.rmsnorm:
1096
+ assert RMSNormGated is not None
1097
+ self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
1098
+ group_size=self.d_ssm // ngroups, **factory_kwargs)
1099
+
1100
+ if self.process_group is None:
1101
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
1102
+ else:
1103
+ self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
1104
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
1105
+ **factory_kwargs)
1106
+
1107
+
1108
+ if self.mamba_multihead_config['merge_op'] == 'norm':
1109
+ self.multihead_layernorm = nn.ModuleList([JambaRMSNorm(self.d_ssm, eps=config.rms_norm_eps) for _ in range(self.share_ratio)])
1110
+ elif self.mamba_multihead_config['merge_op'] == 'scalar_gate':
1111
+ self.multi_head_selection_layer = nn.Linear(self.d_ssm, self.share_ratio)
1112
+ elif self.mamba_multihead_config['merge_op'] == 'concat':
1113
+ assert self.d_ssm % self.share_ratio == 0
1114
+ self.multihead_layernorm = nn.ModuleList([JambaRMSNorm(self.d_ssm, eps=config.rms_norm_eps) for _ in range(self.share_ratio)])
1115
+ self.reduction_layer = nn.Linear(self.d_ssm, self.d_ssm//self.share_ratio)
1116
+
1117
+
1118
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
1119
+ """
1120
+ hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
1121
+ If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
1122
+ split hidden_states during sequence parallel, we split the batch * seqlen dimension
1123
+ (in case batch is small).
1124
+ Returns: same shape as u
1125
+ """
1126
+ assert past_key_value is None, "Not implemented yet!!!"
1127
+
1128
+ seqlen_og = seqlen
1129
+ if seqlen is None:
1130
+ batch, seqlen, dim = hidden_states.shape
1131
+ else:
1132
+ batch_seqlen, dim = hidden_states.shape
1133
+ batch = batch_seqlen // seqlen
1134
+
1135
+ conv_state, ssm_state = None, None
1136
+ if inference_params is not None:
1137
+ inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
1138
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
1139
+ if inference_params.seqlen_offset > 0:
1140
+ # The states are updated inplace
1141
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
1142
+ return out
1143
+
1144
+ zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
1145
+ if seqlen_og is not None:
1146
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
1147
+
1148
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
1149
+ if self.use_mem_eff_path and inference_params is None:
1150
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
1151
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
1152
+
1153
+ out = mamba_split_conv1d_scan_combined(
1154
+ zxbcdt,
1155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
1156
+ self.conv1d.bias,
1157
+ self.dt_bias,
1158
+ A,
1159
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
1160
+ chunk_size=self.chunk_size,
1161
+ seq_idx=seq_idx,
1162
+ activation=self.activation,
1163
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
1164
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
1165
+ outproj_weight=self.out_proj.weight,
1166
+ outproj_bias=self.out_proj.bias,
1167
+ headdim=None if self.D_has_hdim else self.headdim,
1168
+ ngroups=self.ngroups,
1169
+ norm_before_gate=self.norm_before_gate,
1170
+ **dt_limit_kwargs,
1171
+ )
1172
+ if seqlen_og is not None:
1173
+ out = rearrange(out, "b l d -> (b l) d")
1174
+ if self.process_group is not None:
1175
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
1176
+ out = reduce_fn(out, self.process_group)
1177
+ else:
1178
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads * self.num_ssm_param) // 2
1179
+ z0, x0, z, xBC, dt = torch.split(
1180
+ zxbcdt,
1181
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads * self.num_ssm_param],
1182
+ dim=-1
1183
+ )
1184
+
1185
+ if conv_state is not None:
1186
+ if cu_seqlens is None:
1187
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
1188
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
1189
+ xBC_t = rearrange(xBC, "b l d -> b d l")
1190
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
1191
+ else:
1192
+ assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
1193
+ assert batch == 1, "varlen inference only supports batch dimension 1"
1194
+ conv_varlen_states = causal_conv1d_varlen_states(
1195
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
1196
+ )
1197
+ conv_state.copy_(conv_varlen_states)
1198
+ assert self.activation in ["silu", "swish"]
1199
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
1200
+ assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
1201
+ xBC = self.act(
1202
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
1203
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
1204
+ else:
1205
+ xBC = causal_conv1d_fn(
1206
+ xBC.transpose(1, 2),
1207
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
1208
+ bias=self.conv1d.bias,
1209
+ activation=self.activation,
1210
+ seq_idx=seq_idx,
1211
+ ).transpose(1, 2)
1212
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
1213
+
1214
+ x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim)
1215
+ B = rearrange(B, "b l (g n) -> b l g n", g=self.ngroups)
1216
+ C = rearrange(C, "b l (g n) -> b l g n", g=self.ngroups)
1217
+
1218
+ outputs_list = []
1219
+ dt_list = dt
1220
+ for i in range(self.num_ssm_param):
1221
+ dt = dt_list[..., self.nheads*i:self.nheads*(i+1)]
1222
+ A = -torch.exp(self.A_log[i].float()) # (nheads) or (d_inner, d_state)
1223
+ D = rearrange(self.D[i], "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D[i]
1224
+ dt_bias = self.dt_bias[i]
1225
+
1226
+ if self.reuse_ssm:
1227
+ #### duplicate heads with different decays
1228
+ if self.mamba_multihead_config['alpha_mode'] == 'learnable':
1229
+ decay = self.alpha # [share_ratio]
1230
+ elif self.mamba_multihead_config['alpha_mode'] == 'manual':
1231
+ decay = torch.tensor(self.alpha).to(dt) # [share_ratio]
1232
+
1233
+ dt = dt.repeat(1, 1, self.share_ratio) # [bs, seq_len, self.nheads * share_ratio]
1234
+ decay = decay.view(-1, 1).repeat(1, self.nheads).view(-1) # [self.nheads * share_ratio]
1235
+ dt = dt * decay # [bs, seq_len, nheads * share_ratio]
1236
+
1237
+ dt_bias = dt_bias.repeat(self.share_ratio) * decay # [nheads * share_ratio]
1238
+
1239
+ x = x.repeat(1,1,self.share_ratio,1) # [bs, seq_len, nheads * share_ratio, head_dim]
1240
+ D = D.repeat(self.share_ratio,1) if self.D_has_hdim else D.repeat(self.share_ratio) # [nheads * share_ratio]
1241
+ A = A.repeat(self.share_ratio) # [nheads * share_ratio]
1242
+
1243
+ y = mamba_chunk_scan_combined(
1244
+ x,
1245
+ dt,
1246
+ A,
1247
+ B,
1248
+ C,
1249
+ chunk_size=self.chunk_size,
1250
+ D=D,
1251
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim).repeat(1,1,self.share_ratio,1) if not self.rmsnorm else None,
1252
+ dt_bias=dt_bias,
1253
+ dt_softplus=True,
1254
+ seq_idx=seq_idx,
1255
+ cu_seqlens=cu_seqlens,
1256
+ **dt_limit_kwargs,
1257
+ return_final_states=ssm_state is not None,
1258
+ return_varlen_states=cu_seqlens is not None and inference_params is not None,
1259
+ )
1260
+ if ssm_state is not None:
1261
+ y, last_state, *rest = y
1262
+ if cu_seqlens is None:
1263
+ ssm_state.copy_(last_state)
1264
+ else:
1265
+ varlen_states = rest[0]
1266
+ ssm_state.copy_(varlen_states)
1267
+
1268
+ outputs_list.append(y)
1269
+
1270
+ if len(outputs_list) > 1:
1271
+ y = torch.cat(outputs_list, dim=2)
1272
+
1273
+ #### merge heads
1274
+ num_repeat = y.shape[2] // self.nheads
1275
+ head_outputs = torch.chunk(y, num_repeat, dim=2)
1276
+ head_outputs = [rearrange(item, "b l h p -> b l (h p)") for item in head_outputs]
1277
+
1278
+ if self.mamba_multihead_config['merge_op'] == 'norm':
1279
+ y = sum([self.multihead_layernorm[k](item) for k, item in enumerate(head_outputs)])
1280
+
1281
+ elif self.mamba_multihead_config['merge_op'] == 'concat':
1282
+ head_outputs = [self.reduction_layer(self.multihead_layernorm[k](item)) for k, item in enumerate(head_outputs)]
1283
+ y = torch.cat(head_outputs, dim=-1)
1284
+ else:
1285
+ raise ValueError(f"No such merge_op: {self.mamba_multihead_config['merge_op']}")
1286
+
1287
+ if self.rmsnorm:
1288
+ y = self.norm(y, z)
1289
+ if d_mlp > 0:
1290
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
1291
+ if seqlen_og is not None:
1292
+ y = rearrange(y, "b l d -> (b l) d")
1293
+ out = self.out_proj(y)
1294
+ return out, past_key_value
1295
+
1296
+ def step(self, hidden_states, conv_state, ssm_state):
1297
+ dtype = hidden_states.dtype
1298
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
1299
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
1300
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
1301
+ z0, x0, z, xBC, dt = torch.split(
1302
+ zxbcdt,
1303
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
1304
+ dim=-1
1305
+ )
1306
+
1307
+ # Conv step
1308
+ if causal_conv1d_update is None:
1309
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
1310
+ conv_state[:, :, -1] = xBC
1311
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
1312
+ if self.conv1d.bias is not None:
1313
+ xBC = xBC + self.conv1d.bias
1314
+ xBC = self.act(xBC).to(dtype=dtype)
1315
+ else:
1316
+ xBC = causal_conv1d_update(
1317
+ xBC,
1318
+ conv_state,
1319
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
1320
+ self.conv1d.bias,
1321
+ self.activation,
1322
+ )
1323
+
1324
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
1325
+ A = -torch.exp(self.A_log.float()) # (nheads,)
1326
+
1327
+ # SSM step
1328
+ if selective_state_update is None:
1329
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
1330
+ # Discretize A and B
1331
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
1332
+ dA = torch.exp(dt * A) # (batch, nheads)
1333
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
1334
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
1335
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
1336
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
1337
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
1338
+ y = rearrange(y, "b h p -> b (h p)")
1339
+ if not self.rmsnorm:
1340
+ y = y * self.act(z) # (B D)
1341
+ else:
1342
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
1343
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
1344
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
1345
+ D = repeat(self.D, "h -> h p", p=self.headdim)
1346
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
1347
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
1348
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
1349
+ if not self.rmsnorm:
1350
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
1351
+ y = selective_state_update(
1352
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
1353
+ dt_bias=dt_bias, dt_softplus=True
1354
+ )
1355
+ y = rearrange(y, "b h p -> b (h p)")
1356
+ if self.rmsnorm:
1357
+ y = self.norm(y, z)
1358
+ if d_mlp > 0:
1359
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
1360
+ out = self.out_proj(y)
1361
+ return out.unsqueeze(1), conv_state, ssm_state
1362
+
1363
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
1364
+ device = self.out_proj.weight.device
1365
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
1366
+ conv_state = torch.zeros(
1367
+ batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
1368
+ ).transpose(1, 2)
1369
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
1370
+ ssm_state = torch.zeros(
1371
+ batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
1372
+ )
1373
+ return conv_state, ssm_state
1374
+
1375
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
1376
+ assert self.layer_idx is not None
1377
+ if self.layer_idx not in inference_params.key_value_memory_dict:
1378
+ batch_shape = (batch_size,)
1379
+ conv_state = torch.zeros(
1380
+ batch_size,
1381
+ self.d_conv,
1382
+ self.conv1d.weight.shape[0],
1383
+ device=self.conv1d.weight.device,
1384
+ dtype=self.conv1d.weight.dtype,
1385
+ ).transpose(1, 2)
1386
+ ssm_state = torch.zeros(
1387
+ batch_size,
1388
+ self.nheads,
1389
+ self.headdim,
1390
+ self.d_state,
1391
+ device=self.in_proj.weight.device,
1392
+ dtype=self.in_proj.weight.dtype,
1393
+ )
1394
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
1395
+ else:
1396
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
1397
+ # TODO: What if batch size changes between generation, and we reuse the same states?
1398
+ if initialize_states:
1399
+ conv_state.zero_()
1400
+ ssm_state.zero_()
1401
+ return conv_state, ssm_state
1402
+
1403
+
1404
+
1405
+
1406
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
1407
+ class JambaRMSNorm(nn.Module):
1408
+ def __init__(self, hidden_size, eps=1e-6):
1409
+ """
1410
+ JambaRMSNorm is equivalent to T5LayerNorm
1411
+ """
1412
+ super().__init__()
1413
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1414
+ self.variance_epsilon = eps
1415
+
1416
+ def forward(self, hidden_states):
1417
+ input_dtype = hidden_states.dtype
1418
+ hidden_states = hidden_states.to(torch.float32)
1419
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
1420
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1421
+ return self.weight * hidden_states.to(input_dtype)
1422
+
1423
+
1424
+
1425
+
1426
+
1427
+
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fc337e437ea93d1d9a6c5d93793512da4b571843edfa616516630418244947b
3
+ size 4995785984
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46319b16d9db944033225295486a090859d5113e4873eca20c0d836fa38a3f09
3
+ size 491849664
model.safetensors.index.json ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5487609216
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
+ "model.final_layernorm.weight": "model-00002-of-00002.safetensors",
8
+ "model.layers.0.gla.b_proj.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.gla.k_conv1d.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.gla.k_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.gla.o_norm.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.gla.o_proj.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.gla.q_conv1d.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.gla.q_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.gla.v_conv1d.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.gla.v_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.10.mamba.A_log": "model-00001-of-00002.safetensors",
24
+ "model.layers.10.mamba.D": "model-00001-of-00002.safetensors",
25
+ "model.layers.10.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
26
+ "model.layers.10.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.10.mamba.dt_bias": "model-00001-of-00002.safetensors",
28
+ "model.layers.10.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.10.mamba.norm.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.10.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.11.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.11.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.11.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.11.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.12.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.13.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.13.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.13.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.13.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.14.mamba.A_log": "model-00001-of-00002.safetensors",
47
+ "model.layers.14.mamba.D": "model-00001-of-00002.safetensors",
48
+ "model.layers.14.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
49
+ "model.layers.14.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.14.mamba.dt_bias": "model-00001-of-00002.safetensors",
51
+ "model.layers.14.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.14.mamba.norm.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.14.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.15.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.15.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.15.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.15.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.16.gla.b_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.16.gla.k_conv1d.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.16.gla.k_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.16.gla.o_norm.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.16.gla.o_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.16.gla.q_conv1d.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.16.gla.q_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.16.gla.v_conv1d.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.16.gla.v_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.17.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.17.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.17.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.17.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.18.mamba.A_log": "model-00001-of-00002.safetensors",
74
+ "model.layers.18.mamba.D": "model-00001-of-00002.safetensors",
75
+ "model.layers.18.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
76
+ "model.layers.18.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.18.mamba.dt_bias": "model-00001-of-00002.safetensors",
78
+ "model.layers.18.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.18.mamba.norm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.18.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.19.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.19.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.19.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.19.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.2.mamba.A_log": "model-00001-of-00002.safetensors",
87
+ "model.layers.2.mamba.D": "model-00001-of-00002.safetensors",
88
+ "model.layers.2.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
89
+ "model.layers.2.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.2.mamba.dt_bias": "model-00001-of-00002.safetensors",
91
+ "model.layers.2.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.2.mamba.norm.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.2.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.20.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.21.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.21.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.21.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.21.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.22.mamba.A_log": "model-00001-of-00002.safetensors",
106
+ "model.layers.22.mamba.D": "model-00001-of-00002.safetensors",
107
+ "model.layers.22.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
108
+ "model.layers.22.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
109
+ "model.layers.22.mamba.dt_bias": "model-00001-of-00002.safetensors",
110
+ "model.layers.22.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.22.mamba.norm.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.22.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.23.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.23.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.23.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.23.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.24.gla.b_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.24.gla.k_conv1d.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.24.gla.k_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.24.gla.o_norm.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.24.gla.o_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.24.gla.q_conv1d.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.24.gla.q_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.24.gla.v_conv1d.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.24.gla.v_proj.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.24.input_layernorm.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.25.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.25.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.25.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.25.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.26.input_layernorm.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.26.mamba.A_log": "model-00001-of-00002.safetensors",
133
+ "model.layers.26.mamba.D": "model-00001-of-00002.safetensors",
134
+ "model.layers.26.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
135
+ "model.layers.26.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.26.mamba.dt_bias": "model-00001-of-00002.safetensors",
137
+ "model.layers.26.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.26.mamba.norm.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.26.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.27.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.27.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.27.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.27.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.28.gla.b_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.28.gla.k_conv1d.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.28.gla.k_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.28.gla.o_norm.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.28.gla.o_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.28.gla.q_conv1d.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.28.gla.q_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.28.gla.v_conv1d.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.28.gla.v_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.28.input_layernorm.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.29.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.29.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.29.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.29.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.3.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.3.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.3.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.3.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
162
+ "model.layers.30.input_layernorm.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.30.mamba.A_log": "model-00001-of-00002.safetensors",
164
+ "model.layers.30.mamba.D": "model-00001-of-00002.safetensors",
165
+ "model.layers.30.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
166
+ "model.layers.30.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.30.mamba.dt_bias": "model-00001-of-00002.safetensors",
168
+ "model.layers.30.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
169
+ "model.layers.30.mamba.norm.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.30.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.31.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.31.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.31.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.31.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.32.gla.b_proj.weight": "model-00002-of-00002.safetensors",
176
+ "model.layers.32.gla.k_conv1d.weight": "model-00002-of-00002.safetensors",
177
+ "model.layers.32.gla.k_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.32.gla.o_norm.weight": "model-00002-of-00002.safetensors",
179
+ "model.layers.32.gla.o_proj.weight": "model-00002-of-00002.safetensors",
180
+ "model.layers.32.gla.q_conv1d.weight": "model-00002-of-00002.safetensors",
181
+ "model.layers.32.gla.q_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.32.gla.v_conv1d.weight": "model-00002-of-00002.safetensors",
183
+ "model.layers.32.gla.v_proj.weight": "model-00002-of-00002.safetensors",
184
+ "model.layers.32.input_layernorm.weight": "model-00002-of-00002.safetensors",
185
+ "model.layers.33.moe.experts.0.down_proj.weight": "model-00002-of-00002.safetensors",
186
+ "model.layers.33.moe.experts.0.gate_proj.weight": "model-00002-of-00002.safetensors",
187
+ "model.layers.33.moe.experts.0.up_proj.weight": "model-00002-of-00002.safetensors",
188
+ "model.layers.33.pre_moe_layernorm.weight": "model-00002-of-00002.safetensors",
189
+ "model.layers.34.input_layernorm.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.34.mamba.A_log": "model-00002-of-00002.safetensors",
191
+ "model.layers.34.mamba.D": "model-00002-of-00002.safetensors",
192
+ "model.layers.34.mamba.conv1d.bias": "model-00002-of-00002.safetensors",
193
+ "model.layers.34.mamba.conv1d.weight": "model-00002-of-00002.safetensors",
194
+ "model.layers.34.mamba.dt_bias": "model-00002-of-00002.safetensors",
195
+ "model.layers.34.mamba.in_proj.weight": "model-00002-of-00002.safetensors",
196
+ "model.layers.34.mamba.norm.weight": "model-00002-of-00002.safetensors",
197
+ "model.layers.34.mamba.out_proj.weight": "model-00002-of-00002.safetensors",
198
+ "model.layers.35.moe.experts.0.down_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.35.moe.experts.0.gate_proj.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.35.moe.experts.0.up_proj.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.35.pre_moe_layernorm.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
203
+ "model.layers.4.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
204
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
205
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
206
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
207
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
208
+ "model.layers.5.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
209
+ "model.layers.5.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
210
+ "model.layers.5.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
211
+ "model.layers.5.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
212
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
213
+ "model.layers.6.mamba.A_log": "model-00001-of-00002.safetensors",
214
+ "model.layers.6.mamba.D": "model-00001-of-00002.safetensors",
215
+ "model.layers.6.mamba.conv1d.bias": "model-00001-of-00002.safetensors",
216
+ "model.layers.6.mamba.conv1d.weight": "model-00001-of-00002.safetensors",
217
+ "model.layers.6.mamba.dt_bias": "model-00001-of-00002.safetensors",
218
+ "model.layers.6.mamba.in_proj.weight": "model-00001-of-00002.safetensors",
219
+ "model.layers.6.mamba.norm.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.6.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
221
+ "model.layers.7.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.7.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.7.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
224
+ "model.layers.7.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
225
+ "model.layers.8.gla.b_proj.weight": "model-00001-of-00002.safetensors",
226
+ "model.layers.8.gla.k_conv1d.weight": "model-00001-of-00002.safetensors",
227
+ "model.layers.8.gla.k_proj.weight": "model-00001-of-00002.safetensors",
228
+ "model.layers.8.gla.o_norm.weight": "model-00001-of-00002.safetensors",
229
+ "model.layers.8.gla.o_proj.weight": "model-00001-of-00002.safetensors",
230
+ "model.layers.8.gla.q_conv1d.weight": "model-00001-of-00002.safetensors",
231
+ "model.layers.8.gla.q_proj.weight": "model-00001-of-00002.safetensors",
232
+ "model.layers.8.gla.v_conv1d.weight": "model-00001-of-00002.safetensors",
233
+ "model.layers.8.gla.v_proj.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
235
+ "model.layers.9.moe.experts.0.down_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.9.moe.experts.0.gate_proj.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.9.moe.experts.0.up_proj.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.9.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
239
+ "model.memory_tokens": "model-00001-of-00002.safetensors"
240
+ }
241
+ }
modeling_jamba.py ADDED
The diff for this file is too large to render. See raw diff