Erland commited on
Commit
f9582a6
·
verified ·
1 Parent(s): 9eb53fa

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. configs/gsa_340M.json +29 -0
  3. fla/layers/__pycache__/delta_net.cpython-311.pyc +0 -0
  4. fla/layers/__pycache__/gated_deltanet.cpython-311.pyc +0 -0
  5. fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc +0 -0
  6. fla/layers/__pycache__/gla.cpython-311.pyc +0 -0
  7. fla/layers/__pycache__/gsa.cpython-311.pyc +0 -0
  8. fla/layers/__pycache__/hgrn.cpython-311.pyc +0 -0
  9. fla/layers/__pycache__/hgrn2.cpython-311.pyc +0 -0
  10. fla/layers/__pycache__/lightnet.cpython-311.pyc +0 -0
  11. fla/layers/__pycache__/linear_attn.cpython-311.pyc +0 -0
  12. fla/layers/__pycache__/multiscale_retention.cpython-311.pyc +0 -0
  13. fla/layers/__pycache__/nsa.cpython-311.pyc +0 -0
  14. fla/layers/__pycache__/rebased.cpython-311.pyc +0 -0
  15. fla/layers/__pycache__/rwkv6.cpython-311.pyc +0 -0
  16. fla/layers/__pycache__/rwkv7.cpython-311.pyc +0 -0
  17. fla/layers/__pycache__/utils.cpython-311.pyc +0 -0
  18. fla/layers/delta_net.py +291 -0
  19. fla/models/delta_net/__init__.py +12 -0
  20. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-311.pyc +0 -0
  21. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc +0 -0
  22. fla/models/gated_deltanet/__init__.py +12 -0
  23. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-311.pyc +0 -0
  24. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-311.pyc +0 -0
  25. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  26. fla/models/gated_deltaproduct/__init__.py +14 -0
  27. fla/models/gated_deltaproduct/__pycache__/__init__.cpython-311.pyc +0 -0
  28. fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-311.pyc +0 -0
  29. fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py +90 -0
  30. fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
  31. fla/models/gla/__init__.py +13 -0
  32. fla/models/gla/__pycache__/__init__.cpython-311.pyc +0 -0
  33. fla/models/gla/__pycache__/configuration_gla.cpython-311.pyc +0 -0
  34. fla/models/gla/configuration_gla.py +95 -0
  35. fla/models/gla/modeling_gla.py +417 -0
  36. fla/models/gsa/__init__.py +13 -0
  37. fla/models/gsa/__pycache__/__init__.cpython-311.pyc +0 -0
  38. fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc +0 -0
  39. fla/models/gsa/__pycache__/modeling_gsa.cpython-311.pyc +0 -0
  40. fla/models/gsa/configuration_gsa.py +97 -0
  41. fla/models/gsa/modeling_gsa.py +420 -0
  42. fla/models/hgrn/__init__.py +13 -0
  43. fla/models/hgrn/__pycache__/__init__.cpython-311.pyc +0 -0
  44. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-311.pyc +0 -0
  45. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-311.pyc +0 -0
  46. fla/models/hgrn/configuration_hgrn.py +81 -0
  47. fla/models/hgrn/modeling_hgrn.py +420 -0
  48. fla/models/hgrn2/__init__.py +13 -0
  49. fla/models/hgrn2/__pycache__/__init__.cpython-311.pyc +0 -0
  50. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tb/20251107-1856/wandb/run-20251107_185630--transformer.120M.batch48.seqlen2048.context2048.warmup1000.update1.steps15000.lr5e-4.cosine.amd-202511071855/run--transformer.120M.batch48.seqlen2048.context2048.warmup1000.update1.steps15000.lr5e-4.cosine.amd-202511071855.wandb filter=lfs diff=lfs merge=lfs -text
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
fla/layers/__pycache__/delta_net.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
fla/layers/__pycache__/gated_deltanet.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
fla/layers/__pycache__/gla.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
fla/layers/__pycache__/gsa.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
fla/layers/__pycache__/hgrn.cpython-311.pyc ADDED
Binary file (7.25 kB). View file
 
fla/layers/__pycache__/hgrn2.cpython-311.pyc ADDED
Binary file (9.11 kB). View file
 
fla/layers/__pycache__/lightnet.cpython-311.pyc ADDED
Binary file (9.35 kB). View file
 
fla/layers/__pycache__/linear_attn.cpython-311.pyc ADDED
Binary file (7.99 kB). View file
 
fla/layers/__pycache__/multiscale_retention.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
fla/layers/__pycache__/nsa.cpython-311.pyc ADDED
Binary file (6.75 kB). View file
 
fla/layers/__pycache__/rebased.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
fla/layers/__pycache__/rwkv6.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
fla/layers/__pycache__/rwkv7.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
fla/layers/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.92 kB). View file
 
