Erland commited on
Commit
7fdd671
·
verified ·
1 Parent(s): ec1fbcf

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__pycache__/delta_net.cpython-311.pyc +0 -0
  2. fla/models/__init__.py +51 -0
  3. fla/models/__pycache__/__init__.cpython-311.pyc +0 -0
  4. fla/models/__pycache__/utils.cpython-311.pyc +0 -0
  5. fla/models/lightnet/__pycache__/__init__.cpython-311.pyc +0 -0
  6. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-311.pyc +0 -0
  7. fla/models/linear_attn/modeling_linear_attn.py +406 -0
  8. fla/models/mamba/__pycache__/configuration_mamba.cpython-311.pyc +0 -0
  9. fla/models/mamba/__pycache__/modeling_mamba.cpython-311.pyc +0 -0
  10. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-311.pyc +0 -0
  11. fla/models/mamba2/__pycache__/modeling_mamba2.cpython-311.pyc +0 -0
  12. fla/models/mamba2/configuration_mamba2.py +170 -0
  13. fla/models/nsa/__init__.py +15 -0
  14. fla/models/nsa/__pycache__/configuration_nsa.cpython-311.pyc +0 -0
  15. fla/models/nsa/__pycache__/modeling_nsa.cpython-311.pyc +0 -0
  16. fla/models/nsa/configuration_nsa.py +75 -0
  17. fla/models/retnet/__init__.py +13 -0
  18. fla/models/retnet/__pycache__/configuration_retnet.cpython-311.pyc +0 -0
  19. fla/models/retnet/__pycache__/modeling_retnet.cpython-311.pyc +0 -0
  20. fla/models/retnet/configuration_retnet.py +92 -0
  21. fla/models/retnet/modeling_retnet.py +425 -0
  22. fla/models/rwkv6/__pycache__/__init__.cpython-311.pyc +0 -0
  23. fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-311.pyc +0 -0
  24. fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-311.pyc +0 -0
  25. fla/models/rwkv6/configuration_rwkv6.py +82 -0
  26. fla/models/rwkv7/__init__.py +13 -0
  27. fla/models/rwkv7/__pycache__/__init__.cpython-311.pyc +0 -0
  28. fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-311.pyc +0 -0
  29. fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-311.pyc +0 -0
  30. fla/models/rwkv7/configuration_rwkv7.py +105 -0
  31. fla/models/rwkv7/modeling_rwkv7.py +505 -0
  32. fla/models/samba/__init__.py +13 -0
  33. fla/models/samba/__pycache__/configuration_samba.cpython-311.pyc +0 -0
  34. fla/models/samba/__pycache__/modeling_samba.cpython-311.pyc +0 -0
  35. fla/models/samba/configuration_samba.py +92 -0
  36. fla/models/samba/modeling_samba.py +413 -0
  37. fla/models/transformer/__init__.py +13 -0
  38. fla/models/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  39. fla/models/transformer/__pycache__/configuration_transformer.cpython-311.pyc +0 -0
  40. fla/models/transformer/__pycache__/modeling_transformer.cpython-311.pyc +0 -0
  41. fla/models/transformer/configuration_transformer.py +74 -0
  42. fla/models/transformer/modeling_transformer.py +437 -0
  43. fla/models/transformer_mtp/__init__.py +13 -0
  44. fla/models/transformer_mtp/__pycache__/__init__.cpython-311.pyc +0 -0
  45. fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-311.pyc +0 -0
  46. fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-311.pyc +0 -0
  47. fla/models/transformer_mtp/configuration_transformer.py +76 -0
  48. fla/models/transformer_mtp/modeling_transformer.py +601 -0
  49. fla/models/transformer_vanilla/__init__.py +13 -0
  50. fla/models/transformer_vanilla/configuration_transformer.py +71 -0
fla/layers/__pycache__/delta_net.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
fla/models/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
6
+ from fla.models.forgetting_transformer import (
7
+ ForgettingTransformerConfig,
8
+ ForgettingTransformerForCausalLM,
9
+ ForgettingTransformerModel
10
+ )
11
+ from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
12
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel
13
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
14
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
15
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
16
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
17
+ from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel
18
+ from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel
19
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
20
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
21
+ from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
22
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
23
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
24
+ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
25
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
26
+ from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel
27
+ from fla.models.transformer_mtp import MTPTransformerConfig, MTPTransformerForCausalLM, MTPTransformerModel
28
+
29
+ __all__ = [
30
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
31
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
32
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
33
+ 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
34
+ 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
35
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
36
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
37
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
38
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
39
+ 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel',
40
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
41
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
42
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
43
+ 'NSAConfig', 'NSAForCausalLM', 'NSAModel',
44
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
45
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
46
+ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
47
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
48
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
49
+ 'MTPTransformerConfig', 'MTPTransformerForCausalLM', 'MTPTransformerModel',
50
+ 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
51
+ ]
fla/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.45 kB). View file
 
fla/models/__pycache__/utils.cpython-311.pyc ADDED
Binary file (7.19 kB). View file
 
fla/models/lightnet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (756 Bytes). View file
 
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-311.pyc ADDED
Binary file (4.05 kB). View file
 
