zaydzuhri commited on
Commit
1a7a6e1
·
verified ·
1 Parent(s): bd301da

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/__init__.py +13 -0
  2. fla/models/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  4. fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
  5. fla/models/abc/modeling_abc.py +418 -0
  6. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  7. fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc +0 -0
  8. fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc +0 -0
  9. fla/models/bitnet/configuration_bitnet.py +67 -0
  10. fla/models/bitnet/modeling_bitnet.py +441 -0
  11. fla/models/delta_net/__init__.py +12 -0
  12. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  13. fla/models/delta_net/configuration_delta_net.py +91 -0
  14. fla/models/delta_net/modeling_delta_net.py +415 -0
  15. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  16. fla/models/forgetting_transformer/modeling_forgetting_transformer.py +408 -0
  17. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  18. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  19. fla/models/gated_deltaproduct/__init__.py +14 -0
  20. fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py +90 -0
  21. fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
  22. fla/models/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  23. fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
  24. fla/models/gla/modeling_gla.py +417 -0
  25. fla/models/gsa/__init__.py +13 -0
  26. fla/models/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla/models/gsa/configuration_gsa.py +97 -0
  28. fla/models/gsa/modeling_gsa.py +420 -0
  29. fla/models/hgrn/__init__.py +13 -0
  30. fla/models/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  31. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc +0 -0
  32. fla/models/hgrn/configuration_hgrn.py +81 -0
  33. fla/models/hgrn2/__init__.py +13 -0
  34. fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc +0 -0
  36. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  37. fla/models/hgrn2/configuration_hgrn2.py +91 -0
  38. fla/models/hgrn2/modeling_hgrn2.py +421 -0
  39. fla/models/lightnet/configuration_lightnet.py +83 -0
  40. fla/models/lightnet/modeling_lightnet.py +410 -0
  41. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
  42. fla/models/linear_attn/configuration_linear_attn.py +91 -0
  43. fla/models/linear_attn/modeling_linear_attn.py +406 -0
  44. fla/models/mamba/__init__.py +13 -0
  45. fla/models/mamba/__pycache__/__init__.cpython-312.pyc +0 -0
  46. fla/models/mamba/configuration_mamba.py +166 -0
  47. fla/models/mamba/modeling_mamba.py +843 -0
  48. fla/models/mamba2/__init__.py +13 -0
  49. fla/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
  50. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