fla/layers/delta_net.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ **kwargs
90
+ ) -> DeltaNet:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.qk_activation = qk_activation
95
+ self.qk_norm = qk_norm
96
+
97
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
98
+ assert self.qk_norm in ['l2', 'sum']
99
+
100
+ if d_model is not None:
101
+ hidden_size = d_model
102
+ self.hidden_size = hidden_size
103
+ self.expand_k = expand_k
104
+ self.expand_v = expand_v
105
+ self.num_heads = num_heads
106
+ self.use_gate = use_gate
107
+ self.use_short_conv = use_short_conv
108
+ self.conv_size = conv_size
109
+ self.conv_bias = conv_bias
110
+ self.allow_neg_eigval = allow_neg_eigval
111
+
112
+ self.key_dim = int(hidden_size * expand_k)
113
+ self.value_dim = int(hidden_size * expand_v)
114
+ self.head_k_dim = self.key_dim // num_heads
115
+ self.head_v_dim = self.value_dim // num_heads
116
+ self.layer_idx = layer_idx
117
+
118
+ self.silu = nn.SiLU()
119
+ if mode == 'fused_chunk':
120
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
121
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
122
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
123
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
124
+
125
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
126
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
127
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ self.use_beta = use_beta
130
+ if self.use_beta:
131
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ if use_short_conv:
133
+ self.conv_size = conv_size
134
+ self.q_conv1d = ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation='silu' if qk_activation == 'silu' else None
138
+ )
139
+ self.k_conv1d = ShortConvolution(
140
+ hidden_size=self.key_dim,
141
+ kernel_size=conv_size,
142
+ activation='silu' if qk_activation == 'silu' else None
143
+ )
144
+ self.v_conv1d = ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation='silu'
148
+ )
149
+ else:
150
+ raise UserWarning(
151
+ "ShortConvolution is crucial to the performance. "
152
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
153
+ )
154
+ if use_gate:
155
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
156
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
157
+ else:
158
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
159
+
160
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Cache] = None,
167
+ use_cache: Optional[bool] = False,
168
+ output_attentions: Optional[bool] = False,
169
+ **kwargs: Unpack[Dict]
170
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
171
+ if attention_mask is not None:
172
+ assert len(attention_mask.shape) == 2, (
173
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
174
+ "for padding purposes (0 indicating padding). "
175
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
176
+ )
177
+
178
+ # change to inference mode.
179
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
180
+
181
+ last_state = None
182
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
183
+ last_state = past_key_values[self.layer_idx]
184
+
185
+ cu_seqlens = kwargs.get('cu_seqlens', None)
186
+ if self.use_short_conv:
187
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
188
+ if last_state is not None:
189
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
190
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
191
+ q, conv_state_q = self.q_conv1d(
192
+ x=self.q_proj(hidden_states),
193
+ mask=conv_mask,
194
+ cache=conv_state_q,
195
+ output_final_state=use_cache,
196
+ cu_seqlens=cu_seqlens
197
+ )
198
+ k, conv_state_k = self.k_conv1d(
199
+ x=self.k_proj(hidden_states),
200
+ mask=conv_mask,
201
+ cache=conv_state_k,
202
+ output_final_state=use_cache,
203
+ cu_seqlens=cu_seqlens
204
+ )
205
+ v, conv_state_v = self.v_conv1d(
206
+ x=self.v_proj(hidden_states),
207
+ mask=conv_mask,
208
+ cache=conv_state_v,
209
+ output_final_state=use_cache,
210
+ cu_seqlens=cu_seqlens
211
+ )
212
+ else:
213
+ q = self.q_proj(hidden_states)
214
+ k = self.k_proj(hidden_states)
215
+ if self.qk_activation == 'silu':
216
+ q, k = self.silu(q), self.silu(k)
217
+ v = self.silu(self.v_proj(hidden_states))
218
+
219
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
220
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
221
+ if self.qk_activation != 'silu':
222
+ if self.qk_activation == 'relu':
223
+ q, k = q.relu(), k.relu()
224
+ elif self.qk_activation == 'elu':
225
+ q, k = elu_p1(q), elu_p1(k)
226
+ elif self.qk_activation == 'identity':
227
+ pass
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ if self.qk_norm == 'sum':
232
+ q = sum_norm(q).to(q)
233
+ k = sum_norm(k).to(k)
234
+
235
+ if self.use_beta:
236
+ beta = self.b_proj(hidden_states).sigmoid()
237
+ else:
238
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
239
+
240
+ if self.allow_neg_eigval:
241
+ beta = beta * 2.
242
+
243
+ # dealing with padding
244
+ if attention_mask is not None:
245
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
246
+
247
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
248
+ if mode == 'fused_recurrent':
249
+ o, recurrent_state = fused_recurrent_delta_rule(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ beta=beta,
254
+ initial_state=recurrent_state,
255
+ output_final_state=use_cache,
256
+ cu_seqlens=cu_seqlens,
257
+ head_first=False,
258
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
259
+ )
260
+ elif mode == 'chunk':
261
+ o, recurrent_state = chunk_delta_rule(
262
+ q=q,
263
+ k=k,
264
+ v=v,
265
+ beta=beta,
266
+ initial_state=recurrent_state,
267
+ output_final_state=use_cache,
268
+ cu_seqlens=cu_seqlens,
269
+ head_first=False,
270
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
271
+ )
272
+ else:
273
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
274
+
275
+ if past_key_values is not None:
276
+ past_key_values.update(
277
+ recurrent_state=recurrent_state,
278
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
279
+ layer_idx=self.layer_idx,
280
+ offset=q.shape[1]
281
+ )
282
+
283
+ if self.use_gate:
284
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
285
+ o = self.o_norm(o, g)
286
+ else:
287
+ o = self.o_norm(o)
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
fla/models/delta_net/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
6
+ from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel
7
+
8
+ AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
9
+ AutoModel.register(DeltaNetConfig, DeltaNetModel)
10
+ AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
11
+
12
+ __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-311.pyc ADDED
Binary file (3.75 kB). View file
 
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
+
3
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel
5
+
6
+ AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig)
7
+ AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel)
8
+ AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM)
9
+
10
+ __all__ = [
11
+ "GatedDeltaProductConfig",
12
+ "GatedDeltaProductForCausalLM",
13
+ "GatedDeltaProductModel",
14
+ ]
fla/models/gated_deltaproduct/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (860 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaProductConfig(PretrainedConfig):
9
+ model_type = "gated_deltaproduct"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_first: bool = False,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int | None = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ use_forget_gate: bool = False, # when true Gated DeltaProduct, when false DeltaProduct
39
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
40
+ num_householder: int = 1,
41
+ **kwargs,
42
+ ):
43
+ self.attn_mode = attn_mode
44
+ self.hidden_size = hidden_size
45
+ self.expand_v = expand_v
46
+ self.use_gate = use_gate
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.head_dim = head_dim
50
+ self.num_heads = num_heads
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_first = norm_first
58
+ self.norm_eps = norm_eps
59
+ self.attn = attn
60
+ self.use_cache = use_cache
61
+ self.initializer_range = initializer_range
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ # DeltaProduct specific
66
+ self.allow_neg_eigval = allow_neg_eigval
67
+ self.num_householder = num_householder
68
+ self.use_forget_gate = use_forget_gate
69
+
70
+ if attn is not None:
71
+ if not isinstance(attn, Dict):
72
+ raise ValueError("attn must be a dictionary")
73
+ if "layers" not in attn:
74
+ raise ValueError(
75
+ "Layer indices must be provided to initialize hybrid attention layers"
76
+ )
77
+ if "num_heads" not in attn:
78
+ raise ValueError(
79
+ "Number of heads must be provided to initialize hybrid attention layers"
80
+ )
81
+ attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
82
+ attn["window_size"] = attn.get("window_size", None)
83
+
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gated_deltaproduct import GatedDeltaProduct
21
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
24
+ from fla.modules.activations import swiglu_linear
25
+ from fla.modules.layernorm import rms_norm_linear
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers.processing_utils import Unpack
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetMLP(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ hidden_ratio: Optional[int] = None,
38
+ intermediate_size: Optional[int] = None,
39
+ hidden_act: str = "swish",
40
+ norm_first: bool = True,
41
+ norm_eps: float = 1e-5,
42
+ ) -> GatedDeltaNetMLP:
43
+ super().__init__()
44
+
45
+ self.hidden_size = hidden_size
46
+ # the final number of params is `hidden_ratio * hidden_size^2`
47
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
48
+ if hidden_ratio is None:
49
+ hidden_ratio = 4
50
+ if intermediate_size is None:
51
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
52
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.norm_first = norm_first
56
+
57
+ if norm_first:
58
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
59
+
60
+ self.gate_proj = nn.Linear(
61
+ self.hidden_size, self.intermediate_size * 2, bias=False
62
+ )
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[hidden_act]
65
+
66
+ def forward(
67
+ self,
68
+ x: torch.Tensor,
69
+ **kwargs: Unpack[Dict],
70
+ ) -> torch.Tensor:
71
+ if self.norm_first:
72
+ x = rms_norm_linear(
73
+ x,
74
+ self.norm.weight,
75
+ self.norm.bias,
76
+ self.gate_proj.weight,
77
+ self.gate_proj.bias,
78
+ )
79
+ else:
80
+ x = self.gate_proj(x)
81
+ gate, y = x.chunk(2, -1)
82
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
83
+
84
+
85
+ class GatedDeltaProductBlock(nn.Module):
86
+ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
87
+ super().__init__()
88
+ self.hidden_size = config.hidden_size
89
+
90
+ if not config.norm_first:
91
+ self.attn_norm = RMSNorm(
92
+ hidden_size=config.hidden_size, eps=config.norm_eps
93
+ )
94
+ if config.attn is not None and layer_idx in config.attn["layers"]:
95
+ self.attn = Attention(
96
+ hidden_size=config.hidden_size,
97
+ num_heads=config.attn["num_heads"],
98
+ num_kv_heads=config.attn["num_kv_heads"],
99
+ window_size=config.attn["window_size"],
100
+ max_position_embeddings=config.max_position_embeddings,
101
+ layer_idx=layer_idx,
102
+ )
103
+ else:
104
+ self.attn = GatedDeltaProduct(
105
+ mode=config.attn_mode,
106
+ hidden_size=config.hidden_size,
107
+ expand_v=config.expand_v,
108
+ head_dim=config.head_dim,
109
+ num_heads=config.num_heads,
110
+ use_gate=config.use_gate,
111
+ use_forget_gate=config.use_forget_gate,
112
+ use_short_conv=config.use_short_conv,
113
+ conv_size=config.conv_size,
114
+ norm_first=config.norm_first,
115
+ norm_eps=config.norm_eps,
116
+ allow_neg_eigval=config.allow_neg_eigval,
117
+ num_householder=config.num_householder,
118
+ layer_idx=layer_idx,
119
+ use_beta_conv=config.use_beta_conv
120
+ )
121
+ if not config.norm_first:
122
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
123
+ self.mlp = GatedDeltaNetMLP(
124
+ hidden_size=config.hidden_size,
125
+ hidden_ratio=config.hidden_ratio,
126
+ intermediate_size=config.intermediate_size,
127
+ hidden_act=config.hidden_act,
128
+ norm_first=config.norm_first,
129
+ norm_eps=config.norm_eps,
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
137
+ use_cache: Optional[bool] = False,
138
+ output_attentions: Optional[bool] = False,
139
+ **kwargs: Unpack[Dict],
140
+ ) -> Tuple[
141
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
142
+ ]:
143
+ residual = hidden_states
144
+ if hasattr(self, "attn_norm"):
145
+ hidden_states = self.attn_norm(hidden_states)
146
+ hidden_states, attentions, past_key_values = self.attn(
147
+ hidden_states=hidden_states,
148
+ attention_mask=attention_mask,
149
+ past_key_values=past_key_values,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ **kwargs,
153
+ )
154
+ if hasattr(self, "mlp_norm"):
155
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
156
+ else:
157
+ hidden_states = residual + hidden_states
158
+ residual = hidden_states
159
+ hidden_states = self.mlp(hidden_states, **kwargs)
160
+ hidden_states = residual + hidden_states
161
+
162
+ outputs = (hidden_states, attentions, past_key_values)
163
+
164
+ return outputs
165
+
166
+
167
+ class GatedDeltaProductPreTrainedModel(PreTrainedModel):
168
+ config_class = GatedDeltaProductConfig
169
+ supports_gradient_checkpointing = True
170
+ _no_split_modules = ["GatedDeltaNetBlock"]
171
+
172
+ def __init__(self, *inputs, **kwargs):
173
+ super().__init__(*inputs, **kwargs)
174
+
175
+ def _init_weights(
176
+ self,
177
+ module: nn.Module,
178
+ rescale_prenorm_residual: bool = True,
179
+ num_residuals_per_layer: int = 2,
180
+ ):
181
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
182
+ # Slightly different from the TF version which uses truncated_normal for initialization
183
+ # cf https://github.com/pytorch/pytorch/pull/5617
184
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
185
+ if module.bias is not None:
186
+ nn.init.zeros_(module.bias)
187
+ elif isinstance(module, nn.Embedding):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
189
+ if module.padding_idx is not None:
190
+ module.weight.data[module.padding_idx].zero_()
191
+
192
+ if rescale_prenorm_residual:
193
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
194
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
195
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
196
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
197
+ #
198
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
199
+ for name, p in module.named_parameters():
200
+ if name in ["o_proj.weight", "down_proj.weight"]:
201
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
202
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
203
+ # We need to reinit p since this code could be called multiple times
204
+ # Having just p *= scale would repeatedly scale it down
205
+ with torch.no_grad():
206
+ p /= math.sqrt(
207
+ num_residuals_per_layer * self.config.num_hidden_layers
208
+ )
209
+
210
+
211
+ class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
212
+ def __init__(self, config: GatedDeltaProductConfig):
213
+ super().__init__(config)
214
+ self.padding_idx = config.pad_token_id
215
+ self.vocab_size = config.vocab_size
216
+
217
+ self.embeddings = nn.Embedding(
218
+ config.vocab_size, config.hidden_size, self.padding_idx
219
+ )
220
+ self.layers = nn.ModuleList(
221
+ [
222
+ GatedDeltaProductBlock(config, layer_idx)
223
+ for layer_idx in range(config.num_hidden_layers)
224
+ ]
225
+ )
226
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ self.post_init()
231
+
232
+ def get_input_embeddings(self):
233
+ return self.embeddings
234
+
235
+ def set_input_embeddings(self, value):
236
+ self.embeddings = value
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: Optional[torch.LongTensor] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ inputs_embeds: Optional[torch.FloatTensor] = None,
243
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
244
+ use_cache: Optional[bool] = None,
245
+ output_attentions: Optional[bool] = None,
246
+ output_hidden_states: Optional[bool] = None,
247
+ return_dict: Optional[bool] = None,
248
+ **kwargs: Unpack[Dict],
249
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
250
+ if output_attentions:
251
+ warnings.warn(
252
+ "`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
253
+ stacklevel=2,
254
+ )
255
+ output_attentions = False
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = (
267
+ use_cache
268
+ if use_cache is not None
269
+ else (self.config.use_cache if not self.training else False)
270
+ )
271
+ return_dict = (
272
+ return_dict if return_dict is not None else self.config.use_return_dict
273
+ )
274
+
275
+ # retrieve input_ids and inputs_embeds
276
+ if input_ids is not None and inputs_embeds is not None:
277
+ raise ValueError(
278
+ "You cannot specify both input_ids and inputs_embeds at the same time"
279
+ )
280
+ if input_ids is None and inputs_embeds is None:
281
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.embeddings(input_ids)
285
+ hidden_states = inputs_embeds
286
+
287
+ if use_cache and not isinstance(past_key_values, Cache):
288
+ past_key_values = Cache.from_legacy_cache(past_key_values)
289
+
290
+ if self.gradient_checkpointing and self.training and use_cache:
291
+ logger.warning_once(
292
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
293
+ )
294
+ use_cache = False
295
+
296
+ all_hidden_states = () if output_hidden_states else None
297
+ all_attns = () if output_attentions else None
298
+ for layer in self.layers:
299
+ if output_hidden_states:
300
+ all_hidden_states += (hidden_states,)
301
+
302
+ if self.gradient_checkpointing and self.training:
303
+ hidden_states, attentions, past_key_values = (
304
+ self._gradient_checkpointing_func(
305
+ layer.__call__,
306
+ hidden_states,
307
+ attention_mask,
308
+ past_key_values,
309
+ use_cache,
310
+ output_attentions,
311
+ **kwargs,
312
+ )
313
+ )
314
+ else:
315
+ hidden_states, attentions, past_key_values = layer(
316
+ hidden_states,
317
+ attention_mask=attention_mask,
318
+ past_key_values=past_key_values,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ **kwargs,
322
+ )
323
+
324
+ if output_attentions:
325
+ all_attns += (attentions,)
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ # add hidden states from the last decoder layer
329
+ if output_hidden_states:
330
+ all_hidden_states += (hidden_states,)
331
+
332
+ if not return_dict:
333
+ return tuple(
334
+ i
335
+ for i in [
336
+ hidden_states,
337
+ past_key_values,
338
+ all_hidden_states,
339
+ all_attns,
340
+ ]
341
+ if i is not None
342
+ )
343
+ return BaseModelOutputWithPast(
344
+ last_hidden_state=hidden_states,
345
+ past_key_values=past_key_values,
346
+ hidden_states=all_hidden_states,
347
+ attentions=all_attns,
348
+ )
349
+
350
+
351
+ class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
352
+ _tied_weights_keys = ["lm_head.weight"]
353
+
354
+ def __init__(self, config):
355
+ super().__init__(config)
356
+ self.model = GatedDeltaProductModel(config)
357
+ self.vocab_size = config.vocab_size
358
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
359
+
360
+ # Initialize weights and apply final processing
361
+ self.post_init()
362
+
363
+ def get_input_embeddings(self):
364
+ return self.model.embeddings
365
+
366
+ def set_input_embeddings(self, value):
367
+ self.model.embeddings = value
368
+
369
+ def get_output_embeddings(self):
370
+ return self.lm_head
371
+
372
+ def set_output_embeddings(self, new_embeddings):
373
+ self.lm_head = new_embeddings
374
+
375
+ def set_decoder(self, decoder):
376
+ self.model = decoder
377
+
378
+ def get_decoder(self):
379
+ return self.model
380
+
381
+ def generate(self, *args, **kwargs):
382
+ try:
383
+ return super().generate(*args, **kwargs)
384
+ except AttributeError as exception:
385
+ if "past_key_values" in str(exception):
386
+ raise AttributeError(
387
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
388
+ f"which is not supported for {self.__class__.__name__}. "
389
+ f"Try another generation strategy instead. "
390
+ f"For the available generation strategies, check this doc: "
391
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
392
+ )
393
+ else:
394
+ raise exception
395
+
396
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
397
+ def prepare_inputs_for_generation(
398
+ self,
399
+ input_ids: torch.LongTensor = None,
400
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ use_cache: bool = True,
404
+ num_logits_to_keep: Optional[int] = None,
405
+ logits_to_keep: Optional[int] = None,
406
+ **kwargs,
407
+ ):
408
+ # only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
409
+ if past_key_values is not None and len(past_key_values) > 0:
410
+ input_ids = input_ids[:, -1:]
411
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
412
+ if inputs_embeds is not None and past_key_values is None:
413
+ model_inputs = {"inputs_embeds": inputs_embeds}
414
+ else:
415
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
416
+ # recompiles graphs as the stride of the inputs is a guard.
417
+ # Ref: https://github.com/huggingface/transformers/pull/29114
418
+ # TODO: use `next_tokens` directly instead.
419
+ model_inputs = {"input_ids": input_ids.contiguous()}
420
+
421
+ if logits_to_keep is not None:
422
+ model_inputs['logits_to_keep'] = logits_to_keep
423
+
424
+ model_inputs.update(
425
+ {
426
+ "past_key_values": past_key_values,
427
+ "use_cache": use_cache,
428
+ "attention_mask": attention_mask,
429
+ "num_logits_to_keep": num_logits_to_keep,
430
+ }
431
+ )
432
+ return model_inputs
433
+
434
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
435
+ def forward(
436
+ self,
437
+ input_ids: torch.LongTensor = None,
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ inputs_embeds: Optional[torch.Tensor] = None,
440
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ num_logits_to_keep: Optional[int] = 0,
447
+ logits_to_keep: Optional[int] = 0,
448
+ **kwargs: Unpack[Dict],
449
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
450
+ num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
451
+ output_attentions = (
452
+ output_attentions
453
+ if output_attentions is not None
454
+ else self.config.output_attentions
455
+ )
456
+ output_hidden_states = (
457
+ output_hidden_states
458
+ if output_hidden_states is not None
459
+ else self.config.output_hidden_states
460
+ )
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ kwargs.pop("num_items_in_batch", None)
465
+ outputs = self.model(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ inputs_embeds=inputs_embeds,
469
+ past_key_values=past_key_values,
470
+ use_cache=use_cache,
471
+ output_attentions=output_attentions,
472
+ output_hidden_states=output_hidden_states,
473
+ return_dict=return_dict,
474
+ **kwargs,
475
+ )
476
+ hidden_states = outputs[0]
477
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
478
+
479
+ loss, logits = None, None
480
+ if not fuse_linear_and_cross_entropy or labels is None:
481
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
482
+ if labels is not None:
483
+ if self.config.fuse_cross_entropy:
484
+ if fuse_linear_and_cross_entropy:
485
+ loss_fct = FusedLinearCrossEntropyLoss()
486
+ else:
487
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
488
+ else:
489
+ loss_fct = nn.CrossEntropyLoss()
490
+ # Enable model parallelism
491
+ labels = labels.to(hidden_states.device)
492
+ labels = torch.cat(
493
+ (
494
+ labels[..., 1:],
495
+ torch.full_like(labels[:, :1], loss_fct.ignore_index),
496
+ ),
497
+ 1,
498
+ )
499
+ if fuse_linear_and_cross_entropy:
500
+ loss = loss_fct(
501
+ hidden_states.view(-1, self.config.hidden_size),
502
+ labels.view(-1),
503
+ self.lm_head.weight,
504
+ self.lm_head.bias,
505
+ )
506
+ else:
507
+ loss = loss_fct(
508
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
509
+ )
510
+
511
+ if not return_dict:
512
+ output = (logits,) + outputs[1:]
513
+ return (loss, *output) if loss is not None else output
514
+ return CausalLMOutputWithPast(
515
+ loss=loss,
516
+ logits=logits,
517
+ past_key_values=outputs.past_key_values,
518
+ hidden_states=outputs.hidden_states,
519
+ attentions=outputs.attentions,
520
+ )
fla/models/gla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gla.configuration_gla import GLAConfig
6
+ from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
7
+
8
+ AutoConfig.register(GLAConfig.model_type, GLAConfig)
9
+ AutoModel.register(GLAConfig, GLAModel)
10
+ AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11
+
12
+
13
+ __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
fla/models/gla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (736 Bytes). View file
 
fla/models/gla/__pycache__/configuration_gla.cpython-311.pyc ADDED
Binary file (4.16 kB). View file
 
fla/models/gla/configuration_gla.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GLAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gla'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ expand_k: int = 0.5,
17
+ expand_v: int = 1,
18
+ hidden_ratio: Optional[int] = 4,
19
+ intermediate_size: Optional[int] = None,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 4,
22
+ num_kv_heads: Optional[int] = None,
23
+ feature_map: Optional[str] = None,
24
+ attn_mode: str = "chunk",
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ use_output_gate: bool = True,
28
+ clamp_min: Optional[float] = None,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ use_gk: bool = True,
34
+ use_gv: bool = False,
35
+ attn: Optional[Dict] = None,
36
+ use_cache: bool = True,
37
+ pad_token_id: int = None,
38
+ bos_token_id: int = 1,
39
+ eos_token_id: int = 2,
40
+ tie_word_embeddings: bool = False,
41
+ initializer_range: float = 0.006,
42
+ fuse_norm: bool = True,
43
+ fuse_swiglu: bool = True,
44
+ fuse_cross_entropy: bool = True,
45
+ vocab_size: int = 32000,
46
+ **kwargs
47
+ ):
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.attn_mode = attn_mode
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.clamp_min = clamp_min
62
+ self.hidden_act = hidden_act
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.elementwise_affine = elementwise_affine
65
+ self.norm_eps = norm_eps
66
+ self.use_gk = use_gk
67
+ self.use_gv = use_gv
68
+ self.attn = attn
69
+ self.use_cache = use_cache
70
+ self.initializer_range = initializer_range
71
+
72
+ self.fuse_norm = fuse_norm
73
+ self.fuse_swiglu = fuse_swiglu
74
+ self.fuse_cross_entropy = fuse_cross_entropy
75
+ self.vocab_size = vocab_size
76
+
77
+ if attn is not None:
78
+ if not isinstance(attn, Dict):
79
+ raise ValueError("attn must be a dictionary")
80
+ if 'layers' not in attn:
81
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
82
+ if 'num_heads' not in attn:
83
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
84
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
85
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
86
+ attn['window_size'] = attn.get('window_size', None)
87
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
88
+
89
+ super().__init__(
90
+ pad_token_id=pad_token_id,
91
+ bos_token_id=bos_token_id,
92
+ eos_token_id=eos_token_id,
93
+ tie_word_embeddings=tie_word_embeddings,
94
+ **kwargs,
95
+ )
fla/models/gla/modeling_gla.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gla import GatedLinearAttention
20
+ from fla.models.gla.configuration_gla import GLAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GLAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GLABlock(nn.Module):
33
+ def __init__(self, config: GLAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedLinearAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ use_output_gate=config.use_output_gate,
63
+ gate_fn=config.hidden_act,
64
+ elementwise_affine=config.elementwise_affine,
65
+ norm_eps=config.norm_eps,
66
+ clamp_min=config.clamp_min,
67
+ fuse_norm=config.fuse_norm,
68
+ layer_idx=layer_idx
69
+ )
70
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
71
+ self.mlp = GLAMLP(
72
+ hidden_size=config.hidden_size,
73
+ hidden_ratio=config.hidden_ratio,
74
+ intermediate_size=config.intermediate_size,
75
+ hidden_act=config.hidden_act,
76
+ fuse_swiglu=config.fuse_swiglu
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
84
+ use_cache: Optional[bool] = False,
85
+ output_attentions: Optional[bool] = False,
86
+ **kwargs: Unpack[Dict]
87
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
88
+ residual = hidden_states
89
+ hidden_states = self.attn_norm(hidden_states)
90
+ hidden_states, attentions, past_key_values = self.attn(
91
+ hidden_states=hidden_states,
92
+ attention_mask=attention_mask,
93
+ past_key_values=past_key_values,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ **kwargs
97
+ )
98
+ if self.config.fuse_norm:
99
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
100
+ else:
101
+ hidden_states = residual + hidden_states
102
+ residual = hidden_states
103
+ hidden_states = self.mlp_norm(hidden_states)
104
+ hidden_states = self.mlp(hidden_states, **kwargs)
105
+ hidden_states = residual + hidden_states
106
+
107
+ outputs = (hidden_states, attentions, past_key_values)
108
+
109
+ return outputs
110
+
111
+
112
+ class GLAPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = GLAConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['GLABlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ prenorm_residual_strategy: Optional[str] = 'rescale',
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if prenorm_residual_strategy is not None:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ if prenorm_residual_strategy == 'rescale':
158
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
159
+ with torch.no_grad():
160
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
161
+ elif prenorm_residual_strategy == 'zero':
162
+ nn.init.zeros_(p)
163
+ else:
164
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
165
+
166
+
167
+ class GLAModel(GLAPreTrainedModel):
168
+
169
+ def __init__(self, config: GLAConfig):
170
+ super().__init__(config)
171
+ self.padding_idx = config.pad_token_id
172
+ self.vocab_size = config.vocab_size
173
+
174
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
175
+ self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
176
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
177
+
178
+ self.gradient_checkpointing = False
179
+
180
+ self.post_init()
181
+
182
+ def get_input_embeddings(self):
183
+ return self.embeddings
184
+
185
+ def set_input_embeddings(self, value):
186
+ self.embeddings = value
187
+
188
+ def forward(
189
+ self,
190
+ input_ids: Optional[torch.LongTensor] = None,
191
+ attention_mask: Optional[torch.Tensor] = None, # noqa
192
+ inputs_embeds: Optional[torch.FloatTensor] = None,
193
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
194
+ use_cache: Optional[bool] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ return_dict: Optional[bool] = None,
198
+ **kwargs: Unpack[Dict]
199
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
200
+ if output_attentions:
201
+ warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ if input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embeddings(input_ids)
216
+ hidden_states = inputs_embeds
217
+
218
+ if use_cache and not isinstance(past_key_values, Cache):
219
+ past_key_values = Cache.from_legacy_cache(past_key_values)
220
+
221
+ if self.gradient_checkpointing and self.training and use_cache:
222
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
223
+ use_cache = False
224
+
225
+ all_hidden_states = () if output_hidden_states else None
226
+ all_attns = () if output_attentions else None
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ use_cache,
238
+ output_attentions,
239
+ **kwargs
240
+ )
241
+ else:
242
+ hidden_states, attentions, past_key_values = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ use_cache=use_cache,
247
+ output_attentions=output_attentions,
248
+ **kwargs
249
+ )
250
+
251
+ if output_attentions:
252
+ all_attns += (attentions,)
253
+
254
+ hidden_states = self.norm(hidden_states)
255
+
256
+ # add hidden states from the last decoder layer
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ if not return_dict:
261
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
262
+ return BaseModelOutputWithPast(
263
+ last_hidden_state=hidden_states,
264
+ past_key_values=past_key_values,
265
+ hidden_states=all_hidden_states,
266
+ attentions=all_attns
267
+ )
268
+
269
+
270
+ class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
271
+
272
+ _tied_weights_keys = ["lm_head.weight"]
273
+
274
+ def __init__(self, config):
275
+ super().__init__(config)
276
+ self.model = GLAModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+ self.criterion = None
280
+
281
+ # Initialize weights and apply final processing
282
+ self.post_init()
283
+
284
+ def get_input_embeddings(self):
285
+ return self.model.embeddings
286
+
287
+ def set_input_embeddings(self, value):
288
+ self.model.embeddings = value
289
+
290
+ def get_output_embeddings(self):
291
+ return self.lm_head
292
+
293
+ def set_output_embeddings(self, new_embeddings):
294
+ self.lm_head = new_embeddings
295
+
296
+ def set_decoder(self, decoder):
297
+ self.model = decoder
298
+
299
+ def get_decoder(self):
300
+ return self.model
301
+
302
+ def generate(self, *args, **kwargs):
303
+ try:
304
+ return super().generate(*args, **kwargs)
305
+ except AttributeError as exception:
306
+ if 'past_key_values' in str(exception):
307
+ raise AttributeError(
308
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
309
+ f"which is not supported for {self.__class__.__name__}. "
310
+ f"Try another generation strategy instead. "
311
+ f"For the available generation strategies, check this doc: "
312
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
313
+ )
314
+ else:
315
+ raise exception
316
+
317
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
318
+ def prepare_inputs_for_generation(
319
+ self,
320
+ input_ids: torch.LongTensor = None,
321
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ inputs_embeds: Optional[torch.Tensor] = None,
324
+ use_cache: bool = True,
325
+ logits_to_keep: Optional[int] = None,
326
+ **kwargs
327
+ ):
328
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
329
+ if past_key_values is not None and len(past_key_values) > 0:
330
+ input_ids = input_ids[:, -1:]
331
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
332
+ if inputs_embeds is not None and len(past_key_values) == 0:
333
+ model_inputs = {'inputs_embeds': inputs_embeds}
334
+ else:
335
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
336
+ # recompiles graphs as the stride of the inputs is a guard.
337
+ # Ref: https://github.com/huggingface/transformers/pull/29114
338
+ # TODO: use `next_tokens` directly instead.
339
+ model_inputs = {'input_ids': input_ids.contiguous()}
340
+
341
+ if logits_to_keep is not None:
342
+ model_inputs['logits_to_keep'] = logits_to_keep
343
+
344
+ model_inputs.update({
345
+ 'past_key_values': past_key_values,
346
+ 'use_cache': use_cache,
347
+ 'attention_mask': attention_mask,
348
+ })
349
+ return model_inputs
350
+
351
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
352
+ def forward(
353
+ self,
354
+ input_ids: torch.LongTensor = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
358
+ labels: Optional[torch.LongTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ logits_to_keep: Optional[int] = 0,
364
+ **kwargs: Unpack[Dict]
365
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
366
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
367
+ output_hidden_states = (
368
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
369
+ )
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.model(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ inputs_embeds=inputs_embeds,
376
+ past_key_values=past_key_values,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ **kwargs
382
+ )
383
+
384
+ hidden_states = outputs[0]
385
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
386
+
387
+ loss, logits = None, None
388
+ if not fuse_linear_and_cross_entropy or labels is None:
389
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
390
+ if labels is not None:
391
+ if getattr(self, 'criterion', None) is None:
392
+ if fuse_linear_and_cross_entropy:
393
+ criterion = FusedLinearCrossEntropyLoss()
394
+ elif self.config.fuse_cross_entropy:
395
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
396
+ else:
397
+ criterion = nn.CrossEntropyLoss()
398
+ else:
399
+ criterion = self.criterion
400
+ labels = labels.to(hidden_states.device)
401
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
402
+ if fuse_linear_and_cross_entropy:
403
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
404
+ else:
405
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
406
+
407
+ if not return_dict:
408
+ output = (logits,) + outputs[1:]
409
+ return (loss,) + output if loss is not None else output
410
+
411
+ return CausalLMOutputWithPast(
412
+ loss=loss,
413
+ logits=logits,
414
+ past_key_values=outputs.past_key_values,
415
+ hidden_states=outputs.hidden_states,
416
+ attentions=outputs.attentions,
417
+ )
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (736 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc ADDED
Binary file (4.29 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
fla/models/gsa/configuration_gsa.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_logit_normalizer: Optional[int] = 8,
17
+ clamp_min: Optional[float] = None,
18
+ clamp_max: Optional[float] = None,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ num_slots: Optional[int] = 64,
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ exapnd_k: float = 1,
28
+ exapnd_v: float = 1,
29
+ feature_map: str = 'swish',
30
+ use_output_gate: bool = False,
31
+ use_norm: bool = True,
32
+ max_position_embeddings: int = 2048,
33
+ hidden_act: str = "swish",
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-6,
36
+ attn: Optional[Dict] = None,
37
+ use_cache: bool = True,
38
+ pad_token_id: int = None,
39
+ bos_token_id: int = 1,
40
+ eos_token_id: int = 2,
41
+ initializer_range: float = 0.006,
42
+ tie_word_embeddings: bool = False,
43
+ fuse_norm: bool = True,
44
+ fuse_swiglu: bool = True,
45
+ fuse_cross_entropy: bool = True,
46
+ vocab_size: int = 32000,
47
+ **kwargs
48
+ ):
49
+ self.hidden_size = hidden_size
50
+ self.gate_logit_normalizer = gate_logit_normalizer
51
+ self.clamp_min = clamp_min
52
+ self.clamp_max = clamp_max
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_slots = num_slots
59
+ self.use_short_conv = use_short_conv
60
+ self.conv_size = conv_size
61
+ self.expand_k = exapnd_k
62
+ self.expand_v = exapnd_v
63
+ self.feature_map = feature_map
64
+ self.use_output_gate = use_output_gate
65
+ self.use_norm = use_norm
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.elementwise_affine = elementwise_affine
69
+ self.norm_eps = norm_eps
70
+ self.attn = attn
71
+ self.use_cache = use_cache
72
+ self.initializer_range = initializer_range
73
+
74
+ self.fuse_norm = fuse_norm
75
+ self.fuse_swiglu = fuse_swiglu
76
+ self.fuse_cross_entropy = fuse_cross_entropy
77
+ self.vocab_size = vocab_size
78
+
79
+ if attn is not None:
80
+ if not isinstance(attn, Dict):
81
+ raise ValueError("attn must be a dictionary")
82
+ if 'layers' not in attn:
83
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
84
+ if 'num_heads' not in attn:
85
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
86
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
87
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
88
+ attn['window_size'] = attn.get('window_size', None)
89
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
90
+
91
+ super().__init__(
92
+ pad_token_id=pad_token_id,
93
+ bos_token_id=bos_token_id,
94
+ eos_token_id=eos_token_id,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs,
97
+ )
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gsa import GatedSlotAttention
20
+ from fla.models.gsa.configuration_gsa import GSAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GSAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GSABlock(nn.Module):
33
+ def __init__(self, config: GSAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedSlotAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_kv_heads=config.num_kv_heads,
58
+ num_slots=config.num_slots,
59
+ use_short_conv=config.use_short_conv,
60
+ conv_size=config.conv_size,
61
+ feature_map=config.feature_map,
62
+ use_output_gate=config.use_output_gate,
63
+ use_norm=config.use_norm,
64
+ gate_fn=config.hidden_act,
65
+ gate_logit_normalizer=config.gate_logit_normalizer,
66
+ elementwise_affine=config.elementwise_affine,
67
+ norm_eps=config.norm_eps,
68
+ fuse_norm=config.fuse_norm,
69
+ layer_idx=layer_idx
70
+ )
71
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
72
+ self.mlp = GSAMLP(
73
+ hidden_size=config.hidden_size,
74
+ hidden_ratio=config.hidden_ratio,
75
+ intermediate_size=config.intermediate_size,
76
+ hidden_act=config.hidden_act,
77
+ fuse_swiglu=config.fuse_swiglu
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
85
+ use_cache: Optional[bool] = False,
86
+ output_attentions: Optional[bool] = False,
87
+ **kwargs: Unpack[Dict]
88
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
89
+ residual = hidden_states
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states, **kwargs)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class GSAPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = GSAConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['GSABlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class GSAModel(GSAPreTrainedModel):
169
+
170
+ def __init__(self, config: GSAConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask=attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+
277
+ super().__init__(config)
278
+ self.model = GSAModel(config)
279
+ self.vocab_size = config.vocab_size
280
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
281
+ self.criterion = None
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def get_input_embeddings(self):
287
+ return self.model.embeddings
288
+
289
+ def set_input_embeddings(self, value):
290
+ self.model.embeddings = value
291
+
292
+ def get_output_embeddings(self):
293
+ return self.lm_head
294
+
295
+ def set_output_embeddings(self, new_embeddings):
296
+ self.lm_head = new_embeddings
297
+
298
+ def set_decoder(self, decoder):
299
+ self.model = decoder
300
+
301
+ def get_decoder(self):
302
+ return self.model
303
+
304
+ def generate(self, *args, **kwargs):
305
+ try:
306
+ return super().generate(*args, **kwargs)
307
+ except AttributeError as exception:
308
+ if 'past_key_values' in str(exception):
309
+ raise AttributeError(
310
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
311
+ f"which is not supported for {self.__class__.__name__}. "
312
+ f"Try another generation strategy instead. "
313
+ f"For the available generation strategies, check this doc: "
314
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
315
+ )
316
+ else:
317
+ raise exception
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Dict]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ past_key_values=past_key_values,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+
389
+ loss, logits = None, None
390
+ if not fuse_linear_and_cross_entropy or labels is None:
391
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
392
+ if labels is not None:
393
+ if getattr(self, 'criterion', None) is None:
394
+ if fuse_linear_and_cross_entropy:
395
+ criterion = FusedLinearCrossEntropyLoss()
396
+ elif self.config.fuse_cross_entropy:
397
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
398
+ else:
399
+ criterion = nn.CrossEntropyLoss()
400
+ else:
401
+ criterion = self.criterion
402
+ # Enable model parallelism
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (744 Bytes). View file
 
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-311.pyc ADDED
Binary file (3.7 kB). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
fla/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_recurrent",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.expand_ratio = expand_ratio
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.use_lower_bound = use_lower_bound
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.elementwise_affine = elementwise_affine
52
+ self.attn = attn
53
+ self.norm_eps = norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.use_cache = use_cache
56
+ self.initializer_range = initializer_range
57
+
58
+ self.fuse_norm = fuse_norm
59
+ self.fuse_swiglu = fuse_swiglu
60
+ self.fuse_cross_entropy = fuse_cross_entropy
61
+ self.vocab_size = vocab_size
62
+
63
+ if attn is not None:
64
+ if not isinstance(attn, Dict):
65
+ raise ValueError("attn must be a dictionary")
66
+ if 'layers' not in attn:
67
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
68
+ if 'num_heads' not in attn:
69
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
70
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
71
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
72
+ attn['window_size'] = attn.get('window_size', None)
73
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
fla/models/hgrn/modeling_hgrn.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn import HGRNAttention
20
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRNMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRNBlock(nn.Module):
33
+ def __init__(self, config: HGRNConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRNAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ elementwise_affine=config.elementwise_affine,
59
+ norm_eps=config.norm_eps,
60
+ layer_idx=layer_idx
61
+ )
62
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
63
+ self.mlp = HGRNMLP(
64
+ hidden_size=config.hidden_size,
65
+ hidden_ratio=config.hidden_ratio,
66
+ intermediate_size=config.intermediate_size,
67
+ hidden_act=config.hidden_act,
68
+ fuse_swiglu=config.fuse_swiglu
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ hidden_states: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
76
+ use_cache: Optional[bool] = False,
77
+ output_attentions: Optional[bool] = False,
78
+ lower_bound: Optional[torch.Tensor] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ lower_bound=lower_bound,
90
+ **kwargs
91
+ )
92
+ if self.config.fuse_norm:
93
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
94
+ else:
95
+ hidden_states = residual + hidden_states
96
+ residual = hidden_states
97
+ hidden_states = self.mlp_norm(hidden_states)
98
+ hidden_states = self.mlp(hidden_states, **kwargs)
99
+ hidden_states = residual + hidden_states
100
+
101
+ outputs = (hidden_states, attentions, past_key_values)
102
+
103
+ return outputs
104
+
105
+
106
+ class HGRNPreTrainedModel(PreTrainedModel):
107
+
108
+ config_class = HGRNConfig
109
+ base_model_prefix = 'model'
110
+ supports_gradient_checkpointing = True
111
+ _no_split_modules = ['HGRNBlock']
112
+ _supports_cache_class = True
113
+
114
+ def __init__(self, *inputs, **kwargs):
115
+ super().__init__(*inputs, **kwargs)
116
+
117
+ def _init_weights(
118
+ self,
119
+ module: nn.Module,
120
+ prenorm_residual_strategy: Optional[str] = 'rescale',
121
+ num_residuals_per_layer: int = 2,
122
+ ):
123
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
124
+ # Slightly different from the TF version which uses truncated_normal for initialization
125
+ # cf https://github.com/pytorch/pytorch/pull/5617
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ if module.bias is not None:
128
+ nn.init.zeros_(module.bias)
129
+ elif isinstance(module, nn.Embedding):
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ elif hasattr(module, 'reset_parameters'):
132
+ module.reset_parameters()
133
+
134
+ if prenorm_residual_strategy is not None:
135
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
136
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
137
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
138
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
139
+ #
140
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
141
+ p = None
142
+ if hasattr(module, 'o_proj'):
143
+ p = module.o_proj.weight
144
+ elif hasattr(module, 'down_proj'):
145
+ p = module.down_proj.weight
146
+ if p is not None:
147
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
148
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
149
+ # We need to reinit p since this code could be called multiple times
150
+ # Having just p *= scale would repeatedly scale it down
151
+ if prenorm_residual_strategy == 'rescale':
152
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
153
+ with torch.no_grad():
154
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
155
+ elif prenorm_residual_strategy == 'zero':
156
+ nn.init.zeros_(p)
157
+ else:
158
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
159
+
160
+
161
+ class HGRNModel(HGRNPreTrainedModel):
162
+
163
+ def __init__(self, config: HGRNConfig):
164
+ super().__init__(config)
165
+ self.padding_idx = config.pad_token_id
166
+ self.vocab_size = config.vocab_size
167
+
168
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
169
+ if config.use_lower_bound:
170
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
171
+ self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
172
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
173
+
174
+ self.gradient_checkpointing = False
175
+
176
+ self.post_init()
177
+
178
+ def get_input_embeddings(self):
179
+ return self.embeddings
180
+
181
+ def set_input_embeddings(self, value):
182
+ self.embeddings = value
183
+
184
+ def forward(
185
+ self,
186
+ input_ids: Optional[torch.LongTensor] = None,
187
+ attention_mask: Optional[torch.Tensor] = None, # noqa
188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
189
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
190
+ use_cache: Optional[bool] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ return_dict: Optional[bool] = None,
194
+ **kwargs: Unpack[Dict]
195
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
196
+ if output_attentions:
197
+ warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
198
+ output_attentions = False
199
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
200
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
201
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
202
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
+
204
+ # retrieve input_ids and inputs_embeds
205
+ if input_ids is not None and inputs_embeds is not None:
206
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
207
+ if input_ids is None and inputs_embeds is None:
208
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
209
+
210
+ if inputs_embeds is None:
211
+ inputs_embeds = self.embeddings(input_ids)
212
+ hidden_states = inputs_embeds
213
+
214
+ if use_cache and not isinstance(past_key_values, Cache):
215
+ past_key_values = Cache.from_legacy_cache(past_key_values)
216
+
217
+ if self.gradient_checkpointing and self.training and use_cache:
218
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
219
+ use_cache = False
220
+
221
+ all_hidden_states = () if output_hidden_states else None
222
+ all_attns = () if output_attentions else None
223
+
224
+ if self.config.use_lower_bound:
225
+ lower_bounds = self.lower_bounds.softmax(0)
226
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
227
+ for i, layer in enumerate(self.layers):
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ lower_bound,
241
+ **kwargs
242
+ )
243
+ else:
244
+ hidden_states, attentions, past_key_values = layer(
245
+ hidden_states,
246
+ attention_mask=attention_mask,
247
+ past_key_values=past_key_values,
248
+ use_cache=use_cache,
249
+ output_attentions=output_attentions,
250
+ lower_bound=lower_bound,
251
+ **kwargs
252
+ )
253
+
254
+ if output_attentions:
255
+ all_attns += (attentions,)
256
+
257
+ hidden_states = self.norm(hidden_states)
258
+
259
+ # add hidden states from the last decoder layer
260
+ if output_hidden_states:
261
+ all_hidden_states += (hidden_states,)
262
+
263
+ if not return_dict:
264
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=past_key_values,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_attns
270
+ )
271
+
272
+
273
+ class HGRNForCausalLM(HGRNPreTrainedModel, GenerationMixin):
274
+
275
+ _tied_weights_keys = ["lm_head.weight"]
276
+
277
+ def __init__(self, config):
278
+ super().__init__(config)
279
+ self.model = HGRNModel(config)
280
+ self.vocab_size = config.vocab_size
281
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
282
+ self.criterion = None
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ def generate(self, *args, **kwargs):
306
+ try:
307
+ return super().generate(*args, **kwargs)
308
+ except AttributeError as exception:
309
+ if 'past_key_values' in str(exception):
310
+ raise AttributeError(
311
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
312
+ f"which is not supported for {self.__class__.__name__}. "
313
+ f"Try another generation strategy instead. "
314
+ f"For the available generation strategies, check this doc: "
315
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
316
+ )
317
+ else:
318
+ raise exception
319
+
320
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
321
+ def prepare_inputs_for_generation(
322
+ self,
323
+ input_ids: torch.LongTensor = None,
324
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ inputs_embeds: Optional[torch.Tensor] = None,
327
+ use_cache: bool = True,
328
+ logits_to_keep: Optional[int] = None,
329
+ **kwargs: Unpack[Dict]
330
+ ):
331
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
332
+ if past_key_values is not None and len(past_key_values) > 0:
333
+ input_ids = input_ids[:, -1:]
334
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
+ if inputs_embeds is not None and len(past_key_values) == 0:
336
+ model_inputs = {'inputs_embeds': inputs_embeds}
337
+ else:
338
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
339
+ # recompiles graphs as the stride of the inputs is a guard.
340
+ # Ref: https://github.com/huggingface/transformers/pull/29114
341
+ # TODO: use `next_tokens` directly instead.
342
+ model_inputs = {'input_ids': input_ids.contiguous()}
343
+
344
+ if logits_to_keep is not None:
345
+ model_inputs['logits_to_keep'] = logits_to_keep
346
+
347
+ model_inputs.update({
348
+ 'past_key_values': past_key_values,
349
+ 'use_cache': use_cache,
350
+ 'attention_mask': attention_mask,
351
+ })
352
+ return model_inputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ inputs_embeds: Optional[torch.Tensor] = None,
360
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
361
+ labels: Optional[torch.LongTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ logits_to_keep: Optional[int] = 0,
367
+ **kwargs: Unpack[Dict]
368
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
370
+ output_hidden_states = (
371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
372
+ )
373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
374
+
375
+ outputs = self.model(
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ inputs_embeds=inputs_embeds,
379
+ past_key_values=past_key_values,
380
+ use_cache=use_cache,
381
+ output_attentions=output_attentions,
382
+ output_hidden_states=output_hidden_states,
383
+ return_dict=return_dict,
384
+ **kwargs
385
+ )
386
+
387
+ hidden_states = outputs[0]
388
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
389
+
390
+ loss, logits = None, None
391
+ if not fuse_linear_and_cross_entropy or labels is None:
392
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
393
+ if labels is not None:
394
+ if getattr(self, 'criterion', None) is None:
395
+ if fuse_linear_and_cross_entropy:
396
+ criterion = FusedLinearCrossEntropyLoss()
397
+ elif self.config.fuse_cross_entropy:
398
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
399
+ else:
400
+ criterion = nn.CrossEntropyLoss()
401
+ else:
402
+ criterion = self.criterion
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
6
+ from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
7
+
8
+ AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
9
+ AutoModel.register(HGRN2Config, HGRN2Model)
10
+ AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']
fla/models/hgrn2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (753 Bytes). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-311.pyc ADDED
Binary file (3.99 kB). View file