zaydzuhri commited on
Commit
f04fa4f
·
verified ·
1 Parent(s): 016ba03

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
  2. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc +0 -0
  4. fla/models/bitnet/modeling_bitnet.py +441 -0
  5. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  6. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
  7. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  8. fla/models/forgetting_transformer/configuration_forgetting_transformer.py +68 -0
  9. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc +0 -0
  10. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  11. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  12. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc +0 -0
  13. fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc +0 -0
  14. fla/models/gsa/__init__.py +13 -0
  15. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  16. fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc +0 -0
  17. fla/models/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc +0 -0
  19. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  20. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  21. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  22. fla/models/lightnet/modeling_lightnet.py +410 -0
  23. fla/models/linear_attn/configuration_linear_attn.py +91 -0
  24. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
  25. fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc +0 -0
  26. fla/models/mamba2/modeling_mamba2.py +1093 -0
  27. fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc +0 -0
  29. fla/models/nsa/configuration_nsa.py +75 -0
  30. fla/models/retnet/__pycache__/__init__.cpython-312.pyc +0 -0
  31. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  32. fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc +0 -0
  33. fla/models/rwkv6/modeling_rwkv6.py +480 -0
  34. fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc +0 -0
  35. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  36. fla/models/samba/modeling_samba.py +413 -0
  37. fla/models/transformer/__init__.py +13 -0
  38. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  39. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  40. fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  41. fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc +0 -0
  42. fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  43. fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  44. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  45. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
  46. fla/modules/__pycache__/feature_map.cpython-312.pyc +0 -0
  47. fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc +0 -0
  48. fla/modules/__pycache__/fused_kl_div.cpython-312.pyc +0 -0
  49. fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
  50. fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc +0 -0
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (701 Bytes). View file
 
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc ADDED
Binary file (3.59 kB). View file
 
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
fla/models/forgetting_transformer/configuration_forgetting_transformer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ForgettingTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'forgetting_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: Optional[int] = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ use_output_gate: bool = False,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ initializer_range: float = 0.006,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ use_cache: bool = True,
30
+ pad_token_id: Optional[int] = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ fuse_norm: bool = True,
35
+ fuse_swiglu: bool = True,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ **kwargs,
39
+ ):
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_heads = num_heads
43
+ self.num_kv_heads = num_kv_heads
44
+ self.qkv_bias = qkv_bias
45
+ self.qk_norm = qk_norm
46
+ self.window_size = window_size
47
+ self.use_output_gate = use_output_gate
48
+ self.hidden_ratio = hidden_ratio
49
+ self.intermediate_size = intermediate_size
50
+ self.hidden_act = hidden_act
51
+
52
+ self.initializer_range = initializer_range
53
+ self.elementwise_affine = elementwise_affine
54
+ self.norm_eps = norm_eps
55
+ self.use_cache = use_cache
56
+
57
+ self.fuse_norm = fuse_norm
58
+ self.fuse_swiglu = fuse_swiglu
59
+ self.fuse_cross_entropy = fuse_cross_entropy
60
+ self.vocab_size = vocab_size
61
+
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ bos_token_id=bos_token_id,
65
+ eos_token_id=eos_token_id,
66
+ tie_word_embeddings=tie_word_embeddings,
67
+ **kwargs,
68
+ )
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (3.38 kB). View file
 
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (20.7 kB). View file
 
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/models/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (665 Bytes). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/configuration_linear_attn.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LinearAttentionConfig(PretrainedConfig):
9
+
10
+ model_type = 'linear_attn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ feature_map: str = "elementwise_product",
25
+ tie_feature_map_qk: bool = False,
26
+ norm_q: bool = False,
27
+ norm_k: bool = False,
28
+ norm_feature_map: bool = False,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.feature_map = feature_map
56
+ self.tie_feature_map_qk = tie_feature_map_qk
57
+ self.norm_q = norm_q
58
+ self.norm_k = norm_k
59
+ self.norm_feature_map = norm_feature_map
60
+ self.hidden_act = hidden_act
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc ADDED
Binary file (52.4 kB). View file
 