fla/models/abc/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.abc.configuration_abc import ABCConfig
6
+ from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
7
+
8
+ AutoConfig.register(ABCConfig.model_type, ABCConfig)
9
+ AutoModel.register(ABCConfig, ABCModel)
10
+ AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
11
+
12
+
13
+ __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
fla/models/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/abc/modeling_abc.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.abc import ABCAttention
19
+ from fla.layers.attn import Attention
20
+ from fla.models.abc.configuration_abc import ABCConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as ABCMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class ABCBlock(nn.Module):
33
+ def __init__(self, config: ABCConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = ABCAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_slots=config.num_slots,
58
+ use_short_conv=config.use_short_conv,
59
+ conv_size=config.conv_size,
60
+ gate_fn=config.hidden_act,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ use_rope=config.use_rope,
64
+ clamp_min=config.clamp_min,
65
+ clamp_max=config.clamp_max,
66
+ fuse_norm=config.fuse_norm,
67
+ layer_idx=layer_idx
68
+ )
69
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
70
+ self.mlp = ABCMLP(
71
+ hidden_size=config.hidden_size,
72
+ hidden_ratio=config.hidden_ratio,
73
+ intermediate_size=config.intermediate_size,
74
+ hidden_act=config.hidden_act,
75
+ fuse_swiglu=config.fuse_swiglu
76
+ )
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
83
+ use_cache: Optional[bool] = False,
84
+ output_attentions: Optional[bool] = False,
85
+ **kwargs: Unpack[Dict]
86
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
87
+
88
+ residual = hidden_states
89
+
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class ABCPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = ABCConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['ABCBlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class ABCModel(ABCPreTrainedModel):
169
+
170
+ def __init__(self, config: ABCConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+ super().__init__(config)
277
+ self.model = ABCModel(config)
278
+ self.vocab_size = config.vocab_size
279
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
280
+ self.criterion = None
281
+
282
+ # Initialize weights and apply final processing
283
+ self.post_init()
284
+
285
+ def get_input_embeddings(self):
286
+ return self.model.embeddings
287
+
288
+ def set_input_embeddings(self, value):
289
+ self.model.embeddings = value
290
+
291
+ def get_output_embeddings(self):
292
+ return self.lm_head
293
+
294
+ def set_output_embeddings(self, new_embeddings):
295
+ self.lm_head = new_embeddings
296
+
297
+ def set_decoder(self, decoder):
298
+ self.model = decoder
299
+
300
+ def get_decoder(self):
301
+ return self.model
302
+
303
+ def generate(self, *args, **kwargs):
304
+ try:
305
+ return super().generate(*args, **kwargs)
306
+ except AttributeError as exception:
307
+ if 'past_key_values' in str(exception):
308
+ raise AttributeError(
309
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
310
+ f"which is not supported for {self.__class__.__name__}. "
311
+ f"Try another generation strategy instead. "
312
+ f"For the available generation strategies, check this doc: "
313
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
314
+ )
315
+ else:
316
+ raise exception
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids: torch.LongTensor = None,
322
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
323
+ attention_mask: Optional[torch.Tensor] = None,
324
+ inputs_embeds: Optional[torch.Tensor] = None,
325
+ use_cache: bool = True,
326
+ logits_to_keep: Optional[int] = None,
327
+ **kwargs
328
+ ):
329
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
330
+ if past_key_values is not None and len(past_key_values) > 0:
331
+ input_ids = input_ids[:, -1:]
332
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
333
+ if inputs_embeds is not None and len(past_key_values) == 0:
334
+ model_inputs = {'inputs_embeds': inputs_embeds}
335
+ else:
336
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
337
+ # recompiles graphs as the stride of the inputs is a guard.
338
+ # Ref: https://github.com/huggingface/transformers/pull/29114
339
+ # TODO: use `next_tokens` directly instead.
340
+ model_inputs = {'input_ids': input_ids.contiguous()}
341
+
342
+ if logits_to_keep is not None:
343
+ model_inputs['logits_to_keep'] = logits_to_keep
344
+
345
+ model_inputs.update({
346
+ 'past_key_values': past_key_values,
347
+ 'use_cache': use_cache,
348
+ 'attention_mask': attention_mask,
349
+ })
350
+ return model_inputs
351
+
352
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
353
+ def forward(
354
+ self,
355
+ input_ids: torch.LongTensor = None,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ inputs_embeds: Optional[torch.Tensor] = None,
358
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
359
+ labels: Optional[torch.LongTensor] = None,
360
+ use_cache: Optional[bool] = None,
361
+ output_attentions: Optional[bool] = None,
362
+ output_hidden_states: Optional[bool] = None,
363
+ return_dict: Optional[bool] = None,
364
+ logits_to_keep: Optional[int] = 0,
365
+ **kwargs: Unpack[Dict]
366
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
367
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
368
+ output_hidden_states = (
369
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ )
371
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
372
+
373
+ outputs = self.model(
374
+ input_ids=input_ids,
375
+ attention_mask=attention_mask,
376
+ inputs_embeds=inputs_embeds,
377
+ past_key_values=past_key_values,
378
+ use_cache=use_cache,
379
+ output_attentions=output_attentions,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict,
382
+ **kwargs
383
+ )
384
+
385
+ hidden_states = outputs[0]
386
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
387
+
388
+ loss, logits = None, None
389
+ if not fuse_linear_and_cross_entropy or labels is None:
390
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
391
+ if labels is not None:
392
+ if getattr(self, 'criterion', None) is None:
393
+ if fuse_linear_and_cross_entropy:
394
+ criterion = FusedLinearCrossEntropyLoss()
395
+ elif self.config.fuse_cross_entropy:
396
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
397
+ else:
398
+ criterion = nn.CrossEntropyLoss()
399
+ else:
400
+ criterion = self.criterion
401
+ labels = labels.to(hidden_states.device)
402
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
403
+ if fuse_linear_and_cross_entropy:
404
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
405
+ else:
406
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
407
+
408
+ if not return_dict:
409
+ output = (logits,) + outputs[1:]
410
+ return (loss,) + output if loss is not None else output
411
+
412
+ return CausalLMOutputWithPast(
413
+ loss=loss,
414
+ logits=logits,
415
+ past_key_values=outputs.past_key_values,
416
+ hidden_states=outputs.hidden_states,
417
+ attentions=outputs.attentions,
418
+ )
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc ADDED
Binary file (2.37 kB). View file
 
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/bitnet/configuration_bitnet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class BitNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'bitnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ window_size: Optional[int] = None,
20
+ rope_theta: Optional[float] = 10000.,
21
+ max_position_embeddings: int = 2048,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ initializer_range: float = 0.006,
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ use_cache: bool = True,
29
+ pad_token_id: int = None,
30
+ bos_token_id: int = 1,
31
+ eos_token_id: int = 2,
32
+ tie_word_embeddings: bool = False,
33
+ fuse_norm: bool = True,
34
+ fuse_swiglu: bool = True,
35
+ fuse_cross_entropy: bool = True,
36
+ vocab_size: int = 32000,
37
+ **kwargs,
38
+ ):
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_heads = num_heads
42
+ self.num_kv_heads = num_kv_heads
43
+ self.window_size = window_size
44
+ self.rope_theta = rope_theta
45
+ self.max_position_embeddings = max_position_embeddings
46
+
47
+ self.hidden_ratio = hidden_ratio
48
+ self.intermediate_size = intermediate_size
49
+ self.hidden_act = hidden_act
50
+
51
+ self.initializer_range = initializer_range
52
+ self.elementwise_affine = elementwise_affine
53
+ self.norm_eps = norm_eps
54
+ self.use_cache = use_cache
55
+
56
+ self.fuse_norm = fuse_norm
57
+ self.fuse_swiglu = fuse_swiglu
58
+ self.fuse_cross_entropy = fuse_cross_entropy
59
+ self.vocab_size = vocab_size
60
+
61
+ super().__init__(
62
+ pad_token_id=pad_token_id,
63
+ bos_token_id=bos_token_id,
64
+ eos_token_id=eos_token_id,
65
+ tie_word_embeddings=tie_word_embeddings,
66
+ **kwargs,
67
+ )
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/delta_net/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
6
+ from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel
7
+
8
+ AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
9
+ AutoModel.register(DeltaNetConfig, DeltaNetModel)
10
+ AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
11
+
12
+ __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (701 Bytes). View file
 
fla/models/delta_net/configuration_delta_net.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DeltaNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'delta_net'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ use_gate: bool = False,
20
+ use_short_conv: bool = True,
21
+ conv_size: int = 4,
22
+ use_beta: bool = True,
23
+ use_output_norm: bool = True,
24
+ num_heads: int = 16,
25
+ qk_norm: str = 'l2',
26
+ qk_activation: str = 'silu',
27
+ max_position_embeddings: int = 2048,
28
+ hidden_ratio: Optional[int] = 4,
29
+ intermediate_size: Optional[int] = None,
30
+ hidden_act: str = "swish",
31
+ num_hidden_layers: int = 24,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.use_gate = use_gate
51
+ self.use_short_conv = use_short_conv
52
+ self.conv_size = conv_size
53
+ self.use_beta = use_beta
54
+ self.use_output_norm = use_output_norm
55
+ self.num_heads = num_heads
56
+ self.qk_norm = qk_norm
57
+ self.qk_activation = qk_activation
58
+ self.max_position_embeddings = max_position_embeddings
59
+
60
+ self.hidden_ratio = hidden_ratio
61
+ self.intermediate_size = intermediate_size
62
+ self.hidden_act = hidden_act
63
+ self.num_hidden_layers = num_hidden_layers
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/delta_net/modeling_delta_net.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.delta_net import DeltaNet
20
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as DeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class DeltaNetBlock(nn.Module):
33
+ def __init__(self, config: DeltaNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = DeltaNet(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ use_gate=config.use_gate,
59
+ use_beta=config.use_beta,
60
+ use_short_conv=config.use_short_conv,
61
+ use_output_norm=config.use_output_norm,
62
+ conv_size=config.conv_size,
63
+ qk_norm=config.qk_norm,
64
+ qk_activation=config.qk_activation,
65
+ norm_eps=config.norm_eps,
66
+ layer_idx=layer_idx
67
+ )
68
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
69
+ self.mlp = DeltaNetMLP(
70
+ hidden_size=config.hidden_size,
71
+ hidden_ratio=config.hidden_ratio,
72
+ intermediate_size=config.intermediate_size,
73
+ hidden_act=config.hidden_act,
74
+ fuse_swiglu=config.fuse_swiglu
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
82
+ use_cache: Optional[bool] = False,
83
+ output_attentions: Optional[bool] = False,
84
+ **kwargs: Unpack[Dict]
85
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
86
+ residual = hidden_states
87
+ hidden_states = self.attn_norm(hidden_states)
88
+ hidden_states, attentions, past_key_values = self.attn(
89
+ hidden_states=hidden_states,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ **kwargs
95
+ )
96
+ if self.config.fuse_norm:
97
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
98
+ else:
99
+ hidden_states = residual + hidden_states
100
+ residual = hidden_states
101
+ hidden_states = self.mlp_norm(hidden_states)
102
+ hidden_states = self.mlp(hidden_states, **kwargs)
103
+ hidden_states = residual + hidden_states
104
+
105
+ outputs = (hidden_states, attentions, past_key_values)
106
+
107
+ return outputs
108
+
109
+
110
+ class DeltaNetPreTrainedModel(PreTrainedModel):
111
+
112
+ config_class = DeltaNetConfig
113
+ base_model_prefix = 'model'
114
+ supports_gradient_checkpointing = True
115
+ _no_split_modules = ['DeltaNetBlock']
116
+ _supports_cache_class = True
117
+
118
+ def __init__(self, *inputs, **kwargs):
119
+ super().__init__(*inputs, **kwargs)
120
+
121
+ def _init_weights(
122
+ self,
123
+ module: nn.Module,
124
+ prenorm_residual_strategy: Optional[str] = 'rescale',
125
+ num_residuals_per_layer: int = 2,
126
+ ):
127
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
128
+ # Slightly different from the TF version which uses truncated_normal for initialization
129
+ # cf https://github.com/pytorch/pytorch/pull/5617
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ if module.bias is not None:
132
+ nn.init.zeros_(module.bias)
133
+ elif isinstance(module, nn.Embedding):
134
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
135
+ elif hasattr(module, 'reset_parameters'):
136
+ module.reset_parameters()
137
+
138
+ if prenorm_residual_strategy is not None:
139
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
140
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
141
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
142
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
143
+ #
144
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
145
+ p = None
146
+ if hasattr(module, 'o_proj'):
147
+ p = module.o_proj.weight
148
+ elif hasattr(module, 'down_proj'):
149
+ p = module.down_proj.weight
150
+ if p is not None:
151
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
152
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
153
+ # We need to reinit p since this code could be called multiple times
154
+ # Having just p *= scale would repeatedly scale it down
155
+ if prenorm_residual_strategy == 'rescale':
156
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
157
+ with torch.no_grad():
158
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
159
+ elif prenorm_residual_strategy == 'zero':
160
+ nn.init.zeros_(p)
161
+ else:
162
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
163
+
164
+
165
+ class DeltaNetModel(DeltaNetPreTrainedModel):
166
+
167
+ def __init__(self, config: DeltaNetConfig):
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None, # noqa
190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
191
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Dict]
197
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
200
+ output_attentions = False
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
202
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ # retrieve input_ids and inputs_embeds
207
+ if input_ids is not None and inputs_embeds is not None:
208
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
209
+ if input_ids is None and inputs_embeds is None:
210
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
211
+
212
+ if inputs_embeds is None:
213
+ inputs_embeds = self.embeddings(input_ids)
214
+ hidden_states = inputs_embeds
215
+
216
+ if use_cache and not isinstance(past_key_values, Cache):
217
+ past_key_values = Cache.from_legacy_cache(past_key_values)
218
+
219
+ if self.gradient_checkpointing and self.training and use_cache:
220
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ for layer in self.layers:
226
+ if output_hidden_states:
227
+ all_hidden_states += (hidden_states,)
228
+
229
+ if self.gradient_checkpointing and self.training:
230
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
231
+ layer.__call__,
232
+ hidden_states,
233
+ attention_mask,
234
+ past_key_values,
235
+ use_cache,
236
+ output_attentions,
237
+ **kwargs
238
+ )
239
+ else:
240
+ hidden_states, attentions, past_key_values = layer(
241
+ hidden_states,
242
+ attention_mask=attention_mask,
243
+ past_key_values=past_key_values,
244
+ use_cache=use_cache,
245
+ output_attentions=output_attentions,
246
+ **kwargs
247
+ )
248
+
249
+ if output_attentions:
250
+ all_attns += (attentions,)
251
+
252
+ hidden_states = self.norm(hidden_states)
253
+
254
+ # add hidden states from the last decoder layer
255
+ if output_hidden_states:
256
+ all_hidden_states += (hidden_states,)
257
+
258
+ if not return_dict:
259
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
260
+ return BaseModelOutputWithPast(
261
+ last_hidden_state=hidden_states,
262
+ past_key_values=past_key_values,
263
+ hidden_states=all_hidden_states,
264
+ attentions=all_attns
265
+ )
266
+
267
+
268
+ class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
269
+
270
+ _tied_weights_keys = ["lm_head.weight"]
271
+
272
+ def __init__(self, config):
273
+ super().__init__(config)
274
+ self.model = DeltaNetModel(config)
275
+ self.vocab_size = config.vocab_size
276
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
277
+ self.criterion = None
278
+
279
+ # Initialize weights and apply final processing
280
+ self.post_init()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.model.embeddings
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.model.embeddings = value
287
+
288
+ def get_output_embeddings(self):
289
+ return self.lm_head
290
+
291
+ def set_output_embeddings(self, new_embeddings):
292
+ self.lm_head = new_embeddings
293
+
294
+ def set_decoder(self, decoder):
295
+ self.model = decoder
296
+
297
+ def get_decoder(self):
298
+ return self.model
299
+
300
+ def generate(self, *args, **kwargs):
301
+ try:
302
+ return super().generate(*args, **kwargs)
303
+ except AttributeError as exception:
304
+ if 'past_key_values' in str(exception):
305
+ raise AttributeError(
306
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
307
+ f"which is not supported for {self.__class__.__name__}. "
308
+ f"Try another generation strategy instead. "
309
+ f"For the available generation strategies, check this doc: "
310
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
311
+ )
312
+ else:
313
+ raise exception
314
+
315
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
316
+ def prepare_inputs_for_generation(
317
+ self,
318
+ input_ids: torch.LongTensor = None,
319
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ inputs_embeds: Optional[torch.Tensor] = None,
322
+ use_cache: bool = True,
323
+ logits_to_keep: Optional[int] = None,
324
+ **kwargs
325
+ ):
326
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
327
+ if past_key_values is not None and len(past_key_values) > 0:
328
+ input_ids = input_ids[:, -1:]
329
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
330
+ if inputs_embeds is not None and len(past_key_values) == 0:
331
+ model_inputs = {'inputs_embeds': inputs_embeds}
332
+ else:
333
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
334
+ # recompiles graphs as the stride of the inputs is a guard.
335
+ # Ref: https://github.com/huggingface/transformers/pull/29114
336
+ # TODO: use `next_tokens` directly instead.
337
+ model_inputs = {'input_ids': input_ids.contiguous()}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'past_key_values': past_key_values,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ })
347
+ return model_inputs
348
+
349
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ inputs_embeds: Optional[torch.Tensor] = None,
355
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
356
+ labels: Optional[torch.LongTensor] = None,
357
+ use_cache: Optional[bool] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
364
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
+ output_hidden_states = (
366
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
+ )
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+
370
+ outputs = self.model(
371
+ input_ids=input_ids,
372
+ attention_mask=attention_mask,
373
+ inputs_embeds=inputs_embeds,
374
+ past_key_values=past_key_values,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ **kwargs
380
+ )
381
+
382
+ hidden_states = outputs[0]
383
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
384
+
385
+ loss, logits = None, None
386
+ if not fuse_linear_and_cross_entropy or labels is None:
387
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
388
+ if labels is not None:
389
+ if getattr(self, 'criterion', None) is None:
390
+ if fuse_linear_and_cross_entropy:
391
+ criterion = FusedLinearCrossEntropyLoss()
392
+ elif self.config.fuse_cross_entropy:
393
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
394
+ else:
395
+ criterion = nn.CrossEntropyLoss()
396
+ else:
397
+ criterion = self.criterion
398
+ labels = labels.to(hidden_states.device)
399
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
400
+ if fuse_linear_and_cross_entropy:
401
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
402
+ else:
403
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
404
+
405
+ if not return_dict:
406
+ output = (logits,) + outputs[1:]
407
+ return (loss,) + output if loss is not None else output
408
+
409
+ return CausalLMOutputWithPast(
410
+ loss=loss,
411
+ logits=logits,
412
+ past_key_values=outputs.past_key_values,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
+ )
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
fla/models/forgetting_transformer/modeling_forgetting_transformer.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.forgetting_attn import ForgettingAttention
19
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as ForgettingTransformerMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ForgettingTransformerBlock(nn.Module):
33
+
34
+ def __init__(self, config: ForgettingTransformerConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ self.attn = ForgettingAttention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.num_heads,
44
+ num_kv_heads=config.num_kv_heads,
45
+ qkv_bias=config.qkv_bias,
46
+ qk_norm=config.qk_norm,
47
+ window_size=config.window_size,
48
+ use_output_gate=config.use_output_gate,
49
+ layer_idx=layer_idx
50
+ )
51
+
52
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
53
+ self.mlp = ForgettingTransformerMLP(
54
+ hidden_size=config.hidden_size,
55
+ hidden_ratio=config.hidden_ratio,
56
+ intermediate_size=config.intermediate_size,
57
+ hidden_act=config.hidden_act,
58
+ fuse_swiglu=config.fuse_swiglu
59
+ )
60
+
61
+ def forward(
62
+ self,
63
+ hidden_states: torch.Tensor,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
66
+ output_attentions: Optional[bool] = False,
67
+ use_cache: Optional[bool] = False,
68
+ **kwargs: Unpack[Any]
69
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
70
+
71
+ residual = hidden_states
72
+ hidden_states = self.attn_norm(hidden_states)
73
+ hidden_states, attentions, past_key_values = self.attn(
74
+ hidden_states=hidden_states,
75
+ attention_mask=attention_mask,
76
+ past_key_values=past_key_values,
77
+ use_cache=use_cache,
78
+ output_attentions=output_attentions,
79
+ **kwargs
80
+ )
81
+ if self.config.fuse_norm:
82
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
83
+ else:
84
+ hidden_states = residual + hidden_states
85
+ residual = hidden_states
86
+ hidden_states = self.mlp_norm(hidden_states)
87
+ hidden_states = self.mlp(hidden_states, **kwargs)
88
+ hidden_states = residual + hidden_states
89
+
90
+ outputs = (hidden_states,)
91
+
92
+ if output_attentions:
93
+ outputs += (attentions,)
94
+
95
+ if use_cache:
96
+ outputs += (past_key_values,)
97
+
98
+ return outputs
99
+
100
+
101
+ class ForgettingTransformerPreTrainedModel(PreTrainedModel):
102
+
103
+ config_class = ForgettingTransformerConfig
104
+ base_model_prefix = 'model'
105
+ supports_gradient_checkpointing = True
106
+ _no_split_modules = ['ForgettingTransformerBlock']
107
+ _supports_cache_class = True
108
+
109
+ def __init__(self, *inputs, **kwargs):
110
+ super().__init__(*inputs, **kwargs)
111
+
112
+ def _init_weights(
113
+ self,
114
+ module: nn.Module,
115
+ rescale_prenorm_residual: bool = False,
116
+ num_residuals_per_layer: int = 2,
117
+ ):
118
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
119
+ # Slightly different from the TF version which uses truncated_normal for initialization
120
+ # cf https://github.com/pytorch/pytorch/pull/5617
121
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
122
+ if module.bias is not None:
123
+ nn.init.zeros_(module.bias)
124
+ elif isinstance(module, nn.Embedding):
125
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
126
+ elif hasattr(module, 'reset_parameters'):
127
+ module.reset_parameters()
128
+
129
+ if rescale_prenorm_residual:
130
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
131
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
132
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
133
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
134
+ #
135
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
136
+ p = None
137
+ if hasattr(module, 'o_proj'):
138
+ p = module.o_proj.weight
139
+ elif hasattr(module, 'down_proj'):
140
+ p = module.down_proj.weight
141
+ if p is not None:
142
+ # Special Scaled Initialization --> There are 2 Layer Norms per ForgettingTransformer Block
143
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
144
+ # We need to reinit p since this code could be called multiple times
145
+ # Having just p *= scale would repeatedly scale it down
146
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
147
+ with torch.no_grad():
148
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
149
+
150
+
151
+ class ForgettingTransformerModel(ForgettingTransformerPreTrainedModel):
152
+
153
+ def __init__(
154
+ self,
155
+ config: ForgettingTransformerConfig
156
+ ) -> ForgettingTransformerModel:
157
+ super().__init__(config)
158
+ self.padding_idx = config.pad_token_id
159
+ self.vocab_size = config.vocab_size
160
+
161
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
162
+ self.layers = nn.ModuleList([
163
+ ForgettingTransformerBlock(config, layer_idx)
164
+ for layer_idx in range(config.num_hidden_layers)
165
+ ])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
183
+ inputs_embeds: Optional[torch.FloatTensor] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None,
188
+ **kwargs: Unpack[Any]
189
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
190
+ if output_attentions:
191
+ warnings.warn(
192
+ "`ForgettingTransformerModel` does not support output attention weights now, "
193
+ "so `output_attentions` is set to `False`."
194
+ )
195
+ output_attentions = False
196
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
197
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
198
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
199
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
200
+
201
+ # retrieve input_ids and inputs_embeds
202
+ if input_ids is not None and inputs_embeds is not None:
203
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
204
+ elif input_ids is None and inputs_embeds is None:
205
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
206
+
207
+ if use_cache and not isinstance(past_key_values, Cache):
208
+ past_key_values = Cache.from_legacy_cache(past_key_values)
209
+
210
+ if inputs_embeds is None:
211
+ inputs_embeds = self.embeddings(input_ids)
212
+
213
+ # embed positions
214
+ hidden_states = inputs_embeds
215
+
216
+ if self.gradient_checkpointing and self.training:
217
+ if use_cache:
218
+ logger.warning_once(
219
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
220
+ )
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ next_cache = None
226
+
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ layer_outputs = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ output_attentions,
238
+ use_cache,
239
+ **kwargs
240
+ )
241
+ else:
242
+ layer_outputs = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ output_attentions=output_attentions,
247
+ use_cache=use_cache,
248
+ **kwargs
249
+ )
250
+
251
+ hidden_states = layer_outputs[0]
252
+
253
+ if use_cache:
254
+ next_cache = layer_outputs[2 if output_attentions else 1]
255
+
256
+ if output_attentions:
257
+ all_attns += (layer_outputs[1],)
258
+
259
+ hidden_states = self.norm(hidden_states)
260
+
261
+ # add hidden states from the last decoder layer
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if not return_dict:
266
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
267
+
268
+ return BaseModelOutputWithPast(
269
+ last_hidden_state=hidden_states,
270
+ past_key_values=next_cache,
271
+ hidden_states=all_hidden_states,
272
+ attentions=all_attns
273
+ )
274
+
275
+
276
+ class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, GenerationMixin):
277
+
278
+ _tied_weights_keys = ["lm_head.weight"]
279
+
280
+ def __init__(self, config):
281
+ super().__init__(config)
282
+ self.model = ForgettingTransformerModel(config)
283
+ self.vocab_size = config.vocab_size
284
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
285
+ self.criterion = None
286
+
287
+ # Initialize weights and apply final processing
288
+ self.post_init()
289
+
290
+ def get_input_embeddings(self):
291
+ return self.model.embeddings
292
+
293
+ def set_input_embeddings(self, value):
294
+ self.model.embeddings = value
295
+
296
+ def get_output_embeddings(self):
297
+ return self.lm_head
298
+
299
+ def set_output_embeddings(self, new_embeddings):
300
+ self.lm_head = new_embeddings
301
+
302
+ def set_decoder(self, decoder):
303
+ self.model = decoder
304
+
305
+ def get_decoder(self):
306
+ return self.model
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
348
+ inputs_embeds: Optional[torch.FloatTensor] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0,
355
+ **kwargs: Unpack[Any]
356
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
357
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
358
+ output_hidden_states = (
359
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
360
+ )
361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
362
+
363
+ outputs = self.model(
364
+ input_ids=input_ids,
365
+ attention_mask=attention_mask,
366
+ past_key_values=past_key_values,
367
+ inputs_embeds=inputs_embeds,
368
+ use_cache=use_cache,
369
+ output_attentions=output_attentions,
370
+ output_hidden_states=output_hidden_states,
371
+ return_dict=return_dict,
372
+ **kwargs
373
+ )
374
+
375
+ hidden_states = outputs[0]
376
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
377
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
378
+
379
+ loss = None
380
+ if labels is not None:
381
+ if getattr(self, 'criterion', None) is None:
382
+ if fuse_linear_and_cross_entropy:
383
+ criterion = FusedLinearCrossEntropyLoss()
384
+ elif self.config.fuse_cross_entropy:
385
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
386
+ else:
387
+ criterion = nn.CrossEntropyLoss()
388
+ else:
389
+ criterion = self.criterion
390
+ # Enable model parallelism
391
+ labels = labels.to(hidden_states.device)
392
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
393
+ if fuse_linear_and_cross_entropy:
394
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
395
+ else:
396
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
397
+
398
+ if not return_dict:
399
+ output = (logits,) + outputs[1:]
400
+ return (loss,) + output if loss is not None else output
401
+
402
+ return CausalLMOutputWithPast(
403
+ loss=loss,
404
+ logits=logits,
405
+ past_key_values=outputs.past_key_values,
406
+ hidden_states=outputs.hidden_states,
407
+ attentions=outputs.attentions,
408
+ )
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
+
3
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel
5
+
6
+ AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig)
7
+ AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel)
8
+ AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM)
9
+
10
+ __all__ = [
11
+ "GatedDeltaProductConfig",
12
+ "GatedDeltaProductForCausalLM",
13
+ "GatedDeltaProductModel",
14
+ ]
fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaProductConfig(PretrainedConfig):
9
+ model_type = "gated_deltaproduct"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_first: bool = False,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int | None = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ use_forget_gate: bool = False, # when true Gated DeltaProduct, when false DeltaProduct
39
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
40
+ num_householder: int = 1,
41
+ **kwargs,
42
+ ):
43
+ self.attn_mode = attn_mode
44
+ self.hidden_size = hidden_size
45
+ self.expand_v = expand_v
46
+ self.use_gate = use_gate
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.head_dim = head_dim
50
+ self.num_heads = num_heads
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_first = norm_first
58
+ self.norm_eps = norm_eps
59
+ self.attn = attn
60
+ self.use_cache = use_cache
61
+ self.initializer_range = initializer_range
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ # DeltaProduct specific
66
+ self.allow_neg_eigval = allow_neg_eigval
67
+ self.num_householder = num_householder
68
+ self.use_forget_gate = use_forget_gate
69
+
70
+ if attn is not None:
71
+ if not isinstance(attn, Dict):
72
+ raise ValueError("attn must be a dictionary")
73
+ if "layers" not in attn:
74
+ raise ValueError(
75
+ "Layer indices must be provided to initialize hybrid attention layers"
76
+ )
77
+ if "num_heads" not in attn:
78
+ raise ValueError(
79
+ "Number of heads must be provided to initialize hybrid attention layers"
80
+ )
81
+ attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
82
+ attn["window_size"] = attn.get("window_size", None)
83
+
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gated_deltaproduct import GatedDeltaProduct
21
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
24
+ from fla.modules.activations import swiglu_linear
25
+ from fla.modules.layernorm import rms_norm_linear
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers.processing_utils import Unpack
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetMLP(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ hidden_ratio: Optional[int] = None,
38
+ intermediate_size: Optional[int] = None,
39
+ hidden_act: str = "swish",
40
+ norm_first: bool = True,
41
+ norm_eps: float = 1e-5,
42
+ ) -> GatedDeltaNetMLP:
43
+ super().__init__()
44
+
45
+ self.hidden_size = hidden_size
46
+ # the final number of params is `hidden_ratio * hidden_size^2`
47
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
48
+ if hidden_ratio is None:
49
+ hidden_ratio = 4
50
+ if intermediate_size is None:
51
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
52
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.norm_first = norm_first
56
+
57
+ if norm_first:
58
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
59
+
60
+ self.gate_proj = nn.Linear(
61
+ self.hidden_size, self.intermediate_size * 2, bias=False
62
+ )
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[hidden_act]
65
+
66
+ def forward(
67
+ self,
68
+ x: torch.Tensor,
69
+ **kwargs: Unpack[Dict],
70
+ ) -> torch.Tensor:
71
+ if self.norm_first:
72
+ x = rms_norm_linear(
73
+ x,
74
+ self.norm.weight,
75
+ self.norm.bias,
76
+ self.gate_proj.weight,
77
+ self.gate_proj.bias,
78
+ )
79
+ else:
80
+ x = self.gate_proj(x)
81
+ gate, y = x.chunk(2, -1)
82
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
83
+
84
+
85
+ class GatedDeltaProductBlock(nn.Module):
86
+ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
87
+ super().__init__()
88
+ self.hidden_size = config.hidden_size
89
+
90
+ if not config.norm_first:
91
+ self.attn_norm = RMSNorm(
92
+ hidden_size=config.hidden_size, eps=config.norm_eps
93
+ )
94
+ if config.attn is not None and layer_idx in config.attn["layers"]:
95
+ self.attn = Attention(
96
+ hidden_size=config.hidden_size,
97
+ num_heads=config.attn["num_heads"],
98
+ num_kv_heads=config.attn["num_kv_heads"],
99
+ window_size=config.attn["window_size"],
100
+ max_position_embeddings=config.max_position_embeddings,
101
+ layer_idx=layer_idx,
102
+ )
103
+ else:
104
+ self.attn = GatedDeltaProduct(
105
+ mode=config.attn_mode,
106
+ hidden_size=config.hidden_size,
107
+ expand_v=config.expand_v,
108
+ head_dim=config.head_dim,
109
+ num_heads=config.num_heads,
110
+ use_gate=config.use_gate,
111
+ use_forget_gate=config.use_forget_gate,
112
+ use_short_conv=config.use_short_conv,
113
+ conv_size=config.conv_size,
114
+ norm_first=config.norm_first,
115
+ norm_eps=config.norm_eps,
116
+ allow_neg_eigval=config.allow_neg_eigval,
117
+ num_householder=config.num_householder,
118
+ layer_idx=layer_idx,
119
+ use_beta_conv=config.use_beta_conv
120
+ )
121
+ if not config.norm_first:
122
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
123
+ self.mlp = GatedDeltaNetMLP(
124
+ hidden_size=config.hidden_size,
125
+ hidden_ratio=config.hidden_ratio,
126
+ intermediate_size=config.intermediate_size,
127
+ hidden_act=config.hidden_act,
128
+ norm_first=config.norm_first,
129
+ norm_eps=config.norm_eps,
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
137
+ use_cache: Optional[bool] = False,
138
+ output_attentions: Optional[bool] = False,
139
+ **kwargs: Unpack[Dict],
140
+ ) -> Tuple[
141
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
142
+ ]:
143
+ residual = hidden_states
144
+ if hasattr(self, "attn_norm"):
145
+ hidden_states = self.attn_norm(hidden_states)
146
+ hidden_states, attentions, past_key_values = self.attn(
147
+ hidden_states=hidden_states,
148
+ attention_mask=attention_mask,
149
+ past_key_values=past_key_values,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ **kwargs,
153
+ )
154
+ if hasattr(self, "mlp_norm"):
155
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
156
+ else:
157
+ hidden_states = residual + hidden_states
158
+ residual = hidden_states
159
+ hidden_states = self.mlp(hidden_states, **kwargs)
160
+ hidden_states = residual + hidden_states
161
+
162
+ outputs = (hidden_states, attentions, past_key_values)
163
+
164
+ return outputs
165
+
166
+
167
+ class GatedDeltaProductPreTrainedModel(PreTrainedModel):
168
+ config_class = GatedDeltaProductConfig
169
+ supports_gradient_checkpointing = True
170
+ _no_split_modules = ["GatedDeltaNetBlock"]
171
+
172
+ def __init__(self, *inputs, **kwargs):
173
+ super().__init__(*inputs, **kwargs)
174
+
175
+ def _init_weights(
176
+ self,
177
+ module: nn.Module,
178
+ rescale_prenorm_residual: bool = True,
179
+ num_residuals_per_layer: int = 2,
180
+ ):
181
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
182
+ # Slightly different from the TF version which uses truncated_normal for initialization
183
+ # cf https://github.com/pytorch/pytorch/pull/5617
184
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
185
+ if module.bias is not None:
186
+ nn.init.zeros_(module.bias)
187
+ elif isinstance(module, nn.Embedding):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
189
+ if module.padding_idx is not None:
190
+ module.weight.data[module.padding_idx].zero_()
191
+
192
+ if rescale_prenorm_residual:
193
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
194
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
195
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
196
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
197
+ #
198
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
199
+ for name, p in module.named_parameters():
200
+ if name in ["o_proj.weight", "down_proj.weight"]:
201
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
202
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
203
+ # We need to reinit p since this code could be called multiple times
204
+ # Having just p *= scale would repeatedly scale it down
205
+ with torch.no_grad():
206
+ p /= math.sqrt(
207
+ num_residuals_per_layer * self.config.num_hidden_layers
208
+ )
209
+
210
+
211
+ class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
212
+ def __init__(self, config: GatedDeltaProductConfig):
213
+ super().__init__(config)
214
+ self.padding_idx = config.pad_token_id
215
+ self.vocab_size = config.vocab_size
216
+
217
+ self.embeddings = nn.Embedding(
218
+ config.vocab_size, config.hidden_size, self.padding_idx
219
+ )
220
+ self.layers = nn.ModuleList(
221
+ [
222
+ GatedDeltaProductBlock(config, layer_idx)
223
+ for layer_idx in range(config.num_hidden_layers)
224
+ ]
225
+ )
226
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ self.post_init()
231
+
232
+ def get_input_embeddings(self):
233
+ return self.embeddings
234
+
235
+ def set_input_embeddings(self, value):
236
+ self.embeddings = value
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: Optional[torch.LongTensor] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ inputs_embeds: Optional[torch.FloatTensor] = None,
243
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
244
+ use_cache: Optional[bool] = None,
245
+ output_attentions: Optional[bool] = None,
246
+ output_hidden_states: Optional[bool] = None,
247
+ return_dict: Optional[bool] = None,
248
+ **kwargs: Unpack[Dict],
249
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
250
+ if output_attentions:
251
+ warnings.warn(
252
+ "`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
253
+ stacklevel=2,
254
+ )
255
+ output_attentions = False
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = (
267
+ use_cache
268
+ if use_cache is not None
269
+ else (self.config.use_cache if not self.training else False)
270
+ )
271
+ return_dict = (
272
+ return_dict if return_dict is not None else self.config.use_return_dict
273
+ )
274
+
275
+ # retrieve input_ids and inputs_embeds
276
+ if input_ids is not None and inputs_embeds is not None:
277
+ raise ValueError(
278
+ "You cannot specify both input_ids and inputs_embeds at the same time"
279
+ )
280
+ if input_ids is None and inputs_embeds is None:
281
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.embeddings(input_ids)
285
+ hidden_states = inputs_embeds
286
+
287
+ if use_cache and not isinstance(past_key_values, Cache):
288
+ past_key_values = Cache.from_legacy_cache(past_key_values)
289
+
290
+ if self.gradient_checkpointing and self.training and use_cache:
291
+ logger.warning_once(
292
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
293
+ )
294
+ use_cache = False
295
+
296
+ all_hidden_states = () if output_hidden_states else None
297
+ all_attns = () if output_attentions else None
298
+ for layer in self.layers:
299
+ if output_hidden_states:
300
+ all_hidden_states += (hidden_states,)
301
+
302
+ if self.gradient_checkpointing and self.training:
303
+ hidden_states, attentions, past_key_values = (
304
+ self._gradient_checkpointing_func(
305
+ layer.__call__,
306
+ hidden_states,
307
+ attention_mask,
308
+ past_key_values,
309
+ use_cache,
310
+ output_attentions,
311
+ **kwargs,
312
+ )
313
+ )
314
+ else:
315
+ hidden_states, attentions, past_key_values = layer(
316
+ hidden_states,
317
+ attention_mask=attention_mask,
318
+ past_key_values=past_key_values,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ **kwargs,
322
+ )
323
+
324
+ if output_attentions:
325
+ all_attns += (attentions,)
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ # add hidden states from the last decoder layer
329
+ if output_hidden_states:
330
+ all_hidden_states += (hidden_states,)
331
+
332
+ if not return_dict:
333
+ return tuple(
334
+ i
335
+ for i in [
336
+ hidden_states,
337
+ past_key_values,
338
+ all_hidden_states,
339
+ all_attns,
340
+ ]
341
+ if i is not None
342
+ )
343
+ return BaseModelOutputWithPast(
344
+ last_hidden_state=hidden_states,
345
+ past_key_values=past_key_values,
346
+ hidden_states=all_hidden_states,
347
+ attentions=all_attns,
348
+ )
349
+
350
+
351
+ class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
352
+ _tied_weights_keys = ["lm_head.weight"]
353
+
354
+ def __init__(self, config):
355
+ super().__init__(config)
356
+ self.model = GatedDeltaProductModel(config)
357
+ self.vocab_size = config.vocab_size
358
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
359
+
360
+ # Initialize weights and apply final processing
361
+ self.post_init()
362
+
363
+ def get_input_embeddings(self):
364
+ return self.model.embeddings
365
+
366
+ def set_input_embeddings(self, value):
367
+ self.model.embeddings = value
368
+
369
+ def get_output_embeddings(self):
370
+ return self.lm_head
371
+
372
+ def set_output_embeddings(self, new_embeddings):
373
+ self.lm_head = new_embeddings
374
+
375
+ def set_decoder(self, decoder):
376
+ self.model = decoder
377
+
378
+ def get_decoder(self):
379
+ return self.model
380
+
381
+ def generate(self, *args, **kwargs):
382
+ try:
383
+ return super().generate(*args, **kwargs)
384
+ except AttributeError as exception:
385
+ if "past_key_values" in str(exception):
386
+ raise AttributeError(
387
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
388
+ f"which is not supported for {self.__class__.__name__}. "
389
+ f"Try another generation strategy instead. "
390
+ f"For the available generation strategies, check this doc: "
391
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
392
+ )
393
+ else:
394
+ raise exception
395
+
396
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
397
+ def prepare_inputs_for_generation(
398
+ self,
399
+ input_ids: torch.LongTensor = None,
400
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ use_cache: bool = True,
404
+ num_logits_to_keep: Optional[int] = None,
405
+ logits_to_keep: Optional[int] = None,
406
+ **kwargs,
407
+ ):
408
+ # only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
409
+ if past_key_values is not None and len(past_key_values) > 0:
410
+ input_ids = input_ids[:, -1:]
411
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
412
+ if inputs_embeds is not None and past_key_values is None:
413
+ model_inputs = {"inputs_embeds": inputs_embeds}
414
+ else:
415
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
416
+ # recompiles graphs as the stride of the inputs is a guard.
417
+ # Ref: https://github.com/huggingface/transformers/pull/29114
418
+ # TODO: use `next_tokens` directly instead.
419
+ model_inputs = {"input_ids": input_ids.contiguous()}
420
+
421
+ if logits_to_keep is not None:
422
+ model_inputs['logits_to_keep'] = logits_to_keep
423
+
424
+ model_inputs.update(
425
+ {
426
+ "past_key_values": past_key_values,
427
+ "use_cache": use_cache,
428
+ "attention_mask": attention_mask,
429
+ "num_logits_to_keep": num_logits_to_keep,
430
+ }
431
+ )
432
+ return model_inputs
433
+
434
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
435
+ def forward(
436
+ self,
437
+ input_ids: torch.LongTensor = None,
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ inputs_embeds: Optional[torch.Tensor] = None,
440
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ num_logits_to_keep: Optional[int] = 0,
447
+ logits_to_keep: Optional[int] = 0,
448
+ **kwargs: Unpack[Dict],
449
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
450
+ num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
451
+ output_attentions = (
452
+ output_attentions
453
+ if output_attentions is not None
454
+ else self.config.output_attentions
455
+ )
456
+ output_hidden_states = (
457
+ output_hidden_states
458
+ if output_hidden_states is not None
459
+ else self.config.output_hidden_states
460
+ )
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ kwargs.pop("num_items_in_batch", None)
465
+ outputs = self.model(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ inputs_embeds=inputs_embeds,
469
+ past_key_values=past_key_values,
470
+ use_cache=use_cache,
471
+ output_attentions=output_attentions,
472
+ output_hidden_states=output_hidden_states,
473
+ return_dict=return_dict,
474
+ **kwargs,
475
+ )
476
+ hidden_states = outputs[0]
477
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
478
+
479
+ loss, logits = None, None
480
+ if not fuse_linear_and_cross_entropy or labels is None:
481
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
482
+ if labels is not None:
483
+ if self.config.fuse_cross_entropy:
484
+ if fuse_linear_and_cross_entropy:
485
+ loss_fct = FusedLinearCrossEntropyLoss()
486
+ else:
487
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
488
+ else:
489
+ loss_fct = nn.CrossEntropyLoss()
490
+ # Enable model parallelism
491
+ labels = labels.to(hidden_states.device)
492
+ labels = torch.cat(
493
+ (
494
+ labels[..., 1:],
495
+ torch.full_like(labels[:, :1], loss_fct.ignore_index),
496
+ ),
497
+ 1,
498
+ )
499
+ if fuse_linear_and_cross_entropy:
500
+ loss = loss_fct(
501
+ hidden_states.view(-1, self.config.hidden_size),
502
+ labels.view(-1),
503
+ self.lm_head.weight,
504
+ self.lm_head.bias,
505
+ )
506
+ else:
507
+ loss = loss_fct(
508
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
509
+ )
510
+
511
+ if not return_dict:
512
+ output = (logits,) + outputs[1:]
513
+ return (loss, *output) if loss is not None else output
514
+ return CausalLMOutputWithPast(
515
+ loss=loss,
516
+ logits=logits,
517
+ past_key_values=outputs.past_key_values,
518
+ hidden_states=outputs.hidden_states,
519
+ attentions=outputs.attentions,
520
+ )
fla/models/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/gla/modeling_gla.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gla import GatedLinearAttention
20
+ from fla.models.gla.configuration_gla import GLAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GLAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GLABlock(nn.Module):
33
+ def __init__(self, config: GLAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedLinearAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ use_output_gate=config.use_output_gate,
63
+ gate_fn=config.hidden_act,
64
+ elementwise_affine=config.elementwise_affine,
65
+ norm_eps=config.norm_eps,
66
+ clamp_min=config.clamp_min,
67
+ fuse_norm=config.fuse_norm,
68
+ layer_idx=layer_idx
69
+ )
70
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
71
+ self.mlp = GLAMLP(
72
+ hidden_size=config.hidden_size,
73
+ hidden_ratio=config.hidden_ratio,
74
+ intermediate_size=config.intermediate_size,
75
+ hidden_act=config.hidden_act,
76
+ fuse_swiglu=config.fuse_swiglu
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
84
+ use_cache: Optional[bool] = False,
85
+ output_attentions: Optional[bool] = False,
86
+ **kwargs: Unpack[Dict]
87
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
88
+ residual = hidden_states
89
+ hidden_states = self.attn_norm(hidden_states)
90
+ hidden_states, attentions, past_key_values = self.attn(
91
+ hidden_states=hidden_states,
92
+ attention_mask=attention_mask,
93
+ past_key_values=past_key_values,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ **kwargs
97
+ )
98
+ if self.config.fuse_norm:
99
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
100
+ else:
101
+ hidden_states = residual + hidden_states
102
+ residual = hidden_states
103
+ hidden_states = self.mlp_norm(hidden_states)
104
+ hidden_states = self.mlp(hidden_states, **kwargs)
105
+ hidden_states = residual + hidden_states
106
+
107
+ outputs = (hidden_states, attentions, past_key_values)
108
+
109
+ return outputs
110
+
111
+
112
+ class GLAPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = GLAConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['GLABlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ prenorm_residual_strategy: Optional[str] = 'rescale',
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if prenorm_residual_strategy is not None:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ if prenorm_residual_strategy == 'rescale':
158
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
159
+ with torch.no_grad():
160
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
161
+ elif prenorm_residual_strategy == 'zero':
162
+ nn.init.zeros_(p)
163
+ else:
164
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
165
+
166
+
167
+ class GLAModel(GLAPreTrainedModel):
168
+
169
+ def __init__(self, config: GLAConfig):
170
+ super().__init__(config)
171
+ self.padding_idx = config.pad_token_id
172
+ self.vocab_size = config.vocab_size
173
+
174
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
175
+ self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
176
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
177
+
178
+ self.gradient_checkpointing = False
179
+
180
+ self.post_init()
181
+
182
+ def get_input_embeddings(self):
183
+ return self.embeddings
184
+
185
+ def set_input_embeddings(self, value):
186
+ self.embeddings = value
187
+
188
+ def forward(
189
+ self,
190
+ input_ids: Optional[torch.LongTensor] = None,
191
+ attention_mask: Optional[torch.Tensor] = None, # noqa
192
+ inputs_embeds: Optional[torch.FloatTensor] = None,
193
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
194
+ use_cache: Optional[bool] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ return_dict: Optional[bool] = None,
198
+ **kwargs: Unpack[Dict]
199
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
200
+ if output_attentions:
201
+ warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ if input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embeddings(input_ids)
216
+ hidden_states = inputs_embeds
217
+
218
+ if use_cache and not isinstance(past_key_values, Cache):
219
+ past_key_values = Cache.from_legacy_cache(past_key_values)
220
+
221
+ if self.gradient_checkpointing and self.training and use_cache:
222
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
223
+ use_cache = False
224
+
225
+ all_hidden_states = () if output_hidden_states else None
226
+ all_attns = () if output_attentions else None
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ use_cache,
238
+ output_attentions,
239
+ **kwargs
240
+ )
241
+ else:
242
+ hidden_states, attentions, past_key_values = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ use_cache=use_cache,
247
+ output_attentions=output_attentions,
248
+ **kwargs
249
+ )
250
+
251
+ if output_attentions:
252
+ all_attns += (attentions,)
253
+
254
+ hidden_states = self.norm(hidden_states)
255
+
256
+ # add hidden states from the last decoder layer
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ if not return_dict:
261
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
262
+ return BaseModelOutputWithPast(
263
+ last_hidden_state=hidden_states,
264
+ past_key_values=past_key_values,
265
+ hidden_states=all_hidden_states,
266
+ attentions=all_attns
267
+ )
268
+
269
+
270
+ class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
271
+
272
+ _tied_weights_keys = ["lm_head.weight"]
273
+
274
+ def __init__(self, config):
275
+ super().__init__(config)
276
+ self.model = GLAModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+ self.criterion = None
280
+
281
+ # Initialize weights and apply final processing
282
+ self.post_init()
283
+
284
+ def get_input_embeddings(self):
285
+ return self.model.embeddings
286
+
287
+ def set_input_embeddings(self, value):
288
+ self.model.embeddings = value
289
+
290
+ def get_output_embeddings(self):
291
+ return self.lm_head
292
+
293
+ def set_output_embeddings(self, new_embeddings):
294
+ self.lm_head = new_embeddings
295
+
296
+ def set_decoder(self, decoder):
297
+ self.model = decoder
298
+
299
+ def get_decoder(self):
300
+ return self.model
301
+
302
+ def generate(self, *args, **kwargs):
303
+ try:
304
+ return super().generate(*args, **kwargs)
305
+ except AttributeError as exception:
306
+ if 'past_key_values' in str(exception):
307
+ raise AttributeError(
308
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
309
+ f"which is not supported for {self.__class__.__name__}. "
310
+ f"Try another generation strategy instead. "
311
+ f"For the available generation strategies, check this doc: "
312
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
313
+ )
314
+ else:
315
+ raise exception
316
+
317
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
318
+ def prepare_inputs_for_generation(
319
+ self,
320
+ input_ids: torch.LongTensor = None,
321
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ inputs_embeds: Optional[torch.Tensor] = None,
324
+ use_cache: bool = True,
325
+ logits_to_keep: Optional[int] = None,
326
+ **kwargs
327
+ ):
328
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
329
+ if past_key_values is not None and len(past_key_values) > 0:
330
+ input_ids = input_ids[:, -1:]
331
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
332
+ if inputs_embeds is not None and len(past_key_values) == 0:
333
+ model_inputs = {'inputs_embeds': inputs_embeds}
334
+ else:
335
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
336
+ # recompiles graphs as the stride of the inputs is a guard.
337
+ # Ref: https://github.com/huggingface/transformers/pull/29114
338
+ # TODO: use `next_tokens` directly instead.
339
+ model_inputs = {'input_ids': input_ids.contiguous()}
340
+
341
+ if logits_to_keep is not None:
342
+ model_inputs['logits_to_keep'] = logits_to_keep
343
+
344
+ model_inputs.update({
345
+ 'past_key_values': past_key_values,
346
+ 'use_cache': use_cache,
347
+ 'attention_mask': attention_mask,
348
+ })
349
+ return model_inputs
350
+
351
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
352
+ def forward(
353
+ self,
354
+ input_ids: torch.LongTensor = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
358
+ labels: Optional[torch.LongTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ logits_to_keep: Optional[int] = 0,
364
+ **kwargs: Unpack[Dict]
365
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
366
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
367
+ output_hidden_states = (
368
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
369
+ )
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.model(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ inputs_embeds=inputs_embeds,
376
+ past_key_values=past_key_values,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ **kwargs
382
+ )
383
+
384
+ hidden_states = outputs[0]
385
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
386
+
387
+ loss, logits = None, None
388
+ if not fuse_linear_and_cross_entropy or labels is None:
389
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
390
+ if labels is not None:
391
+ if getattr(self, 'criterion', None) is None:
392
+ if fuse_linear_and_cross_entropy:
393
+ criterion = FusedLinearCrossEntropyLoss()
394
+ elif self.config.fuse_cross_entropy:
395
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
396
+ else:
397
+ criterion = nn.CrossEntropyLoss()
398
+ else:
399
+ criterion = self.criterion
400
+ labels = labels.to(hidden_states.device)
401
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
402
+ if fuse_linear_and_cross_entropy:
403
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
404
+ else:
405
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
406
+
407
+ if not return_dict:
408
+ output = (logits,) + outputs[1:]
409
+ return (loss,) + output if loss is not None else output
410
+
411
+ return CausalLMOutputWithPast(
412
+ loss=loss,
413
+ logits=logits,
414
+ past_key_values=outputs.past_key_values,
415
+ hidden_states=outputs.hidden_states,
416
+ attentions=outputs.attentions,
417
+ )
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gsa/configuration_gsa.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_logit_normalizer: Optional[int] = 8,
17
+ clamp_min: Optional[float] = None,
18
+ clamp_max: Optional[float] = None,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ num_slots: Optional[int] = 64,
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ exapnd_k: float = 1,
28
+ exapnd_v: float = 1,
29
+ feature_map: str = 'swish',
30
+ use_output_gate: bool = False,
31
+ use_norm: bool = True,
32
+ max_position_embeddings: int = 2048,
33
+ hidden_act: str = "swish",
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-6,
36
+ attn: Optional[Dict] = None,
37
+ use_cache: bool = True,
38
+ pad_token_id: int = None,
39
+ bos_token_id: int = 1,
40
+ eos_token_id: int = 2,
41
+ initializer_range: float = 0.006,
42
+ tie_word_embeddings: bool = False,
43
+ fuse_norm: bool = True,
44
+ fuse_swiglu: bool = True,
45
+ fuse_cross_entropy: bool = True,
46
+ vocab_size: int = 32000,
47
+ **kwargs
48
+ ):
49
+ self.hidden_size = hidden_size
50
+ self.gate_logit_normalizer = gate_logit_normalizer
51
+ self.clamp_min = clamp_min
52
+ self.clamp_max = clamp_max
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_slots = num_slots
59
+ self.use_short_conv = use_short_conv
60
+ self.conv_size = conv_size
61
+ self.expand_k = exapnd_k
62
+ self.expand_v = exapnd_v
63
+ self.feature_map = feature_map
64
+ self.use_output_gate = use_output_gate
65
+ self.use_norm = use_norm
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.elementwise_affine = elementwise_affine
69
+ self.norm_eps = norm_eps
70
+ self.attn = attn
71
+ self.use_cache = use_cache
72
+ self.initializer_range = initializer_range
73
+
74
+ self.fuse_norm = fuse_norm
75
+ self.fuse_swiglu = fuse_swiglu
76
+ self.fuse_cross_entropy = fuse_cross_entropy
77
+ self.vocab_size = vocab_size
78
+
79
+ if attn is not None:
80
+ if not isinstance(attn, Dict):
81
+ raise ValueError("attn must be a dictionary")
82
+ if 'layers' not in attn:
83
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
84
+ if 'num_heads' not in attn:
85
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
86
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
87
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
88
+ attn['window_size'] = attn.get('window_size', None)
89
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
90
+
91
+ super().__init__(
92
+ pad_token_id=pad_token_id,
93
+ bos_token_id=bos_token_id,
94
+ eos_token_id=eos_token_id,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs,
97
+ )
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gsa import GatedSlotAttention
20
+ from fla.models.gsa.configuration_gsa import GSAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GSAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GSABlock(nn.Module):
33
+ def __init__(self, config: GSAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedSlotAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_kv_heads=config.num_kv_heads,
58
+ num_slots=config.num_slots,
59
+ use_short_conv=config.use_short_conv,
60
+ conv_size=config.conv_size,
61
+ feature_map=config.feature_map,
62
+ use_output_gate=config.use_output_gate,
63
+ use_norm=config.use_norm,
64
+ gate_fn=config.hidden_act,
65
+ gate_logit_normalizer=config.gate_logit_normalizer,
66
+ elementwise_affine=config.elementwise_affine,
67
+ norm_eps=config.norm_eps,
68
+ fuse_norm=config.fuse_norm,
69
+ layer_idx=layer_idx
70
+ )
71
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
72
+ self.mlp = GSAMLP(
73
+ hidden_size=config.hidden_size,
74
+ hidden_ratio=config.hidden_ratio,
75
+ intermediate_size=config.intermediate_size,
76
+ hidden_act=config.hidden_act,
77
+ fuse_swiglu=config.fuse_swiglu
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
85
+ use_cache: Optional[bool] = False,
86
+ output_attentions: Optional[bool] = False,
87
+ **kwargs: Unpack[Dict]
88
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
89
+ residual = hidden_states
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states, **kwargs)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class GSAPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = GSAConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['GSABlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class GSAModel(GSAPreTrainedModel):
169
+
170
+ def __init__(self, config: GSAConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask=attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+
277
+ super().__init__(config)
278
+ self.model = GSAModel(config)
279
+ self.vocab_size = config.vocab_size
280
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
281
+ self.criterion = None
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def get_input_embeddings(self):
287
+ return self.model.embeddings
288
+
289
+ def set_input_embeddings(self, value):
290
+ self.model.embeddings = value
291
+
292
+ def get_output_embeddings(self):
293
+ return self.lm_head
294
+
295
+ def set_output_embeddings(self, new_embeddings):
296
+ self.lm_head = new_embeddings
297
+
298
+ def set_decoder(self, decoder):
299
+ self.model = decoder
300
+
301
+ def get_decoder(self):
302
+ return self.model
303
+
304
+ def generate(self, *args, **kwargs):
305
+ try:
306
+ return super().generate(*args, **kwargs)
307
+ except AttributeError as exception:
308
+ if 'past_key_values' in str(exception):
309
+ raise AttributeError(
310
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
311
+ f"which is not supported for {self.__class__.__name__}. "
312
+ f"Try another generation strategy instead. "
313
+ f"For the available generation strategies, check this doc: "
314
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
315
+ )
316
+ else:
317
+ raise exception
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Dict]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ past_key_values=past_key_values,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+
389
+ loss, logits = None, None
390
+ if not fuse_linear_and_cross_entropy or labels is None:
391
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
392
+ if labels is not None:
393
+ if getattr(self, 'criterion', None) is None:
394
+ if fuse_linear_and_cross_entropy:
395
+ criterion = FusedLinearCrossEntropyLoss()
396
+ elif self.config.fuse_cross_entropy:
397
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
398
+ else:
399
+ criterion = nn.CrossEntropyLoss()
400
+ else:
401
+ criterion = self.criterion
402
+ # Enable model parallelism
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (665 Bytes). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
fla/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_recurrent",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.expand_ratio = expand_ratio
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.use_lower_bound = use_lower_bound
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.elementwise_affine = elementwise_affine
52
+ self.attn = attn
53
+ self.norm_eps = norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.use_cache = use_cache
56
+ self.initializer_range = initializer_range
57
+
58
+ self.fuse_norm = fuse_norm
59
+ self.fuse_swiglu = fuse_swiglu
60
+ self.fuse_cross_entropy = fuse_cross_entropy
61
+ self.vocab_size = vocab_size
62
+
63
+ if attn is not None:
64
+ if not isinstance(attn, Dict):
65
+ raise ValueError("attn must be a dictionary")
66
+ if 'layers' not in attn:
67
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
68
+ if 'num_heads' not in attn:
69
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
70
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
71
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
72
+ attn['window_size'] = attn.get('window_size', None)
73
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
fla/models/hgrn2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
6
+ from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
7
+
8
+ AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
9
+ AutoModel.register(HGRN2Config, HGRN2Model)
10
+ AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']
fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (674 Bytes). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/hgrn2/configuration_hgrn2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRN2Config(PretrainedConfig):
9
+
10
+ model_type = 'hgrn2'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ use_lower_bound: bool = True,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ max_position_embeddings: int = 2048,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.use_lower_bound = use_lower_bound
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.hidden_ratio = hidden_ratio
60
+ self.intermediate_size = intermediate_size
61
+ self.hidden_act = hidden_act
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/hgrn2/modeling_hgrn2.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn2 import HGRN2Attention
20
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRN2MLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRN2Block(nn.Module):
33
+ def __init__(self, config: HGRN2Config, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRN2Attention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ num_heads=config.num_heads,
56
+ expand_ratio=config.expand_ratio,
57
+ use_short_conv=config.use_short_conv,
58
+ conv_size=config.conv_size,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = HGRN2MLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ lower_bound: Optional[torch.Tensor] = False,
80
+ **kwargs: Unpack[Dict]
81
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
82
+ residual = hidden_states
83
+ hidden_states = self.attn_norm(hidden_states)
84
+ hidden_states, attentions, past_key_values = self.attn(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ past_key_values=past_key_values,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ lower_bound=lower_bound,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class HGRN2PreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = HGRN2Config
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['HGRN2Block']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class HGRN2Model(HGRN2PreTrainedModel):
163
+
164
+ def __init__(self, config: HGRN2Config):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ if config.use_lower_bound:
171
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
172
+ self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
173
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
174
+
175
+ self.gradient_checkpointing = False
176
+
177
+ self.post_init()
178
+
179
+ def get_input_embeddings(self):
180
+ return self.embeddings
181
+
182
+ def set_input_embeddings(self, value):
183
+ self.embeddings = value
184
+
185
+ def forward(
186
+ self,
187
+ input_ids: Optional[torch.LongTensor] = None,
188
+ attention_mask: Optional[torch.Tensor] = None, # noqa
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
191
+ use_cache: Optional[bool] = None,
192
+ output_attentions: Optional[bool] = None,
193
+ output_hidden_states: Optional[bool] = None,
194
+ return_dict: Optional[bool] = None,
195
+ **kwargs: Unpack[Dict]
196
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
197
+ if output_attentions:
198
+ warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.")
199
+ output_attentions = False
200
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
201
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
+
205
+ # retrieve input_ids and inputs_embeds
206
+ if input_ids is not None and inputs_embeds is not None:
207
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
208
+ if input_ids is None and inputs_embeds is None:
209
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
210
+
211
+ if inputs_embeds is None:
212
+ inputs_embeds = self.embeddings(input_ids)
213
+ hidden_states = inputs_embeds
214
+
215
+ if use_cache and not isinstance(past_key_values, Cache):
216
+ past_key_values = Cache.from_legacy_cache(past_key_values)
217
+
218
+ if self.gradient_checkpointing and self.training and use_cache:
219
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
220
+ use_cache = False
221
+
222
+ all_hidden_states = () if output_hidden_states else None
223
+ all_attns = () if output_attentions else None
224
+
225
+ if self.config.use_lower_bound:
226
+ lower_bounds = self.lower_bounds.softmax(0)
227
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
228
+ for i, layer in enumerate(self.layers):
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
233
+ if self.gradient_checkpointing and self.training:
234
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
235
+ layer.__call__,
236
+ hidden_states,
237
+ attention_mask,
238
+ past_key_values,
239
+ use_cache,
240
+ output_attentions,
241
+ lower_bound,
242
+ **kwargs
243
+ )
244
+ else:
245
+ hidden_states, attentions, past_key_values = layer(
246
+ hidden_states,
247
+ attention_mask=attention_mask,
248
+ past_key_values=past_key_values,
249
+ use_cache=use_cache,
250
+ output_attentions=output_attentions,
251
+ lower_bound=lower_bound,
252
+ **kwargs
253
+ )
254
+
255
+ if output_attentions:
256
+ all_attns += (attentions,)
257
+
258
+ hidden_states = self.norm(hidden_states)
259
+
260
+ # add hidden states from the last decoder layer
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ if not return_dict:
265
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
266
+ return BaseModelOutputWithPast(
267
+ last_hidden_state=hidden_states,
268
+ past_key_values=past_key_values,
269
+ hidden_states=all_hidden_states,
270
+ attentions=all_attns
271
+ )
272
+
273
+
274
+ class HGRN2ForCausalLM(HGRN2PreTrainedModel, GenerationMixin):
275
+
276
+ _tied_weights_keys = ["lm_head.weight"]
277
+
278
+ def __init__(self, config):
279
+ super().__init__(config)
280
+ self.model = HGRN2Model(config)
281
+ self.vocab_size = config.vocab_size
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+ self.criterion = None
284
+
285
+ # Initialize weights and apply final processing
286
+ self.post_init()
287
+
288
+ def get_input_embeddings(self):
289
+ return self.model.embeddings
290
+
291
+ def set_input_embeddings(self, value):
292
+ self.model.embeddings = value
293
+
294
+ def get_output_embeddings(self):
295
+ return self.lm_head
296
+
297
+ def set_output_embeddings(self, new_embeddings):
298
+ self.lm_head = new_embeddings
299
+
300
+ def set_decoder(self, decoder):
301
+ self.model = decoder
302
+
303
+ def get_decoder(self):
304
+ return self.model
305
+
306
+ def generate(self, *args, **kwargs):
307
+ try:
308
+ return super().generate(*args, **kwargs)
309
+ except AttributeError as exception:
310
+ if 'past_key_values' in str(exception):
311
+ raise AttributeError(
312
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
313
+ f"which is not supported for {self.__class__.__name__}. "
314
+ f"Try another generation strategy instead. "
315
+ f"For the available generation strategies, check this doc: "
316
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
317
+ )
318
+ else:
319
+ raise exception
320
+
321
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
322
+ def prepare_inputs_for_generation(
323
+ self,
324
+ input_ids: torch.LongTensor = None,
325
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+ inputs_embeds: Optional[torch.Tensor] = None,
328
+ use_cache: bool = True,
329
+ logits_to_keep: Optional[int] = None,
330
+ **kwargs: Unpack[Dict]
331
+ ):
332
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
333
+ if past_key_values is not None and len(past_key_values) > 0:
334
+ input_ids = input_ids[:, -1:]
335
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
336
+ if inputs_embeds is not None and len(past_key_values) == 0:
337
+ model_inputs = {'inputs_embeds': inputs_embeds}
338
+ else:
339
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
340
+ # recompiles graphs as the stride of the inputs is a guard.
341
+ # Ref: https://github.com/huggingface/transformers/pull/29114
342
+ # TODO: use `next_tokens` directly instead.
343
+ model_inputs = {'input_ids': input_ids.contiguous()}
344
+
345
+ if logits_to_keep is not None:
346
+ model_inputs['logits_to_keep'] = logits_to_keep
347
+
348
+ model_inputs.update({
349
+ 'past_key_values': past_key_values,
350
+ 'use_cache': use_cache,
351
+ 'attention_mask': attention_mask,
352
+ })
353
+ return model_inputs
354
+
355
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
356
+ def forward(
357
+ self,
358
+ input_ids: torch.LongTensor = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ inputs_embeds: Optional[torch.Tensor] = None,
361
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
362
+ labels: Optional[torch.LongTensor] = None,
363
+ use_cache: Optional[bool] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ logits_to_keep: Optional[int] = 0,
368
+ **kwargs: Unpack[Dict]
369
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+
376
+ outputs = self.model(
377
+ input_ids=input_ids,
378
+ attention_mask=attention_mask,
379
+ inputs_embeds=inputs_embeds,
380
+ past_key_values=past_key_values,
381
+ use_cache=use_cache,
382
+ output_attentions=output_attentions,
383
+ output_hidden_states=output_hidden_states,
384
+ return_dict=return_dict,
385
+ **kwargs
386
+ )
387
+
388
+ hidden_states = outputs[0]
389
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
390
+
391
+ loss, logits = None, None
392
+ if not fuse_linear_and_cross_entropy or labels is None:
393
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
394
+ if labels is not None:
395
+ if getattr(self, 'criterion', None) is None:
396
+ if fuse_linear_and_cross_entropy:
397
+ criterion = FusedLinearCrossEntropyLoss()
398
+ elif self.config.fuse_cross_entropy:
399
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
400
+ else:
401
+ criterion = nn.CrossEntropyLoss()
402
+ else:
403
+ criterion = self.criterion
404
+ labels = labels.to(hidden_states.device)
405
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
406
+ if fuse_linear_and_cross_entropy:
407
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
408
+ else:
409
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
410
+
411
+ if not return_dict:
412
+ output = (logits,) + outputs[1:]
413
+ return (loss,) + output if loss is not None else output
414
+
415
+ return CausalLMOutputWithPast(
416
+ loss=loss,
417
+ logits=logits,
418
+ past_key_values=outputs.past_key_values,
419
+ hidden_states=outputs.hidden_states,
420
+ attentions=outputs.attentions,
421
+ )
fla/models/lightnet/configuration_lightnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LightNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'lightnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ max_position_embeddings: int = 2048,
26
+ gate_low_rank_dim: int = 128,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+ self.num_heads = num_heads
46
+ self.expand_ratio = expand_ratio
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.gate_low_rank_dim = gate_low_rank_dim
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.elementwise_affine = elementwise_affine
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
fla/models/linear_attn/configuration_linear_attn.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LinearAttentionConfig(PretrainedConfig):
9
+
10
+ model_type = 'linear_attn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ feature_map: str = "elementwise_product",
25
+ tie_feature_map_qk: bool = False,
26
+ norm_q: bool = False,
27
+ norm_k: bool = False,
28
+ norm_feature_map: bool = False,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.feature_map = feature_map
56
+ self.tie_feature_map_qk = tie_feature_map_qk
57
+ self.norm_q = norm_q
58
+ self.norm_k = norm_k
59
+ self.norm_feature_map = norm_feature_map
60
+ self.hidden_act = hidden_act
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/linear_attn/modeling_linear_attn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.linear_attn import LinearAttention
20
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LinearAttentionMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class LinearAttentionBlock(nn.Module):
30
+ def __init__(self, config: LinearAttentionConfig, layer_idx: int):
31
+ super().__init__()
32
+
33
+ self.config = config
34
+ self.layer_idx = layer_idx
35
+
36
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
37
+ if config.attn is not None and layer_idx in config.attn['layers']:
38
+ self.attn = Attention(
39
+ hidden_size=config.hidden_size,
40
+ num_heads=config.attn['num_heads'],
41
+ num_kv_heads=config.attn['num_kv_heads'],
42
+ qkv_bias=config.attn['qkv_bias'],
43
+ window_size=config.attn['window_size'],
44
+ rope_theta=config.attn['rope_theta'],
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ layer_idx=layer_idx
47
+ )
48
+ else:
49
+ self.attn = LinearAttention(
50
+ mode=config.attn_mode,
51
+ hidden_size=config.hidden_size,
52
+ expand_k=config.expand_k,
53
+ expand_v=config.expand_v,
54
+ num_heads=config.num_heads,
55
+ num_kv_heads=config.num_kv_heads,
56
+ feature_map=config.feature_map,
57
+ tie_feature_map_qk=config.tie_feature_map_qk,
58
+ norm_q=config.norm_q,
59
+ norm_k=config.norm_k,
60
+ do_feature_map_norm=config.norm_feature_map,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = LinearAttentionMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs,
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ # currently not supported
85
+ attentions, past_key_values = None, None
86
+ hidden_states = self.attn_norm(hidden_states)
87
+ hidden_states = self.attn(hidden_states=hidden_states, **kwargs)
88
+ if self.config.fuse_norm:
89
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
90
+ else:
91
+ hidden_states = residual + hidden_states
92
+ residual = hidden_states
93
+ hidden_states = self.mlp_norm(hidden_states)
94
+ hidden_states = self.mlp(hidden_states, **kwargs)
95
+ hidden_states = residual + hidden_states
96
+
97
+ outputs = (hidden_states, attentions, past_key_values)
98
+
99
+ return outputs
100
+
101
+
102
+ class LinearAttentionPreTrainedModel(PreTrainedModel):
103
+
104
+ config_class = LinearAttentionConfig
105
+ base_model_prefix = 'model'
106
+ supports_gradient_checkpointing = True
107
+ _no_split_modules = ['LinearAttentionBlock']
108
+ _supports_cache_class = True
109
+
110
+ def __init__(self, *inputs, **kwargs):
111
+ super().__init__(*inputs, **kwargs)
112
+
113
+ def _init_weights(
114
+ self,
115
+ module: nn.Module,
116
+ prenorm_residual_strategy: Optional[str] = 'rescale',
117
+ num_residuals_per_layer: int = 2,
118
+ ):
119
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
120
+ # Slightly different from the TF version which uses truncated_normal for initialization
121
+ # cf https://github.com/pytorch/pytorch/pull/5617
122
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ elif isinstance(module, nn.Embedding):
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ elif hasattr(module, 'reset_parameters'):
128
+ module.reset_parameters()
129
+
130
+ if prenorm_residual_strategy is not None:
131
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
132
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
133
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
134
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
135
+ #
136
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
137
+ p = None
138
+ if hasattr(module, 'o_proj'):
139
+ p = module.o_proj.weight
140
+ elif hasattr(module, 'down_proj'):
141
+ p = module.down_proj.weight
142
+ if p is not None:
143
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
144
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
145
+ # We need to reinit p since this code could be called multiple times
146
+ # Having just p *= scale would repeatedly scale it down
147
+ if prenorm_residual_strategy == 'rescale':
148
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
149
+ with torch.no_grad():
150
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
151
+ elif prenorm_residual_strategy == 'zero':
152
+ nn.init.zeros_(p)
153
+ else:
154
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
155
+
156
+
157
+ class LinearAttentionModel(LinearAttentionPreTrainedModel):
158
+
159
+ def __init__(self, config: LinearAttentionConfig):
160
+ super().__init__(config)
161
+ self.padding_idx = config.pad_token_id
162
+ self.vocab_size = config.vocab_size
163
+
164
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
165
+ self.layers = nn.ModuleList([LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None, # noqa
182
+ inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None
188
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
189
+ if output_attentions:
190
+ warnings.warn(
191
+ "`LinearAttentionModel` does not support output attention weights now, "
192
+ "so `output_attentions` is set to `False`."
193
+ )
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ )
233
+ else:
234
+ hidden_states, attentions, past_key_values = layer(
235
+ hidden_states,
236
+ attention_mask=attention_mask,
237
+ past_key_values=past_key_values,
238
+ use_cache=use_cache,
239
+ output_attentions=output_attentions
240
+ )
241
+
242
+ if output_attentions:
243
+ all_attns += (attentions,)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ if not return_dict:
252
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
253
+ return BaseModelOutputWithPast(
254
+ last_hidden_state=hidden_states,
255
+ past_key_values=past_key_values,
256
+ hidden_states=all_hidden_states,
257
+ attentions=all_attns
258
+ )
259
+
260
+
261
+ class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel, GenerationMixin):
262
+
263
+ _tied_weights_keys = ["lm_head.weight"]
264
+
265
+ def __init__(self, config):
266
+ super().__init__(config)
267
+ self.model = LinearAttentionModel(config)
268
+ self.vocab_size = config.vocab_size
269
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
270
+ self.criterion = None
271
+
272
+ # Initialize weights and apply final processing
273
+ self.post_init()
274
+
275
+ def get_input_embeddings(self):
276
+ return self.model.embeddings
277
+
278
+ def set_input_embeddings(self, value):
279
+ self.model.embeddings = value
280
+
281
+ def get_output_embeddings(self):
282
+ return self.lm_head
283
+
284
+ def set_output_embeddings(self, new_embeddings):
285
+ self.lm_head = new_embeddings
286
+
287
+ def set_decoder(self, decoder):
288
+ self.model = decoder
289
+
290
+ def get_decoder(self):
291
+ return self.model
292
+
293
+ def generate(self, *args, **kwargs):
294
+ try:
295
+ return super().generate(*args, **kwargs)
296
+ except AttributeError as exception:
297
+ if 'past_key_values' in str(exception):
298
+ raise AttributeError(
299
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
300
+ f"which is not supported for {self.__class__.__name__}. "
301
+ f"Try another generation strategy instead. "
302
+ f"For the available generation strategies, check this doc: "
303
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
304
+ )
305
+ else:
306
+ raise exception
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ inputs_embeds: Optional[torch.Tensor] = None,
348
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0
355
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
361
+
362
+ outputs = self.model(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ inputs_embeds=inputs_embeds,
366
+ past_key_values=past_key_values,
367
+ use_cache=use_cache,
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict
371
+ )
372
+
373
+ hidden_states = outputs[0]
374
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
375
+
376
+ loss, logits = None, None
377
+ if not fuse_linear_and_cross_entropy or labels is None:
378
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
379
+ if labels is not None:
380
+ if getattr(self, 'criterion', None) is None:
381
+ if fuse_linear_and_cross_entropy:
382
+ criterion = FusedLinearCrossEntropyLoss()
383
+ elif self.config.fuse_cross_entropy:
384
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
385
+ else:
386
+ criterion = nn.CrossEntropyLoss()
387
+ else:
388
+ criterion = self.criterion
389
+ labels = labels.to(hidden_states.device)
390
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
391
+ if fuse_linear_and_cross_entropy:
392
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
393
+ else:
394
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
395
+
396
+ if not return_dict:
397
+ output = (logits,) + outputs[1:]
398
+ return (loss,) + output if loss is not None else output
399
+
400
+ return CausalLMOutputWithPast(
401
+ loss=loss,
402
+ logits=logits,
403
+ past_key_values=outputs.past_key_values,
404
+ hidden_states=outputs.hidden_states,
405
+ attentions=outputs.attentions,
406
+ )
fla/models/mamba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba.configuration_mamba import MambaConfig
6
+ from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel
7
+
8
+ AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
9
+ AutoModel.register(MambaConfig, MambaModel, True)
10
+ AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
fla/models/mamba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (717 Bytes). View file
 
fla/models/mamba/configuration_mamba.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MAMBA configuration"""
16
+
17
+ import math
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+
22
+ class MambaConfig(PretrainedConfig):
23
+ """
24
+ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
25
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
26
+ defaults will yield a similar configuration to that of the MAMBA
27
+ [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*):
35
+ Vocabulary size of the Mamba model.
36
+ hidden_size (`int`, *optional*):
37
+ Dimensionality of the embeddings and hidden states. Default: 2048.
38
+ state_size (`int`, *optional*):
39
+ Shape of the state space latents. Default: 16.
40
+ num_hidden_layers (`int`, *optional*):
41
+ Number of hidden layers in the model. Default: 48.
42
+ layer_norm_epsilon (`float`, *optional*):
43
+ The epsilon to use in the layer normalization layers. Default: 1e-5.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id. Default: 0.
46
+ bos_token_id (`int`, *optional*):
47
+ The id of the beginning of sentence token in the vocabulary. Default: 0.
48
+ eos_token_id (`int`, *optional*):
49
+ The id of the end of sentence token in the vocabulary. Default: 0.
50
+ expand (`int`, *optional*):
51
+ Expanding factor used to determine the intermediate size. Default: 2.
52
+ conv_kernel (`int`, *optional*):
53
+ Size of the convolution kernel. Default: 4.
54
+ use_bias (`bool`, *optional*):
55
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block. Default: `False`.
56
+ use_conv_bias (`bool`, *optional*):
57
+ Whether or not to use bias in the convolution layer of the mixer block. Default: `True`.
58
+ hidden_act (`str`, *optional*):
59
+ The non-linear activation function (function or string) in the decoder. Default: `"silu"`.
60
+ initializer_range (`float`, *optional*):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Default: 0.1.
62
+ residual_in_fp32 (`bool`, *optional*):
63
+ Whether or not residuals should be in `float32`.
64
+ If set to `False` residuals will keep the same `dtype` as the rest of the model. Default: `True`.
65
+ time_step_rank (`Union[int,str]`, *optional*):
66
+ Rank of the the discretization projection matrix.
67
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`. Default: `"auto"`.
68
+ time_step_scale (`float`, *optional*):
69
+ Scale used used to scale `dt_proj.bias`. Default: 1.0.
70
+ time_step_min (`float`, *optional*):
71
+ Minimum `time_step` used to bound `dt_proj.bias`. Default: 0.001.
72
+ time_step_max (`float`, *optional*):
73
+ Maximum `time_step` used to bound `dt_proj.bias`. Default: 0.1.
74
+ time_step_init_scheme (`float`, *optional*):
75
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`. Default: `"random"`.
76
+ time_step_floor (`float`, *optional*):
77
+ Minimum clamping value of the `dt_proj.bias` layer initialization. Default: 0.0001.
78
+ window_size (`int`, *optional*):
79
+ The window size used for sliding window attention. Default: 2048.
80
+ rescale_prenorm_residual (`bool`, *optional*):
81
+ Whether or not to rescale `out_proj` weights when initializing. Default: `False`.
82
+ use_cache (`bool`, *optional*):
83
+ Whether or not the cache should be used. Default: `True`.
84
+
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import MambaConfig, MambaModel
90
+
91
+ >>> # Initializing a Mamba configuration
92
+ >>> configuration = MambaConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = MambaModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "mamba"
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size: int = 32000,
106
+ hidden_size: int = 2048,
107
+ state_size: int = 16,
108
+ num_hidden_layers: int = 48,
109
+ layer_norm_epsilon=1e-5,
110
+ pad_token_id: int = 0,
111
+ bos_token_id: int = 1,
112
+ eos_token_id: int = 2,
113
+ expand: int = 2,
114
+ conv_kernel: int = 4,
115
+ use_bias: bool = False,
116
+ use_conv_bias: bool = True,
117
+ hidden_act: str = "silu",
118
+ initializer_range: str = 0.1,
119
+ residual_in_fp32: bool = False,
120
+ time_step_rank: str = "auto",
121
+ time_step_scale: float = 1.0,
122
+ time_step_min: float = 0.001,
123
+ time_step_max: float = 0.1,
124
+ time_step_init_scheme: str = "random",
125
+ time_step_floor: float = 1e-4,
126
+ rescale_prenorm_residual: bool = False,
127
+ use_cache: bool = True,
128
+ fuse_norm: bool = True,
129
+ fuse_cross_entropy: bool = True,
130
+ tie_word_embeddings: bool = False,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.hidden_size = hidden_size
135
+ self.state_size = state_size
136
+ self.num_hidden_layers = num_hidden_layers
137
+ self.layer_norm_epsilon = layer_norm_epsilon
138
+ self.conv_kernel = conv_kernel
139
+ self.expand = expand
140
+ self.intermediate_size = int(expand * self.hidden_size)
141
+ self.bos_token_id = bos_token_id
142
+ self.eos_token_id = eos_token_id
143
+ self.pad_token_id = pad_token_id
144
+ self.use_bias = use_bias
145
+ self.use_conv_bias = use_conv_bias
146
+ self.hidden_act = hidden_act
147
+ self.initializer_range = initializer_range
148
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
149
+ self.time_step_scale = time_step_scale
150
+ self.time_step_min = time_step_min
151
+ self.time_step_max = time_step_max
152
+ self.time_step_init_scheme = time_step_init_scheme
153
+ self.time_step_floor = time_step_floor
154
+ self.rescale_prenorm_residual = rescale_prenorm_residual
155
+ self.residual_in_fp32 = residual_in_fp32
156
+ self.use_cache = use_cache
157
+ self.fuse_norm = fuse_norm
158
+ self.fuse_cross_entropy = fuse_cross_entropy
159
+
160
+ super().__init__(
161
+ bos_token_id=bos_token_id,
162
+ eos_token_id=eos_token_id,
163
+ pad_token_id=pad_token_id,
164
+ tie_word_embeddings=tie_word_embeddings,
165
+ **kwargs
166
+ )
fla/models/mamba/modeling_mamba.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MAMBA model."""
16
+
17
+ import math
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.configuration_utils import PretrainedConfig
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.utils import ModelOutput, logging
30
+ from transformers.utils.deprecation import deprecate_kwarg
31
+
32
+ from fla.models.mamba.configuration_mamba import MambaConfig
33
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter('ignore')
40
+ try:
41
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
42
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
43
+ except ImportError:
44
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
45
+
46
+ try:
47
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
48
+ except ImportError:
49
+ causal_conv1d_update, causal_conv1d_fn = None, None
50
+ is_fast_path_available = all((
51
+ selective_state_update,
52
+ selective_scan_fn,
53
+ causal_conv1d_fn,
54
+ causal_conv1d_update,
55
+ mamba_inner_fn
56
+ ))
57
+
58
+
59
+ class MambaCache:
60
+ """
61
+ Cache for mamba model which does not have attention mechanism and key value states.
62
+
63
+ Arguments:
64
+ config (`PretrainedConfig):
65
+ The configuration file defining the shape-related attributes required to initialize the static cache.
66
+ batch_size (`int`):
67
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
68
+ smaller batch size is used.
69
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
70
+ The default `dtype` to use when initializing the layer.
71
+ device (`torch.device` or `str`, *optional*):
72
+ The device on which the cache should be initialized. Should be the same as the layer.
73
+
74
+ Attributes:
75
+ dtype: (`torch.dtype`):
76
+ The default `dtype` used to initializing the cache.
77
+ intermediate_size: (`int`):
78
+ Model's intermediate_size taken from config.
79
+ ssm_state_size: (`int`):
80
+ Model's state_size taken from config.
81
+ conv_kernel_size: (`int`):
82
+ Model's convolution kernel size taken from config
83
+ conv_states: (`torch.Tensor`):
84
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
85
+ ssm_states: (`torch.Tensor`):
86
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
87
+
88
+ Example:
89
+
90
+ ```python
91
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
92
+
93
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
94
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
95
+
96
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
97
+
98
+ >>> # Prepare a cache class and pass it to model's forward
99
+ >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
100
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
101
+ >>> outputs.past_key_values
102
+ MambaCache()
103
+ ```
104
+ """
105
+
106
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
107
+ def __init__(
108
+ self,
109
+ config: PretrainedConfig,
110
+ batch_size: int = None,
111
+ dtype: torch.dtype = torch.float16,
112
+ device: Optional[Union[torch.device, str]] = None,
113
+ max_batch_size: Optional[int] = None,
114
+ ):
115
+ if max_batch_size is not None:
116
+ logger.warning_once(
117
+ f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
118
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
119
+ )
120
+ self.dtype = dtype
121
+ self.batch_size = batch_size or max_batch_size
122
+ self.intermediate_size = config.intermediate_size
123
+ self.ssm_state_size = config.state_size
124
+ self.conv_kernel_size = config.conv_kernel
125
+
126
+ self.conv_states: torch.Tensor = torch.zeros(
127
+ config.num_hidden_layers,
128
+ self.batch_size,
129
+ self.intermediate_size,
130
+ self.conv_kernel_size,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+ self.ssm_states: torch.Tensor = torch.zeros(
135
+ config.num_hidden_layers,
136
+ self.batch_size,
137
+ self.intermediate_size,
138
+ self.ssm_state_size,
139
+ device=device,
140
+ dtype=dtype,
141
+ )
142
+
143
+ torch._dynamo.mark_static_address(self.conv_states)
144
+ torch._dynamo.mark_static_address(self.ssm_states)
145
+
146
+ def update_conv_state(
147
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
148
+ ) -> torch.Tensor:
149
+ conv_state = self.conv_states[layer_idx]
150
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
151
+
152
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
153
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
154
+ self.conv_states[layer_idx].zero_()
155
+ self.conv_states[layer_idx] += conv_state
156
+ return self.conv_states[layer_idx]
157
+
158
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
159
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
160
+ return self.ssm_states[layer_idx]
161
+
162
+ def reset(self):
163
+ self.conv_states.zero_()
164
+ self.ssm_states.zero_()
165
+
166
+
167
+ class MambaMixer(nn.Module):
168
+ """
169
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
170
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
171
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
172
+ and is why Mamba is called **selective** state spaces)
173
+ """
174
+
175
+ def __init__(self, config: MambaConfig, layer_idx: int):
176
+ super().__init__()
177
+ self.config = config
178
+ self.hidden_size = config.hidden_size
179
+ self.ssm_state_size = config.state_size
180
+ self.conv_kernel_size = config.conv_kernel
181
+ self.intermediate_size = config.intermediate_size
182
+ self.time_step_rank = int(config.time_step_rank)
183
+ self.layer_idx = layer_idx
184
+ self.use_conv_bias = config.use_conv_bias
185
+ self.conv1d = nn.Conv1d(
186
+ in_channels=self.intermediate_size,
187
+ out_channels=self.intermediate_size,
188
+ bias=config.use_conv_bias,
189
+ kernel_size=config.conv_kernel,
190
+ groups=self.intermediate_size,
191
+ padding=config.conv_kernel - 1,
192
+ )
193
+
194
+ self.activation = config.hidden_act
195
+ self.act = ACT2FN[config.hidden_act]
196
+
197
+ # projection of the input hidden states
198
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
199
+ # selective projection used to make dt, B and C input dependant
200
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
201
+ # time step projection (discretization)
202
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
203
+
204
+ # S4D real initialization. These are not discretized!
205
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
206
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
207
+ A = A.expand(self.intermediate_size, -1).contiguous()
208
+
209
+ self.A_log = nn.Parameter(torch.log(A))
210
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
211
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
212
+ self.use_bias = config.use_bias
213
+
214
+ if not is_fast_path_available:
215
+ logger.warning_once(
216
+ "The fast path is not available because on of "
217
+ "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
218
+ " is None. Falling back to the naive implementation. "
219
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
220
+ " https://github.com/Dao-AILab/causal-conv1d"
221
+ )
222
+
223
+ def cuda_kernels_forward(
224
+ self,
225
+ hidden_states: torch.Tensor,
226
+ cache_params: Optional[MambaCache] = None,
227
+ cache_position: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.LongTensor] = None,
229
+ ):
230
+ # 1. Gated MLP's linear projection
231
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
232
+
233
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
234
+ contextualized_states = mamba_inner_fn(
235
+ projected_states,
236
+ self.conv1d.weight,
237
+ self.conv1d.bias if self.use_conv_bias else None,
238
+ self.x_proj.weight,
239
+ self.dt_proj.weight,
240
+ self.out_proj.weight,
241
+ self.out_proj.bias.float() if self.use_bias else None,
242
+ -torch.exp(self.A_log.float()),
243
+ None, # input-dependent B
244
+ None, # input-dependent C
245
+ self.D.float(),
246
+ delta_bias=self.dt_proj.bias.float(),
247
+ delta_softplus=True,
248
+ )
249
+
250
+ else:
251
+ hidden_states, gate = projected_states.chunk(2, dim=1)
252
+
253
+ if attention_mask is not None:
254
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
255
+
256
+ # 2. Convolution sequence transformation
257
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
258
+ if cache_params is not None and cache_position[0] > 0:
259
+ hidden_states = causal_conv1d_update(
260
+ hidden_states.squeeze(-1),
261
+ cache_params.conv_states[self.layer_idx],
262
+ conv_weights,
263
+ self.conv1d.bias,
264
+ self.activation,
265
+ )
266
+ hidden_states = hidden_states.unsqueeze(-1)
267
+ else:
268
+ if cache_params is not None:
269
+ conv_states = nn.functional.pad(
270
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
271
+ )
272
+ cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
273
+ hidden_states = causal_conv1d_fn(
274
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
275
+ )
276
+
277
+ if attention_mask is not None:
278
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
279
+
280
+ # 3. State Space Model sequence transformation
281
+ # 3.a. input varying initialization of time_step, B and C
282
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
283
+ time_step, B, C = torch.split(
284
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
285
+ )
286
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
287
+
288
+ A = -torch.exp(self.A_log.float())
289
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
290
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
291
+ if cache_params is not None and cache_position[0] > 0:
292
+ scan_outputs = selective_state_update(
293
+ cache_params.ssm_states[self.layer_idx],
294
+ hidden_states[..., 0],
295
+ discrete_time_step[..., 0],
296
+ A,
297
+ B[:, 0],
298
+ C[:, 0],
299
+ self.D,
300
+ gate[..., 0],
301
+ time_proj_bias,
302
+ dt_softplus=True,
303
+ ).unsqueeze(-1)
304
+ else:
305
+ scan_outputs, ssm_state = selective_scan_fn(
306
+ hidden_states,
307
+ discrete_time_step,
308
+ A,
309
+ B.transpose(1, 2),
310
+ C.transpose(1, 2),
311
+ self.D.float(),
312
+ gate,
313
+ time_proj_bias,
314
+ delta_softplus=True,
315
+ return_last_state=True,
316
+ )
317
+ if ssm_state is not None and cache_params is not None:
318
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
319
+
320
+ # 4. Final linear projection
321
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
322
+ return contextualized_states
323
+
324
+ def slow_forward(
325
+ self,
326
+ input_states,
327
+ cache_params: Optional[MambaCache] = None,
328
+ cache_position: Optional[torch.LongTensor] = None,
329
+ attention_mask: Optional[torch.LongTensor] = None
330
+ ):
331
+ batch_size, seq_len, _ = input_states.shape
332
+ dtype = input_states.dtype
333
+ # 1. Gated MLP's linear projection
334
+ # [batch, 2 * intermediate_size, seq_len]
335
+ projected_states = self.in_proj(input_states).transpose(1, 2)
336
+ hidden_states, gate = projected_states.chunk(2, dim=1)
337
+
338
+ if attention_mask is not None:
339
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
340
+
341
+ # 2. Convolution sequence transformation
342
+ if cache_params is not None:
343
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
344
+ ssm_state = ssm_state.to(hidden_states.device)
345
+ # use `cache_position.shape[0]` to check whether we are in prefill
346
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
347
+ # breaks dynamo fullgraph constraints
348
+ if cache_position.shape[0] == self.conv_kernel_size:
349
+ conv_state = nn.functional.pad(
350
+ hidden_states,
351
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
352
+ )
353
+
354
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
355
+ # [batch, intermediate_size, seq_len]
356
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
357
+ else:
358
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
359
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
360
+ if self.use_conv_bias:
361
+ hidden_states += self.conv1d.bias
362
+ # [batch, intermediate_size, 1] : decoding
363
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
364
+ else:
365
+ ssm_state = torch.zeros(
366
+ (batch_size, self.intermediate_size, self.ssm_state_size),
367
+ device=hidden_states.device, dtype=dtype
368
+ )
369
+ # [batch, intermediate_size, seq_len]
370
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
371
+
372
+ if attention_mask is not None:
373
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
374
+
375
+ # 3. State Space Model sequence transformation
376
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
377
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
378
+ time_step, B, C = torch.split(
379
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
380
+ )
381
+ # [batch, seq_len, intermediate_size]
382
+ discrete_time_step = self.dt_proj(time_step)
383
+ # [batch, intermediate_size, seq_len]
384
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
385
+
386
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
387
+ # [intermediate_size, ssm_state_size]
388
+ A = -torch.exp(self.A_log.float())
389
+ # [batch, intermediate_size, seq_len, ssm_state_size]
390
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
391
+ # [batch, intermediate_size, seq_len, ssm_state_size]
392
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
393
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
394
+
395
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
396
+ scan_outputs = []
397
+ for i in range(seq_len):
398
+ # [batch, intermediade_size, ssm_state]
399
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
400
+ # [batch, intermediade_size, 1]
401
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
402
+ scan_outputs.append(scan_output[:, :, 0])
403
+ # [batch, seq_len, intermediade_size]
404
+ scan_output = torch.stack(scan_outputs, dim=-1)
405
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
406
+ scan_output = (scan_output * self.act(gate))
407
+
408
+ if cache_params is not None:
409
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
410
+
411
+ # 4. Final linear projection
412
+ # [batch, seq_len, hidden_size]
413
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2))
414
+ return contextualized_states
415
+ # fmt: on
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states,
420
+ cache_params: Optional[MambaCache] = None,
421
+ cache_position: Optional[torch.LongTensor] = None,
422
+ attention_mask: Optional[torch.LongTensor] = None,
423
+ ):
424
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
425
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
426
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
427
+
428
+
429
+ class MambaBlock(nn.Module):
430
+ def __init__(self, config, layer_idx):
431
+ super().__init__()
432
+ self.config = config
433
+ self.layer_idx = layer_idx
434
+ self.residual_in_fp32 = config.residual_in_fp32
435
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
436
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states,
441
+ cache_params: Optional[MambaCache] = None,
442
+ cache_position: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ ):
445
+ residual = hidden_states
446
+ hidden_states = self.norm(hidden_states)
447
+ if self.residual_in_fp32:
448
+ residual = residual.to(torch.float32)
449
+
450
+ hidden_states = self.mixer(
451
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
452
+ )
453
+ hidden_states = residual + hidden_states
454
+ if self.residual_in_fp32:
455
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
456
+ return hidden_states
457
+
458
+
459
+ class MambaPreTrainedModel(PreTrainedModel):
460
+ """
461
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
462
+ models.
463
+ """
464
+
465
+ config_class = MambaConfig
466
+ base_model_prefix = "backbone"
467
+ _no_split_modules = ["MambaBlock", "MambaMixer"]
468
+ supports_gradient_checkpointing = True
469
+ _is_stateful = True
470
+
471
+ def _init_weights(self, module):
472
+ """Initialize the weights."""
473
+ if isinstance(module, nn.Linear):
474
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
475
+ if module.bias is not None:
476
+ if not getattr(module.bias, "_no_reinit", False):
477
+ nn.init.zeros_(module.bias)
478
+ elif isinstance(module, MambaMixer):
479
+ module.A_log._no_weight_decay = True
480
+ module.D._no_weight_decay = True
481
+
482
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
483
+ if self.config.time_step_init_scheme == "constant":
484
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
485
+ elif self.config.time_step_init_scheme == "random":
486
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
487
+
488
+ dt = torch.exp(
489
+ torch.rand(self.config.intermediate_size)
490
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
491
+ + math.log(self.config.time_step_min)
492
+ ).clamp(min=self.config.time_step_floor)
493
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
494
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
495
+ with torch.no_grad():
496
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
497
+ module.dt_proj.bias._no_reinit = True
498
+ elif isinstance(module, nn.Embedding):
499
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
500
+ elif hasattr(module, 'reset_parameters'):
501
+ module.reset_parameters()
502
+
503
+ if self.config.rescale_prenorm_residual:
504
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
505
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
506
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
507
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
508
+ #
509
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
510
+ for name, p in module.named_parameters():
511
+ if name in ["out_proj.weight"]:
512
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
513
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
514
+ # We need to reinit p since this code could be called multiple times
515
+ # Having just p *= scale would repeatedly scale it down
516
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
517
+ with torch.no_grad():
518
+ p /= math.sqrt(self.config.num_hidden_layers)
519
+
520
+
521
+ @dataclass
522
+ class MambaOutput(ModelOutput):
523
+ """
524
+ Class for the MAMBA model outputs.
525
+
526
+ Args:
527
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
528
+ Sequence of hidden-states at the output of the last layer of the model.
529
+ cache_params (`MambaCache`):
530
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
531
+ avoid providing the old `input_ids`.
532
+
533
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
534
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
535
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
536
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
537
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
538
+
539
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
540
+ """
541
+
542
+ last_hidden_state: Optional[torch.FloatTensor] = None
543
+ cache_params: Optional[MambaCache] = None
544
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
545
+
546
+
547
+ @dataclass
548
+ class MambaCausalLMOutput(ModelOutput):
549
+ """
550
+ Base class for causal language model (or autoregressive) outputs.
551
+
552
+ Args:
553
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
554
+ Language modeling loss (for next-token prediction).
555
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
556
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
557
+ cache_params (`MambaCache`):
558
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
559
+ avoid providing the old `input_ids`.
560
+
561
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
562
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
563
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
564
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
565
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
566
+
567
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
568
+ """
569
+
570
+ loss: Optional[torch.FloatTensor] = None
571
+ logits: Optional[torch.FloatTensor] = None
572
+ cache_params: Optional[MambaCache] = None
573
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
574
+
575
+
576
+ class MambaModel(MambaPreTrainedModel):
577
+ def __init__(self, config):
578
+ super().__init__(config)
579
+
580
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
581
+ self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
582
+
583
+ self.gradient_checkpointing = False
584
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
585
+ # Initialize weights and apply final processing
586
+ self._register_load_state_dict_pre_hook(self.load_hook)
587
+ self.post_init()
588
+
589
+ def load_hook(self, state_dict, prefix, *args):
590
+ for k in state_dict:
591
+ if "embedding." in k:
592
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
593
+ break
594
+
595
+ def get_input_embeddings(self):
596
+ return self.embeddings
597
+
598
+ def set_input_embeddings(self, new_embeddings):
599
+ self.embeddings = new_embeddings
600
+
601
+ def forward(
602
+ self,
603
+ input_ids: Optional[torch.LongTensor] = None,
604
+ inputs_embeds: Optional[torch.LongTensor] = None,
605
+ cache_params: Optional[MambaCache] = None,
606
+ use_cache: Optional[bool] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ return_dict: Optional[bool] = None,
609
+ cache_position: Optional[torch.LongTensor] = None,
610
+ attention_mask: Optional[torch.LongTensor] = None,
611
+ ) -> Union[Tuple, MambaOutput]:
612
+ output_hidden_states = (
613
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
614
+ )
615
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
616
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
617
+
618
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
619
+ raise ValueError(
620
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
621
+ )
622
+
623
+ if inputs_embeds is None:
624
+ inputs_embeds = self.embeddings(input_ids)
625
+
626
+ if self.gradient_checkpointing and self.training and use_cache:
627
+ use_cache = False
628
+
629
+ if use_cache:
630
+ if cache_params is None:
631
+ cache_params = MambaCache(
632
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
633
+ )
634
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
635
+ elif cache_position is None:
636
+ # cases when we do manual forward instead of using `model.generate` which will initiate
637
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
638
+ # hack to conjecture the current cache position
639
+ raise ValueError(
640
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
641
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
642
+ "be initialized for you automatically"
643
+ )
644
+ else:
645
+ cache_params = None
646
+
647
+ hidden_states = inputs_embeds
648
+ all_hidden_states = () if output_hidden_states else None
649
+ for mixer_block in self.layers:
650
+ if self.gradient_checkpointing and self.training:
651
+ hidden_states = self._gradient_checkpointing_func(
652
+ mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
653
+ )
654
+ else:
655
+ hidden_states = mixer_block(
656
+ hidden_states,
657
+ cache_params=cache_params,
658
+ cache_position=cache_position,
659
+ attention_mask=attention_mask,
660
+ )
661
+
662
+ if output_hidden_states:
663
+ all_hidden_states = all_hidden_states + (hidden_states,)
664
+
665
+ hidden_states = self.norm_f(hidden_states)
666
+
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states,)
669
+
670
+ if not return_dict:
671
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
672
+
673
+ return MambaOutput(
674
+ last_hidden_state=hidden_states,
675
+ cache_params=cache_params if use_cache else None,
676
+ hidden_states=all_hidden_states,
677
+ )
678
+
679
+
680
+ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
681
+
682
+ _tied_weights_keys = ["lm_head.weight"]
683
+
684
+ def __init__(self, config):
685
+ super().__init__(config)
686
+ self.backbone = MambaModel(config)
687
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
688
+ self.criterion = None
689
+
690
+ # Initialize weights and apply final processing
691
+ self.post_init()
692
+
693
+ def get_output_embeddings(self):
694
+ return self.lm_head
695
+
696
+ def set_output_embeddings(self, new_embeddings):
697
+ self.lm_head = new_embeddings
698
+
699
+ def get_input_embeddings(self):
700
+ return self.backbone.get_input_embeddings()
701
+
702
+ def set_input_embeddings(self, new_embeddings):
703
+ return self.backbone.set_input_embeddings(new_embeddings)
704
+
705
+ def _update_model_kwargs_for_generation(
706
+ self, outputs: ModelOutput,
707
+ model_kwargs: Dict[str, Any],
708
+ num_new_tokens: int = 1,
709
+ **kwargs
710
+ ) -> Dict[str, Any]:
711
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
712
+ if (
713
+ model_kwargs.get("use_cache", True)
714
+ and "cache_position" in model_kwargs
715
+ and model_kwargs["cache_position"] is not None
716
+ ):
717
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
718
+
719
+ if "attention_mask" in model_kwargs:
720
+ attention_mask = model_kwargs["attention_mask"]
721
+ model_kwargs["attention_mask"] = torch.cat(
722
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
723
+ )
724
+
725
+ return model_kwargs
726
+
727
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
728
+ def prepare_inputs_for_generation(
729
+ self,
730
+ input_ids,
731
+ inputs_embeds=None,
732
+ use_cache=None,
733
+ cache_params: Optional[MambaCache] = None,
734
+ cache_position: Optional[torch.LongTensor] = None,
735
+ attention_mask: Optional[torch.LongTensor] = None,
736
+ logits_to_keep: Optional[int] = None,
737
+ **kwargs,
738
+ ):
739
+ if use_cache:
740
+ # `cache_position` should have been initialized in `generate`
741
+ if cache_position is None:
742
+ raise ValueError(
743
+ "`cache_position` should not be None as it should have been initialized in "
744
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
745
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
746
+ )
747
+ if cache_position[0] > 0:
748
+ input_ids = input_ids[:, -1].unsqueeze(-1)
749
+
750
+ if attention_mask is not None:
751
+ attention_mask = None
752
+
753
+ else:
754
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
755
+ # considering padding will be applied when input length is shorter, and truncation
756
+ # will be applied when it is longer, so it will be equivalent to always have it match
757
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
758
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
759
+
760
+ if inputs_embeds is not None and cache_params is None:
761
+ model_inputs = {"inputs_embeds": inputs_embeds}
762
+ else:
763
+ model_inputs = {"input_ids": input_ids.contiguous()}
764
+
765
+ if logits_to_keep is not None:
766
+ model_inputs['logits_to_keep'] = logits_to_keep
767
+
768
+ model_inputs.update({
769
+ 'cache_params': cache_params,
770
+ 'use_cache': use_cache,
771
+ 'cache_position': cache_position,
772
+ 'attention_mask': attention_mask,
773
+ 'logits_to_keep': logits_to_keep,
774
+ })
775
+ return model_inputs
776
+
777
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
778
+ def forward(
779
+ self,
780
+ input_ids: Optional[torch.LongTensor] = None,
781
+ attention_mask: Optional[torch.LongTensor] = None,
782
+ inputs_embeds: Optional[torch.FloatTensor] = None,
783
+ cache_params: Optional[MambaCache] = None,
784
+ labels: Optional[torch.LongTensor] = None,
785
+ output_hidden_states: Optional[bool] = None,
786
+ return_dict: Optional[bool] = None,
787
+ use_cache: Optional[bool] = None,
788
+ cache_position: Optional[torch.Tensor] = None,
789
+ logits_to_keep: Optional[int] = 0,
790
+ **kwargs, # for now we need this for generation
791
+ ) -> Union[Tuple, MambaCausalLMOutput]:
792
+ r"""
793
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
795
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
796
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
797
+ """
798
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
799
+
800
+ mamba_outputs = self.backbone(
801
+ input_ids,
802
+ cache_params=cache_params,
803
+ inputs_embeds=inputs_embeds,
804
+ output_hidden_states=output_hidden_states,
805
+ return_dict=return_dict,
806
+ use_cache=use_cache,
807
+ cache_position=cache_position,
808
+ attention_mask=attention_mask,
809
+ )
810
+ hidden_states = mamba_outputs[0]
811
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
812
+
813
+ loss, logits = None, None
814
+ if not fuse_linear_and_cross_entropy or labels is None:
815
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
816
+ if labels is not None:
817
+ if getattr(self, 'criterion', None) is None:
818
+ if fuse_linear_and_cross_entropy:
819
+ criterion = FusedLinearCrossEntropyLoss()
820
+ elif self.config.fuse_cross_entropy:
821
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
822
+ else:
823
+ criterion = nn.CrossEntropyLoss()
824
+ else:
825
+ criterion = self.criterion
826
+ # Enable model parallelism
827
+ labels = labels.to(hidden_states.device)
828
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
829
+ if fuse_linear_and_cross_entropy:
830
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
831
+ else:
832
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
833
+
834
+ if not return_dict:
835
+ output = (logits,) + mamba_outputs[1:]
836
+ return (loss,) + output if loss is not None else output
837
+
838
+ return MambaCausalLMOutput(
839
+ loss=loss,
840
+ logits=logits,
841
+ cache_params=mamba_outputs.cache_params,
842
+ hidden_states=mamba_outputs.hidden_states,
843
+ )
fla/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla/models/mamba2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (695 Bytes). View file
 
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.5 kB). View file