fla/models/linear_attn/modeling_linear_attn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.linear_attn import LinearAttention
20
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LinearAttentionMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class LinearAttentionBlock(nn.Module):
30
+ def __init__(self, config: LinearAttentionConfig, layer_idx: int):
31
+ super().__init__()
32
+
33
+ self.config = config
34
+ self.layer_idx = layer_idx
35
+
36
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
37
+ if config.attn is not None and layer_idx in config.attn['layers']:
38
+ self.attn = Attention(
39
+ hidden_size=config.hidden_size,
40
+ num_heads=config.attn['num_heads'],
41
+ num_kv_heads=config.attn['num_kv_heads'],
42
+ qkv_bias=config.attn['qkv_bias'],
43
+ window_size=config.attn['window_size'],
44
+ rope_theta=config.attn['rope_theta'],
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ layer_idx=layer_idx
47
+ )
48
+ else:
49
+ self.attn = LinearAttention(
50
+ mode=config.attn_mode,
51
+ hidden_size=config.hidden_size,
52
+ expand_k=config.expand_k,
53
+ expand_v=config.expand_v,
54
+ num_heads=config.num_heads,
55
+ num_kv_heads=config.num_kv_heads,
56
+ feature_map=config.feature_map,
57
+ tie_feature_map_qk=config.tie_feature_map_qk,
58
+ norm_q=config.norm_q,
59
+ norm_k=config.norm_k,
60
+ do_feature_map_norm=config.norm_feature_map,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = LinearAttentionMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs,
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ # currently not supported
85
+ attentions, past_key_values = None, None
86
+ hidden_states = self.attn_norm(hidden_states)
87
+ hidden_states = self.attn(hidden_states=hidden_states, **kwargs)
88
+ if self.config.fuse_norm:
89
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
90
+ else:
91
+ hidden_states = residual + hidden_states
92
+ residual = hidden_states
93
+ hidden_states = self.mlp_norm(hidden_states)
94
+ hidden_states = self.mlp(hidden_states, **kwargs)
95
+ hidden_states = residual + hidden_states
96
+
97
+ outputs = (hidden_states, attentions, past_key_values)
98
+
99
+ return outputs
100
+
101
+
102
+ class LinearAttentionPreTrainedModel(PreTrainedModel):
103
+
104
+ config_class = LinearAttentionConfig
105
+ base_model_prefix = 'model'
106
+ supports_gradient_checkpointing = True
107
+ _no_split_modules = ['LinearAttentionBlock']
108
+ _supports_cache_class = True
109
+
110
+ def __init__(self, *inputs, **kwargs):
111
+ super().__init__(*inputs, **kwargs)
112
+
113
+ def _init_weights(
114
+ self,
115
+ module: nn.Module,
116
+ prenorm_residual_strategy: Optional[str] = 'rescale',
117
+ num_residuals_per_layer: int = 2,
118
+ ):
119
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
120
+ # Slightly different from the TF version which uses truncated_normal for initialization
121
+ # cf https://github.com/pytorch/pytorch/pull/5617
122
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ elif isinstance(module, nn.Embedding):
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ elif hasattr(module, 'reset_parameters'):
128
+ module.reset_parameters()
129
+
130
+ if prenorm_residual_strategy is not None:
131
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
132
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
133
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
134
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
135
+ #
136
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
137
+ p = None
138
+ if hasattr(module, 'o_proj'):
139
+ p = module.o_proj.weight
140
+ elif hasattr(module, 'down_proj'):
141
+ p = module.down_proj.weight
142
+ if p is not None:
143
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
144
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
145
+ # We need to reinit p since this code could be called multiple times
146
+ # Having just p *= scale would repeatedly scale it down
147
+ if prenorm_residual_strategy == 'rescale':
148
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
149
+ with torch.no_grad():
150
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
151
+ elif prenorm_residual_strategy == 'zero':
152
+ nn.init.zeros_(p)
153
+ else:
154
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
155
+
156
+
157
+ class LinearAttentionModel(LinearAttentionPreTrainedModel):
158
+
159
+ def __init__(self, config: LinearAttentionConfig):
160
+ super().__init__(config)
161
+ self.padding_idx = config.pad_token_id
162
+ self.vocab_size = config.vocab_size
163
+
164
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
165
+ self.layers = nn.ModuleList([LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None, # noqa
182
+ inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None
188
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
189
+ if output_attentions:
190
+ warnings.warn(
191
+ "`LinearAttentionModel` does not support output attention weights now, "
192
+ "so `output_attentions` is set to `False`."
193
+ )
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ )
233
+ else:
234
+ hidden_states, attentions, past_key_values = layer(
235
+ hidden_states,
236
+ attention_mask=attention_mask,
237
+ past_key_values=past_key_values,
238
+ use_cache=use_cache,
239
+ output_attentions=output_attentions
240
+ )
241
+
242
+ if output_attentions:
243
+ all_attns += (attentions,)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ if not return_dict:
252
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
253
+ return BaseModelOutputWithPast(
254
+ last_hidden_state=hidden_states,
255
+ past_key_values=past_key_values,
256
+ hidden_states=all_hidden_states,
257
+ attentions=all_attns
258
+ )
259
+
260
+
261
+ class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel, GenerationMixin):
262
+
263
+ _tied_weights_keys = ["lm_head.weight"]
264
+
265
+ def __init__(self, config):
266
+ super().__init__(config)
267
+ self.model = LinearAttentionModel(config)
268
+ self.vocab_size = config.vocab_size
269
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
270
+ self.criterion = None
271
+
272
+ # Initialize weights and apply final processing
273
+ self.post_init()
274
+
275
+ def get_input_embeddings(self):
276
+ return self.model.embeddings
277
+
278
+ def set_input_embeddings(self, value):
279
+ self.model.embeddings = value
280
+
281
+ def get_output_embeddings(self):
282
+ return self.lm_head
283
+
284
+ def set_output_embeddings(self, new_embeddings):
285
+ self.lm_head = new_embeddings
286
+
287
+ def set_decoder(self, decoder):
288
+ self.model = decoder
289
+
290
+ def get_decoder(self):
291
+ return self.model
292
+
293
+ def generate(self, *args, **kwargs):
294
+ try:
295
+ return super().generate(*args, **kwargs)
296
+ except AttributeError as exception:
297
+ if 'past_key_values' in str(exception):
298
+ raise AttributeError(
299
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
300
+ f"which is not supported for {self.__class__.__name__}. "
301
+ f"Try another generation strategy instead. "
302
+ f"For the available generation strategies, check this doc: "
303
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
304
+ )
305
+ else:
306
+ raise exception
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ inputs_embeds: Optional[torch.Tensor] = None,
348
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0
355
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
361
+
362
+ outputs = self.model(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ inputs_embeds=inputs_embeds,
366
+ past_key_values=past_key_values,
367
+ use_cache=use_cache,
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict
371
+ )
372
+
373
+ hidden_states = outputs[0]
374
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
375
+
376
+ loss, logits = None, None
377
+ if not fuse_linear_and_cross_entropy or labels is None:
378
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
379
+ if labels is not None:
380
+ if getattr(self, 'criterion', None) is None:
381
+ if fuse_linear_and_cross_entropy:
382
+ criterion = FusedLinearCrossEntropyLoss()
383
+ elif self.config.fuse_cross_entropy:
384
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
385
+ else:
386
+ criterion = nn.CrossEntropyLoss()
387
+ else:
388
+ criterion = self.criterion
389
+ labels = labels.to(hidden_states.device)
390
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
391
+ if fuse_linear_and_cross_entropy:
392
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
393
+ else:
394
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
395
+
396
+ if not return_dict:
397
+ output = (logits,) + outputs[1:]
398
+ return (loss,) + output if loss is not None else output
399
+
400
+ return CausalLMOutputWithPast(
401
+ loss=loss,
402
+ logits=logits,
403
+ past_key_values=outputs.past_key_values,
404
+ hidden_states=outputs.hidden_states,
405
+ attentions=outputs.attentions,
406
+ )
fla/models/mamba/__pycache__/configuration_mamba.cpython-311.pyc ADDED
Binary file (7.33 kB). View file
 
fla/models/mamba/__pycache__/modeling_mamba.cpython-311.pyc ADDED
Binary file (42.9 kB). View file
 
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
fla/models/mamba2/__pycache__/modeling_mamba2.cpython-311.pyc ADDED
Binary file (53.5 kB). View file
 
fla/models/mamba2/configuration_mamba2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MAMBA2 configuration"""
15
+
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class Mamba2Config(PretrainedConfig):
22
+ """
23
+ This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
24
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
25
+ defaults will yield a similar configuration to that of the MAMBA2
26
+ [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+
32
+ Args:
33
+ num_heads (`int`, *optional*, defaults to 64):
34
+ Number of heads for the evolution matrices of mamba 2.
35
+ head_dim (`int`, *optional*, defaults to 64):
36
+ Dimension of each head.
37
+ vocab_size (`int`, *optional*, defaults to 32768):
38
+ Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Mamba2Model`].
40
+ hidden_size (`int`, *optional*, defaults to 2048):
41
+ Dimensionality of the embeddings and hidden states.
42
+ state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
43
+ num_hidden_layers (`int`, *optional*, defaults to 48):
44
+ Number of hidden layers in the model.
45
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
46
+ The epsilon to use in the layer normalization layers.
47
+ pad_token_id (`int`, *optional*, defaults to 0):
48
+ Padding token id.
49
+ bos_token_id (`int`, *optional*, defaults to 1):
50
+ The id of the beginning of sentence token in the vocabulary.
51
+ eos_token_id (`int`, *optional*, defaults to 2):
52
+ The id of the end of sentence token in the vocabulary.
53
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
54
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
55
+ n_groups (`int`, *optional*, defaults to 1):
56
+ Number of groups for the evolution matrices of mamba 2.
57
+ use_bias (`bool`, *optional*, defaults to `False`):
58
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
59
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
60
+ Whether or not to use bias in the convolution layer of the mixer block.
61
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ initializer_range (`float`, *optional*, defaults to 0.1):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
66
+ Whether or not residuals should be in `float32`.
67
+ If set to `False` residuals will keep the same `dtype` as the rest of the model
68
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
69
+ Rank of the discretization projection matrix.
70
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
71
+ time_step_min (`float`, *optional*, defaults to 0.001):
72
+ Minimum `time_step` used to bound `dt_proj.bias`.
73
+ time_step_max (`float`, *optional*, defaults to 0.1):
74
+ Maximum `time_step` used to bound `dt_proj.bias`.
75
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
76
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
77
+ time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
78
+ Accepted range of time step values.
79
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
80
+ Whether or not to rescale `out_proj` weights when initializing.
81
+ use_cache (`bool`, *optional*, defaults to `True`):
82
+ Whether or not the cache should be used.
83
+ rms_norm (`bool`, *optional*, defaults to `True`):
84
+ Whether to use RMS norm or not.
85
+ chunk_size (`int`, *optional*, defaults to 256):
86
+ Size of the chunks that will comprise the sequence.
87
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
88
+ Whether to tie word embeddings or not.
89
+ """
90
+
91
+ model_type = "mamba2"
92
+
93
+ def __init__(
94
+ self,
95
+ num_heads: int = 64,
96
+ head_dim: int = 64,
97
+ vocab_size: int = 32000,
98
+ hidden_size: int = 2048,
99
+ state_size: int = 128,
100
+ num_hidden_layers: int = 48,
101
+ layer_norm_epsilon: float = 1e-5,
102
+ pad_token_id: int = 0,
103
+ bos_token_id: int = 1,
104
+ eos_token_id: int = 2,
105
+ expand: int = 2,
106
+ conv_kernel: int = 4,
107
+ n_groups: int = 1,
108
+ use_bias: bool = False,
109
+ use_conv_bias: bool = True,
110
+ hidden_act: str = "silu",
111
+ initializer_range: float = 0.1,
112
+ residual_in_fp32: bool = True,
113
+ time_step_rank: str = "auto",
114
+ time_step_min: float = 0.001,
115
+ time_step_max: float = 0.1,
116
+ time_step_floor: float = 1e-4,
117
+ time_step_limit=(0.0, float("inf")),
118
+ rescale_prenorm_residual: bool = True,
119
+ use_cache: bool = True,
120
+ rms_norm: bool = True,
121
+ chunk_size: int = 256,
122
+ fuse_norm: bool = True,
123
+ fuse_cross_entropy: bool = True,
124
+ tie_word_embeddings: bool = False,
125
+ **kwargs,
126
+ ):
127
+ self.vocab_size = vocab_size
128
+ self.hidden_size = hidden_size
129
+ self.state_size = state_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.layer_norm_epsilon = layer_norm_epsilon
132
+ self.conv_kernel = conv_kernel
133
+ self.expand = expand
134
+
135
+ self.bos_token_id = bos_token_id
136
+ self.eos_token_id = eos_token_id
137
+ self.pad_token_id = pad_token_id
138
+ self.use_bias = use_bias
139
+ self.use_conv_bias = use_conv_bias
140
+ self.hidden_act = hidden_act
141
+ self.initializer_range = initializer_range
142
+ self.time_step_rank = (
143
+ math.ceil(self.hidden_size / 16)
144
+ if time_step_rank == "auto"
145
+ else time_step_rank
146
+ )
147
+ self.time_step_min = time_step_min
148
+ self.time_step_max = time_step_max
149
+ self.time_step_floor = time_step_floor
150
+ self.rescale_prenorm_residual = rescale_prenorm_residual
151
+ self.residual_in_fp32 = residual_in_fp32
152
+ self.use_cache = use_cache
153
+ self.n_groups = n_groups
154
+ self.num_heads = num_heads
155
+ self.head_dim = head_dim
156
+ self.rms_norm = rms_norm
157
+ self.state_size = state_size
158
+ self.chunk_size = chunk_size
159
+ self.time_step_limit = time_step_limit
160
+ self.fuse_norm = fuse_norm
161
+ self.fuse_cross_entropy = fuse_cross_entropy
162
+ self.tie_word_embeddings = tie_word_embeddings
163
+
164
+ super().__init__(
165
+ bos_token_id=bos_token_id,
166
+ eos_token_id=eos_token_id,
167
+ pad_token_id=pad_token_id,
168
+ tie_word_embeddings=tie_word_embeddings,
169
+ **kwargs,
170
+ )
fla/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla/models/nsa/__pycache__/configuration_nsa.cpython-311.pyc ADDED
Binary file (2.94 kB). View file
 
fla/models/nsa/__pycache__/modeling_nsa.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/nsa/configuration_nsa.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class NSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'nsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 64,
18
+ num_kv_heads: int = 4,
19
+ head_dim: int = 32,
20
+ qkv_bias: bool = False,
21
+ block_size: int = 64,
22
+ block_counts: Optional[int] = 16,
23
+ window_size: Optional[int] = 512,
24
+ rope_theta: Optional[float] = 10000.,
25
+ max_position_embeddings: int = 2048,
26
+ hidden_ratio: Optional[int] = 4,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = "swish",
29
+ initializer_range: float = 0.006,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_cache: bool = True,
33
+ pad_token_id: int = None,
34
+ bos_token_id: int = 1,
35
+ eos_token_id: int = 2,
36
+ tie_word_embeddings: bool = False,
37
+ fuse_norm: bool = True,
38
+ fuse_swiglu: bool = True,
39
+ fuse_cross_entropy: bool = True,
40
+ vocab_size: int = 32000,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.qkv_bias = qkv_bias
49
+ self.block_size = block_size
50
+ self.block_counts = block_counts
51
+ self.window_size = window_size
52
+ self.rope_theta = rope_theta
53
+ self.max_position_embeddings = max_position_embeddings
54
+
55
+ self.hidden_ratio = hidden_ratio
56
+ self.intermediate_size = intermediate_size
57
+ self.hidden_act = hidden_act
58
+
59
+ self.initializer_range = initializer_range
60
+ self.elementwise_affine = elementwise_affine
61
+ self.norm_eps = norm_eps
62
+ self.use_cache = use_cache
63
+
64
+ self.fuse_norm = fuse_norm
65
+ self.fuse_swiglu = fuse_swiglu
66
+ self.fuse_cross_entropy = fuse_cross_entropy
67
+ self.vocab_size = vocab_size
68
+
69
+ super().__init__(
70
+ pad_token_id=pad_token_id,
71
+ bos_token_id=bos_token_id,
72
+ eos_token_id=eos_token_id,
73
+ tie_word_embeddings=tie_word_embeddings,
74
+ **kwargs,
75
+ )
fla/models/retnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.retnet.configuration_retnet import RetNetConfig
6
+ from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel
7
+
8
+ AutoConfig.register(RetNetConfig.model_type, RetNetConfig)
9
+ AutoModel.register(RetNetConfig, RetNetModel)
10
+ AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM)
11
+
12
+
13
+ __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel']
fla/models/retnet/__pycache__/configuration_retnet.cpython-311.pyc ADDED
Binary file (3.87 kB). View file
 
fla/models/retnet/__pycache__/modeling_retnet.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
fla/models/retnet/configuration_retnet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Optional
6
+
7
+ from transformers.configuration_utils import PretrainedConfig
8
+
9
+
10
+ class RetNetConfig(PretrainedConfig):
11
+
12
+ model_type = 'retnet'
13
+ keys_to_ignore_at_inference = ['past_key_values']
14
+
15
+ def __init__(
16
+ self,
17
+ attn_mode: str = "chunk",
18
+ hidden_size: int = 2048,
19
+ expand_k: int = 1,
20
+ expand_v: int = 2,
21
+ hidden_ratio: Optional[int] = 2,
22
+ intermediate_size: Optional[int] = None,
23
+ num_hidden_layers: int = 24,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: Optional[str] = None,
27
+ hidden_act: str = "swish",
28
+ use_short_conv: bool = False,
29
+ conv_size: int = 4,
30
+ use_output_gate: bool = True,
31
+ max_position_embeddings: int = 2048,
32
+ elementwise_affine: Optional[bool] = True,
33
+ norm_eps: float = 1e-6,
34
+ attn: Optional[Dict] = None,
35
+ use_cache: bool = True,
36
+ pad_token_id: int = None,
37
+ bos_token_id: int = 1,
38
+ eos_token_id: int = 2,
39
+ tie_word_embeddings: bool = False,
40
+ initializer_range: float = 0.006,
41
+ fuse_norm: bool = True,
42
+ fuse_swiglu: bool = True,
43
+ fuse_cross_entropy: bool = True,
44
+ vocab_size: int = 32000,
45
+ **kwargs
46
+ ) -> RetNetConfig:
47
+ self.attn_mode = attn_mode
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.hidden_act = hidden_act
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.hidden_act = hidden_act
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.elementwise_affine = elementwise_affine
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+
69
+ self.fuse_norm = fuse_norm
70
+ self.fuse_swiglu = fuse_swiglu
71
+ self.fuse_cross_entropy = fuse_cross_entropy
72
+ self.vocab_size = vocab_size
73
+
74
+ if attn is not None:
75
+ if not isinstance(attn, Dict):
76
+ raise ValueError("attn must be a dictionary")
77
+ if 'layers' not in attn:
78
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
79
+ if 'num_heads' not in attn:
80
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
81
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
82
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
83
+ attn['window_size'] = attn.get('window_size', None)
84
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs,
92
+ )
fla/models/retnet/modeling_retnet.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.multiscale_retention import MultiScaleRetention
20
+ from fla.models.retnet.configuration_retnet import RetNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as RetNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class RetNetBlock(nn.Module):
33
+ def __init__(self, config: RetNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = MultiScaleRetention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_output_gate=config.use_output_gate,
61
+ gate_fn=config.hidden_act,
62
+ elementwise_affine=config.elementwise_affine,
63
+ norm_eps=config.norm_eps,
64
+ fuse_norm=config.fuse_norm,
65
+ layer_idx=layer_idx
66
+ )
67
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
68
+ self.mlp = RetNetMLP(
69
+ hidden_size=config.hidden_size,
70
+ hidden_ratio=config.hidden_ratio,
71
+ intermediate_size=config.intermediate_size,
72
+ hidden_act=config.hidden_act,
73
+ fuse_swiglu=config.fuse_swiglu
74
+ )
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ attention_mask: Optional[torch.Tensor] = None,
80
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
81
+ use_cache: Optional[bool] = False,
82
+ output_attentions: Optional[bool] = False,
83
+ **kwargs: Unpack[Dict]
84
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
85
+
86
+ residual = hidden_states
87
+
88
+ hidden_states = self.attn_norm(hidden_states)
89
+ hidden_states, attentions, past_key_values = self.attn(
90
+ hidden_states=hidden_states,
91
+ attention_mask=attention_mask,
92
+ past_key_values=past_key_values,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ **kwargs
96
+ )
97
+ if self.config.fuse_norm:
98
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
99
+ else:
100
+ hidden_states = residual + hidden_states
101
+ residual = hidden_states
102
+ hidden_states = self.mlp_norm(hidden_states)
103
+ hidden_states = self.mlp(hidden_states, **kwargs)
104
+ hidden_states = residual + hidden_states
105
+
106
+ outputs = (hidden_states, attentions, past_key_values)
107
+
108
+ return outputs
109
+
110
+
111
+ class RetNetPreTrainedModel(PreTrainedModel):
112
+
113
+ config_class = RetNetConfig
114
+ base_model_prefix = 'model'
115
+ supports_gradient_checkpointing = True
116
+ _no_split_modules = ['RetNetBlock']
117
+ _supports_cache_class = True
118
+
119
+ def __init__(self, *inputs, **kwargs):
120
+ super().__init__(*inputs, **kwargs)
121
+
122
+ def _init_weights(
123
+ self,
124
+ module: nn.Module,
125
+ prenorm_residual_strategy: Optional[str] = 'rescale',
126
+ num_residuals_per_layer: int = 2,
127
+ ):
128
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
129
+ # Slightly different from the TF version which uses truncated_normal for initialization
130
+ # cf https://github.com/pytorch/pytorch/pull/5617
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ if module.bias is not None:
133
+ nn.init.zeros_(module.bias)
134
+ elif isinstance(module, nn.Embedding):
135
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
136
+ elif hasattr(module, 'reset_parameters'):
137
+ module.reset_parameters()
138
+
139
+ if prenorm_residual_strategy is not None:
140
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
141
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
142
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
143
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
144
+ #
145
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
146
+ p = None
147
+ if hasattr(module, 'o_proj'):
148
+ p = module.o_proj.weight
149
+ elif hasattr(module, 'down_proj'):
150
+ p = module.down_proj.weight
151
+ if p is not None:
152
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
153
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
154
+ # We need to reinit p since this code could be called multiple times
155
+ # Having just p *= scale would repeatedly scale it down
156
+ if prenorm_residual_strategy == 'rescale':
157
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
158
+ with torch.no_grad():
159
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
160
+ elif prenorm_residual_strategy == 'zero':
161
+ nn.init.zeros_(p)
162
+ else:
163
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
164
+
165
+
166
+ class RetNetModel(RetNetPreTrainedModel):
167
+
168
+ def __init__(self, config: RetNetConfig):
169
+ super().__init__(config)
170
+ self.padding_idx = config.pad_token_id
171
+ self.vocab_size = config.vocab_size
172
+
173
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
174
+ self.layers = nn.ModuleList(
175
+ [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
176
+ )
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn(
203
+ "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
204
+ )
205
+ output_attentions = False
206
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
207
+ output_hidden_states = (
208
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
209
+ )
210
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
211
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
212
+
213
+ # retrieve input_ids and inputs_embeds
214
+ if input_ids is not None and inputs_embeds is not None:
215
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
216
+ if input_ids is None and inputs_embeds is None:
217
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
218
+
219
+ if inputs_embeds is None:
220
+ inputs_embeds = self.embeddings(input_ids)
221
+ hidden_states = inputs_embeds
222
+
223
+ if use_cache and not isinstance(past_key_values, Cache):
224
+ past_key_values = Cache.from_legacy_cache(past_key_values)
225
+
226
+ if self.gradient_checkpointing and self.training and use_cache:
227
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
228
+ use_cache = False
229
+
230
+ all_hidden_states = () if output_hidden_states else None
231
+ all_attns = () if output_attentions else None
232
+ for layer in self.layers:
233
+ if output_hidden_states:
234
+ all_hidden_states += (hidden_states,)
235
+
236
+ if self.gradient_checkpointing and self.training:
237
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
238
+ layer.__call__,
239
+ hidden_states,
240
+ attention_mask,
241
+ past_key_values,
242
+ use_cache,
243
+ output_attentions,
244
+ **kwargs
245
+ )
246
+ else:
247
+ hidden_states, attentions, past_key_values = layer(
248
+ hidden_states,
249
+ attention_mask=attention_mask,
250
+ past_key_values=past_key_values,
251
+ use_cache=use_cache,
252
+ output_attentions=output_attentions,
253
+ **kwargs
254
+ )
255
+
256
+ if output_attentions:
257
+ all_attns += (attentions,)
258
+
259
+ hidden_states = self.norm(hidden_states)
260
+
261
+ # add hidden states from the last decoder layer
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if not return_dict:
266
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
267
+ return BaseModelOutputWithPast(
268
+ last_hidden_state=hidden_states,
269
+ past_key_values=past_key_values,
270
+ hidden_states=all_hidden_states,
271
+ attentions=all_attns
272
+ )
273
+
274
+
275
+ class RetNetForCausalLM(RetNetPreTrainedModel, GenerationMixin):
276
+
277
+ _tied_weights_keys = ["lm_head.weight"]
278
+
279
+ def __init__(self, config):
280
+ super().__init__(config)
281
+ self.model = RetNetModel(config)
282
+ self.vocab_size = config.vocab_size
283
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
284
+ self.criterion = None
285
+
286
+ # Initialize weights and apply final processing
287
+ self.post_init()
288
+
289
+ def get_input_embeddings(self):
290
+ return self.model.embeddings
291
+
292
+ def set_input_embeddings(self, value):
293
+ self.model.embeddings = value
294
+
295
+ def get_output_embeddings(self):
296
+ return self.lm_head
297
+
298
+ def set_output_embeddings(self, new_embeddings):
299
+ self.lm_head = new_embeddings
300
+
301
+ def set_decoder(self, decoder):
302
+ self.model = decoder
303
+
304
+ def get_decoder(self):
305
+ return self.model
306
+
307
+ def generate(self, *args, **kwargs):
308
+ try:
309
+ return super().generate(*args, **kwargs)
310
+ except AttributeError as exception:
311
+ # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
312
+ if 'past_key_values' in str(exception):
313
+ raise AttributeError(
314
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
315
+ f"which is not supported for {self.__class__.__name__}. "
316
+ f"Try another generation strategy instead. "
317
+ f"For the available generation strategies, check this doc: "
318
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
319
+ )
320
+ else:
321
+ raise exception
322
+
323
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
324
+ def prepare_inputs_for_generation(
325
+ self,
326
+ input_ids: torch.LongTensor = None,
327
+ past_key_values: Optional[torch.Tensor] = None,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ use_cache: Optional[bool] = True,
331
+ logits_to_keep: Optional[int] = None,
332
+ **kwargs: Unpack[Dict]
333
+ ):
334
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
335
+ if past_key_values is not None:
336
+ input_ids = input_ids[:, -1:]
337
+
338
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
339
+ if inputs_embeds is not None and len(past_key_values) == 0:
340
+ model_inputs = {'inputs_embeds': inputs_embeds}
341
+ else:
342
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
343
+ # recompiles graphs as the stride of the inputs is a guard.
344
+ # Ref: https://github.com/huggingface/transformers/pull/29114
345
+ # TODO: use `next_tokens` directly instead.
346
+ model_inputs = {'input_ids': input_ids.contiguous()}
347
+
348
+ if logits_to_keep is not None:
349
+ model_inputs['logits_to_keep'] = logits_to_keep
350
+
351
+ model_inputs.update({
352
+ 'past_key_values': past_key_values,
353
+ 'use_cache': use_cache,
354
+ 'attention_mask': attention_mask,
355
+ })
356
+ return model_inputs
357
+
358
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor = None,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ inputs_embeds: Optional[torch.FloatTensor] = None,
364
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
365
+ labels: Optional[torch.LongTensor] = None,
366
+ use_cache: Optional[bool] = None,
367
+ output_attentions: Optional[bool] = None,
368
+ output_hidden_states: Optional[bool] = None,
369
+ return_dict: Optional[bool] = None,
370
+ logits_to_keep: Optional[int] = 0,
371
+ **kwargs: Unpack[Dict]
372
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
373
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
374
+ output_hidden_states = (
375
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
376
+ )
377
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
378
+
379
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
380
+ outputs = self.model(
381
+ input_ids=input_ids,
382
+ attention_mask=attention_mask,
383
+ inputs_embeds=inputs_embeds,
384
+ past_key_values=past_key_values,
385
+ use_cache=use_cache,
386
+ output_attentions=output_attentions,
387
+ output_hidden_states=output_hidden_states,
388
+ return_dict=return_dict,
389
+ **kwargs
390
+ )
391
+
392
+ hidden_states = outputs[0]
393
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
394
+
395
+ loss, logits = None, None
396
+ if not fuse_linear_and_cross_entropy or labels is None:
397
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
398
+ if labels is not None:
399
+ if getattr(self, 'criterion', None) is None:
400
+ if fuse_linear_and_cross_entropy:
401
+ criterion = FusedLinearCrossEntropyLoss()
402
+ elif self.config.fuse_cross_entropy:
403
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
404
+ else:
405
+ criterion = nn.CrossEntropyLoss()
406
+ else:
407
+ criterion = self.criterion
408
+ labels = labels.to(hidden_states.device)
409
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
410
+ if fuse_linear_and_cross_entropy:
411
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
412
+ else:
413
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
414
+
415
+ if not return_dict:
416
+ output = (logits,) + outputs[1:]
417
+ return (loss,) + output if loss is not None else output
418
+
419
+ return CausalLMOutputWithPast(
420
+ loss=loss,
421
+ logits=logits,
422
+ past_key_values=outputs.past_key_values,
423
+ hidden_states=outputs.hidden_states,
424
+ attentions=outputs.attentions,
425
+ )
fla/models/rwkv6/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (744 Bytes). View file
 
fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-311.pyc ADDED
Binary file (3.72 kB). View file
 
fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-311.pyc ADDED
Binary file (22.2 kB). View file
 
fla/models/rwkv6/configuration_rwkv6.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV6Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv6'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 0.5,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 3.5,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ proj_low_rank_dim: int = 32,
24
+ gate_low_rank_dim: int = 64,
25
+ hidden_act: str = "sqrelu",
26
+ max_position_embeddings: int = 2048,
27
+ norm_first: bool = True,
28
+ norm_bias: bool = True,
29
+ norm_eps: float = 1e-5,
30
+ attn: Optional[Dict] = None,
31
+ use_cache: bool = True,
32
+ pad_token_id: int = None,
33
+ bos_token_id: int = 1,
34
+ eos_token_id: int = 2,
35
+ tie_word_embeddings: bool = False,
36
+ initializer_range: float = 0.006,
37
+ fuse_norm: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.attn_mode = attn_mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.hidden_ratio = hidden_ratio
47
+ self.intermediate_size = intermediate_size
48
+ self.norm_first = norm_first
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_heads = num_heads
51
+ self.proj_low_rank_dim = proj_low_rank_dim
52
+ self.gate_low_rank_dim = gate_low_rank_dim
53
+ self.hidden_act = hidden_act
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.norm_bias = norm_bias
56
+ self.norm_eps = norm_eps
57
+ self.attn = attn
58
+ self.use_cache = use_cache
59
+ self.initializer_range = initializer_range
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_cross_entropy = fuse_cross_entropy
62
+ self.vocab_size = vocab_size
63
+
64
+ if attn is not None:
65
+ if not isinstance(attn, Dict):
66
+ raise ValueError("attn must be a dictionary")
67
+ if 'layers' not in attn:
68
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
69
+ if 'num_heads' not in attn:
70
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
71
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
72
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
73
+ attn['window_size'] = attn.get('window_size', None)
74
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
fla/models/rwkv7/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
6
+ from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model
7
+
8
+ AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True)
9
+ AutoModel.register(RWKV7Config, RWKV7Model, True)
10
+ AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model']
fla/models/rwkv7/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (744 Bytes). View file
 
fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-311.pyc ADDED
Binary file (4.81 kB). View file
 
fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
fla/models/rwkv7/configuration_rwkv7.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV7Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv7'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ hidden_ratio: Optional[int] = 4,
18
+ intermediate_size: Optional[int] = None,
19
+ num_hidden_layers: int = 24,
20
+ head_dim: Optional[int] = 64,
21
+ num_heads: Optional[int] = None,
22
+ decay_low_rank_dim: int = 64,
23
+ gate_low_rank_dim: int = 128,
24
+ a_low_rank_dim: int = 64,
25
+ v_low_rank_dim: int = 16,
26
+ hidden_act: str = "sqrelu",
27
+ max_position_embeddings: int = 2048,
28
+ norm_first: bool = True,
29
+ norm_bias: bool = True,
30
+ norm_eps: float = 1e-5,
31
+ attn: Optional[Dict] = None,
32
+ use_cache: bool = True,
33
+ pad_token_id: int = None,
34
+ bos_token_id: int = 1,
35
+ eos_token_id: int = 2,
36
+ tie_word_embeddings: bool = False,
37
+ initializer_range: float = 0.006,
38
+ fuse_norm: bool = True,
39
+ fuse_cross_entropy: bool = True,
40
+ vocab_size: int = 32000,
41
+ value_dim: Optional[Union[int, List[int]]] = None,
42
+ **kwargs
43
+ ):
44
+ self.attn_mode = attn_mode
45
+ self.hidden_size = hidden_size
46
+ self.hidden_ratio = hidden_ratio
47
+ self.intermediate_size = intermediate_size
48
+ self.norm_first = norm_first
49
+ self.num_hidden_layers = num_hidden_layers
50
+
51
+ if head_dim is None and num_heads is not None:
52
+ head_dim = int(hidden_size // num_heads)
53
+ elif head_dim is not None and num_heads is None:
54
+ num_heads = int(hidden_size // head_dim)
55
+
56
+ if value_dim is None:
57
+ value_dim = [hidden_size] * num_hidden_layers
58
+ elif isinstance(value_dim, int):
59
+ assert value_dim >= hidden_size, "value_dim must be greater than hidden_size"
60
+ assert value_dim % hidden_size == 0, "value_dim must be divisible by hidden_size"
61
+ value_dim = [value_dim] * num_hidden_layers
62
+ else:
63
+ assert len(value_dim) == num_hidden_layers, "value_dim must have the same length as num_hidden_layers"
64
+ for v in value_dim:
65
+ assert v >= hidden_size, "value_dim must be greater than hidden_size"
66
+ assert v % hidden_size == 0, "value_dim must be divisible by hidden_size"
67
+
68
+ self.head_dim = head_dim
69
+ self.num_heads = num_heads
70
+ self.value_dim = value_dim
71
+
72
+ self.decay_low_rank_dim = decay_low_rank_dim
73
+ self.gate_low_rank_dim = gate_low_rank_dim
74
+ self.a_low_rank_dim = a_low_rank_dim
75
+ self.v_low_rank_dim = v_low_rank_dim
76
+ self.hidden_act = hidden_act
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.norm_bias = norm_bias
79
+ self.norm_eps = norm_eps
80
+ self.attn = attn
81
+ self.use_cache = use_cache
82
+ self.initializer_range = initializer_range
83
+ self.fuse_norm = fuse_norm
84
+ self.fuse_cross_entropy = fuse_cross_entropy
85
+ self.vocab_size = vocab_size
86
+
87
+ if attn is not None:
88
+ if not isinstance(attn, Dict):
89
+ raise ValueError("attn must be a dictionary")
90
+ if 'layers' not in attn:
91
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
92
+ if 'num_heads' not in attn:
93
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
94
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
95
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
96
+ attn['window_size'] = attn.get('window_size', None)
97
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
98
+
99
+ super().__init__(
100
+ pad_token_id=pad_token_id,
101
+ bos_token_id=bos_token_id,
102
+ eos_token_id=eos_token_id,
103
+ tie_word_embeddings=tie_word_embeddings,
104
+ **kwargs,
105
+ )
fla/models/rwkv7/modeling_rwkv7.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV7FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV7FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
55
+
56
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
58
+ self.act_fn = ACT2FN[hidden_act]
59
+
60
+ self.layer_idx = layer_idx
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ state: Optional[Cache] = None
67
+ ) -> torch.Tensor:
68
+ if attention_mask is not None:
69
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
70
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
71
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
72
+ else:
73
+ shifted = self.time_shift(x)
74
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
75
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
76
+ if state is not None:
77
+ # no need to update the offset twice
78
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
79
+ return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state
80
+
81
+
82
+ class RWKV7Block(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: RWKV7Config,
87
+ layer_idx: int
88
+ ) -> RWKV7Block:
89
+ super().__init__()
90
+
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+
94
+ if config.norm_first and layer_idx == 0:
95
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
96
+ config.hidden_size,
97
+ bias=config.norm_bias,
98
+ eps=config.norm_eps
99
+ )
100
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
101
+ config.hidden_size,
102
+ bias=config.norm_bias,
103
+ eps=config.norm_eps
104
+ )
105
+ if config.attn is not None and layer_idx in config.attn['layers']:
106
+ self.attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ num_heads=config.attn['num_heads'],
109
+ num_kv_heads=config.attn['num_kv_heads'],
110
+ qkv_bias=config.attn['qkv_bias'],
111
+ window_size=config.attn['window_size'],
112
+ rope_theta=config.attn['rope_theta'],
113
+ max_position_embeddings=config.max_position_embeddings,
114
+ layer_idx=layer_idx
115
+ )
116
+ else:
117
+ self.attn = RWKV7Attention(
118
+ mode=config.attn_mode,
119
+ hidden_size=config.hidden_size,
120
+ head_dim=config.head_dim,
121
+ num_heads=config.num_heads,
122
+ decay_low_rank_dim=config.decay_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ a_low_rank_dim=config.a_low_rank_dim,
125
+ v_low_rank_dim=config.v_low_rank_dim,
126
+ norm_eps=config.norm_eps,
127
+ fuse_norm=config.fuse_norm,
128
+ layer_idx=layer_idx,
129
+ value_dim=config.value_dim[layer_idx]
130
+ )
131
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
132
+ config.hidden_size,
133
+ bias=config.norm_bias,
134
+ eps=config.norm_eps
135
+ )
136
+ self.ffn = RWKV7FeedForward(
137
+ hidden_size=config.hidden_size,
138
+ hidden_ratio=config.hidden_ratio,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ layer_idx=layer_idx
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ past_key_values: Optional[Cache] = None,
149
+ use_cache: Optional[bool] = False,
150
+ output_attentions: Optional[bool] = False,
151
+ v_first: torch.Tensor = None,
152
+ **kwargs,
153
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
154
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
155
+ hidden_states = self.attn_norm(residual)
156
+ hidden_states, attentions, past_key_values, v_first = self.attn(
157
+ hidden_states=hidden_states,
158
+ attention_mask=attention_mask,
159
+ past_key_values=past_key_values,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ v_first=v_first,
163
+ **kwargs
164
+ )
165
+ if self.config.fuse_norm:
166
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
167
+ else:
168
+ hidden_states = residual + hidden_states
169
+ residual = hidden_states
170
+ hidden_states = self.ffn_norm(hidden_states)
171
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
172
+ hidden_states = residual + hidden_states
173
+
174
+ outputs = (hidden_states, attentions, past_key_values, v_first)
175
+
176
+ return outputs
177
+
178
+
179
+ class RWKV7PreTrainedModel(PreTrainedModel):
180
+
181
+ config_class = RWKV7Config
182
+ base_model_prefix = 'model'
183
+ supports_gradient_checkpointing = True
184
+ _no_split_modules = ['RWKV7Block']
185
+ _supports_cache_class = True
186
+ _skip_keys_device_placement = ["past_key_values"]
187
+
188
+ def __init__(self, *inputs, **kwargs):
189
+ super().__init__(*inputs, **kwargs)
190
+
191
+ def _init_weights(
192
+ self,
193
+ module: nn.Module,
194
+ rescale_prenorm_residual: bool = True,
195
+ num_residuals_per_layer: int = 2,
196
+ ):
197
+ warnings.warn(
198
+ "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. "
199
+ "The detailed initialization scheme is currently not implemented here but can be found in the "
200
+ "official code repository. We emphasize that using the recommended initialization is essential "
201
+ "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization "
202
+ "may lead to performance degradation.\n"
203
+ "Alternatively, please generate initial weights from the official RWKV code repository, and "
204
+ "convert the PyTorch checkpoint into FLA supported format."
205
+ )
206
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
207
+ # Slightly different from the TF version which uses truncated_normal for initialization
208
+ # cf https://github.com/pytorch/pytorch/pull/5617
209
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
210
+ if module.bias is not None:
211
+ nn.init.zeros_(module.bias)
212
+ elif isinstance(module, nn.Parameter):
213
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
216
+ elif hasattr(module, 'reset_parameters'):
217
+ module.reset_parameters()
218
+
219
+ if rescale_prenorm_residual:
220
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
221
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
222
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
223
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
224
+ #
225
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
226
+ p = None
227
+ if hasattr(module, 'o_proj'):
228
+ p = module.o_proj.weight
229
+ elif hasattr(module, 'down_proj'):
230
+ p = module.down_proj.weight
231
+ if p is not None:
232
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
233
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
234
+ # We need to reinit p since this code could be called multiple times
235
+ # Having just p *= scale would repeatedly scale it down
236
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
237
+ with torch.no_grad():
238
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
239
+
240
+
241
+ class RWKV7Model(RWKV7PreTrainedModel):
242
+
243
+ def __init__(self, config: RWKV7Config):
244
+ super().__init__(config)
245
+ self.padding_idx = config.pad_token_id
246
+ self.vocab_size = config.vocab_size
247
+
248
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
249
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
250
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
251
+ config.hidden_size,
252
+ bias=config.norm_bias,
253
+ eps=config.norm_eps
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ self.post_init()
259
+
260
+ def get_input_embeddings(self):
261
+ return self.embeddings
262
+
263
+ def set_input_embeddings(self, value):
264
+ self.embeddings = value
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None, # noqa
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ past_key_values: Optional[Cache] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ **kwargs: Unpack[Dict]
277
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
278
+ if output_attentions:
279
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
280
+ output_attentions = False
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
283
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
289
+ if input_ids is None and inputs_embeds is None:
290
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.embeddings(input_ids)
294
+ hidden_states = inputs_embeds
295
+
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ past_key_values = Cache.from_legacy_cache(past_key_values)
298
+
299
+ if self.gradient_checkpointing and self.training and use_cache:
300
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
301
+ use_cache = False
302
+
303
+ all_hidden_states = () if output_hidden_states else None
304
+ all_attns = () if output_attentions else None
305
+
306
+ v_first = torch.zeros_like(hidden_states)
307
+ for layer in self.layers:
308
+ if output_hidden_states:
309
+ all_hidden_states += (hidden_states,)
310
+
311
+ if self.gradient_checkpointing and self.training:
312
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
313
+ layer.__call__,
314
+ hidden_states,
315
+ attention_mask,
316
+ past_key_values,
317
+ use_cache,
318
+ output_attentions,
319
+ v_first,
320
+ **kwargs
321
+ )
322
+ else:
323
+ hidden_states, attentions, past_key_values, v_first = layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ past_key_values=past_key_values,
327
+ use_cache=use_cache,
328
+ output_attentions=output_attentions,
329
+ v_first=v_first,
330
+ **kwargs
331
+ )
332
+
333
+ if output_attentions:
334
+ all_attns += (attentions,)
335
+
336
+ hidden_states = self.norm(hidden_states)
337
+
338
+ # add hidden states from the last decoder layer
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+
342
+ if not return_dict:
343
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
344
+ return BaseModelOutputWithPast(
345
+ last_hidden_state=hidden_states,
346
+ past_key_values=past_key_values,
347
+ hidden_states=all_hidden_states,
348
+ attentions=all_attns
349
+ )
350
+
351
+
352
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
353
+
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = RWKV7Model(config)
359
+ self.vocab_size = config.vocab_size
360
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
+ self.criterion = None
362
+
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def get_input_embeddings(self):
367
+ return self.model.embeddings
368
+
369
+ def set_input_embeddings(self, value):
370
+ self.model.embeddings = value
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head = new_embeddings
377
+
378
+ def set_decoder(self, decoder):
379
+ self.model = decoder
380
+
381
+ def get_decoder(self):
382
+ return self.model
383
+
384
+ def generate(self, *args, **kwargs):
385
+ try:
386
+ return super().generate(*args, **kwargs)
387
+ except AttributeError as exception:
388
+ if 'past_key_values' in str(exception):
389
+ raise AttributeError(
390
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
391
+ f"which is not supported for {self.__class__.__name__}. "
392
+ f"Try another generation strategy instead. "
393
+ f"For the available generation strategies, check this doc: "
394
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
395
+ )
396
+ else:
397
+ raise exception
398
+
399
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
400
+ def prepare_inputs_for_generation(
401
+ self,
402
+ input_ids: torch.LongTensor = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ use_cache: bool = True,
407
+ logits_to_keep: Optional[int] = None,
408
+ **kwargs
409
+ ):
410
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
411
+ if past_key_values is not None and len(past_key_values) > 0:
412
+ input_ids = input_ids[:, -1:]
413
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
414
+ if inputs_embeds is not None and len(past_key_values) == 0:
415
+ model_inputs = {'inputs_embeds': inputs_embeds}
416
+ else:
417
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
418
+ # recompiles graphs as the stride of the inputs is a guard.
419
+ # Ref: https://github.com/huggingface/transformers/pull/29114
420
+ # TODO: use `next_tokens` directly instead.
421
+ model_inputs = {'input_ids': input_ids.contiguous()}
422
+
423
+ if logits_to_keep is not None:
424
+ model_inputs['logits_to_keep'] = logits_to_keep
425
+
426
+ model_inputs.update({
427
+ 'past_key_values': past_key_values,
428
+ 'use_cache': use_cache,
429
+ 'attention_mask': attention_mask,
430
+ })
431
+ return model_inputs
432
+
433
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ inputs_embeds: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[Cache] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ shift_labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ logits_to_keep: Optional[int] = 0,
447
+ **kwargs: Unpack[Dict]
448
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
449
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
+ output_hidden_states = (
451
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
+ )
453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ inputs_embeds=inputs_embeds,
459
+ past_key_values=past_key_values,
460
+ use_cache=use_cache,
461
+ output_attentions=output_attentions,
462
+ output_hidden_states=output_hidden_states,
463
+ return_dict=return_dict,
464
+ **kwargs
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
469
+
470
+ loss, logits = None, None
471
+ has_labels = (labels is not None) or (shift_labels is not None)
472
+ if not (fuse_linear_and_cross_entropy and has_labels):
473
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
474
+ if has_labels:
475
+ if getattr(self, 'criterion', None) is None:
476
+ if fuse_linear_and_cross_entropy:
477
+ criterion = FusedLinearCrossEntropyLoss()
478
+ elif self.config.fuse_cross_entropy:
479
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
480
+ else:
481
+ criterion = nn.CrossEntropyLoss()
482
+ else:
483
+ criterion = self.criterion
484
+
485
+ # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files.
486
+ if shift_labels is None:
487
+ shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
488
+ shift_labels = shift_labels.to(hidden_states.device)
489
+
490
+ if fuse_linear_and_cross_entropy:
491
+ loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias)
492
+ else:
493
+ loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,) + outputs[1:]
497
+ return (loss,) + output if loss is not None else output
498
+
499
+ return CausalLMOutputWithPast(
500
+ loss=loss,
501
+ logits=logits,
502
+ past_key_values=outputs.past_key_values,
503
+ hidden_states=outputs.hidden_states,
504
+ attentions=outputs.attentions,
505
+ )
fla/models/samba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.samba.configuration_samba import SambaConfig
6
+ from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel
7
+
8
+ AutoConfig.register(SambaConfig.model_type, SambaConfig, True)
9
+ AutoModel.register(SambaConfig, SambaModel, True)
10
+ AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock']
fla/models/samba/__pycache__/configuration_samba.cpython-311.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-311.pyc ADDED
Binary file (21.7 kB). View file
 
fla/models/samba/configuration_samba.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ from typing import Dict, Optional
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class SambaConfig(PretrainedConfig):
10
+
11
+ model_type = "samba"
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2304,
16
+ state_size: int = 16,
17
+ num_hidden_layers: int = 18,
18
+ norm_eps=1e-5,
19
+ pad_token_id: int = 0,
20
+ bos_token_id: int = 1,
21
+ eos_token_id: int = 2,
22
+ expand: int = 2,
23
+ conv_kernel: int = 4,
24
+ use_bias: bool = False,
25
+ use_conv_bias: bool = True,
26
+ hidden_act: str = "swish",
27
+ initializer_range: str = 0.02,
28
+ residual_in_fp32: bool = False,
29
+ time_step_rank: str = "auto",
30
+ time_step_scale: float = 1.0,
31
+ time_step_min: float = 0.001,
32
+ time_step_max: float = 0.1,
33
+ time_step_init_scheme: str = "random",
34
+ time_step_floor: float = 1e-4,
35
+ max_position_embeddings: int = 2048,
36
+ attn: Optional[Dict] = {
37
+ 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17),
38
+ 'num_heads': 18,
39
+ 'num_kv_heads': 18,
40
+ 'qkv_bias': False,
41
+ 'window_size': 2048,
42
+ 'rope_theta': 10000.
43
+ },
44
+ hidden_ratio: Optional[int] = 4,
45
+ rescale_prenorm_residual: bool = False,
46
+ use_cache: bool = True,
47
+ fuse_norm: bool = True,
48
+ fuse_swiglu: bool = True,
49
+ fuse_cross_entropy: bool = True,
50
+ vocab_size: int = 32000,
51
+ tie_word_embeddings: bool = False,
52
+ **kwargs,
53
+ ):
54
+ self.hidden_size = hidden_size
55
+ self.state_size = state_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_eps = norm_eps
58
+ self.conv_kernel = conv_kernel
59
+ self.expand = expand
60
+ self.intermediate_size = int(expand * self.hidden_size)
61
+ self.bos_token_id = bos_token_id
62
+ self.eos_token_id = eos_token_id
63
+ self.pad_token_id = pad_token_id
64
+ self.use_bias = use_bias
65
+ self.use_conv_bias = use_conv_bias
66
+ self.hidden_act = hidden_act
67
+ self.initializer_range = initializer_range
68
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
69
+ self.time_step_scale = time_step_scale
70
+ self.time_step_min = time_step_min
71
+ self.time_step_max = time_step_max
72
+ self.time_step_init_scheme = time_step_init_scheme
73
+ self.time_step_floor = time_step_floor
74
+ self.max_position_embeddings = max_position_embeddings
75
+ self.attn = attn
76
+ self.hidden_ratio = hidden_ratio
77
+ self.rescale_prenorm_residual = rescale_prenorm_residual
78
+ self.residual_in_fp32 = residual_in_fp32
79
+ self.use_cache = use_cache
80
+
81
+ self.fuse_norm = fuse_norm
82
+ self.fuse_swiglu = fuse_swiglu
83
+ self.fuse_cross_entropy = fuse_cross_entropy
84
+ self.vocab_size = vocab_size
85
+
86
+ super().__init__(
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ pad_token_id=pad_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs
92
+ )
fla/models/samba/modeling_samba.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import ModelOutput, logging
15
+ from transformers.utils.deprecation import deprecate_kwarg
16
+
17
+ from fla.layers.attn import Attention
18
+ from fla.models.mamba.modeling_mamba import MambaCache, MambaMixer
19
+ from fla.models.samba.configuration_samba import SambaConfig
20
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
21
+ from fla.modules import GatedMLP as SambaMLP
22
+ from fla.modules import RMSNorm
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.processing_utils import Unpack
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SambaBlock(nn.Module):
31
+ def __init__(self, config, layer_idx):
32
+ super().__init__()
33
+
34
+ self.config = config
35
+ self.layer_idx = layer_idx
36
+
37
+ self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
38
+ if config.attn is not None and layer_idx in config.attn['layers']:
39
+ self.mixer = Attention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.attn['num_heads'],
42
+ num_kv_heads=config.attn['num_kv_heads'],
43
+ qkv_bias=config.attn['qkv_bias'],
44
+ window_size=config.attn['window_size'],
45
+ rope_theta=config.attn['rope_theta'],
46
+ max_position_embeddings=config.max_position_embeddings,
47
+ layer_idx=layer_idx
48
+ )
49
+ else:
50
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = SambaMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ hidden_act=config.hidden_act,
56
+ fuse_swiglu=config.fuse_swiglu
57
+ )
58
+
59
+ def forward(
60
+ self,
61
+ hidden_states: torch.Tensor,
62
+ cache_params: Optional[Tuple[torch.Tensor]] = None,
63
+ **kwargs: Unpack[Dict]
64
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
65
+
66
+ residual = hidden_states
67
+ hidden_states = self.mixer_norm(hidden_states)
68
+ if isinstance(self.mixer, MambaMixer):
69
+ hidden_states = self.mixer(hidden_states, cache_params=cache_params, **kwargs)
70
+ else:
71
+ hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params, **kwargs)
72
+ if self.config.fuse_norm:
73
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
74
+ else:
75
+ hidden_states = residual + hidden_states
76
+ residual = hidden_states
77
+ hidden_states = self.mlp_norm(hidden_states)
78
+ hidden_states = self.mlp(hidden_states, **kwargs)
79
+ hidden_states = residual + hidden_states
80
+ return hidden_states
81
+
82
+
83
+ class SambaPreTrainedModel(PreTrainedModel):
84
+ """
85
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
86
+ models.
87
+ """
88
+
89
+ config_class = SambaConfig
90
+ base_model_prefix = "backbone"
91
+ _no_split_modules = ["SambaBlock"]
92
+ supports_gradient_checkpointing = True
93
+
94
+ def _init_weights(self, module):
95
+ """Initialize the weights."""
96
+ if isinstance(module, nn.Linear):
97
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
98
+ if module.bias is not None:
99
+ if not getattr(module.bias, "_no_reinit", False):
100
+ nn.init.zeros_(module.bias)
101
+ elif isinstance(module, MambaMixer):
102
+ module.A_log._no_weight_decay = True
103
+ module.D._no_weight_decay = True
104
+
105
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
106
+ if self.config.time_step_init_scheme == "constant":
107
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
108
+ elif self.config.time_step_init_scheme == "random":
109
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
110
+
111
+ dt = torch.exp(
112
+ torch.rand(self.config.intermediate_size)
113
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
114
+ + math.log(self.config.time_step_min)
115
+ ).clamp(min=self.config.time_step_floor)
116
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
117
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
118
+ with torch.no_grad():
119
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
120
+ module.dt_proj.bias._no_reinit = True
121
+ elif isinstance(module, nn.Embedding):
122
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
123
+ elif hasattr(module, 'reset_parameters'):
124
+ module.reset_parameters()
125
+
126
+ if self.config.rescale_prenorm_residual:
127
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
128
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
129
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
130
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
131
+ #
132
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
133
+ for name, p in module.named_parameters():
134
+ if name in ["out_proj.weight"]:
135
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
136
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
137
+ # We need to reinit p since this code could be called multiple times
138
+ # Having just p *= scale would repeatedly scale it down
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(self.config.num_layers)
142
+
143
+
144
+ @dataclass
145
+ class SambaOutput(ModelOutput):
146
+ """
147
+ Class for the Samba model outputs.
148
+
149
+ Args:
150
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
151
+ Sequence of hidden-states at the output of the last layer of the model.
152
+ cache_params (`MambaCache`):
153
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
154
+ avoid providing the old `input_ids`.
155
+
156
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
157
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
158
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
159
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
160
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
161
+
162
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
163
+ """
164
+
165
+ last_hidden_state: Optional[torch.FloatTensor] = None
166
+ cache_params: Optional[MambaCache] = None
167
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
168
+
169
+
170
+ @dataclass
171
+ class SambaCausalLMOutput(ModelOutput):
172
+ """
173
+ Base class for causal language model (or autoregressive) outputs.
174
+
175
+ Args:
176
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
177
+ Language modeling loss (for next-token prediction).
178
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
179
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
180
+ cache_params (`MambaCache`):
181
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
182
+ avoid providing the old `input_ids`.
183
+
184
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
185
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
186
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
187
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
188
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
189
+
190
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
191
+ """
192
+
193
+ loss: Optional[torch.FloatTensor] = None
194
+ logits: Optional[torch.FloatTensor] = None
195
+ cache_params: Optional[MambaCache] = None
196
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
197
+
198
+
199
+ class SambaModel(SambaPreTrainedModel):
200
+ def __init__(self, config):
201
+ super().__init__(config)
202
+
203
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
204
+ self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
205
+
206
+ self.gradient_checkpointing = False
207
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps)
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embeddings
213
+
214
+ def set_input_embeddings(self, new_embeddings):
215
+ self.embeddings = new_embeddings
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ inputs_embeds: Optional[torch.LongTensor] = None,
221
+ cache_params: Optional[MambaCache] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ return_dict: Optional[bool] = None,
225
+ **kwargs: Unpack[Dict]
226
+ ) -> Union[Tuple, SambaOutput]:
227
+ output_hidden_states = (
228
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
229
+ )
230
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
232
+
233
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
234
+ raise ValueError(
235
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
236
+ )
237
+
238
+ if inputs_embeds is None:
239
+ inputs_embeds = self.embeddings(input_ids)
240
+
241
+ if self.gradient_checkpointing and self.training and use_cache:
242
+ use_cache = False
243
+
244
+ if cache_params is None and use_cache:
245
+ cache_params = MambaCache(
246
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
247
+ )
248
+
249
+ hidden_states = inputs_embeds
250
+ all_hidden_states = () if output_hidden_states else None
251
+ for mixer_block in self.layers:
252
+ if self.gradient_checkpointing and self.training:
253
+ hidden_states = self._gradient_checkpointing_func(
254
+ mixer_block.__call__,
255
+ hidden_states,
256
+ cache_params,
257
+ **kwargs
258
+ )
259
+ else:
260
+ hidden_states = mixer_block(
261
+ hidden_states,
262
+ cache_params=cache_params,
263
+ **kwargs
264
+ )
265
+
266
+ if output_hidden_states:
267
+ all_hidden_states = all_hidden_states + (hidden_states,)
268
+
269
+ if use_cache:
270
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
271
+
272
+ hidden_states = self.norm_f(hidden_states)
273
+
274
+ if output_hidden_states:
275
+ all_hidden_states = all_hidden_states + (hidden_states,)
276
+
277
+ if not return_dict:
278
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
279
+
280
+ return SambaOutput(
281
+ last_hidden_state=hidden_states,
282
+ cache_params=cache_params if use_cache else None,
283
+ hidden_states=all_hidden_states,
284
+ )
285
+
286
+
287
+ class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin):
288
+
289
+ _tied_weights_keys = ["lm_head.weight"]
290
+
291
+ def __init__(self, config):
292
+ super().__init__(config)
293
+ self.backbone = SambaModel(config)
294
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
295
+ self.criterion = None
296
+
297
+ # Initialize weights and apply final processing
298
+ self.post_init()
299
+
300
+ def get_output_embeddings(self):
301
+ return self.lm_head
302
+
303
+ def set_output_embeddings(self, new_embeddings):
304
+ self.lm_head = new_embeddings
305
+
306
+ def get_input_embeddings(self):
307
+ return self.backbone.get_input_embeddings()
308
+
309
+ def set_input_embeddings(self, new_embeddings):
310
+ return self.backbone.set_input_embeddings(new_embeddings)
311
+
312
+ def _update_model_kwargs_for_generation(
313
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
314
+ ) -> Dict[str, Any]:
315
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
316
+ return model_kwargs
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids,
322
+ cache_params:
323
+ Optional[MambaCache] = None,
324
+ inputs_embeds=None,
325
+ attention_mask=None,
326
+ use_cache: Optional[bool] = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs: Unpack[Dict]
329
+ ):
330
+ # only last token for inputs_ids if the state is passed along.
331
+ if cache_params is not None:
332
+ input_ids = input_ids[:, -1].unsqueeze(-1)
333
+
334
+ if inputs_embeds is not None and cache_params is None:
335
+ model_inputs = {"inputs_embeds": inputs_embeds}
336
+ else:
337
+ model_inputs = {"input_ids": input_ids}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'cache_params': cache_params,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ 'logits_to_keep': logits_to_keep,
347
+ })
348
+ return model_inputs
349
+
350
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
351
+ def forward(
352
+ self,
353
+ input_ids: Optional[torch.LongTensor] = None,
354
+ attention_mask: Optional[torch.Tensor] = None, # noqa
355
+ inputs_embeds: Optional[torch.FloatTensor] = None,
356
+ cache_params: Optional[MambaCache] = None,
357
+ labels: Optional[torch.LongTensor] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ use_cache: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, SambaCausalLMOutput]:
364
+ r"""
365
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
366
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
367
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
368
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
369
+ """
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.backbone(
373
+ input_ids,
374
+ cache_params=cache_params,
375
+ inputs_embeds=inputs_embeds,
376
+ output_hidden_states=output_hidden_states,
377
+ return_dict=return_dict,
378
+ use_cache=use_cache,
379
+ **kwargs
380
+ )
381
+ hidden_states = outputs[0]
382
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
383
+
384
+ loss, logits = None, None
385
+ if not fuse_linear_and_cross_entropy or labels is None:
386
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
387
+ if labels is not None:
388
+ if getattr(self, 'criterion', None) is None:
389
+ if fuse_linear_and_cross_entropy:
390
+ criterion = FusedLinearCrossEntropyLoss()
391
+ elif self.config.fuse_cross_entropy:
392
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
393
+ else:
394
+ criterion = nn.CrossEntropyLoss()
395
+ else:
396
+ criterion = self.criterion
397
+ labels = labels.to(hidden_states.device)
398
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
399
+ if fuse_linear_and_cross_entropy:
400
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
401
+ else:
402
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
403
+
404
+ if not return_dict:
405
+ output = (logits,) + outputs[1:]
406
+ return (loss,) + output if loss is not None else output
407
+
408
+ return SambaCausalLMOutput(
409
+ loss=loss,
410
+ logits=logits,
411
+ cache_params=outputs.cache_params,
412
+ hidden_states=outputs.hidden_states,
413
+ )
fla/models/transformer/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer.configuration_transformer import TransformerConfig
6
+ from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel
7
+
8
+ AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
9
+ AutoModel.register(TransformerConfig, TransformerModel)
10
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
fla/models/transformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (785 Bytes). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
fla/models/transformer/configuration_transformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class TransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ use_myopic_loss: bool = False,
40
+ **kwargs,
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_heads = num_heads
45
+ self.num_kv_heads = num_kv_heads
46
+ self.qkv_bias = qkv_bias
47
+ self.qk_norm = qk_norm
48
+ self.window_size = window_size
49
+ self.rope_theta = rope_theta
50
+ self.max_position_embeddings = max_position_embeddings
51
+
52
+ self.hidden_ratio = hidden_ratio
53
+ self.intermediate_size = intermediate_size
54
+ self.hidden_act = hidden_act
55
+
56
+ self.initializer_range = initializer_range
57
+ self.elementwise_affine = elementwise_affine
58
+ self.norm_eps = norm_eps
59
+ self.use_cache = use_cache
60
+
61
+ self.fuse_norm = fuse_norm
62
+ self.fuse_swiglu = fuse_swiglu
63
+ self.fuse_cross_entropy = fuse_cross_entropy
64
+ self.vocab_size = vocab_size
65
+
66
+ self.use_myopic_loss = use_myopic_loss
67
+
68
+ super().__init__(
69
+ pad_token_id=pad_token_id,
70
+ bos_token_id=bos_token_id,
71
+ eos_token_id=eos_token_id,
72
+ tie_word_embeddings=tie_word_embeddings,
73
+ **kwargs,
74
+ )
fla/models/transformer/modeling_transformer.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from dataclasses import dataclass
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.utils.deprecation import deprecate_kwarg
19
+
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from fla.layers.attn import Attention
24
+ from fla.models.transformer.configuration_transformer import TransformerConfig
25
+ from fla.models.utils import Cache
26
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, FusedLinearListNetLoss
27
+ from fla.modules import GatedMLP as TransformerMLP
28
+ from fla.modules import RMSNorm
29
+ from fla.modules.seq_to_myopic import seq_to_myopic
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers.processing_utils import Unpack
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ @dataclass
38
+ class TOPLMOutputWithPast(CausalLMOutputWithPast):
39
+ ntp_loss: Optional[torch.FloatTensor] = None
40
+ top_loss: Optional[torch.FloatTensor] = None
41
+
42
+ class TransformerBlock(nn.Module):
43
+
44
+ def __init__(self, config: TransformerConfig, layer_idx: int):
45
+ super().__init__()
46
+
47
+ self.config = config
48
+ self.layer_idx = layer_idx
49
+
50
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
51
+ self.attn = Attention(
52
+ hidden_size=config.hidden_size,
53
+ num_heads=config.num_heads,
54
+ num_kv_heads=config.num_kv_heads,
55
+ qkv_bias=config.qkv_bias,
56
+ qk_norm=config.qk_norm,
57
+ window_size=config.window_size,
58
+ rope_theta=config.rope_theta,
59
+ max_position_embeddings=config.max_position_embeddings,
60
+ layer_idx=layer_idx
61
+ )
62
+
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = TransformerMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
77
+ output_attentions: Optional[bool] = False,
78
+ use_cache: Optional[bool] = False,
79
+ **kwargs: Unpack[Any]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+
82
+ residual = hidden_states
83
+ hidden_states = self.attn_norm(hidden_states)
84
+ hidden_states, attentions, past_key_values = self.attn(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ past_key_values=past_key_values,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ **kwargs
91
+ )
92
+ if self.config.fuse_norm:
93
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
94
+ else:
95
+ hidden_states = residual + hidden_states
96
+ residual = hidden_states
97
+ hidden_states = self.mlp_norm(hidden_states)
98
+ hidden_states = self.mlp(hidden_states, **kwargs)
99
+ hidden_states = residual + hidden_states
100
+
101
+ outputs = (hidden_states,)
102
+
103
+ if output_attentions:
104
+ outputs += (attentions,)
105
+
106
+ if use_cache:
107
+ outputs += (past_key_values,)
108
+
109
+ return outputs
110
+
111
+
112
+ class TransformerPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = TransformerConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['TransformerBlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ rescale_prenorm_residual: bool = False,
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if rescale_prenorm_residual:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
158
+ with torch.no_grad():
159
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
160
+
161
+
162
+ class TransformerModel(TransformerPreTrainedModel):
163
+
164
+ def __init__(
165
+ self,
166
+ config: TransformerConfig
167
+ ) -> TransformerModel:
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
191
+ inputs_embeds: Optional[torch.FloatTensor] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Any]
197
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn(
200
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
201
+ )
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ elif input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if use_cache and not isinstance(past_key_values, Cache):
215
+ past_key_values = Cache.from_legacy_cache(past_key_values)
216
+
217
+ if inputs_embeds is None:
218
+ inputs_embeds = self.embeddings(input_ids)
219
+
220
+ # embed positions
221
+ hidden_states = inputs_embeds
222
+
223
+ if self.gradient_checkpointing and self.training:
224
+ if use_cache:
225
+ logger.warning_once(
226
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
227
+ )
228
+ use_cache = False
229
+
230
+ all_hidden_states = () if output_hidden_states else None
231
+ all_attns = () if output_attentions else None
232
+ next_cache = None
233
+
234
+ for layer in self.layers:
235
+ if output_hidden_states:
236
+ all_hidden_states += (hidden_states,)
237
+
238
+ if self.gradient_checkpointing and self.training:
239
+ layer_outputs = self._gradient_checkpointing_func(
240
+ layer.__call__,
241
+ hidden_states,
242
+ attention_mask,
243
+ past_key_values,
244
+ output_attentions,
245
+ use_cache,
246
+ **kwargs
247
+ )
248
+ else:
249
+ layer_outputs = layer(
250
+ hidden_states,
251
+ attention_mask=attention_mask,
252
+ past_key_values=past_key_values,
253
+ output_attentions=output_attentions,
254
+ use_cache=use_cache,
255
+ **kwargs
256
+ )
257
+
258
+ hidden_states = layer_outputs[0]
259
+
260
+ if use_cache:
261
+ next_cache = layer_outputs[2 if output_attentions else 1]
262
+
263
+ if output_attentions:
264
+ all_attns += (layer_outputs[1],)
265
+
266
+ hidden_states = self.norm(hidden_states)
267
+
268
+ # add hidden states from the last decoder layer
269
+ if output_hidden_states:
270
+ all_hidden_states += (hidden_states,)
271
+
272
+ if not return_dict:
273
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
274
+
275
+ return BaseModelOutputWithPast(
276
+ last_hidden_state=hidden_states,
277
+ past_key_values=next_cache,
278
+ hidden_states=all_hidden_states,
279
+ attentions=all_attns
280
+ )
281
+
282
+
283
+ class TransformerForCausalLM(TransformerPreTrainedModel, GenerationMixin):
284
+
285
+ _tied_weights_keys = ["lm_head.weight"]
286
+
287
+ def __init__(self, config):
288
+ super().__init__(config)
289
+ self.model = TransformerModel(config)
290
+ self.vocab_size = config.vocab_size
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
292
+ if config.use_myopic_loss:
293
+ self.myopic_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
294
+ self.myopic_criterion = FusedLinearListNetLoss()
295
+ self.criterion = None
296
+ self.pad_token_id = config.pad_token_id
297
+
298
+ # Initialize weights and apply final processing
299
+ self.post_init()
300
+
301
+ def get_input_embeddings(self):
302
+ return self.model.embeddings
303
+
304
+ def set_input_embeddings(self, value):
305
+ self.model.embeddings = value
306
+
307
+ def get_output_embeddings(self):
308
+ return self.lm_head
309
+
310
+ def set_output_embeddings(self, new_embeddings):
311
+ self.lm_head = new_embeddings
312
+
313
+ def set_decoder(self, decoder):
314
+ self.model = decoder
315
+
316
+ def get_decoder(self):
317
+ return self.model
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
359
+ inputs_embeds: Optional[torch.FloatTensor] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Any]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ past_key_values=past_key_values,
378
+ inputs_embeds=inputs_embeds,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
389
+
390
+ loss = None
391
+ ntp_loss = None
392
+ myopic_loss = None
393
+ if labels is not None:
394
+ if getattr(self, 'criterion', None) is None:
395
+ if fuse_linear_and_cross_entropy:
396
+ criterion = FusedLinearCrossEntropyLoss()
397
+ elif self.config.fuse_cross_entropy:
398
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
399
+ else:
400
+ criterion = nn.CrossEntropyLoss()
401
+ else:
402
+ criterion = self.criterion
403
+ # Enable model parallelism
404
+ labels = labels.to(hidden_states.device)
405
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
406
+ ntp_labels = labels[..., :hidden_states.shape[1]].contiguous()
407
+ if fuse_linear_and_cross_entropy:
408
+ ntp_loss = criterion(hidden_states, ntp_labels, self.lm_head.weight, self.lm_head.bias)
409
+ else:
410
+ ntp_loss = criterion(logits.view(ntp_labels.numel(), -1), ntp_labels.reshape(-1))
411
+
412
+ if self.config.use_myopic_loss:
413
+ myopic_labels = seq_to_myopic(labels, self.vocab_size, hidden_states.shape[1], pad_token_id=self.pad_token_id).contiguous()
414
+ myopic_loss = self.myopic_criterion(hidden_states, myopic_labels, self.myopic_head.weight, self.myopic_head.bias)
415
+ # print(f"NTP Loss: {ntp_loss.item()}, Myopic Loss: {myopic_loss.item()}")
416
+ # For debugging, get the index where the myopic label is the highest and print the corresponding logits
417
+ # idx_max = torch.argmax(myopic_labels.view(-1, self.vocab_size), dim=1)
418
+ # # Print the labels and logits at that index
419
+ # print(f"Labels: {myopic_labels.view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}")
420
+ # print(f"Logits: {F.sigmoid(myopic_logits).view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}")
421
+ loss = ntp_loss + myopic_loss
422
+ else:
423
+ loss = ntp_loss
424
+
425
+ if not return_dict:
426
+ output = (logits,) + outputs[1:]
427
+ return (loss,) + output if loss is not None else output
428
+
429
+ return TOPLMOutputWithPast(
430
+ loss=loss,
431
+ ntp_loss=ntp_loss,
432
+ top_loss=myopic_loss,
433
+ logits=logits,
434
+ past_key_values=outputs.past_key_values,
435
+ hidden_states=outputs.hidden_states,
436
+ attentions=outputs.attentions,
437
+ )
fla/models/transformer_mtp/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer_mtp.configuration_transformer import MTPTransformerConfig
6
+ from fla.models.transformer_mtp.modeling_transformer import MTPTransformerForCausalLM, MTPTransformerModel
7
+
8
+ AutoConfig.register(MTPTransformerConfig.model_type, MTPTransformerConfig)
9
+ AutoModel.register(MTPTransformerConfig, MTPTransformerModel)
10
+ AutoModelForCausalLM.register(MTPTransformerConfig, MTPTransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
fla/models/transformer_mtp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (852 Bytes). View file
 
fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-311.pyc ADDED
Binary file (2.98 kB). View file
 
fla/models/transformer_mtp/__pycache__/modeling_transformer.cpython-311.pyc ADDED
Binary file (26 kB). View file
 
fla/models/transformer_mtp/configuration_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class MTPTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'mtp_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ n_future_tokens: int = 1,
40
+ use_custom_backward: Optional[bool] = False,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.qkv_bias = qkv_bias
48
+ self.qk_norm = qk_norm
49
+ self.window_size = window_size
50
+ self.rope_theta = rope_theta
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+
57
+ self.initializer_range = initializer_range
58
+ self.elementwise_affine = elementwise_affine
59
+ self.norm_eps = norm_eps
60
+ self.use_cache = use_cache
61
+
62
+ self.fuse_norm = fuse_norm
63
+ self.fuse_swiglu = fuse_swiglu
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+ self.vocab_size = vocab_size
66
+
67
+ self.n_future_tokens = n_future_tokens
68
+ self.use_custom_backward = use_custom_backward
69
+
70
+ super().__init__(
71
+ pad_token_id=pad_token_id,
72
+ bos_token_id=bos_token_id,
73
+ eos_token_id=eos_token_id,
74
+ tie_word_embeddings=tie_word_embeddings,
75
+ **kwargs,
76
+ )
fla/models/transformer_mtp/modeling_transformer.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from dataclasses import dataclass
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.utils.deprecation import deprecate_kwarg
19
+
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from fla.layers.attn import Attention
24
+ from fla.models.transformer_mtp.configuration_transformer import MTPTransformerConfig
25
+ from fla.models.utils import Cache
26
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
27
+ from fla.modules import GatedMLP as TransformerMLP
28
+ from fla.modules import RMSNorm
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers.processing_utils import Unpack
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ class SequentialHeadsCustomBackward(torch.autograd.Function):
37
+ @staticmethod
38
+ def forward(ctx, trunk_output, lm_head, norm_layer, logits_to_keep, *prediction_heads):
39
+ # We now need the norm layer in the forward pass calculation
40
+ ctx.prediction_heads = prediction_heads
41
+ ctx.lm_head = lm_head
42
+ ctx.norm_layer = norm_layer
43
+ ctx.logits_to_keep = logits_to_keep
44
+ ctx.save_for_backward(trunk_output)
45
+
46
+ latents = []
47
+ for head in prediction_heads:
48
+ # Assuming head forward signature is `head(hidden_states)`
49
+ latent = head(trunk_output)[0]
50
+ latents.append(latent)
51
+
52
+ latents_stacked = torch.stack(latents, dim=-2)
53
+ # Apply the final norm before the lm_head
54
+ normalized_latents = norm_layer(latents_stacked)
55
+ all_logits = lm_head(normalized_latents[:, -logits_to_keep:])
56
+ return all_logits
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ trunk_output, = ctx.saved_tensors
61
+ prediction_heads = ctx.prediction_heads
62
+ lm_head = ctx.lm_head
63
+ norm_layer = ctx.norm_layer
64
+ logits_to_keep = ctx.logits_to_keep
65
+
66
+ d = trunk_output.detach().requires_grad_(True)
67
+ grad_output_per_head = grad_output.unbind(dim=2)
68
+
69
+ # We need to manually handle the backward pass for the final norm layer once
70
+ # before the loop, as its gradient depends on all heads.
71
+ # To do this, we reconstruct the input to the lm_head and do a backward pass.
72
+ with torch.enable_grad():
73
+ # Re-run the head computations to get the input to the norm layer
74
+ latents = []
75
+ for head in prediction_heads:
76
+ latents.append(head(d)[0])
77
+ latents_stacked = torch.stack(latents, dim=-2)
78
+ latents_stacked.requires_grad_(True)
79
+ # The part of the graph we need to backprop through first
80
+ normalized_latents = norm_layer(latents_stacked)
81
+
82
+ # Backpropagate through the lm_head and norm_layer
83
+ normalized_latents.backward(lm_head.weight.grad @ grad_output)
84
+
85
+ # Now, `latents_stacked.grad` contains the sum of gradients from all heads
86
+ # just before the final normalization. We can now unbind it.
87
+ grad_per_head_latent = latents_stacked.grad.unbind(dim=-2)
88
+
89
+ # Now, backpropagate through each head individually.
90
+ for i, head in enumerate(prediction_heads):
91
+ with torch.enable_grad():
92
+ head_latent = head(d)[0]
93
+ # Backpropagate using the gradient for this specific head's output
94
+ head_latent.backward(gradient=grad_per_head_latent[i])
95
+
96
+ num_nones = 2 + len(prediction_heads) # for lm_head, norm_layer, and *prediction_heads
97
+ return (d.grad,) + (None,) * num_nones
98
+
99
+ def seq_to_mtp(
100
+ long_input_ids: torch.Tensor,
101
+ model_seq_len: int,
102
+ n_future_tokens: int
103
+ ) -> torch.Tensor:
104
+ """
105
+ Generates a tensor of future targets on the fly from a long input sequence.
106
+
107
+ This version assumes `long_input_ids` contains both the tokens for the model's
108
+ input AND the future tokens needed for the labels.
109
+ It extracts the correct targets without adding artificial padding.
110
+
111
+ Args:
112
+ long_input_ids (torch.Tensor): The input sequences from the dataloader,
113
+ shape (B, T + n_future_tokens).
114
+ model_seq_len (int): The sequence length `T` that the model processes.
115
+ n_future_tokens (int): The number of future tokens to predict for each time step.
116
+
117
+ Returns:
118
+ torch.Tensor: The target tensor of shape (B, T, n_future_tokens).
119
+ y[b, t, k] corresponds to the (k+1)-th token after input_ids[b, t].
120
+ """
121
+ B, total_len = long_input_ids.shape
122
+ assert total_len >= model_seq_len + n_future_tokens, \
123
+ "long_input_ids must be at least model_seq_len + n_future_tokens long."
124
+
125
+ # 1. Create sliding windows (views) over the long tensor.
126
+ # .unfold() is a highly efficient way to create sliding windows.
127
+ # We create windows of size `n_future_tokens + 1`. For each time step `t`,
128
+ # the window will contain the input token and its `n_future_tokens` targets.
129
+ # Example (n=3, window_size=4):
130
+ # For t=0, window is [t0, t1, t2, t3]
131
+ # For t=1, window is [t1, t2, t3, t4]
132
+ # Shape of windows: (B, total_len - n_future_tokens, n_future_tokens + 1)
133
+ windows = long_input_ids.unfold(dimension=1, size=n_future_tokens + 1, step=1)
134
+
135
+ # 2. Slice the windows to get only the targets.
136
+ # We slice off the first element of each window (the input token itself)
137
+ # to keep only the future tokens.
138
+ # Example window [t0, t1, t2, t3] -> becomes targets [t1, t2, t3]
139
+ all_targets = windows[:, :, 1:]
140
+
141
+ # 3. Trim the result to match the model's output sequence length.
142
+ # We only need the targets for the first `model_seq_len` positions.
143
+ output_targets = all_targets[:, :model_seq_len, :]
144
+
145
+ return output_targets.transpose(1, 2)
146
+
147
+
148
+ @dataclass
149
+ class MTPLMOutputWithPast(CausalLMOutputWithPast):
150
+ ntp_loss: Optional[torch.FloatTensor] = None
151
+ mtp_loss: Optional[torch.FloatTensor] = None
152
+
153
+ class MTPTransformerBlock(nn.Module):
154
+
155
+ def __init__(self, config: MTPTransformerConfig, layer_idx: int):
156
+ super().__init__()
157
+
158
+ self.config = config
159
+ self.layer_idx = layer_idx
160
+
161
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
162
+ self.attn = Attention(
163
+ hidden_size=config.hidden_size,
164
+ num_heads=config.num_heads,
165
+ num_kv_heads=config.num_kv_heads,
166
+ qkv_bias=config.qkv_bias,
167
+ qk_norm=config.qk_norm,
168
+ window_size=config.window_size,
169
+ rope_theta=config.rope_theta,
170
+ max_position_embeddings=config.max_position_embeddings,
171
+ layer_idx=layer_idx
172
+ )
173
+
174
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+ self.mlp = TransformerMLP(
176
+ hidden_size=config.hidden_size,
177
+ hidden_ratio=config.hidden_ratio,
178
+ intermediate_size=config.intermediate_size,
179
+ hidden_act=config.hidden_act,
180
+ fuse_swiglu=config.fuse_swiglu
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ hidden_states: torch.Tensor,
186
+ attention_mask: Optional[torch.Tensor] = None,
187
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
188
+ output_attentions: Optional[bool] = False,
189
+ use_cache: Optional[bool] = False,
190
+ **kwargs: Unpack[Any]
191
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
192
+
193
+ residual = hidden_states
194
+ hidden_states = self.attn_norm(hidden_states)
195
+ hidden_states, attentions, past_key_values = self.attn(
196
+ hidden_states=hidden_states,
197
+ attention_mask=attention_mask,
198
+ past_key_values=past_key_values,
199
+ use_cache=use_cache,
200
+ output_attentions=output_attentions,
201
+ **kwargs
202
+ )
203
+ if self.config.fuse_norm:
204
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
205
+ else:
206
+ hidden_states = residual + hidden_states
207
+ residual = hidden_states
208
+ hidden_states = self.mlp_norm(hidden_states)
209
+ hidden_states = self.mlp(hidden_states, **kwargs)
210
+ hidden_states = residual + hidden_states
211
+
212
+ outputs = (hidden_states,)
213
+
214
+ if output_attentions:
215
+ outputs += (attentions,)
216
+
217
+ if use_cache:
218
+ outputs += (past_key_values,)
219
+
220
+ return outputs
221
+
222
+
223
+ class MTPTransformerPreTrainedModel(PreTrainedModel):
224
+
225
+ config_class = MTPTransformerConfig
226
+ base_model_prefix = 'model'
227
+ supports_gradient_checkpointing = True
228
+ _no_split_modules = ['MTPTransformerBlock']
229
+ _supports_cache_class = True
230
+
231
+ def __init__(self, *inputs, **kwargs):
232
+ super().__init__(*inputs, **kwargs)
233
+
234
+ def _init_weights(
235
+ self,
236
+ module: nn.Module,
237
+ rescale_prenorm_residual: bool = False,
238
+ num_residuals_per_layer: int = 2,
239
+ ):
240
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
241
+ # Slightly different from the TF version which uses truncated_normal for initialization
242
+ # cf https://github.com/pytorch/pytorch/pull/5617
243
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
244
+ if module.bias is not None:
245
+ nn.init.zeros_(module.bias)
246
+ elif isinstance(module, nn.Embedding):
247
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
248
+ elif hasattr(module, 'reset_parameters'):
249
+ module.reset_parameters()
250
+
251
+ if rescale_prenorm_residual:
252
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
253
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
254
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
255
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
256
+ #
257
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
258
+ p = None
259
+ if hasattr(module, 'o_proj'):
260
+ p = module.o_proj.weight
261
+ elif hasattr(module, 'down_proj'):
262
+ p = module.down_proj.weight
263
+ if p is not None:
264
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
265
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
266
+ # We need to reinit p since this code could be called multiple times
267
+ # Having just p *= scale would repeatedly scale it down
268
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
269
+ with torch.no_grad():
270
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
271
+
272
+
273
+ class MTPTransformerModel(MTPTransformerPreTrainedModel):
274
+
275
+ def __init__(
276
+ self,
277
+ config: MTPTransformerConfig
278
+ ) -> MTPTransformerModel:
279
+ super().__init__(config)
280
+ self.padding_idx = config.pad_token_id
281
+ self.vocab_size = config.vocab_size
282
+
283
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
284
+ self.layers = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers - config.n_future_tokens)])
285
+ self.extra_heads = nn.ModuleList([MTPTransformerBlock(config, layer_idx) for layer_idx in range(config.n_future_tokens)])
286
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
287
+
288
+ self.gradient_checkpointing = False
289
+
290
+ self.post_init()
291
+
292
+ def get_input_embeddings(self):
293
+ return self.embeddings
294
+
295
+ def set_input_embeddings(self, value):
296
+ self.embeddings = value
297
+
298
+ def forward(
299
+ self,
300
+ input_ids: Optional[torch.LongTensor] = None,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
303
+ inputs_embeds: Optional[torch.FloatTensor] = None,
304
+ use_cache: Optional[bool] = None,
305
+ output_attentions: Optional[bool] = None,
306
+ output_hidden_states: Optional[bool] = None,
307
+ return_dict: Optional[bool] = None,
308
+ return_all_heads: bool = False, # if Training, this is True
309
+ **kwargs: Unpack[Any]
310
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
311
+ if output_attentions:
312
+ warnings.warn(
313
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
314
+ )
315
+ output_attentions = False
316
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
317
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
318
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+ use_custom_backward = self.config.use_custom_backward and self.training
321
+ if self.training and return_all_heads is False:
322
+ logger.warning_once(
323
+ "`return_all_heads=False` is incompatible with training. Setting `return_all_heads=True`..."
324
+ )
325
+ return_all_heads = True
326
+
327
+ # retrieve input_ids and inputs_embeds
328
+ if input_ids is not None and inputs_embeds is not None:
329
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
330
+ elif input_ids is None and inputs_embeds is None:
331
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
332
+
333
+ if use_cache and not isinstance(past_key_values, Cache):
334
+ past_key_values = Cache.from_legacy_cache(past_key_values)
335
+
336
+ if inputs_embeds is None:
337
+ inputs_embeds = self.embeddings(input_ids)
338
+
339
+ # embed positions
340
+ hidden_states = inputs_embeds
341
+
342
+ if self.gradient_checkpointing and self.training:
343
+ if use_cache:
344
+ logger.warning_once(
345
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
346
+ )
347
+ use_cache = False
348
+
349
+ all_hidden_states = () if output_hidden_states else None
350
+ all_attns = () if output_attentions else None
351
+ next_cache = None
352
+
353
+ for layer in self.layers:
354
+ if output_hidden_states:
355
+ all_hidden_states += (hidden_states,)
356
+
357
+ if self.gradient_checkpointing and self.training:
358
+ layer_outputs = self._gradient_checkpointing_func(
359
+ layer.__call__,
360
+ hidden_states,
361
+ attention_mask,
362
+ past_key_values,
363
+ output_attentions,
364
+ use_cache,
365
+ **kwargs
366
+ )
367
+ else:
368
+ layer_outputs = layer(
369
+ hidden_states,
370
+ attention_mask=attention_mask,
371
+ past_key_values=past_key_values,
372
+ output_attentions=output_attentions,
373
+ use_cache=use_cache,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = layer_outputs[0]
378
+
379
+ if use_cache:
380
+ next_cache = layer_outputs[2 if output_attentions else 1]
381
+
382
+ if output_attentions:
383
+ all_attns += (layer_outputs[1],)
384
+
385
+ trunk = hidden_states
386
+
387
+ n_heads_to_use = self.config.n_future_tokens if return_all_heads else 1
388
+ prediction_heads = self.extra_heads
389
+
390
+ if use_custom_backward and self.training:
391
+ # all_logits = SequentialHeadsCustomBackward.apply(trunk, self.lm_head, *prediction_heads)
392
+ hidden_states = trunk # return hidden states and apply custom backward on the MTPTransformersLM
393
+ else:
394
+ latents = []
395
+ for i, layer in enumerate(prediction_heads):
396
+ if output_hidden_states:
397
+ all_hidden_states += (hidden_states,)
398
+
399
+ if self.gradient_checkpointing and self.training:
400
+ layer_outputs = self._gradient_checkpointing_func(
401
+ layer.__call__,
402
+ trunk, # Use trunk instead of hidden states
403
+ attention_mask,
404
+ past_key_values,
405
+ output_attentions,
406
+ use_cache,
407
+ **kwargs
408
+ )
409
+ else:
410
+ layer_outputs = layer(
411
+ trunk, # Use trunk instead of hidden states
412
+ attention_mask=attention_mask,
413
+ past_key_values=past_key_values,
414
+ output_attentions=output_attentions,
415
+ use_cache=use_cache,
416
+ **kwargs
417
+ )
418
+ hidden_states = layer_outputs[0]
419
+ latents.append(hidden_states)
420
+
421
+ if use_cache:
422
+ next_cache = layer_outputs[2 if output_attentions else 1]
423
+
424
+ if output_attentions:
425
+ all_attns += (layer_outputs[1],)
426
+
427
+ hidden_states = torch.stack(latents, dim=-2) # (B, T, n_heads_to_use, D)
428
+ hidden_states = self.norm(hidden_states)
429
+
430
+ # add hidden states from the last decoder layer
431
+ if output_hidden_states and not self.custom_backward:
432
+ all_hidden_states += (hidden_states,)
433
+
434
+ if not return_dict:
435
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
436
+
437
+ return BaseModelOutputWithPast(
438
+ last_hidden_state=hidden_states,
439
+ past_key_values=next_cache,
440
+ hidden_states=all_hidden_states,
441
+ attentions=all_attns
442
+ )
443
+
444
+
445
+ class MTPTransformerForCausalLM(MTPTransformerPreTrainedModel, GenerationMixin):
446
+
447
+ _tied_weights_keys = ["lm_head.weight"]
448
+
449
+ def __init__(self, config):
450
+ super().__init__(config)
451
+ self.model = MTPTransformerModel(config)
452
+ self.vocab_size = config.vocab_size
453
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
454
+ self.criterion = None
455
+ self.pad_token_id = config.pad_token_id
456
+
457
+ # Initialize weights and apply final processing
458
+ self.post_init()
459
+
460
+ def get_input_embeddings(self):
461
+ return self.model.embeddings
462
+
463
+ def set_input_embeddings(self, value):
464
+ self.model.embeddings = value
465
+
466
+ def get_output_embeddings(self):
467
+ return self.lm_head
468
+
469
+ def set_output_embeddings(self, new_embeddings):
470
+ self.lm_head = new_embeddings
471
+
472
+ def set_decoder(self, decoder):
473
+ self.model = decoder
474
+
475
+ def get_decoder(self):
476
+ return self.model
477
+
478
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
479
+ def prepare_inputs_for_generation(
480
+ self,
481
+ input_ids: torch.LongTensor = None,
482
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
483
+ attention_mask: Optional[torch.Tensor] = None,
484
+ inputs_embeds: Optional[torch.Tensor] = None,
485
+ use_cache: bool = True,
486
+ logits_to_keep: Optional[int] = None,
487
+ **kwargs
488
+ ):
489
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
490
+ if past_key_values is not None and len(past_key_values) > 0:
491
+ input_ids = input_ids[:, -1:]
492
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
493
+ if inputs_embeds is not None and len(past_key_values) == 0:
494
+ model_inputs = {'inputs_embeds': inputs_embeds}
495
+ else:
496
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
497
+ # recompiles graphs as the stride of the inputs is a guard.
498
+ # Ref: https://github.com/huggingface/transformers/pull/29114
499
+ # TODO: use `next_tokens` directly instead.
500
+ model_inputs = {'input_ids': input_ids.contiguous()}
501
+
502
+ if logits_to_keep is not None:
503
+ model_inputs['logits_to_keep'] = logits_to_keep
504
+
505
+ model_inputs.update({
506
+ 'past_key_values': past_key_values,
507
+ 'use_cache': use_cache,
508
+ 'attention_mask': attention_mask,
509
+ })
510
+ return model_inputs
511
+
512
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
513
+ def forward(
514
+ self,
515
+ input_ids: torch.LongTensor = None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
518
+ inputs_embeds: Optional[torch.FloatTensor] = None,
519
+ labels: Optional[torch.LongTensor] = None,
520
+ use_cache: Optional[bool] = None,
521
+ output_attentions: Optional[bool] = None,
522
+ output_hidden_states: Optional[bool] = None,
523
+ return_dict: Optional[bool] = None,
524
+ logits_to_keep: Optional[int] = 0,
525
+ **kwargs: Unpack[Any]
526
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
527
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
528
+ output_hidden_states = (
529
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
530
+ )
531
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
532
+
533
+ outputs = self.model(
534
+ input_ids=input_ids,
535
+ attention_mask=attention_mask,
536
+ past_key_values=past_key_values,
537
+ inputs_embeds=inputs_embeds,
538
+ use_cache=use_cache,
539
+ output_attentions=output_attentions,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ **kwargs
543
+ )
544
+
545
+ hidden_states = outputs[0] # (B, T, n_heads_to_use, D)
546
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
547
+
548
+ use_custom_backward = self.config.use_custom_backward and self.training
549
+ if use_custom_backward and self.training:
550
+ all_logits = SequentialHeadsCustomBackward.apply(
551
+ hidden_states, self.lm_head, self.model.norm, logits_to_keep, *self.model.extra_heads
552
+ )
553
+ else:
554
+ all_logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
555
+
556
+ loss = None
557
+ if labels is not None:
558
+ B, T, n_heads_prediction, D = hidden_states.shape
559
+ loss = torch.zeros(1, device=hidden_states.device)
560
+ ntp_loss = torch.zeros(1, device=hidden_states.device)
561
+ mtp_loss = torch.zeros(1, device=hidden_states.device)
562
+ if getattr(self, 'criterion', None) is None:
563
+ if fuse_linear_and_cross_entropy:
564
+ criterion = FusedLinearCrossEntropyLoss()
565
+ elif self.config.fuse_cross_entropy:
566
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
567
+ else:
568
+ criterion = nn.CrossEntropyLoss()
569
+ else:
570
+ criterion = self.criterion
571
+ # Enable model parallelism
572
+ labels = labels.to(hidden_states.device)
573
+ all_labels = seq_to_mtp(labels, n_future_tokens=n_heads_prediction, model_seq_len=T)
574
+ # Loop across prediction heads
575
+ for i in range(n_heads_prediction):
576
+ # labels in the shape of (B, n_heads_prediction, T)
577
+ labels = all_labels[:, i, :]
578
+ if fuse_linear_and_cross_entropy:
579
+ current_loss = criterion(hidden_states[:, :, i, :], labels.contiguous(), self.lm_head.weight, self.lm_head.bias)
580
+ else:
581
+ logits = all_logits[:, :, i, :]
582
+ current_loss = criterion(logits.view(labels.numel(), -1), labels.reshape(-1))
583
+ if i == 0: # NTP
584
+ ntp_loss = current_loss
585
+ else:
586
+ mtp_loss += current_loss
587
+ loss += current_loss
588
+
589
+ if not return_dict:
590
+ output = (all_logits,) + outputs[1:]
591
+ return (loss,) + output if loss is not None else output
592
+
593
+ return MTPLMOutputWithPast(
594
+ loss=loss,
595
+ ntp_loss=ntp_loss if loss is not None else None,
596
+ mtp_loss=mtp_loss if loss is not None else None,
597
+ logits=all_logits,
598
+ past_key_values=outputs.past_key_values,
599
+ hidden_states=outputs.hidden_states,
600
+ attentions=outputs.attentions,
601
+ )
fla/models/transformer_vanilla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer.configuration_transformer import TransformerConfig
6
+ from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel
7
+
8
+ AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
9
+ AutoModel.register(TransformerConfig, TransformerModel)
10
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
fla/models/transformer_vanilla/configuration_transformer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class MTPTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.num_heads = num_heads
44
+ self.num_kv_heads = num_kv_heads
45
+ self.qkv_bias = qkv_bias
46
+ self.qk_norm = qk_norm
47
+ self.window_size = window_size
48
+ self.rope_theta = rope_theta
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+
55
+ self.initializer_range = initializer_range
56
+ self.elementwise_affine = elementwise_affine
57
+ self.norm_eps = norm_eps
58
+ self.use_cache = use_cache
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ super().__init__(
66
+ pad_token_id=pad_token_id,
67
+ bos_token_id=bos_token_id,
68
+ eos_token_id=eos_token_id,
69
+ tie_word_embeddings=tie_word_embeddings,
70
+ **kwargs,
71
+ )