fla/models/mamba2/modeling_mamba2.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch MAMBA2 model."""
15
+
16
+ import math
17
+ import warnings
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from transformers.activations import ACT2FN
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import ModelOutput, logging
28
+ from transformers.utils.deprecation import deprecate_kwarg
29
+
30
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
31
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
32
+ from fla.modules.layernorm_gated import RMSNormGated
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ with warnings.catch_warnings():
37
+ warnings.simplefilter('ignore')
38
+ try:
39
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
40
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
41
+ except ImportError:
42
+ (
43
+ selective_state_update,
44
+ mamba_chunk_scan_combined,
45
+ mamba_split_conv1d_scan_combined,
46
+ ) = (None, None, None)
47
+ try:
48
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
49
+ except ImportError:
50
+ causal_conv1d_update, causal_conv1d_fn = None, None
51
+ is_fast_path_available = all((
52
+ selective_state_update,
53
+ causal_conv1d_fn,
54
+ causal_conv1d_update
55
+ ))
56
+
57
+
58
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
59
+ """
60
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
61
+
62
+ Assumes that we only have tensors of either size 4 or 3
63
+ """
64
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
65
+
66
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
67
+
68
+
69
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
70
+ """
71
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
72
+ simultaneously splitting it into chunk sequences.
73
+
74
+ Assumes that we only have tensors of either size 4 or 3
75
+ """
76
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
77
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
78
+
79
+ if len(input_tensor.shape) == 3:
80
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
81
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
82
+ else:
83
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] ->
84
+ # [bsz, -1, chunk_size, num_heads, head_dim or state_size]
85
+ return input_tensor.reshape(
86
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
87
+ )
88
+
89
+
90
+ def segment_sum(input_tensor):
91
+ """
92
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
93
+ """
94
+ chunk_size = input_tensor.size(-1)
95
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
96
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
97
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
98
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
99
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
100
+ input_tensor = input_tensor.masked_fill(~mask, 0)
101
+ # 3. compute actual cumsum
102
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
103
+
104
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
105
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
106
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
107
+ return tensor_segsum
108
+
109
+
110
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
111
+ """
112
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
113
+ """
114
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
115
+ dtype = hidden_states.dtype
116
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
117
+
118
+ return hidden_states
119
+
120
+
121
+ class Mamba2Cache:
122
+ """
123
+ Arguments:
124
+ config: Mamba2Config
125
+ batch_size: int
126
+ dtype: torch.dtype
127
+ device: torch.device
128
+
129
+ Attributes:
130
+ dtype: (`torch.dtype`):
131
+ The default `dtype` used to initializing the cache.
132
+ conv_kernel_size: (`int`):
133
+ Model's convolution kernel size taken from config.
134
+ n_groups: (`int`):
135
+ Model's number of groups taken from the config - similar to tensor parallel in Transformer.
136
+ state_size: (`int`):
137
+ Model's SSM state size taken from config.
138
+ num_heads: (`int`):
139
+ The number of heads used in the linear attention / SSM.
140
+ head_dim: (`int`):
141
+ The respective dimension of the heads used in the linear attention / SSM.
142
+ intermediate_size: (`int`):
143
+ Model's intermediate_size based on (expand * hidden_dim) from config.
144
+ conv_states: (`torch.Tensor`):
145
+ A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]`
146
+ that holds convolutional states.
147
+ ssm_states: (`torch.Tensor`):
148
+ A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ config: Mamba2Config,
154
+ batch_size: int,
155
+ dtype: torch.dtype = torch.float16,
156
+ device: Optional[str] = None,
157
+ ):
158
+ self.dtype = dtype
159
+ self.conv_kernel_size = config.conv_kernel
160
+ self.n_groups = config.n_groups
161
+ self.state_size = config.state_size
162
+ self.num_heads = config.num_heads
163
+ self.head_dim = config.head_dim
164
+ self.intermediate_size = int(config.expand * config.hidden_size)
165
+
166
+ self.conv_states = torch.zeros(
167
+ config.num_hidden_layers,
168
+ batch_size,
169
+ self.intermediate_size + 2 * self.n_groups * self.state_size,
170
+ self.conv_kernel_size,
171
+ device=device,
172
+ dtype=dtype,
173
+ )
174
+ self.ssm_states = torch.zeros(
175
+ config.num_hidden_layers,
176
+ batch_size,
177
+ self.num_heads,
178
+ self.head_dim,
179
+ self.state_size,
180
+ device=device,
181
+ dtype=dtype,
182
+ )
183
+
184
+ def update_conv_state(
185
+ self,
186
+ layer_idx: int,
187
+ new_conv_state: torch.Tensor,
188
+ cache_init: bool = False
189
+ ) -> torch.Tensor:
190
+ if cache_init:
191
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
192
+ else:
193
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
194
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
195
+ return self.conv_states[layer_idx]
196
+
197
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
198
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
199
+ return self.ssm_states[layer_idx]
200
+
201
+ def reset(self):
202
+ self.conv_states.zero_()
203
+ self.ssm_states.zero_()
204
+
205
+
206
+ class Mamba2Mixer(nn.Module):
207
+ """
208
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
209
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
210
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
211
+ and is why Mamba is called **selective** state spaces)
212
+ """
213
+
214
+ def __init__(self, config: Mamba2Config, layer_idx: int):
215
+ super().__init__()
216
+ self.num_heads = config.num_heads
217
+ self.hidden_size = config.hidden_size
218
+ self.ssm_state_size = config.state_size
219
+ self.conv_kernel_size = config.conv_kernel
220
+ self.intermediate_size = int(config.expand * self.hidden_size)
221
+ self.time_step_rank = int(config.time_step_rank)
222
+ self.layer_idx = layer_idx
223
+ self.use_conv_bias = config.use_conv_bias
224
+ self.activation = config.hidden_act
225
+ self.act = ACT2FN[config.hidden_act]
226
+
227
+ self.layer_norm_epsilon = config.layer_norm_epsilon
228
+ self.rms_norm = config.rms_norm
229
+
230
+ self.n_groups = config.n_groups
231
+ self.head_dim = config.head_dim
232
+ self.chunk_size = config.chunk_size
233
+
234
+ self.time_step_limit = config.time_step_limit
235
+ self.time_step_min = config.time_step_min
236
+ self.time_step_max = config.time_step_max
237
+
238
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
239
+ self.conv1d = nn.Conv1d(
240
+ in_channels=self.conv_dim,
241
+ out_channels=self.conv_dim,
242
+ bias=config.use_conv_bias,
243
+ kernel_size=config.conv_kernel,
244
+ groups=self.conv_dim,
245
+ padding=config.conv_kernel - 1,
246
+ )
247
+
248
+ # projection of the input hidden states
249
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
250
+ self.in_proj = nn.Linear(
251
+ self.hidden_size,
252
+ projection_size,
253
+ bias=config.use_bias,
254
+ )
255
+ # selective projection used to make dt, B and C input dependant
256
+
257
+ # time step projection (discretization)
258
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
259
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
260
+
261
+ # S4D real initialization. These are not discretized!
262
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
263
+ A = torch.arange(1, self.num_heads + 1)
264
+ self.A_log = nn.Parameter(torch.log(A))
265
+ self.A_log._no_weight_decay = True
266
+ self.norm = RMSNormGated(
267
+ self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=False
268
+ )
269
+ self.D = nn.Parameter(torch.ones(self.num_heads))
270
+ self.D._no_weight_decay = True
271
+
272
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
273
+ self.use_bias = config.use_bias
274
+
275
+ if not is_fast_path_available:
276
+ logger.warning_once(
277
+ "The fast path is not available because one of "
278
+ "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. "
279
+ "Falling back to the naive implementation. "
280
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
281
+ "https://github.com/Dao-AILab/causal-conv1d"
282
+ )
283
+
284
+ def cuda_kernels_forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ cache_params: Optional[Mamba2Cache] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ attention_mask: Optional[torch.Tensor] = None,
290
+ ):
291
+ # 1. Gated MLP's linear projection
292
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
293
+ projected_states = self.in_proj(hidden_states)
294
+
295
+ # Set up dimensions for reshapes later
296
+ batch_size, seq_len, _ = hidden_states.shape
297
+ groups_time_state_size = self.n_groups * self.ssm_state_size
298
+ d_mlp = (
299
+ projected_states.shape[-1]
300
+ - 2 * self.intermediate_size
301
+ - 2 * self.n_groups * self.ssm_state_size
302
+ - self.num_heads
303
+ ) // 2
304
+
305
+ # Single step calculations via cache
306
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
307
+ _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
308
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
309
+ )
310
+
311
+ # 2. Convolution sequence transformation
312
+ hidden_states_B_C = causal_conv1d_update(
313
+ hidden_states_B_C,
314
+ cache_params.conv_states[self.layer_idx],
315
+ self.conv1d.weight.squeeze(1),
316
+ self.conv1d.bias,
317
+ self.activation,
318
+ )
319
+
320
+ hidden_states, B, C = torch.split(
321
+ hidden_states_B_C,
322
+ [
323
+ self.intermediate_size,
324
+ groups_time_state_size,
325
+ groups_time_state_size,
326
+ ],
327
+ dim=-1,
328
+ )
329
+
330
+ # 3. SSM transformation
331
+ A = -torch.exp(self.A_log.float()) # (nheads,)
332
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
333
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
334
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
335
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
336
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
337
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
338
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
339
+
340
+ hidden_states = selective_state_update(
341
+ cache_params.ssm_states[self.layer_idx],
342
+ hidden_states_reshaped,
343
+ dt,
344
+ A,
345
+ B,
346
+ C,
347
+ D,
348
+ z=None,
349
+ dt_bias=dt_bias,
350
+ dt_softplus=True,
351
+ )
352
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
353
+ hidden_states = self.norm(hidden_states, gate)
354
+
355
+ # 4. Final linear projection
356
+ out = self.out_proj(hidden_states)[:, None, ...]
357
+
358
+ # Fused calculations or step by step if no initialized cache is found
359
+ else:
360
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
361
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
362
+
363
+ # 2-4. Fused kernel for conv1d, SSM, and the final projection
364
+ if self.training and cache_params is None:
365
+ out = mamba_split_conv1d_scan_combined(
366
+ projected_states,
367
+ self.conv1d.weight.squeeze(1),
368
+ self.conv1d.bias,
369
+ self.dt_bias,
370
+ A,
371
+ D=self.D,
372
+ chunk_size=self.chunk_size,
373
+ seq_idx=None, # was seq_idx
374
+ activation=self.activation,
375
+ rmsnorm_weight=self.norm.weight,
376
+ rmsnorm_eps=self.norm.eps,
377
+ outproj_weight=self.out_proj.weight,
378
+ outproj_bias=self.out_proj.bias,
379
+ headdim=self.head_dim,
380
+ ngroups=self.n_groups,
381
+ norm_before_gate=False,
382
+ return_final_states=False,
383
+ **dt_limit_kwargs,
384
+ )
385
+
386
+ else:
387
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
388
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
389
+ )
390
+
391
+ # 2. Convolution sequence transformation
392
+ # Init cache
393
+ if cache_params is not None:
394
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
395
+ conv_states = nn.functional.pad(
396
+ hidden_states_B_C_transposed,
397
+ (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
398
+ )
399
+ cache_params.update_conv_state(
400
+ layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
401
+ )
402
+
403
+ if self.activation not in ["silu", "swish"]:
404
+ hidden_states_B_C = self.act(
405
+ self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
406
+ )
407
+ else:
408
+ hidden_states_B_C = causal_conv1d_fn(
409
+ x=hidden_states_B_C.transpose(1, 2),
410
+ weight=self.conv1d.weight.squeeze(1),
411
+ bias=self.conv1d.bias,
412
+ activation=self.activation,
413
+ ).transpose(1, 2)
414
+
415
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
416
+ hidden_states, B, C = torch.split(
417
+ hidden_states_B_C,
418
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
419
+ dim=-1,
420
+ )
421
+
422
+ # 3. SSM transformation
423
+ scan_output, ssm_state = mamba_chunk_scan_combined(
424
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
425
+ dt,
426
+ A,
427
+ B.view(batch_size, seq_len, self.n_groups, -1),
428
+ C.view(batch_size, seq_len, self.n_groups, -1),
429
+ chunk_size=self.chunk_size,
430
+ D=self.D,
431
+ z=None,
432
+ seq_idx=None,
433
+ return_final_states=True,
434
+ dt_bias=self.dt_bias,
435
+ dt_softplus=True,
436
+ **dt_limit_kwargs,
437
+ )
438
+
439
+ # Init cache
440
+ if ssm_state is not None and cache_params is not None:
441
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
442
+
443
+ scan_output = scan_output.view(batch_size, seq_len, -1)
444
+ # Multiply "gate" branch and apply extra normalization layer
445
+ scan_output = self.norm(scan_output, gate)
446
+
447
+ # 4. Final linear projection
448
+ out = self.out_proj(scan_output)
449
+ return out
450
+
451
+ # fmt: off
452
+ def torch_forward(
453
+ self,
454
+ input_states,
455
+ cache_params: Optional[Mamba2Cache] = None,
456
+ cache_position: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None
458
+ ):
459
+ batch_size, seq_len, _ = input_states.shape
460
+ dtype = input_states.dtype
461
+
462
+ # 1. Gated MLP's linear projection
463
+ input_states = apply_mask_to_padding_states(input_states, attention_mask)
464
+ projected_states = self.in_proj(input_states)
465
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size -
466
+ 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2
467
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
468
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
469
+ )
470
+
471
+ # 2. Convolution sequence transformation
472
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
473
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
474
+
475
+ # We need to guarantee that anything regarding the cache is on the same device
476
+ conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
477
+
478
+ hidden_states_B_C = torch.sum(
479
+ conv_states * self.conv1d.weight.squeeze(1), dim=-1
480
+ )
481
+ if self.use_conv_bias:
482
+ hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
483
+ hidden_states_B_C = self.act(hidden_states_B_C)
484
+ else:
485
+ # Init cache
486
+ if cache_params is not None:
487
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
488
+ conv_states = nn.functional.pad(
489
+ hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
490
+ )
491
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
492
+
493
+ hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
494
+
495
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
496
+ hidden_states, B, C = torch.split(
497
+ hidden_states_B_C,
498
+ [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
499
+ dim=-1
500
+ )
501
+
502
+ # 3. SSM transformation
503
+ A = -torch.exp(self.A_log.float()) # [num_heads]
504
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
505
+ # We need to guarantee that anything regarding the cache is on the same device
506
+ cache_device = cache_params.ssm_states.device
507
+
508
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
509
+ # for batched generation
510
+ dt = dt[:, 0, :][:, None, ...]
511
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
512
+ # [num_heads] -> [num_heads, head_dim]
513
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
514
+
515
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
516
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
517
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
518
+ # [bsz, num_heads, head_dim, state_size]
519
+ dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
520
+
521
+ # Discretize B
522
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
523
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
524
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
525
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
526
+ B = B.reshape(batch_size, -1, B.shape[-1])
527
+ # [bsz, num_heads, head_dim, state_size]
528
+ dB = dt[..., None] * B[..., None, :]
529
+
530
+ # Discretize x into dB
531
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
532
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
533
+ dBx = (dB * hidden_states[..., None]).to(device=cache_device)
534
+
535
+ # State calculation
536
+ cache_params.update_ssm_state(
537
+ layer_idx=self.layer_idx,
538
+ new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
539
+ )
540
+
541
+ # Subsequent output
542
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
543
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
544
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
545
+ C = C.reshape(batch_size, -1, C.shape[-1])
546
+ # [bsz, num_heads, head_dim]
547
+
548
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
549
+ # Reshape ssm_states to merge the first two dimensions
550
+ # Shape: [b*h, d, n]
551
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size)
552
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
553
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
554
+ y = y.view(batch_size, self.num_heads, self.head_dim)
555
+
556
+ # D skip connection
557
+ # [num_heads] -> [num_heads, head_dim]
558
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
559
+ y = (y + hidden_states * D).to(y.dtype)
560
+
561
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
562
+ y = y.reshape(batch_size, -1)[:, None, ...]
563
+ else:
564
+ # begin ssd naive implementation without einsums
565
+ dt = nn.functional.softplus(dt + self.dt_bias)
566
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
567
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
568
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
569
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
570
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
571
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
572
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
573
+
574
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
575
+
576
+ # Discretize x and A
577
+ hidden_states = hidden_states * dt[..., None]
578
+ A = A.to(hidden_states.dtype) * dt
579
+
580
+ # Rearrange into blocks/chunks
581
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
582
+
583
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
584
+ A = A.permute(0, 3, 1, 2)
585
+ A_cumsum = torch.cumsum(A, dim=-1)
586
+
587
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
588
+ # This is the analog of a causal mask
589
+ L = torch.exp(segment_sum(A))
590
+
591
+ # Contraction of C and B to get G (attention-weights like)
592
+ # shape: (b, c, l, s, h, n)
593
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
594
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
595
+
596
+ # Compute M, equivalent to applying attention mask to weights
597
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
598
+ M = M_intermediate.sum(dim=-1)
599
+
600
+ # Compute Y_diag (apply to values)
601
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
602
+
603
+ # 2. Compute the state for each intra-chunk
604
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
605
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
606
+ B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
607
+ states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
608
+
609
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
610
+ # (middle term of factorization of off-diag blocks; A terms)
611
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
612
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
613
+ else:
614
+ previous_states = torch.zeros_like(states[:, :1])
615
+ states = torch.cat([previous_states, states], dim=1)
616
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
617
+ decay_chunk = decay_chunk.transpose(1, 3)
618
+ new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
619
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
620
+
621
+ # 4. Compute state -> output conversion per chunk
622
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
623
+ state_decay_out = torch.exp(A_cumsum)
624
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
625
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
626
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
627
+
628
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
629
+ y = Y_diag + Y_off
630
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
631
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
632
+
633
+ y = y + D_residual
634
+ # Cutting off padded chunks
635
+ if pad_size > 0:
636
+ y = y[:, :seq_len, :, :]
637
+ y = y.reshape(batch_size, seq_len, -1)
638
+
639
+ # Init cache
640
+ if ssm_state is not None and cache_params is not None:
641
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
642
+
643
+ scan_output = self.norm(y, gate)
644
+
645
+ # end ssd naive
646
+
647
+ # 4. Final linear projection
648
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
649
+ return contextualized_states
650
+ # fmt: on
651
+
652
+ def forward(
653
+ self,
654
+ hidden_states,
655
+ cache_params: Optional[Mamba2Cache] = None,
656
+ cache_position: Optional[torch.LongTensor] = None,
657
+ attention_mask: Optional[torch.Tensor] = None,
658
+ ):
659
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
660
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
661
+ dtype = hidden_states.dtype
662
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
663
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
664
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
665
+
666
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
667
+
668
+
669
+ class Mamba2Block(nn.Module):
670
+ def __init__(self, config, layer_idx):
671
+ super().__init__()
672
+ self.config = config
673
+ self.layer_idx = layer_idx
674
+ self.residual_in_fp32 = config.residual_in_fp32
675
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
676
+ self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states,
681
+ cache_params: Optional[Mamba2Cache] = None,
682
+ cache_position: Optional[torch.LongTensor] = None,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ ):
685
+ residual = hidden_states
686
+ hidden_states = self.norm(hidden_states)
687
+ if self.residual_in_fp32:
688
+ residual = residual.to(torch.float32)
689
+
690
+ hidden_states = self.mixer(
691
+ hidden_states,
692
+ cache_params=cache_params,
693
+ cache_position=cache_position,
694
+ attention_mask=attention_mask,
695
+ )
696
+ hidden_states = residual + hidden_states
697
+ if self.residual_in_fp32:
698
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
699
+ return hidden_states
700
+
701
+
702
+ class Mamba2PreTrainedModel(PreTrainedModel, GenerationMixin):
703
+ """
704
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
705
+ models.
706
+ """
707
+
708
+ config_class = Mamba2Config
709
+ base_model_prefix = "backbone"
710
+ _no_split_modules = ["Mamba2Block"]
711
+ supports_gradient_checkpointing = True
712
+ _is_stateful = True
713
+
714
+ def _init_weights(
715
+ self,
716
+ module: nn.Module,
717
+ num_residuals_per_layer: int = 1,
718
+ ):
719
+ """Initialize the weights."""
720
+ if isinstance(module, Mamba2Mixer):
721
+
722
+ # --- A_log ---
723
+ A = torch.arange(1, module.num_heads + 1)
724
+ with torch.no_grad():
725
+ if not isinstance(module.A_log, torch.distributed.tensor.DTensor):
726
+ module.A_log.copy_(torch.log(A))
727
+ else:
728
+ logger.warning_once("`A_log` is a DTensor, skipping initialization")
729
+ module.A_log._no_weight_decay = True
730
+
731
+ # --- D ---
732
+ nn.init.ones_(module.D)
733
+ module.D._no_weight_decay = True
734
+
735
+ # --- dt_bias ---
736
+ dt = torch.exp(
737
+ torch.rand(self.config.num_heads)
738
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
739
+ + math.log(self.config.time_step_min)
740
+ ).clamp(min=self.config.time_step_floor)
741
+
742
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
743
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
744
+ with torch.no_grad():
745
+ if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor):
746
+ module.dt_bias.copy_(inv_dt)
747
+ else:
748
+ logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
749
+ module.dt_bias._no_reinit = True
750
+
751
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
752
+ # Slightly different from the TF version which uses truncated_normal for initialization
753
+ # cf https://github.com/pytorch/pytorch/pull/5617
754
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
755
+ if module.bias is not None:
756
+ nn.init.zeros_(module.bias)
757
+ # guard against deprecated behavior
758
+ if hasattr(module.bias, "_no_reinit"):
759
+ raise ValueError("This is not supposed to happen")
760
+ elif isinstance(module, nn.Embedding):
761
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
762
+ elif hasattr(module, 'reset_parameters'):
763
+ module.reset_parameters()
764
+
765
+ if self.config.rescale_prenorm_residual:
766
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
767
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
768
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
769
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
770
+ #
771
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
772
+ p = None
773
+ if hasattr(module, 'o_proj'):
774
+ # p = module.o_proj.weight
775
+ # guard against deprecated behavior
776
+ raise ValueError("This is not supposed to happen")
777
+ elif hasattr(module, 'out_proj'):
778
+ p = module.out_proj.weight
779
+ elif hasattr(module, 'down_proj'):
780
+ p = module.down_proj.weight
781
+ if p is not None:
782
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
783
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
784
+ # We need to reinit p since this code could be called multiple times
785
+ # Having just p *= scale would repeatedly scale it down
786
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
787
+ with torch.no_grad():
788
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
789
+
790
+
791
+ @dataclass
792
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
793
+ class Mamba2Output(ModelOutput):
794
+ """
795
+ Class for the MAMBA2 model outputs.
796
+
797
+ Args:
798
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
799
+ Sequence of hidden-states at the output of the last layer of the model.
800
+ cache_params (`Mamba2Cache`):
801
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
802
+ avoid providing the old `input_ids`.
803
+
804
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
805
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
806
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
807
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
808
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
809
+
810
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
811
+ """
812
+
813
+ last_hidden_state: Optional[torch.FloatTensor] = None
814
+ cache_params: Optional[Mamba2Cache] = None
815
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+
818
+ @dataclass
819
+ # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
820
+ class Mamba2CausalLMOutput(ModelOutput):
821
+ """
822
+ Base class for causal language model (or autoregressive) outputs.
823
+
824
+ Args:
825
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
826
+ Language modeling loss (for next-token prediction).
827
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
828
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
829
+ cache_params (`Mamba2Cache`):
830
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
831
+ avoid providing the old `input_ids`.
832
+
833
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
834
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
835
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
836
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
837
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
838
+
839
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
840
+ """
841
+
842
+ loss: Optional[torch.FloatTensor] = None
843
+ logits: Optional[torch.FloatTensor] = None
844
+ cache_params: Optional[Mamba2Cache] = None
845
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
846
+
847
+
848
+ class Mamba2Model(Mamba2PreTrainedModel):
849
+ def __init__(self, config):
850
+ super().__init__(config)
851
+
852
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
853
+ self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
854
+
855
+ self.gradient_checkpointing = False
856
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
857
+ # Initialize weights and apply final processing
858
+ self._register_load_state_dict_pre_hook(self.load_hook)
859
+ self.post_init()
860
+
861
+ def load_hook(self, state_dict, prefix, *args):
862
+ for k in state_dict:
863
+ if "embedding." in k:
864
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
865
+ break
866
+
867
+ def get_input_embeddings(self):
868
+ return self.embeddings
869
+
870
+ def set_input_embeddings(self, new_embeddings):
871
+ self.embeddings = new_embeddings
872
+
873
+ def forward(
874
+ self,
875
+ input_ids: Optional[torch.LongTensor] = None,
876
+ inputs_embeds: Optional[torch.LongTensor] = None,
877
+ cache_params: Optional[Mamba2Cache] = None,
878
+ use_cache: Optional[bool] = None,
879
+ output_hidden_states: Optional[bool] = None,
880
+ return_dict: Optional[bool] = None,
881
+ cache_position: Optional[torch.LongTensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ **kwargs,
884
+ ) -> Union[Tuple, Mamba2Output]:
885
+ output_hidden_states = (
886
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
887
+ )
888
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
889
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
890
+
891
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
892
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
893
+
894
+ if inputs_embeds is None:
895
+ inputs_embeds = self.embeddings(input_ids)
896
+
897
+ if self.gradient_checkpointing and self.training and use_cache:
898
+ use_cache = False
899
+
900
+ if use_cache:
901
+ if cache_params is None:
902
+ cache_params = Mamba2Cache(
903
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
904
+ )
905
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
906
+ elif cache_position is None:
907
+ # cases when we do manual forward instead of using `model.generate` which will initiate
908
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
909
+ # hack to conjecture the current cache position
910
+ raise ValueError(
911
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
912
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
913
+ "be initialized for you automatically"
914
+ )
915
+ else:
916
+ cache_params = None
917
+
918
+ hidden_states = inputs_embeds
919
+ all_hidden_states = () if output_hidden_states else None
920
+ for mixer_block in self.layers:
921
+ if self.gradient_checkpointing and self.training:
922
+ hidden_states = self._gradient_checkpointing_func(
923
+ mixer_block.__call__,
924
+ hidden_states,
925
+ cache_params,
926
+ cache_position,
927
+ attention_mask,
928
+ )
929
+ else:
930
+ hidden_states = mixer_block(
931
+ hidden_states,
932
+ cache_params=cache_params,
933
+ cache_position=cache_position,
934
+ attention_mask=attention_mask,
935
+ )
936
+
937
+ if output_hidden_states:
938
+ all_hidden_states = all_hidden_states + (hidden_states,)
939
+
940
+ hidden_states = self.norm_f(hidden_states)
941
+
942
+ if output_hidden_states:
943
+ all_hidden_states = all_hidden_states + (hidden_states,)
944
+
945
+ if not return_dict:
946
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
947
+
948
+ return Mamba2Output(
949
+ last_hidden_state=hidden_states,
950
+ cache_params=cache_params if use_cache else None,
951
+ hidden_states=all_hidden_states,
952
+ )
953
+
954
+
955
+ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
956
+ _tied_weights_keys = []
957
+
958
+ def __init__(self, config):
959
+ super().__init__(config)
960
+ self.backbone = Mamba2Model(config)
961
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
962
+ self.criterion = None
963
+
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ def get_output_embeddings(self):
968
+ return self.lm_head
969
+
970
+ def set_output_embeddings(self, new_embeddings):
971
+ self.lm_head = new_embeddings
972
+
973
+ def get_input_embeddings(self):
974
+ return self.backbone.get_input_embeddings()
975
+
976
+ def set_input_embeddings(self, new_embeddings):
977
+ return self.backbone.set_input_embeddings(new_embeddings)
978
+
979
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
980
+ def prepare_inputs_for_generation(
981
+ self,
982
+ input_ids,
983
+ inputs_embeds=None,
984
+ use_cache=None,
985
+ cache_params: Optional[Mamba2Cache] = None,
986
+ cache_position: Optional[torch.LongTensor] = None,
987
+ attention_mask: Optional[torch.Tensor] = None,
988
+ logits_to_keep: Optional[int] = None,
989
+ **kwargs,
990
+ ):
991
+ if use_cache:
992
+ # `cache_position` should have been initialized in `generate`
993
+ if cache_position is None:
994
+ raise ValueError(
995
+ "`cache_position` should not be None as it should have been initialized in "
996
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
997
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
998
+ )
999
+ if cache_position[0] > 0:
1000
+ input_ids = input_ids[:, -1][..., None]
1001
+
1002
+ if attention_mask is not None:
1003
+ attention_mask = None
1004
+ else:
1005
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
1006
+ # considering padding will be applied when input length is shorter, and truncation
1007
+ # will be applied when it is longer, so it will be equivalent to always have it match
1008
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
1009
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
1010
+
1011
+ if inputs_embeds is not None and cache_params is None:
1012
+ model_inputs = {"inputs_embeds": inputs_embeds}
1013
+ else:
1014
+ model_inputs = {"input_ids": input_ids}
1015
+
1016
+ if logits_to_keep is not None:
1017
+ model_inputs['logits_to_keep'] = logits_to_keep
1018
+
1019
+ model_inputs.update({
1020
+ 'attention_mask': attention_mask,
1021
+ 'cache_params': cache_params,
1022
+ 'use_cache': use_cache,
1023
+ 'cache_position': cache_position,
1024
+ 'logits_to_keep': logits_to_keep
1025
+ })
1026
+ return model_inputs
1027
+
1028
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1029
+ def forward(
1030
+ self,
1031
+ input_ids: Optional[torch.LongTensor] = None,
1032
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1033
+ cache_params: Optional[Mamba2Cache] = None,
1034
+ labels: Optional[torch.LongTensor] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ return_dict: Optional[bool] = None,
1037
+ use_cache: Optional[bool] = None,
1038
+ cache_position: Optional[torch.Tensor] = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ logits_to_keep: Optional[int] = 0,
1041
+ **kwargs, # for now we need this for generation
1042
+ ) -> Union[Tuple, Mamba2CausalLMOutput]:
1043
+ r"""
1044
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1045
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1046
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1047
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1048
+ """
1049
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1050
+
1051
+ outputs = self.backbone(
1052
+ input_ids,
1053
+ cache_params=cache_params,
1054
+ inputs_embeds=inputs_embeds,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ use_cache=use_cache,
1058
+ cache_position=cache_position,
1059
+ attention_mask=attention_mask,
1060
+ )
1061
+ hidden_states = outputs[0]
1062
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
1063
+
1064
+ loss, logits = None, None
1065
+ if not fuse_linear_and_cross_entropy or labels is None:
1066
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
1067
+ if labels is not None:
1068
+ if getattr(self, 'criterion', None) is None:
1069
+ if fuse_linear_and_cross_entropy:
1070
+ criterion = FusedLinearCrossEntropyLoss()
1071
+ elif self.config.fuse_cross_entropy:
1072
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
1073
+ else:
1074
+ criterion = nn.CrossEntropyLoss()
1075
+ else:
1076
+ criterion = self.criterion
1077
+ labels = labels.to(hidden_states.device)
1078
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
1079
+ if fuse_linear_and_cross_entropy:
1080
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
1081
+ else:
1082
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
1083
+
1084
+ if not return_dict:
1085
+ output = (logits,) + outputs[1:]
1086
+ return (loss,) + output if loss is not None else output
1087
+
1088
+ return Mamba2CausalLMOutput(
1089
+ loss=loss,
1090
+ logits=logits,
1091
+ cache_params=outputs.cache_params,
1092
+ hidden_states=outputs.hidden_states,
1093
+ )
fla/models/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc ADDED
Binary file (2.64 kB). View file
 
fla/models/nsa/configuration_nsa.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class NSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'nsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 64,
18
+ num_kv_heads: int = 4,
19
+ head_dim: int = 32,
20
+ qkv_bias: bool = False,
21
+ block_size: int = 64,
22
+ block_counts: Optional[int] = 16,
23
+ window_size: Optional[int] = 512,
24
+ rope_theta: Optional[float] = 10000.,
25
+ max_position_embeddings: int = 2048,
26
+ hidden_ratio: Optional[int] = 4,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = "swish",
29
+ initializer_range: float = 0.006,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_cache: bool = True,
33
+ pad_token_id: int = None,
34
+ bos_token_id: int = 1,
35
+ eos_token_id: int = 2,
36
+ tie_word_embeddings: bool = False,
37
+ fuse_norm: bool = True,
38
+ fuse_swiglu: bool = True,
39
+ fuse_cross_entropy: bool = True,
40
+ vocab_size: int = 32000,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.qkv_bias = qkv_bias
49
+ self.block_size = block_size
50
+ self.block_counts = block_counts
51
+ self.window_size = window_size
52
+ self.rope_theta = rope_theta
53
+ self.max_position_embeddings = max_position_embeddings
54
+
55
+ self.hidden_ratio = hidden_ratio
56
+ self.intermediate_size = intermediate_size
57
+ self.hidden_act = hidden_act
58
+
59
+ self.initializer_range = initializer_range
60
+ self.elementwise_affine = elementwise_affine
61
+ self.norm_eps = norm_eps
62
+ self.use_cache = use_cache
63
+
64
+ self.fuse_norm = fuse_norm
65
+ self.fuse_swiglu = fuse_swiglu
66
+ self.fuse_cross_entropy = fuse_cross_entropy
67
+ self.vocab_size = vocab_size
68
+
69
+ super().__init__(
70
+ pad_token_id=pad_token_id,
71
+ bos_token_id=bos_token_id,
72
+ eos_token_id=eos_token_id,
73
+ tie_word_embeddings=tie_word_embeddings,
74
+ **kwargs,
75
+ )
fla/models/retnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
fla/models/rwkv6/modeling_rwkv6.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv6 import LerpLinear, RWKV6Attention
20
+ from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV6FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV6FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 3.5
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.key = LerpLinear(hidden_size, intermediate_size)
55
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
56
+ self.receptance = LerpLinear(hidden_size, hidden_size)
57
+ self.act_fn = ACT2FN[hidden_act]
58
+
59
+ self.layer_idx = layer_idx
60
+
61
+ def forward(
62
+ self,
63
+ x: torch.Tensor,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ state: Optional[Cache] = None
66
+ ) -> torch.Tensor:
67
+ if attention_mask is not None:
68
+ x = x.mul_(attention_mask[:, -x.shape[-2]:, None])
69
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
70
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
71
+ else:
72
+ shifted = self.time_shift(x)
73
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
74
+ shifted[:, 0] = state[self.layer_idx]['ffn_state']
75
+ delta = shifted - x
76
+ key = self.act_fn(self.key(x, delta))
77
+ value = self.value(key)
78
+ receptance = self.receptance(x, delta)
79
+
80
+ if state is not None:
81
+ # no need to update the offset twice
82
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
83
+ return receptance.sigmoid() * value, state
84
+
85
+
86
+ class RWKV6Block(nn.Module):
87
+ def __init__(self, config: RWKV6Config, layer_idx: int):
88
+ super().__init__()
89
+
90
+ self.config = config
91
+ self.layer_idx = layer_idx
92
+
93
+ if config.norm_first and layer_idx == 0:
94
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
95
+ config.hidden_size,
96
+ bias=config.norm_bias,
97
+ eps=config.norm_eps
98
+ )
99
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
100
+ config.hidden_size,
101
+ bias=config.norm_bias,
102
+ eps=config.norm_eps
103
+ )
104
+ if config.attn is not None and layer_idx in config.attn['layers']:
105
+ self.attn = Attention(
106
+ hidden_size=config.hidden_size,
107
+ num_heads=config.attn['num_heads'],
108
+ num_kv_heads=config.attn['num_kv_heads'],
109
+ qkv_bias=config.attn['qkv_bias'],
110
+ window_size=config.attn['window_size'],
111
+ rope_theta=config.attn['rope_theta'],
112
+ max_position_embeddings=config.max_position_embeddings,
113
+ layer_idx=layer_idx
114
+ )
115
+ else:
116
+ self.attn = RWKV6Attention(
117
+ mode=config.attn_mode,
118
+ hidden_size=config.hidden_size,
119
+ expand_k=config.expand_k,
120
+ expand_v=config.expand_v,
121
+ num_heads=config.num_heads,
122
+ proj_low_rank_dim=config.proj_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ norm_eps=config.norm_eps,
125
+ fuse_norm=config.fuse_norm,
126
+ layer_idx=layer_idx
127
+ )
128
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
129
+ config.hidden_size,
130
+ bias=config.norm_bias,
131
+ eps=config.norm_eps
132
+ )
133
+ self.ffn = RWKV6FeedForward(
134
+ hidden_size=config.hidden_size,
135
+ hidden_ratio=config.hidden_ratio,
136
+ intermediate_size=config.intermediate_size,
137
+ hidden_act=config.hidden_act,
138
+ layer_idx=layer_idx
139
+ )
140
+
141
+ def forward(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ attention_mask: Optional[torch.Tensor] = None,
145
+ past_key_values: Optional[Cache] = None,
146
+ use_cache: Optional[bool] = False,
147
+ output_attentions: Optional[bool] = False,
148
+ **kwargs,
149
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
150
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
151
+ hidden_states = self.attn_norm(residual)
152
+ hidden_states, attentions, past_key_values = self.attn(
153
+ hidden_states=hidden_states,
154
+ attention_mask=attention_mask,
155
+ past_key_values=past_key_values,
156
+ use_cache=use_cache,
157
+ output_attentions=output_attentions,
158
+ **kwargs
159
+ )
160
+ if self.config.fuse_norm:
161
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
162
+ else:
163
+ hidden_states = residual + hidden_states
164
+ residual = hidden_states
165
+ hidden_states = self.ffn_norm(hidden_states)
166
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
167
+ hidden_states = residual + hidden_states
168
+
169
+ outputs = (hidden_states, attentions, past_key_values)
170
+
171
+ return outputs
172
+
173
+
174
+ class RWKV6PreTrainedModel(PreTrainedModel):
175
+
176
+ config_class = RWKV6Config
177
+ base_model_prefix = 'model'
178
+ supports_gradient_checkpointing = True
179
+ _no_split_modules = ['RWKV6Block']
180
+ _supports_cache_class = True
181
+
182
+ def __init__(self, *inputs, **kwargs):
183
+ super().__init__(*inputs, **kwargs)
184
+
185
+ def _init_weights(
186
+ self,
187
+ module: nn.Module,
188
+ rescale_prenorm_residual: bool = True,
189
+ num_residuals_per_layer: int = 2,
190
+ ):
191
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
192
+ # Slightly different from the TF version which uses truncated_normal for initialization
193
+ # cf https://github.com/pytorch/pytorch/pull/5617
194
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
195
+ if module.bias is not None:
196
+ nn.init.zeros_(module.bias)
197
+ elif isinstance(module, nn.Parameter):
198
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
199
+ elif isinstance(module, nn.Embedding):
200
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
201
+ elif hasattr(module, 'reset_parameters'):
202
+ module.reset_parameters()
203
+
204
+ if rescale_prenorm_residual:
205
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
206
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
207
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
208
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
209
+ #
210
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
211
+ p = None
212
+ if hasattr(module, 'o_proj'):
213
+ p = module.o_proj.weight
214
+ elif hasattr(module, 'down_proj'):
215
+ p = module.down_proj.weight
216
+ if p is not None:
217
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
218
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
219
+ # We need to reinit p since this code could be called multiple times
220
+ # Having just p *= scale would repeatedly scale it down
221
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
222
+ with torch.no_grad():
223
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
224
+
225
+
226
+ class RWKV6Model(RWKV6PreTrainedModel):
227
+
228
+ def __init__(self, config: RWKV6Config):
229
+ super().__init__(config)
230
+ self.padding_idx = config.pad_token_id
231
+ self.vocab_size = config.vocab_size
232
+
233
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
234
+ self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
235
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
236
+ config.hidden_size,
237
+ bias=config.norm_bias,
238
+ eps=config.norm_eps
239
+ )
240
+
241
+ self.gradient_checkpointing = False
242
+
243
+ self.post_init()
244
+
245
+ def get_input_embeddings(self):
246
+ return self.embeddings
247
+
248
+ def set_input_embeddings(self, value):
249
+ self.embeddings = value
250
+
251
+ def forward(
252
+ self,
253
+ input_ids: Optional[torch.LongTensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None, # noqa
255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
256
+ past_key_values: Optional[Cache] = None,
257
+ use_cache: Optional[bool] = None,
258
+ output_attentions: Optional[bool] = None,
259
+ output_hidden_states: Optional[bool] = None,
260
+ return_dict: Optional[bool] = None,
261
+ **kwargs: Unpack[Dict]
262
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
263
+ if output_attentions:
264
+ warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.")
265
+ output_attentions = False
266
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
267
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
268
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
269
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
270
+
271
+ # retrieve input_ids and inputs_embeds
272
+ if input_ids is not None and inputs_embeds is not None:
273
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
274
+ if input_ids is None and inputs_embeds is None:
275
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
276
+
277
+ if inputs_embeds is None:
278
+ inputs_embeds = self.embeddings(input_ids)
279
+ hidden_states = inputs_embeds
280
+
281
+ if use_cache and not isinstance(past_key_values, Cache):
282
+ past_key_values = Cache.from_legacy_cache(past_key_values)
283
+
284
+ if self.gradient_checkpointing and self.training and use_cache:
285
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
286
+ use_cache = False
287
+
288
+ all_hidden_states = () if output_hidden_states else None
289
+ all_attns = () if output_attentions else None
290
+ for layer in self.layers:
291
+ if output_hidden_states:
292
+ all_hidden_states += (hidden_states,)
293
+
294
+ if self.gradient_checkpointing and self.training:
295
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
296
+ layer.__call__,
297
+ hidden_states,
298
+ attention_mask,
299
+ past_key_values,
300
+ use_cache,
301
+ output_attentions,
302
+ **kwargs
303
+ )
304
+ else:
305
+ hidden_states, attentions, past_key_values = layer(
306
+ hidden_states,
307
+ attention_mask=attention_mask,
308
+ past_key_values=past_key_values,
309
+ use_cache=use_cache,
310
+ output_attentions=output_attentions,
311
+ **kwargs
312
+ )
313
+
314
+ if output_attentions:
315
+ all_attns += (attentions,)
316
+
317
+ hidden_states = self.norm(hidden_states)
318
+
319
+ # add hidden states from the last decoder layer
320
+ if output_hidden_states:
321
+ all_hidden_states += (hidden_states,)
322
+
323
+ if not return_dict:
324
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
325
+ return BaseModelOutputWithPast(
326
+ last_hidden_state=hidden_states,
327
+ past_key_values=past_key_values,
328
+ hidden_states=all_hidden_states,
329
+ attentions=all_attns
330
+ )
331
+
332
+
333
+ class RWKV6ForCausalLM(RWKV6PreTrainedModel, GenerationMixin):
334
+
335
+ _tied_weights_keys = ["lm_head.weight"]
336
+
337
+ def __init__(self, config):
338
+ super().__init__(config)
339
+ self.model = RWKV6Model(config)
340
+ self.vocab_size = config.vocab_size
341
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
342
+ self.criterion = None
343
+
344
+ # Initialize weights and apply final processing
345
+ self.post_init()
346
+
347
+ def get_input_embeddings(self):
348
+ return self.model.embeddings
349
+
350
+ def set_input_embeddings(self, value):
351
+ self.model.embeddings = value
352
+
353
+ def get_output_embeddings(self):
354
+ return self.lm_head
355
+
356
+ def set_output_embeddings(self, new_embeddings):
357
+ self.lm_head = new_embeddings
358
+
359
+ def set_decoder(self, decoder):
360
+ self.model = decoder
361
+
362
+ def get_decoder(self):
363
+ return self.model
364
+
365
+ def generate(self, *args, **kwargs):
366
+ try:
367
+ return super().generate(*args, **kwargs)
368
+ except AttributeError as exception:
369
+ if 'past_key_values' in str(exception):
370
+ raise AttributeError(
371
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
372
+ f"which is not supported for {self.__class__.__name__}. "
373
+ f"Try another generation strategy instead. "
374
+ f"For the available generation strategies, check this doc: "
375
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
376
+ )
377
+ else:
378
+ raise exception
379
+
380
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
381
+ def prepare_inputs_for_generation(
382
+ self,
383
+ input_ids: torch.LongTensor = None,
384
+ past_key_values: Optional[Cache] = None,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ inputs_embeds: Optional[torch.Tensor] = None,
387
+ use_cache: bool = True,
388
+ logits_to_keep: Optional[int] = None,
389
+ **kwargs
390
+ ):
391
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
392
+ if past_key_values is not None and len(past_key_values) > 0:
393
+ input_ids = input_ids[:, -1:]
394
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
395
+ if inputs_embeds is not None and len(past_key_values) == 0:
396
+ model_inputs = {'inputs_embeds': inputs_embeds}
397
+ else:
398
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
399
+ # recompiles graphs as the stride of the inputs is a guard.
400
+ # Ref: https://github.com/huggingface/transformers/pull/29114
401
+ # TODO: use `next_tokens` directly instead.
402
+ model_inputs = {'input_ids': input_ids.contiguous()}
403
+
404
+ if logits_to_keep is not None:
405
+ model_inputs['logits_to_keep'] = logits_to_keep
406
+
407
+ model_inputs.update({
408
+ 'past_key_values': past_key_values,
409
+ 'use_cache': use_cache,
410
+ 'attention_mask': attention_mask,
411
+ })
412
+ return model_inputs
413
+
414
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
415
+ def forward(
416
+ self,
417
+ input_ids: torch.LongTensor = None,
418
+ attention_mask: Optional[torch.Tensor] = None,
419
+ inputs_embeds: Optional[torch.Tensor] = None,
420
+ past_key_values: Optional[Cache] = None,
421
+ labels: Optional[torch.LongTensor] = None,
422
+ use_cache: Optional[bool] = None,
423
+ output_attentions: Optional[bool] = None,
424
+ output_hidden_states: Optional[bool] = None,
425
+ return_dict: Optional[bool] = None,
426
+ logits_to_keep: Optional[int] = 0,
427
+ **kwargs: Unpack[Dict]
428
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
429
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
430
+ output_hidden_states = (
431
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
432
+ )
433
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
434
+
435
+ outputs = self.model(
436
+ input_ids=input_ids,
437
+ attention_mask=attention_mask,
438
+ inputs_embeds=inputs_embeds,
439
+ past_key_values=past_key_values,
440
+ use_cache=use_cache,
441
+ output_attentions=output_attentions,
442
+ output_hidden_states=output_hidden_states,
443
+ return_dict=return_dict,
444
+ **kwargs
445
+ )
446
+
447
+ hidden_states = outputs[0]
448
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
449
+
450
+ loss, logits = None, None
451
+ if not fuse_linear_and_cross_entropy or labels is None:
452
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
453
+ if labels is not None:
454
+ if getattr(self, 'criterion', None) is None:
455
+ if fuse_linear_and_cross_entropy:
456
+ criterion = FusedLinearCrossEntropyLoss()
457
+ elif self.config.fuse_cross_entropy:
458
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
459
+ else:
460
+ criterion = nn.CrossEntropyLoss()
461
+ else:
462
+ criterion = self.criterion
463
+ labels = labels.to(hidden_states.device)
464
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
465
+ if fuse_linear_and_cross_entropy:
466
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
467
+ else:
468
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
469
+
470
+ if not return_dict:
471
+ output = (logits,) + outputs[1:]
472
+ return (loss,) + output if loss is not None else output
473
+
474
+ return CausalLMOutputWithPast(
475
+ loss=loss,
476
+ logits=logits,
477
+ past_key_values=outputs.past_key_values,
478
+ hidden_states=outputs.hidden_states,
479
+ attentions=outputs.attentions,
480
+ )
fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.39 kB). View file
 
fla/models/samba/modeling_samba.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import ModelOutput, logging
15
+ from transformers.utils.deprecation import deprecate_kwarg
16
+
17
+ from fla.layers.attn import Attention
18
+ from fla.models.mamba.modeling_mamba import MambaCache, MambaMixer
19
+ from fla.models.samba.configuration_samba import SambaConfig
20
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
21
+ from fla.modules import GatedMLP as SambaMLP
22
+ from fla.modules import RMSNorm
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.processing_utils import Unpack
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SambaBlock(nn.Module):
31
+ def __init__(self, config, layer_idx):
32
+ super().__init__()
33
+
34
+ self.config = config
35
+ self.layer_idx = layer_idx
36
+
37
+ self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
38
+ if config.attn is not None and layer_idx in config.attn['layers']:
39
+ self.mixer = Attention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.attn['num_heads'],
42
+ num_kv_heads=config.attn['num_kv_heads'],
43
+ qkv_bias=config.attn['qkv_bias'],
44
+ window_size=config.attn['window_size'],
45
+ rope_theta=config.attn['rope_theta'],
46
+ max_position_embeddings=config.max_position_embeddings,
47
+ layer_idx=layer_idx
48
+ )
49
+ else:
50
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = SambaMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ hidden_act=config.hidden_act,
56
+ fuse_swiglu=config.fuse_swiglu
57
+ )
58
+
59
+ def forward(
60
+ self,
61
+ hidden_states: torch.Tensor,
62
+ cache_params: Optional[Tuple[torch.Tensor]] = None,
63
+ **kwargs: Unpack[Dict]
64
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
65
+
66
+ residual = hidden_states
67
+ hidden_states = self.mixer_norm(hidden_states)
68
+ if isinstance(self.mixer, MambaMixer):
69
+ hidden_states = self.mixer(hidden_states, cache_params=cache_params, **kwargs)
70
+ else:
71
+ hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params, **kwargs)
72
+ if self.config.fuse_norm:
73
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
74
+ else:
75
+ hidden_states = residual + hidden_states
76
+ residual = hidden_states
77
+ hidden_states = self.mlp_norm(hidden_states)
78
+ hidden_states = self.mlp(hidden_states, **kwargs)
79
+ hidden_states = residual + hidden_states
80
+ return hidden_states
81
+
82
+
83
+ class SambaPreTrainedModel(PreTrainedModel):
84
+ """
85
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
86
+ models.
87
+ """
88
+
89
+ config_class = SambaConfig
90
+ base_model_prefix = "backbone"
91
+ _no_split_modules = ["SambaBlock"]
92
+ supports_gradient_checkpointing = True
93
+
94
+ def _init_weights(self, module):
95
+ """Initialize the weights."""
96
+ if isinstance(module, nn.Linear):
97
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
98
+ if module.bias is not None:
99
+ if not getattr(module.bias, "_no_reinit", False):
100
+ nn.init.zeros_(module.bias)
101
+ elif isinstance(module, MambaMixer):
102
+ module.A_log._no_weight_decay = True
103
+ module.D._no_weight_decay = True
104
+
105
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
106
+ if self.config.time_step_init_scheme == "constant":
107
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
108
+ elif self.config.time_step_init_scheme == "random":
109
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
110
+
111
+ dt = torch.exp(
112
+ torch.rand(self.config.intermediate_size)
113
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
114
+ + math.log(self.config.time_step_min)
115
+ ).clamp(min=self.config.time_step_floor)
116
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
117
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
118
+ with torch.no_grad():
119
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
120
+ module.dt_proj.bias._no_reinit = True
121
+ elif isinstance(module, nn.Embedding):
122
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
123
+ elif hasattr(module, 'reset_parameters'):
124
+ module.reset_parameters()
125
+
126
+ if self.config.rescale_prenorm_residual:
127
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
128
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
129
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
130
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
131
+ #
132
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
133
+ for name, p in module.named_parameters():
134
+ if name in ["out_proj.weight"]:
135
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
136
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
137
+ # We need to reinit p since this code could be called multiple times
138
+ # Having just p *= scale would repeatedly scale it down
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(self.config.num_layers)
142
+
143
+
144
+ @dataclass
145
+ class SambaOutput(ModelOutput):
146
+ """
147
+ Class for the Samba model outputs.
148
+
149
+ Args:
150
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
151
+ Sequence of hidden-states at the output of the last layer of the model.
152
+ cache_params (`MambaCache`):
153
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
154
+ avoid providing the old `input_ids`.
155
+
156
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
157
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
158
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
159
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
160
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
161
+
162
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
163
+ """
164
+
165
+ last_hidden_state: Optional[torch.FloatTensor] = None
166
+ cache_params: Optional[MambaCache] = None
167
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
168
+
169
+
170
+ @dataclass
171
+ class SambaCausalLMOutput(ModelOutput):
172
+ """
173
+ Base class for causal language model (or autoregressive) outputs.
174
+
175
+ Args:
176
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
177
+ Language modeling loss (for next-token prediction).
178
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
179
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
180
+ cache_params (`MambaCache`):
181
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
182
+ avoid providing the old `input_ids`.
183
+
184
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
185
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
186
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
187
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
188
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
189
+
190
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
191
+ """
192
+
193
+ loss: Optional[torch.FloatTensor] = None
194
+ logits: Optional[torch.FloatTensor] = None
195
+ cache_params: Optional[MambaCache] = None
196
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
197
+
198
+
199
+ class SambaModel(SambaPreTrainedModel):
200
+ def __init__(self, config):
201
+ super().__init__(config)
202
+
203
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
204
+ self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
205
+
206
+ self.gradient_checkpointing = False
207
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps)
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embeddings
213
+
214
+ def set_input_embeddings(self, new_embeddings):
215
+ self.embeddings = new_embeddings
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ inputs_embeds: Optional[torch.LongTensor] = None,
221
+ cache_params: Optional[MambaCache] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ return_dict: Optional[bool] = None,
225
+ **kwargs: Unpack[Dict]
226
+ ) -> Union[Tuple, SambaOutput]:
227
+ output_hidden_states = (
228
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
229
+ )
230
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
232
+
233
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
234
+ raise ValueError(
235
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
236
+ )
237
+
238
+ if inputs_embeds is None:
239
+ inputs_embeds = self.embeddings(input_ids)
240
+
241
+ if self.gradient_checkpointing and self.training and use_cache:
242
+ use_cache = False
243
+
244
+ if cache_params is None and use_cache:
245
+ cache_params = MambaCache(
246
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
247
+ )
248
+
249
+ hidden_states = inputs_embeds
250
+ all_hidden_states = () if output_hidden_states else None
251
+ for mixer_block in self.layers:
252
+ if self.gradient_checkpointing and self.training:
253
+ hidden_states = self._gradient_checkpointing_func(
254
+ mixer_block.__call__,
255
+ hidden_states,
256
+ cache_params,
257
+ **kwargs
258
+ )
259
+ else:
260
+ hidden_states = mixer_block(
261
+ hidden_states,
262
+ cache_params=cache_params,
263
+ **kwargs
264
+ )
265
+
266
+ if output_hidden_states:
267
+ all_hidden_states = all_hidden_states + (hidden_states,)
268
+
269
+ if use_cache:
270
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
271
+
272
+ hidden_states = self.norm_f(hidden_states)
273
+
274
+ if output_hidden_states:
275
+ all_hidden_states = all_hidden_states + (hidden_states,)
276
+
277
+ if not return_dict:
278
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
279
+
280
+ return SambaOutput(
281
+ last_hidden_state=hidden_states,
282
+ cache_params=cache_params if use_cache else None,
283
+ hidden_states=all_hidden_states,
284
+ )
285
+
286
+
287
+ class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin):
288
+
289
+ _tied_weights_keys = ["lm_head.weight"]
290
+
291
+ def __init__(self, config):
292
+ super().__init__(config)
293
+ self.backbone = SambaModel(config)
294
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
295
+ self.criterion = None
296
+
297
+ # Initialize weights and apply final processing
298
+ self.post_init()
299
+
300
+ def get_output_embeddings(self):
301
+ return self.lm_head
302
+
303
+ def set_output_embeddings(self, new_embeddings):
304
+ self.lm_head = new_embeddings
305
+
306
+ def get_input_embeddings(self):
307
+ return self.backbone.get_input_embeddings()
308
+
309
+ def set_input_embeddings(self, new_embeddings):
310
+ return self.backbone.set_input_embeddings(new_embeddings)
311
+
312
+ def _update_model_kwargs_for_generation(
313
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
314
+ ) -> Dict[str, Any]:
315
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
316
+ return model_kwargs
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids,
322
+ cache_params:
323
+ Optional[MambaCache] = None,
324
+ inputs_embeds=None,
325
+ attention_mask=None,
326
+ use_cache: Optional[bool] = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs: Unpack[Dict]
329
+ ):
330
+ # only last token for inputs_ids if the state is passed along.
331
+ if cache_params is not None:
332
+ input_ids = input_ids[:, -1].unsqueeze(-1)
333
+
334
+ if inputs_embeds is not None and cache_params is None:
335
+ model_inputs = {"inputs_embeds": inputs_embeds}
336
+ else:
337
+ model_inputs = {"input_ids": input_ids}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'cache_params': cache_params,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ 'logits_to_keep': logits_to_keep,
347
+ })
348
+ return model_inputs
349
+
350
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
351
+ def forward(
352
+ self,
353
+ input_ids: Optional[torch.LongTensor] = None,
354
+ attention_mask: Optional[torch.Tensor] = None, # noqa
355
+ inputs_embeds: Optional[torch.FloatTensor] = None,
356
+ cache_params: Optional[MambaCache] = None,
357
+ labels: Optional[torch.LongTensor] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ use_cache: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, SambaCausalLMOutput]:
364
+ r"""
365
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
366
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
367
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
368
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
369
+ """
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.backbone(
373
+ input_ids,
374
+ cache_params=cache_params,
375
+ inputs_embeds=inputs_embeds,
376
+ output_hidden_states=output_hidden_states,
377
+ return_dict=return_dict,
378
+ use_cache=use_cache,
379
+ **kwargs
380
+ )
381
+ hidden_states = outputs[0]
382
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
383
+
384
+ loss, logits = None, None
385
+ if not fuse_linear_and_cross_entropy or labels is None:
386
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
387
+ if labels is not None:
388
+ if getattr(self, 'criterion', None) is None:
389
+ if fuse_linear_and_cross_entropy:
390
+ criterion = FusedLinearCrossEntropyLoss()
391
+ elif self.config.fuse_cross_entropy:
392
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
393
+ else:
394
+ criterion = nn.CrossEntropyLoss()
395
+ else:
396
+ criterion = self.criterion
397
+ labels = labels.to(hidden_states.device)
398
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
399
+ if fuse_linear_and_cross_entropy:
400
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
401
+ else:
402
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
403
+
404
+ if not return_dict:
405
+ output = (logits,) + outputs[1:]
406
+ return (loss,) + output if loss is not None else output
407
+
408
+ return SambaCausalLMOutput(
409
+ loss=loss,
410
+ logits=logits,
411
+ cache_params=outputs.cache_params,
412
+ hidden_states=outputs.hidden_states,
413
+ )
fla/models/transformer/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer.configuration_transformer import TransformerConfig
6
+ from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel
7
+
8
+ AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
9
+ AutoModel.register(TransformerConfig, TransformerModel)
10
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (795 Bytes). View file
 
fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.69 kB). View file
 
fla/models/transformer_top/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21 kB). View file
 
fla/modules/__pycache__/feature_map.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
fla/modules/__pycache__/fused_bitlinear.cpython-312.pyc ADDED
Binary file (23.6 kB). View file
 
fla/modules/__pycache__/fused_kl_div.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc ADDED
Binary file (17.8 kB). View file