zaydzuhri commited on
Commit
8faa4ba
·
verified ·
1 Parent(s): b3f00e2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/__pycache__/utils.cpython-312.pyc +0 -0
  2. fla/layers/__init__.py +44 -0
  3. fla/layers/abc.py +218 -0
  4. fla/layers/delta_net.py +291 -0
  5. fla/layers/forgetting_attn.py +109 -0
  6. fla/layers/gated_deltanet.py +293 -0
  7. fla/layers/gated_deltaproduct.py +351 -0
  8. fla/layers/gla.py +294 -0
  9. fla/layers/gsa.py +227 -0
  10. fla/layers/hgrn.py +168 -0
  11. fla/layers/hgrn2.py +211 -0
  12. fla/layers/linear_attn.py +166 -0
  13. fla/layers/multiscale_retention.py +298 -0
  14. fla/layers/nsa.py +138 -0
  15. fla/layers/rebased.py +133 -0
  16. fla/layers/rwkv6.py +307 -0
  17. fla/layers/rwkv7.py +221 -0
  18. fla/models/__init__.py +55 -0
  19. fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc +0 -0
  20. fla/models/bitnet/modeling_bitnet.py +441 -0
  21. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
  22. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc +0 -0
  23. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  24. fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
  25. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc +0 -0
  26. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  27. fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc +0 -0
  28. fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc +0 -0
  29. fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc +0 -0
  30. fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc +0 -0
  31. fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc +0 -0
  32. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla/models/samba/__pycache__/__init__.cpython-312.pyc +0 -0
  34. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  35. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  36. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  37. fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  38. fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc +0 -0
  39. fla/models/utils.py +147 -0
  40. fla/modules/__init__.py +30 -0
  41. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  42. fla/modules/__pycache__/fused_kl_div.cpython-312.pyc +0 -0
  43. fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
  44. fla/modules/__pycache__/seq_to_top.cpython-312.pyc +0 -0
  45. fla/modules/fused_linear_listnet_loss.py +427 -0
  46. fla/modules/l2norm.py +176 -0
  47. fla/modules/layernorm_gated.py +528 -0
  48. fla/modules/mlp.py +127 -0
  49. fla/ops/__init__.py +45 -0
  50. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
fla/__pycache__/utils.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/delta_net.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ **kwargs
90
+ ) -> DeltaNet:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.qk_activation = qk_activation
95
+ self.qk_norm = qk_norm
96
+
97
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
98
+ assert self.qk_norm in ['l2', 'sum']
99
+
100
+ if d_model is not None:
101
+ hidden_size = d_model
102
+ self.hidden_size = hidden_size
103
+ self.expand_k = expand_k
104
+ self.expand_v = expand_v
105
+ self.num_heads = num_heads
106
+ self.use_gate = use_gate
107
+ self.use_short_conv = use_short_conv
108
+ self.conv_size = conv_size
109
+ self.conv_bias = conv_bias
110
+ self.allow_neg_eigval = allow_neg_eigval
111
+
112
+ self.key_dim = int(hidden_size * expand_k)
113
+ self.value_dim = int(hidden_size * expand_v)
114
+ self.head_k_dim = self.key_dim // num_heads
115
+ self.head_v_dim = self.value_dim // num_heads
116
+ self.layer_idx = layer_idx
117
+
118
+ self.silu = nn.SiLU()
119
+ if mode == 'fused_chunk':
120
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
121
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
122
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
123
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
124
+
125
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
126
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
127
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ self.use_beta = use_beta
130
+ if self.use_beta:
131
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ if use_short_conv:
133
+ self.conv_size = conv_size
134
+ self.q_conv1d = ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation='silu' if qk_activation == 'silu' else None
138
+ )
139
+ self.k_conv1d = ShortConvolution(
140
+ hidden_size=self.key_dim,
141
+ kernel_size=conv_size,
142
+ activation='silu' if qk_activation == 'silu' else None
143
+ )
144
+ self.v_conv1d = ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation='silu'
148
+ )
149
+ else:
150
+ raise UserWarning(
151
+ "ShortConvolution is crucial to the performance. "
152
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
153
+ )
154
+ if use_gate:
155
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
156
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
157
+ else:
158
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
159
+
160
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Cache] = None,
167
+ use_cache: Optional[bool] = False,
168
+ output_attentions: Optional[bool] = False,
169
+ **kwargs: Unpack[Dict]
170
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
171
+ if attention_mask is not None:
172
+ assert len(attention_mask.shape) == 2, (
173
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
174
+ "for padding purposes (0 indicating padding). "
175
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
176
+ )
177
+
178
+ # change to inference mode.
179
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
180
+
181
+ last_state = None
182
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
183
+ last_state = past_key_values[self.layer_idx]
184
+
185
+ cu_seqlens = kwargs.get('cu_seqlens', None)
186
+ if self.use_short_conv:
187
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
188
+ if last_state is not None:
189
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
190
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
191
+ q, conv_state_q = self.q_conv1d(
192
+ x=self.q_proj(hidden_states),
193
+ mask=conv_mask,
194
+ cache=conv_state_q,
195
+ output_final_state=use_cache,
196
+ cu_seqlens=cu_seqlens
197
+ )
198
+ k, conv_state_k = self.k_conv1d(
199
+ x=self.k_proj(hidden_states),
200
+ mask=conv_mask,
201
+ cache=conv_state_k,
202
+ output_final_state=use_cache,
203
+ cu_seqlens=cu_seqlens
204
+ )
205
+ v, conv_state_v = self.v_conv1d(
206
+ x=self.v_proj(hidden_states),
207
+ mask=conv_mask,
208
+ cache=conv_state_v,
209
+ output_final_state=use_cache,
210
+ cu_seqlens=cu_seqlens
211
+ )
212
+ else:
213
+ q = self.q_proj(hidden_states)
214
+ k = self.k_proj(hidden_states)
215
+ if self.qk_activation == 'silu':
216
+ q, k = self.silu(q), self.silu(k)
217
+ v = self.silu(self.v_proj(hidden_states))
218
+
219
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
220
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
221
+ if self.qk_activation != 'silu':
222
+ if self.qk_activation == 'relu':
223
+ q, k = q.relu(), k.relu()
224
+ elif self.qk_activation == 'elu':
225
+ q, k = elu_p1(q), elu_p1(k)
226
+ elif self.qk_activation == 'identity':
227
+ pass
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ if self.qk_norm == 'sum':
232
+ q = sum_norm(q).to(q)
233
+ k = sum_norm(k).to(k)
234
+
235
+ if self.use_beta:
236
+ beta = self.b_proj(hidden_states).sigmoid()
237
+ else:
238
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
239
+
240
+ if self.allow_neg_eigval:
241
+ beta = beta * 2.
242
+
243
+ # dealing with padding
244
+ if attention_mask is not None:
245
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
246
+
247
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
248
+ if mode == 'fused_recurrent':
249
+ o, recurrent_state = fused_recurrent_delta_rule(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ beta=beta,
254
+ initial_state=recurrent_state,
255
+ output_final_state=use_cache,
256
+ cu_seqlens=cu_seqlens,
257
+ head_first=False,
258
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
259
+ )
260
+ elif mode == 'chunk':
261
+ o, recurrent_state = chunk_delta_rule(
262
+ q=q,
263
+ k=k,
264
+ v=v,
265
+ beta=beta,
266
+ initial_state=recurrent_state,
267
+ output_final_state=use_cache,
268
+ cu_seqlens=cu_seqlens,
269
+ head_first=False,
270
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
271
+ )
272
+ else:
273
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
274
+
275
+ if past_key_values is not None:
276
+ past_key_values.update(
277
+ recurrent_state=recurrent_state,
278
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
279
+ layer_idx=self.layer_idx,
280
+ offset=q.shape[1]
281
+ )
282
+
283
+ if self.use_gate:
284
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
285
+ o = self.o_norm(o, g)
286
+ else:
287
+ o = self.o_norm(o)
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/gated_deltanet.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ @torch.compile
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ @torch.compile
29
+ def sum_norm(x):
30
+ return (x / x.sum(-1, keepdim=True)).to(x)
31
+
32
+
33
+ class GatedDeltaNet(nn.Module):
34
+ """
35
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
36
+
37
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
38
+
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ **kwargs
95
+ ) -> GatedDeltaNet:
96
+ super().__init__()
97
+
98
+ self.mode = mode
99
+
100
+ self.hidden_size = hidden_size
101
+ self.expand_v = expand_v
102
+
103
+ self.use_gate = use_gate
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+
108
+ self.head_dim = head_dim
109
+ self.num_heads = num_heads
110
+
111
+ self.key_dim = int(self.num_heads * self.head_dim)
112
+ self.value_dim = int(self.key_dim * self.expand_v)
113
+ self.head_k_dim = head_dim
114
+ self.head_v_dim = int(head_dim * self.expand_v)
115
+ self.layer_idx = layer_idx
116
+
117
+ # Consistency check: Ensure expand_v produces integer values
118
+ if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
119
+ raise ValueError(
120
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
121
+ f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
122
+ )
123
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
124
+ raise ValueError(
125
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
126
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
127
+ )
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+
130
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
131
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
133
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
134
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
135
+
136
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
137
+ self.A_log = nn.Parameter(torch.log(A))
138
+ self.A_log._no_weight_decay = True
139
+ # hard coded for now
140
+ dt_min = 0.001
141
+ dt_max = 0.1
142
+ dt_init_floor = 1e-4
143
+ dt = torch.exp(
144
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
145
+ + math.log(dt_min)
146
+ )
147
+ dt = torch.clamp(dt, min=dt_init_floor)
148
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
149
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
150
+ self.dt_bias = nn.Parameter(inv_dt)
151
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
152
+ # name.endswith("bias") in param_grouping.py
153
+ self.dt_bias._no_weight_decay = True
154
+
155
+ if use_short_conv:
156
+ self.conv_size = conv_size
157
+ self.q_conv1d = ShortConvolution(
158
+ hidden_size=self.key_dim,
159
+ kernel_size=conv_size,
160
+ activation='silu'
161
+ )
162
+ self.k_conv1d = ShortConvolution(
163
+ hidden_size=self.key_dim,
164
+ kernel_size=conv_size,
165
+ activation='silu'
166
+ )
167
+ self.v_conv1d = ShortConvolution(
168
+ hidden_size=self.value_dim,
169
+ kernel_size=conv_size,
170
+ activation='silu'
171
+ )
172
+ else:
173
+ raise UserWarning(
174
+ "ShortConvolution is crucial to the performance. "
175
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
176
+ )
177
+ if use_gate:
178
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
179
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
180
+ else:
181
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
182
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[Cache] = None,
189
+ use_cache: Optional[bool] = False,
190
+ output_attentions: Optional[bool] = False,
191
+ **kwargs: Unpack[Dict]
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
193
+ if attention_mask is not None:
194
+ assert len(attention_mask.shape) == 2, (
195
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
196
+ "for padding purposes (0 indicating padding). "
197
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
198
+ )
199
+
200
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
201
+ if self.training:
202
+ assert mode == 'chunk', "Only chunk mode is supported in training."
203
+
204
+ last_state = None
205
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
206
+ last_state = past_key_values[self.layer_idx]
207
+
208
+ cu_seqlens = kwargs.get('cu_seqlens', None)
209
+ if self.use_short_conv:
210
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
211
+ if last_state is not None:
212
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
213
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
214
+ q, conv_state_q = self.q_conv1d(
215
+ x=self.q_proj(hidden_states),
216
+ mask=conv_mask,
217
+ cache=conv_state_q,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens
220
+ )
221
+ k, conv_state_k = self.k_conv1d(
222
+ x=self.k_proj(hidden_states),
223
+ mask=conv_mask,
224
+ cache=conv_state_k,
225
+ output_final_state=use_cache,
226
+ cu_seqlens=cu_seqlens
227
+ )
228
+ v, conv_state_v = self.v_conv1d(
229
+ x=self.v_proj(hidden_states),
230
+ mask=conv_mask,
231
+ cache=conv_state_v,
232
+ output_final_state=use_cache,
233
+ cu_seqlens=cu_seqlens
234
+ )
235
+ else:
236
+ q = F.silu(self.q_proj(hidden_states))
237
+ k = F.silu(self.k_proj(hidden_states))
238
+ v = F.silu(self.v_proj(hidden_states))
239
+
240
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
241
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
242
+ beta = self.b_proj(hidden_states).sigmoid()
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
244
+
245
+ # dealing with padding
246
+ if attention_mask is not None:
247
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
248
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
249
+
250
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
251
+ if mode == 'chunk':
252
+ o, recurrent_state = chunk_gated_delta_rule(
253
+ q=q,
254
+ k=k,
255
+ v=v,
256
+ g=g,
257
+ beta=beta,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False,
262
+ use_qk_l2norm_in_kernel=True
263
+ )
264
+ elif mode == 'fused_recurrent':
265
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
266
+ q=q,
267
+ k=k,
268
+ v=v,
269
+ g=g,
270
+ beta=beta,
271
+ initial_state=recurrent_state,
272
+ output_final_state=use_cache,
273
+ cu_seqlens=cu_seqlens,
274
+ head_first=False,
275
+ use_qk_l2norm_in_kernel=True
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281
+ layer_idx=self.layer_idx,
282
+ offset=q.shape[1]
283
+ )
284
+
285
+ if self.use_gate:
286
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
287
+ o = self.o_norm(o, g)
288
+ else:
289
+ o = self.o_norm(o)
290
+ o = rearrange(o, 'b t h d -> b t (h d)')
291
+ o = self.o_proj(o)
292
+
293
+ return o, None, past_key_values
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_eps: float = 1e-5,
41
+ gate_logit_normalizer: int = 8,
42
+ feature_map: str = 'swish',
43
+ use_output_gate: bool = False,
44
+ use_norm: bool = True,
45
+ layer_idx: Optional[int] = None,
46
+ scale: Optional[float] = 1.,
47
+ **kwargs
48
+ ) -> GatedSlotAttention:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+ self.hidden_size = hidden_size
53
+ self.expand_k = expand_k
54
+ self.expand_v = expand_v
55
+ self.num_heads = num_heads
56
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
57
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
58
+ self.key_dim = int(hidden_size * expand_k)
59
+ self.value_dim = int(hidden_size * expand_v)
60
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
61
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
62
+ self.head_k_dim = self.key_dim // self.num_heads
63
+ self.head_v_dim = self.value_dim // self.num_heads
64
+
65
+ self.use_short_conv = use_short_conv
66
+ self.conv_size = conv_size
67
+ self.conv_bias = conv_bias
68
+
69
+ self.gate_logit_normalizer = gate_logit_normalizer
70
+
71
+ self.use_output_gate = use_output_gate
72
+ self.use_norm = use_norm
73
+ self.scale = scale
74
+
75
+ if num_slots is None:
76
+ num_slots = self.head_k_dim
77
+ self.num_slots = num_slots
78
+
79
+ self.layer_idx = layer_idx
80
+
81
+ if layer_idx is None:
82
+ warnings.warn(
83
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
84
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
85
+ "when creating this class."
86
+ )
87
+
88
+ self.register_module('feature_map', None)
89
+ if feature_map == 'swish':
90
+ self.feature_map = SwishFeatureMap()
91
+ elif feature_map == 'relu':
92
+ self.feature_map = ReLUFeatureMap()
93
+ elif feature_map == 't2r':
94
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
95
+ else:
96
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
97
+
98
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
101
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
102
+
103
+ if use_short_conv:
104
+ self.conv_size = conv_size
105
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
106
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
107
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
108
+
109
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
110
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ past_key_values: Optional[Cache] = None,
117
+ use_cache: Optional[bool] = False,
118
+ output_attentions: Optional[bool] = False,
119
+ **kwargs: Unpack[Dict]
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
121
+ if attention_mask is not None:
122
+ assert len(attention_mask.shape) == 2, (
123
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
124
+ "for padding purposes (0 indicating padding). "
125
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
126
+ )
127
+
128
+ # launching the triton kernel for just one token will actually be slower
129
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
130
+
131
+ last_state = None
132
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
133
+ last_state = past_key_values[self.layer_idx]
134
+
135
+ cu_seqlens = kwargs.get('cu_seqlens', None)
136
+ if self.use_short_conv:
137
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
138
+ if last_state is not None:
139
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
140
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
141
+ q, conv_state_q = self.q_conv1d(
142
+ x=self.q_proj(hidden_states),
143
+ mask=conv_mask,
144
+ cache=conv_state_q,
145
+ output_final_state=use_cache,
146
+ cu_seqlens=cu_seqlens
147
+ )
148
+ k, conv_state_k = self.k_conv1d(
149
+ x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache,
153
+ cu_seqlens=cu_seqlens
154
+ )
155
+ v, conv_state_v = self.v_conv1d(
156
+ x=self.v_proj(hidden_states),
157
+ mask=conv_mask,
158
+ cache=conv_state_v,
159
+ output_final_state=use_cache,
160
+ cu_seqlens=cu_seqlens
161
+ )
162
+ else:
163
+ q = self.q_proj(hidden_states)
164
+ k = self.k_proj(hidden_states)
165
+ v = self.v_proj(hidden_states)
166
+ f = self.f_proj(hidden_states)
167
+
168
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
169
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
170
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
171
+ f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
172
+
173
+ if self.feature_map is not None:
174
+ q, k = map(lambda x: self.feature_map(x), (q, k))
175
+ v = F.silu(v)
176
+
177
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
178
+ s = (1 - f.exp()).to(f.dtype)
179
+ # dealing with left-padding
180
+ if attention_mask is not None:
181
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
182
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
183
+
184
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
185
+ if mode == 'fused_recurrent':
186
+ o, recurrent_state = fused_recurrent_gsa(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ s=s,
191
+ g=f,
192
+ initial_state=recurrent_state,
193
+ output_final_state=use_cache,
194
+ scale=self.scale,
195
+ cu_seqlens=cu_seqlens,
196
+ head_first=False
197
+ )
198
+ elif mode == 'chunk':
199
+ o, recurrent_state = chunk_gsa(
200
+ q=q,
201
+ k=k,
202
+ v=v,
203
+ s=s,
204
+ g=f,
205
+ initial_state=recurrent_state,
206
+ output_final_state=use_cache,
207
+ scale=self.scale,
208
+ cu_seqlens=cu_seqlens,
209
+ head_first=False
210
+ )
211
+ else:
212
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
213
+
214
+ if past_key_values is not None:
215
+ past_key_values.update(
216
+ recurrent_state=recurrent_state,
217
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
218
+ layer_idx=self.layer_idx,
219
+ offset=q.shape[1]
220
+ )
221
+
222
+ o = rearrange(o, 'b t h d -> b t (h d)')
223
+ o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
224
+ return o, None, past_key_values
225
+
226
+ def state_size(self, *args, **kwargs) -> int:
227
+ return 2 * self.num_slots * self.hidden_size
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class MultiScaleRetention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
24
+
25
+ Args:
26
+ mode (str, Optional):
27
+ Which Retention kernel to use.
28
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
29
+ Default: `chunk`.
30
+ hidden_size (int, Optional):
31
+ The hidden size of the input. Default: 1024.
32
+ expand_k (float, Optional):
33
+ The expansion ratio for the key dim. Default: 1.0.
34
+ expand_v (float, Optional):
35
+ The expansion ratio for the value dim. Default: 2.0.
36
+ num_heads (int, Optional):
37
+ The number of heads. Default: 8.
38
+ num_kv_heads (int, Optional):
39
+ The number of key/value heads, used for MQA. Default: None.
40
+ feature_map (str, Optional):
41
+ Feature map function applied to queries/keys. Default: None.
42
+ use_short_conv (bool, Optional):
43
+ Whether to use short convolutions. Default: `False`.
44
+ conv_size (int, Optional):
45
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
46
+ conv_bias (bool, Optional):
47
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
48
+ use_output_gate (bool, Optional):
49
+ Whether to use output gate. Default: `True`.
50
+ gate_fn (str, Optional):
51
+ The activation function for the output gate. Default: `swish`.
52
+ elementwise_affine (bool, Optional):
53
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
54
+ norm_eps (float, Optional):
55
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
56
+ fuse_norm (bool, Optional):
57
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
58
+ layer_idx (int, Optional):
59
+ The index of the layer. Default: None.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ mode: str = 'chunk',
65
+ hidden_size: int = 1024,
66
+ expand_k: float = 1.0,
67
+ expand_v: float = 2.0,
68
+ num_heads: int = 8,
69
+ num_kv_heads: Optional[int] = None,
70
+ feature_map: Optional[str] = None,
71
+ use_short_conv: bool = False,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ use_output_gate: bool = True,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ fuse_norm: bool = True,
79
+ layer_idx: int = None,
80
+ **kwargs
81
+ ) -> MultiScaleRetention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+ self.use_output_gate = use_output_gate
97
+
98
+ self.key_dim = int(hidden_size * expand_k)
99
+ self.value_dim = int(hidden_size * expand_v)
100
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
101
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
102
+ self.layer_idx = layer_idx
103
+
104
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
105
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
106
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
107
+
108
+ self.head_k_dim = self.key_dim // num_heads
109
+ self.head_v_dim = self.value_dim // num_heads
110
+
111
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
112
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
113
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
114
+ if self.use_output_gate:
115
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
116
+
117
+ if use_short_conv:
118
+ self.conv_size = conv_size
119
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
120
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
121
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
122
+
123
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
124
+
125
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
126
+ self.g_norm_swish_gate = FusedRMSNormGated(
127
+ hidden_size=self.head_v_dim,
128
+ elementwise_affine=elementwise_affine,
129
+ eps=norm_eps
130
+ )
131
+ self.fuse_norm_and_gate = True
132
+ else:
133
+ self.fuse_norm_and_gate = False
134
+ self.g_norm = RMSNorm(
135
+ hidden_size=self.head_v_dim,
136
+ elementwise_affine=elementwise_affine,
137
+ eps=norm_eps
138
+ )
139
+ self.gate_fn = ACT2FN[gate_fn]
140
+
141
+ # TODO: fix this issue
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
143
+ # Ideally, we would want to support arbitrary d_head_qk
144
+ assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256"
145
+ self.rotary = RotaryEmbedding(dim=self.head_k_dim)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_values: Optional[Cache] = None,
152
+ use_cache: Optional[bool] = False,
153
+ output_attentions: Optional[bool] = False,
154
+ **kwargs
155
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
156
+ if attention_mask is not None:
157
+ assert len(attention_mask.shape) == 2, (
158
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
159
+ "for padding purposes (0 indicating padding). "
160
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
161
+ )
162
+
163
+ # launching the triton kernel for just one token will actually be slower
164
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
165
+
166
+ last_state = None
167
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
168
+ last_state = past_key_values[self.layer_idx]
169
+
170
+ cu_seqlens = kwargs.get('cu_seqlens', None)
171
+ if self.use_short_conv:
172
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
173
+ if last_state is not None:
174
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
175
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
176
+ q, conv_state_q = self.q_conv1d(
177
+ x=self.q_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_q,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens
182
+ )
183
+ k, conv_state_k = self.k_conv1d(
184
+ x=self.k_proj(hidden_states),
185
+ mask=conv_mask,
186
+ cache=conv_state_k,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens
189
+ )
190
+ v, conv_state_v = self.v_conv1d(
191
+ x=self.v_proj(hidden_states),
192
+ mask=conv_mask,
193
+ cache=conv_state_v,
194
+ output_final_state=use_cache,
195
+ cu_seqlens=cu_seqlens
196
+ )
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+
202
+ # dealing with left-padding
203
+ if attention_mask is not None:
204
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
205
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
206
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
207
+ if self.feature_map_fn is not None:
208
+ q, k = map(self.feature_map_fn, (q, k))
209
+
210
+ seqlen_offset, max_seqlen = 0, q.shape[1]
211
+ if past_key_values is not None:
212
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
213
+ max_seqlen = q.shape[1] + seqlen_offset
214
+
215
+ if attention_mask is not None:
216
+ # to deliminate the offsets of padding tokens
217
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
218
+ max_seqlen = q.shape[1] + max(seqlen_offset)
219
+
220
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
221
+
222
+ if self.num_kv_groups > 1:
223
+ k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
224
+ v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
225
+ else:
226
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
227
+
228
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
229
+ if mode == 'chunk':
230
+ o, recurrent_state = chunk_retention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ cu_seqlens=cu_seqlens,
237
+ head_first=False
238
+ )
239
+ elif mode == 'fused_chunk':
240
+ o, recurrent_state = fused_chunk_retention(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ cu_seqlens=cu_seqlens,
247
+ head_first=False
248
+ )
249
+ elif mode == 'parallel':
250
+ o, recurrent_state = parallel_retention(
251
+ q=q,
252
+ k=k,
253
+ v=v,
254
+ cu_seqlens=cu_seqlens,
255
+ head_first=False
256
+ )
257
+ elif mode == 'fused_recurrent':
258
+ o, recurrent_state = fused_recurrent_retention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ initial_state=recurrent_state,
263
+ output_final_state=use_cache,
264
+ cu_seqlens=cu_seqlens,
265
+ head_first=False
266
+ )
267
+ else:
268
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
269
+
270
+ if past_key_values is not None:
271
+ past_key_values.update(
272
+ recurrent_state=recurrent_state,
273
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274
+ layer_idx=self.layer_idx,
275
+ offset=q.shape[1]
276
+ )
277
+
278
+ if self.use_output_gate:
279
+ g = self.g_proj(hidden_states)
280
+ if self.fuse_norm_and_gate:
281
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
282
+ o = self.g_norm_swish_gate(o, g)
283
+ o = rearrange(o, 'b t h d -> b t (h d)')
284
+ else:
285
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
286
+ o = o * self.gate_fn(g)
287
+ else:
288
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
292
+
293
+ def state_size(self, **kwargs) -> int:
294
+ state_size = self.key_dim * self.head_v_dim
295
+ for module in self.children():
296
+ if isinstance(module, ShortConvolution):
297
+ state_size += module.state_size
298
+ return state_size
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values
fla/layers/rebased.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ l_max: int = 2048,
27
+ feature_dim: int = 16,
28
+ num_key_value_heads: int = 16,
29
+ num_heads: int = 16,
30
+ use_gamma: Optional[bool] = True,
31
+ use_beta: Optional[bool] = True,
32
+ normalize: Optional[bool] = True,
33
+ causal: bool = True,
34
+ eps: float = 1e-5,
35
+ mode: str = "parallel",
36
+ layer_idx: Optional[int] = None,
37
+ **kwargs
38
+ ) -> ReBasedLinearAttention:
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.l_max = l_max
42
+ self.mode = mode
43
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
44
+
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+ self.eps = eps
54
+ self.mode = mode
55
+ self.layer_idx = layer_idx
56
+
57
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
58
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
61
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
62
+ self.dropout = nn.Identity()
63
+
64
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
65
+ mode = self.mode
66
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
67
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
68
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
69
+ if mode == "fused_chunk":
70
+ o = fused_chunk_linear_attn(
71
+ q=q,
72
+ k=k,
73
+ v=v,
74
+ normalize=True,
75
+ scale=1,
76
+ head_first=False
77
+ )
78
+ elif mode == 'chunk':
79
+ o = chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'parallel':
88
+ assert q.shape[-1] <= 128
89
+ o = parallel_rebased(
90
+ q=q,
91
+ k=k,
92
+ v=v,
93
+ eps=self.eps,
94
+ use_scale=True,
95
+ use_normalize=True,
96
+ head_first=False
97
+ )
98
+ o = self.o_proj(o)
99
+ o = self.dropout(o)
100
+ return o
101
+
102
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
103
+ def forward_reference(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ filters: torch.Tensor = None,
107
+ *args,
108
+ **kwargs
109
+ ):
110
+ """
111
+ x (torch.Tensor): tensor of shape (b, d, t)
112
+ y (torch.Tensor): tensor of shape (b, d, t)
113
+ """
114
+ b, t, _ = hidden_states.size()
115
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
116
+
117
+ q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
118
+ k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
119
+ v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
120
+
121
+ # Linear attention
122
+ q, k = self.feature_map(q), self.feature_map(k)
123
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
124
+
125
+ # Compute attention
126
+ if self.causal:
127
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
128
+ else:
129
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
130
+ y = rearrange(y, 'b h t d -> b t (h d)')
131
+ y = self.o_proj(y.to(hidden_states.dtype))
132
+ y = self.dropout(y)
133
+ return y.to(hidden_states.dtype)
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/layers/rwkv7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.layers.rwkv6 import LoRA
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV7Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ head_dim: Optional[int] = 64,
29
+ num_heads: Optional[int] = None,
30
+ decay_low_rank_dim: int = 64,
31
+ gate_low_rank_dim: int = 128,
32
+ a_low_rank_dim: int = 64,
33
+ v_low_rank_dim: int = 16,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None,
37
+ fuse_norm: bool = False,
38
+ value_dim: int = None,
39
+ **kwargs
40
+ ) -> RWKV7Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
45
+ self.hidden_size = hidden_size
46
+
47
+ self.key_dim = hidden_size
48
+ self.value_dim = value_dim if value_dim is not None else hidden_size
49
+ if head_dim is None and num_heads is None:
50
+ raise ValueError("Either `head_dim` or `num_heads` must be specified.")
51
+ elif head_dim is not None:
52
+ self.head_dim = head_dim
53
+ self.num_heads = int(hidden_size // head_dim)
54
+ elif num_heads is not None:
55
+ self.head_dim = int(hidden_size // num_heads)
56
+ self.num_heads = num_heads
57
+ self.head_v_dim = int(self.value_dim // self.num_heads)
58
+
59
+ self.decay_low_rank_dim = decay_low_rank_dim
60
+ self.gate_low_rank_dim = gate_low_rank_dim
61
+ self.a_low_rank_dim = a_low_rank_dim
62
+ self.v_low_rank_dim = v_low_rank_dim
63
+ self.layer_idx = layer_idx
64
+ self.fuse_norm = fuse_norm
65
+
66
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
67
+
68
+ self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
69
+
70
+ self.k_k = nn.Parameter(torch.zeros(self.key_dim))
71
+ self.k_a = nn.Parameter(torch.zeros(self.key_dim))
72
+ self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
73
+
74
+ self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
75
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
76
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
77
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
78
+
79
+ self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh')
80
+ if self.layer_idx != 0:
81
+ self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None)
82
+ self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None)
83
+ self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False)
84
+
85
+ if self.fuse_norm:
86
+ self.g_norm = GroupNorm(
87
+ num_groups=self.num_heads,
88
+ hidden_size=self.value_dim,
89
+ elementwise_affine=elementwise_affine,
90
+ eps=self.head_dim*norm_eps,
91
+ bias=True,
92
+ )
93
+ else:
94
+ self.g_norm = nn.GroupNorm(
95
+ num_groups=self.num_heads,
96
+ num_channels=self.value_dim,
97
+ eps=self.head_dim*norm_eps,
98
+ affine=elementwise_affine
99
+ )
100
+
101
+ self.apply(self._initialize_weights)
102
+
103
+ def _initialize_weights(self, module: nn.Module):
104
+ if getattr(module, "_is_hf_initialized", False):
105
+ return
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ if isinstance(module, nn.Parameter):
111
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
112
+ module._is_hf_initialized = True
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ past_key_values: Optional[Cache] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = False,
121
+ v_first: torch.Tensor = None,
122
+ **kwargs
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
124
+ if attention_mask is not None:
125
+ assert len(attention_mask.shape) == 2, (
126
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
127
+ "for padding purposes (0 indicating padding). "
128
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
129
+ )
130
+
131
+ batch_size, seq_len, _ = hidden_states.shape
132
+
133
+ if self.training:
134
+ # if training, use chunk mode no matter how short the sequence is
135
+ mode = 'chunk'
136
+ else:
137
+ # launching the triton kernel for just one token will actually be slower
138
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
139
+
140
+ last_state = None
141
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
142
+ last_state = past_key_values[self.layer_idx]
143
+
144
+ if attention_mask is not None:
145
+ hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
146
+ if hidden_states.shape[1] == 1 and last_state is not None:
147
+ shifted = last_state['conv_state'].unsqueeze(1)
148
+ else:
149
+ shifted = self.time_shift(hidden_states)
150
+ if last_state is not None:
151
+ shifted[:, 0] = last_state['conv_state']
152
+
153
+ # [batch_size, seq_len, hidden_size]
154
+ delta = shifted - hidden_states
155
+ xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
156
+
157
+ r = self.r_proj(xr)
158
+ # -math.exp(-0.5) = -0.6065306597126334
159
+ # I think .to(torch.float) is unnecessary here, since we calculate lora in bloat16
160
+ # when we apply sigmoid, bf16 input will not have numerical issue
161
+ # FIXME: check if we can remove .to(torch.float)
162
+ w = -0.6065306597126334 * self.w_lora(xw).to(torch.float).sigmoid()
163
+
164
+ k = self.k_proj(xk)
165
+ v = self.v_proj(xv)
166
+
167
+ if self.layer_idx == 0:
168
+ v_first = v
169
+ else:
170
+ v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid())
171
+ a = self.a_lora(xa).sigmoid()
172
+ g = self.g_lora(xg)
173
+
174
+ if self.fuse_norm:
175
+ kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim))
176
+ else:
177
+ kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0)
178
+
179
+ k = k.addcmul(k * (a - 1), self.k_a)
180
+
181
+ # dealing with left-padding
182
+ if attention_mask is not None:
183
+ v = v * attention_mask[:, -v.shape[-2]:, None]
184
+ r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
185
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
186
+
187
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
188
+
189
+ rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
190
+ cu_seqlens = kwargs.get('cu_seqlens', None)
191
+ o, recurrent_state = rwkv7_fn(
192
+ r=r,
193
+ w=w,
194
+ k=k,
195
+ v=v,
196
+ a=-kk,
197
+ b=kk * a,
198
+ scale=1.,
199
+ initial_state=recurrent_state,
200
+ output_final_state=use_cache,
201
+ cu_seqlens=cu_seqlens,
202
+ head_first=False
203
+ )
204
+
205
+ if past_key_values is not None:
206
+ past_key_values.update(
207
+ recurrent_state=recurrent_state,
208
+ conv_state=hidden_states[:, -1],
209
+ layer_idx=self.layer_idx,
210
+ offset=r.shape[1]
211
+ )
212
+
213
+ if self.fuse_norm:
214
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)'))
215
+ else:
216
+ o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1)
217
+
218
+ o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1)
219
+ o = self.o_proj(o * g)
220
+
221
+ return o, None, past_key_values, v_first
fla/models/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
6
+ from fla.models.forgetting_transformer import (
7
+ ForgettingTransformerConfig,
8
+ ForgettingTransformerForCausalLM,
9
+ ForgettingTransformerModel
10
+ )
11
+ from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
12
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel
13
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
14
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
15
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
16
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
17
+ from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel
18
+ from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel
19
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
20
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
21
+ from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
22
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
23
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
24
+ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
25
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
26
+ from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel
27
+ from fla.models.transformer_top import TOPTransformerConfig, TOPTransformerForCausalLM, TOPTransformerModel
28
+ from fla.models.transformer_mtp import MTPTransformerConfig, MTPTransformerForCausalLM, MTPTransformerModel
29
+ from fla.models.transformer_dsmtp import DSMTPTransformerConfig, DSMTPTransformerForCausalLM, DSMTPTransformerModel
30
+
31
+ __all__ = [
32
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
33
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
34
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
35
+ 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
36
+ 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
37
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
38
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
39
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
40
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
41
+ 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel',
42
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
43
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
44
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
45
+ 'NSAConfig', 'NSAForCausalLM', 'NSAModel',
46
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
47
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
48
+ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
49
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
50
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
51
+ 'TOPTransformerConfig', 'TOPTransformerForCausalLM', 'TOPTransformerModel',
52
+ 'MTPTransformerConfig', 'MTPTransformerForCausalLM', 'MTPTransformerModel',
53
+ 'DSMTPTransformerConfig', 'DSMTPTransformerForCausalLM', 'DSMTPTransformerModel',
54
+ 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
55
+ ]
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc ADDED
Binary file (2.4 kB). View file
 
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc ADDED
Binary file (3.62 kB). View file
 
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc ADDED
Binary file (3.31 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (19 kB). View file
 
fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc ADDED
Binary file (41.6 kB). View file
 
fla/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc ADDED
Binary file (52.4 kB). View file
 
fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc ADDED
Binary file (2.67 kB). View file
 
fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc ADDED
Binary file (17.6 kB). View file
 
fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (715 Bytes). View file
 
fla/models/samba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (745 Bytes). View file
 
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.42 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.54 kB). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (823 Bytes). View file
 
fla/models/utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import transformers
9
+
10
+
11
+ class Cache(transformers.cache_utils.Cache):
12
+ """
13
+ A cache used for storing hidden states produced by flash linear attention models.
14
+
15
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
16
+ """
17
+
18
+ is_compileable = True
19
+
20
+ def __init__(
21
+ self,
22
+ seen_tokens: int = 0
23
+ ) -> Cache:
24
+ super().__init__()
25
+
26
+ self.states: List[Dict[str, Any]] = []
27
+
28
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
29
+
30
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
31
+ if layer_idx < len(self):
32
+ return self.states[layer_idx]
33
+ else:
34
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
35
+
36
+ def __iter__(self):
37
+ for state in self.states:
38
+ yield state
39
+
40
+ def __len__(self):
41
+ return len(self.states)
42
+
43
+ def update(
44
+ self,
45
+ recurrent_state: torch.Tensor = None,
46
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
47
+ conv_state: Tuple[torch.Tensor] = None,
48
+ ffn_state: torch.Tensor = None,
49
+ layer_idx: int = 0,
50
+ offset: Optional[int] = 1,
51
+ cache_kwargs: Optional[Dict[str, Any]] = None,
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
55
+
56
+ Args:
57
+ recurrent_state (`torch.Tensor`, `optional`):
58
+ The new recurrent state to cache.
59
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
60
+ The new attention key/value states to cache.
61
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
62
+ The new convolution state to cache.
63
+ layer_idx (`int`, defaults to 0):
64
+ The index of the layer to cache the states for.
65
+ offset (`int`, `optional`, defaults to 1):
66
+ The number of new tokens being processed.
67
+ cache_kwargs (`Dict[str, Any]`, `optional`):
68
+ Additional arguments for the cache subclass.
69
+
70
+ Return:
71
+ Dictionary of the updated state.
72
+ """
73
+
74
+ # Update the number of seen tokens
75
+ if layer_idx == 0:
76
+ self._seen_tokens += offset
77
+
78
+ if attn_state is not None:
79
+ input_size = attn_state[0].shape[-2]
80
+ window_size = cache_kwargs.get('window_size', None)
81
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
82
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
83
+ if len(self.states) <= layer_idx:
84
+ if attn_state is not None:
85
+ if window_size is not None and input_size > window_size:
86
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
87
+ attn_state[1][..., -window_size:, :].contiguous())
88
+ state = dict(
89
+ recurrent_state=recurrent_state,
90
+ attn_state=attn_state,
91
+ conv_state=conv_state,
92
+ ffn_state=ffn_state
93
+ )
94
+ self.states.append(state)
95
+ else:
96
+ state = self.states[layer_idx]
97
+ if recurrent_state is not None:
98
+ state['recurrent_state'] = recurrent_state
99
+ if attn_state is not None:
100
+ key_state, value_state = state['attn_state']
101
+ if window_size is not None and key_state.shape[-2] == window_size:
102
+ # DO NOT allocate new memory if the cache is full
103
+ # roll the key/value states to the left by `input_size`
104
+ key_state = key_state.roll(-input_size, -2)
105
+ value_state = value_state.roll(-input_size, -2)
106
+ # replace the last `input_size` tokens with the new key/value states
107
+ key_state[..., -input_size:, :] = attn_state[0]
108
+ value_state[..., -input_size:, :] = attn_state[1]
109
+ attn_state = (key_state, value_state)
110
+ else:
111
+ attn_state = (torch.cat([key_state, attn_state[0]], -2),
112
+ torch.cat([value_state, attn_state[1]], -2),)
113
+ state['attn_state'] = attn_state
114
+ if conv_state is not None:
115
+ state['conv_state'] = conv_state
116
+ if ffn_state is not None:
117
+ state['ffn_state'] = ffn_state
118
+
119
+ return state
120
+
121
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
122
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
123
+ if len(self.states) <= layer_idx:
124
+ return 0
125
+ return self._seen_tokens
126
+
127
+ def get_max_length(self) -> Optional[int]:
128
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
129
+ return None
130
+
131
+ def to_legacy_cache(self) -> Tuple:
132
+ return tuple(self.states)
133
+
134
+ @classmethod
135
+ @torch.compiler.disable
136
+ def from_legacy_cache(
137
+ cls,
138
+ past_key_values: Optional[Tuple] = None,
139
+ seen_tokens: int = 0
140
+ ) -> Cache:
141
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
142
+
143
+ cache = cls(seen_tokens)
144
+ if isinstance(past_key_values, list):
145
+ for layer_idx in range(len(past_key_values)):
146
+ cache.states.append(past_key_values[layer_idx])
147
+ return cache
fla/modules/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution
4
+ from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear
5
+ from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
6
+ from fla.modules.fused_kl_div import FusedKLDivLoss
7
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
8
+ from fla.modules.fused_linear_listnet_loss import FusedLinearListNetLoss
9
+ from fla.modules.fused_norm_gate import (
10
+ FusedLayerNormGated,
11
+ FusedLayerNormSwishGate,
12
+ FusedLayerNormSwishGateLinear,
13
+ FusedRMSNormGated,
14
+ FusedRMSNormSwishGate,
15
+ FusedRMSNormSwishGateLinear
16
+ )
17
+ from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear
18
+ from fla.modules.mlp import GatedMLP
19
+ from fla.modules.rotary import RotaryEmbedding
20
+
21
+ __all__ = [
22
+ 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
23
+ 'BitLinear', 'FusedBitLinear',
24
+ 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss',
25
+ 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
26
+ 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear',
27
+ 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
28
+ 'GatedMLP',
29
+ 'RotaryEmbedding'
30
+ ]
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/fused_kl_div.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
fla/modules/__pycache__/seq_to_top.cpython-312.pyc ADDED
Binary file (4.12 kB). View file
 
fla/modules/fused_linear_listnet_loss.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Code adapted from
4
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
16
+ from torch.distributed.tensor.parallel import ParallelStyle
17
+
18
+ from fla.ops.utils import logsumexp_fwd
19
+ from fla.ops.utils.op import exp
20
+ from fla.utils import input_guard
21
+
22
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
23
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
24
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
25
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
26
+ MAX_FUSED_SIZE = 65536 // 2
27
+
28
+ @triton.jit
29
+ def listnet_kernel(
30
+ logits,
31
+ targets, # Now full target distributions
32
+ lse_logits,
33
+ lse_targets,
34
+ loss,
35
+ total,
36
+ ignore_index,
37
+ logit_scale: tl.constexpr,
38
+ reduction: tl.constexpr,
39
+ V: tl.constexpr,
40
+ BV: tl.constexpr
41
+ ):
42
+ i_n = tl.program_id(0).to(tl.int64)
43
+ NV = tl.cdiv(V, BV)
44
+
45
+ # Pointers to current token's data
46
+ logits_ptr = logits + i_n * V
47
+ targets_ptr = targets + i_n * V
48
+ loss_ptr = loss + i_n
49
+
50
+ # Compute prediction softmax
51
+ b_lse_logits = tl.load(lse_logits + i_n)
52
+ b_lse_targets = tl.load(lse_targets + i_n)
53
+ b_loss = 0.0
54
+
55
+ # Compute gradient: softmax(pred) - softmax(target)
56
+ for iv in range(0, NV):
57
+ o_v = iv * BV + tl.arange(0, BV)
58
+ mask = o_v < V
59
+
60
+ # Load target and compute softmax
61
+ t_val = tl.load(targets_ptr + o_v, mask=mask, other=0.0)
62
+ p_target = tl.exp(t_val - b_lse_targets)
63
+
64
+ # Load logits and compute softmax
65
+ l_val = tl.load(logits_ptr + o_v, mask=mask, other=0.0) * logit_scale
66
+ l_val_minus_lse = l_val - b_lse_logits
67
+ p_pred = tl.exp(l_val_minus_lse)
68
+
69
+ # Gradient calculation
70
+ grad_val = p_pred - p_target
71
+ if reduction == "mean":
72
+ grad_val = grad_val / total
73
+ grad_val = tl.where(b_lse_targets == float('inf'), 0.0, grad_val)
74
+ tl.store(logits_ptr + o_v, grad_val, mask=mask)
75
+
76
+ # Cross-entropy loss
77
+ # instead of: b_loss -= tl.sum(p_target * tl.log(p_pred), axis=0)
78
+ b_loss -= tl.sum(p_target * l_val_minus_lse, axis=0)
79
+
80
+ tl.store(loss_ptr, b_loss)
81
+
82
+ @triton.jit
83
+ def elementwise_mul_kernel(
84
+ x,
85
+ g,
86
+ N: tl.constexpr,
87
+ B: tl.constexpr
88
+ ):
89
+ """
90
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
91
+ The multiplication is performed in-place on the tensor pointed by x.
92
+
93
+ Parameters:
94
+ x:
95
+ Pointer to the input tensor.
96
+ g:
97
+ Pointer to the gradient output value.
98
+ N (int):
99
+ The number of columns in the input tensor.
100
+ B (int):
101
+ The block size for Triton operations.
102
+ """
103
+
104
+ # Get the program ID and convert it to int64 to avoid overflow
105
+ i_x = tl.program_id(0).to(tl.int64)
106
+ o_x = i_x * B + tl.arange(0, B)
107
+
108
+ # Load the gradient output value
109
+ b_g = tl.load(g)
110
+ b_x = tl.load(x + o_x, mask=o_x < N)
111
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
112
+
113
+
114
+ def fused_linear_listnet_forward(
115
+ x: torch.Tensor,
116
+ targets: torch.Tensor, # Float tensor [N, V]
117
+ weight: torch.Tensor,
118
+ bias: torch.Tensor = None,
119
+ ignore_index: int = -100,
120
+ logit_scale: float = 1.0,
121
+ num_chunks: int = 8,
122
+ reduction: str = "mean"
123
+ ):
124
+ N, H, V = *x.shape, weight.shape[0]
125
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
126
+ NC = min(num_chunks, triton.cdiv(V, H))
127
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
128
+ NC = triton.cdiv(N, C)
129
+
130
+ # Initialize outputs
131
+ dx = torch.zeros_like(x)
132
+ dw = torch.zeros_like(weight, dtype=torch.float) if weight is not None else None
133
+ db = torch.zeros_like(bias, dtype=torch.float) if bias is not None else None
134
+ loss = torch.zeros(N, device=x.device, dtype=torch.float)
135
+ total = N # All tokens considered
136
+
137
+ for ic in range(NC):
138
+ start, end = ic * C, min((ic + 1) * C, N)
139
+ c_x = x[start:end]
140
+ c_logits = F.linear(c_x, weight, bias)
141
+ c_targets = targets[start:end]
142
+ c_lse_logits = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float)
143
+ c_lse_targets = logsumexp_fwd(c_targets, dtype=torch.float).nan_to_num(nan=float("inf"))
144
+ c_loss = loss[start:end]
145
+
146
+ # Call ListNet kernel
147
+ listnet_kernel[(c_logits.shape[0],)](
148
+ logits=c_logits,
149
+ targets=c_targets, # Full target distributions
150
+ lse_logits=c_lse_logits,
151
+ lse_targets=c_lse_targets,
152
+ loss=c_loss,
153
+ total=total,
154
+ ignore_index=ignore_index,
155
+ logit_scale=logit_scale,
156
+ reduction=reduction,
157
+ V=V,
158
+ BV=BV,
159
+ num_warps=32
160
+ )
161
+
162
+ # Backward through linear layer
163
+ dx[start:end] = torch.mm(c_logits, weight)
164
+ if weight is not None:
165
+ dw += c_logits.t() @ c_x
166
+ if bias is not None:
167
+ db += c_logits.sum(0)
168
+
169
+ loss = loss.sum()
170
+ if reduction == "mean":
171
+ loss = loss / total
172
+
173
+ return loss, dx, dw, db
174
+
175
+
176
+ def fused_linear_listnet_backward(
177
+ do: torch.Tensor,
178
+ dx: torch.Tensor,
179
+ dw: torch.Tensor,
180
+ db: torch.Tensor
181
+ ):
182
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
183
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
184
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
185
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
186
+ N, H = dx.shape
187
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
188
+
189
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
190
+ x=dx,
191
+ g=do,
192
+ N=N*H,
193
+ B=B,
194
+ num_warps=32,
195
+ )
196
+
197
+ # handle dw
198
+ if dw is not None:
199
+ V, H = dw.shape
200
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
201
+ x=dw,
202
+ g=do,
203
+ N=V*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ if db is not None:
209
+ V = db.shape[0]
210
+ elementwise_mul_kernel[(triton.cdiv(V, B),)](
211
+ x=db,
212
+ g=do,
213
+ N=V,
214
+ B=B,
215
+ num_warps=32,
216
+ )
217
+ return dx, dw, db
218
+
219
+
220
+ class FusedLinearListNetFunction(torch.autograd.Function):
221
+ @staticmethod
222
+ def forward(
223
+ ctx,
224
+ x: torch.Tensor,
225
+ targets: torch.Tensor, # Float targets
226
+ weight: torch.Tensor,
227
+ bias: torch.Tensor = None,
228
+ ignore_index: int = -100,
229
+ logit_scale: float = 1.0,
230
+ num_chunks: int = 8,
231
+ reduction: str = "mean"
232
+ ):
233
+ loss, dx, dw, db = fused_linear_listnet_forward(
234
+ x, targets, weight, bias, ignore_index,
235
+ logit_scale, num_chunks, reduction
236
+ )
237
+ ctx.save_for_backward(dx, dw, db)
238
+ return loss
239
+
240
+ @staticmethod
241
+ def backward(ctx, do):
242
+ dx, dw, db = ctx.saved_tensors
243
+ dx, dw, db = fused_linear_listnet_backward(do, dx, dw, db)
244
+ return dx, None, dw, db, None, None, None, None
245
+
246
+
247
+ def fused_linear_listnet_loss(
248
+ x: torch.Tensor,
249
+ target: torch.LongTensor,
250
+ weight: torch.Tensor,
251
+ bias: torch.Tensor = None,
252
+ ignore_index: int = -100,
253
+ label_smoothing: float = 0.0,
254
+ logit_scale: float = 1.0,
255
+ num_chunks: int = 8,
256
+ reduction: str = "mean"
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """
259
+ Args:
260
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
261
+ target (torch.LongTensor): [batch_size * seq_len]
262
+ where each value is in [0, vocab_size).
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ bias (Optional[torch.Tensor]): [vocab_size]
266
+ where `vocab_size` is the number of classes.
267
+ ignore_index: int.
268
+ If target == ignore_index, the loss is set to 0.0.
269
+ label_smoothing: float
270
+ logit_scale: float
271
+ A scaling factor applied to the logits. Default: 1.0
272
+ num_chunks: int
273
+ The number of chunks to split the input tensor into for processing.
274
+ This can help optimize memory usage and computation speed.
275
+ Default: 8
276
+ reduction:
277
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
278
+ 'mean': the weighted mean of the output is taken,
279
+ 'sum': the output will be summed.
280
+ Default: 'mean'.
281
+ Returns:
282
+ losses: [batch,], float
283
+ """
284
+ return FusedLinearListNetFunction.apply(
285
+ x,
286
+ target,
287
+ weight,
288
+ bias,
289
+ ignore_index,
290
+ logit_scale,
291
+ num_chunks,
292
+ reduction
293
+ )
294
+
295
+
296
+ class FusedLinearListNetLoss(nn.Module):
297
+
298
+ def __init__(
299
+ self,
300
+ ignore_index: int = -100,
301
+ label_smoothing: float = 0.0,
302
+ logit_scale: float = 1.0,
303
+ num_chunks: int = 8,
304
+ reduction: str = "mean"
305
+ ):
306
+ """
307
+ Args:
308
+ ignore_index: int.
309
+ If target == ignore_index, the loss is set to 0.0.
310
+ label_smoothing: float
311
+ logit_scale: float
312
+ A scaling factor applied to the logits. Default: 1.0
313
+ num_chunks: int
314
+ The number of chunks to split the input tensor into for processing.
315
+ This can help optimize memory usage and computation speed.
316
+ Default: 8
317
+ reduction:
318
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
319
+ 'mean': the weighted mean of the output is taken,
320
+ 'sum': the output will be summed.
321
+ Default: 'mean'.
322
+ """
323
+ super().__init__()
324
+
325
+ assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported"
326
+
327
+ self.ignore_index = ignore_index
328
+ self.label_smoothing = label_smoothing
329
+ self.logit_scale = logit_scale
330
+ self.num_chunks = num_chunks
331
+ self.reduction = reduction
332
+
333
+ @torch.compiler.disable
334
+ def forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ target: torch.LongTensor,
338
+ weight: torch.Tensor,
339
+ bias: Optional[torch.Tensor] = None
340
+ ):
341
+ """
342
+ Args:
343
+ x (torch.Tensor): [batch_size, seq_len, hidden_size]
344
+ target (torch.LongTensor): [batch_size, seq_len]
345
+ where each value is in [0, V).
346
+ weight (torch.Tensor): [vocab_size, hidden_size]
347
+ where `vocab_size` is the number of classes.
348
+ bias (Optional[torch.Tensor]): [vocab_size]
349
+ where `vocab_size` is the number of classes.
350
+ Returns:
351
+ loss
352
+ """
353
+ loss = fused_linear_listnet_loss(
354
+ x.view(-1, x.shape[-1]),
355
+ target.view(-1, target.shape[-1]),
356
+ weight=weight,
357
+ bias=bias,
358
+ ignore_index=self.ignore_index,
359
+ label_smoothing=self.label_smoothing,
360
+ logit_scale=self.logit_scale,
361
+ num_chunks=self.num_chunks,
362
+ reduction=self.reduction
363
+ )
364
+ return loss
365
+
366
+
367
+ class LinearLossParallel(ParallelStyle):
368
+ def __init__(
369
+ self,
370
+ *,
371
+ sequence_dim: int = 1,
372
+ use_local_output: bool = False,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.sequence_sharding = (Shard(sequence_dim),)
377
+ self.use_local_output = use_local_output
378
+
379
+ @staticmethod
380
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
381
+ x, target, weight, bias = inputs
382
+
383
+ if not isinstance(x, DTensor):
384
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
385
+ x = DTensor.from_local(x, device_mesh, sequence_sharding)
386
+ if x.placements != sequence_sharding:
387
+ x = x.redistribute(placements=sequence_sharding, async_op=True)
388
+ if not isinstance(target, DTensor):
389
+ target = DTensor.from_local(target, device_mesh, [Replicate()])
390
+ if target.placements != sequence_sharding:
391
+ target = target.redistribute(placements=sequence_sharding, async_op=True)
392
+
393
+ if not isinstance(weight, DTensor):
394
+ weight = DTensor.from_local(weight, device_mesh, [Replicate()])
395
+ if weight.placements != [Replicate()]:
396
+ # we replicate the weight/bias in FLCE
397
+ weight = weight.redistribute(placements=[Replicate()], async_op=True)
398
+
399
+ if bias is not None and not isinstance(bias, DTensor):
400
+ bias = DTensor.from_local(bias, device_mesh, [Replicate()])
401
+ if bias is not None and bias.placements != [Replicate()]:
402
+ bias = bias.redistribute(placements=[Replicate()], async_op=True)
403
+
404
+ return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias
405
+
406
+ @staticmethod
407
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
408
+ return outputs.to_local() if use_local_output else outputs
409
+
410
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
411
+ return distribute_module(
412
+ module,
413
+ device_mesh,
414
+ partition_fn=None,
415
+ input_fn=partial(self._prepare_input_fn, self.sequence_sharding),
416
+ output_fn=partial(self._prepare_output_fn, self.use_local_output)
417
+ )
418
+
419
+ # Naive ListNet loss function implementation
420
+ def list_net_loss(y_pred, y_true):
421
+ """
422
+ ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
423
+ :param y_pred: predictions from the model, shape [*, slate_length]
424
+ :param y_true: ground truth labels, shape [*, slate_length]
425
+ :return: loss value, a torch.Tensor
426
+ """
427
+ return torch.mean(-torch.sum(F.softmax(y_true, dim=-1).nan_to_num(nan=0) * F.log_softmax(y_pred, dim=-1), dim=-1))
fla/modules/l2norm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import input_guard
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=num_warps)
16
+ for num_warps in [1, 2, 4, 8, 16, 32]
17
+ ],
18
+ key=['N']
19
+ )
20
+ @triton.jit
21
+ def l2norm_fwd_kernel(
22
+ X,
23
+ Y,
24
+ N,
25
+ eps,
26
+ BLOCK_N: tl.constexpr,
27
+ ):
28
+ i_m = tl.program_id(0)
29
+ X += i_m * N
30
+ Y += i_m * N
31
+ # Compute mean and variance
32
+ cols = tl.arange(0, BLOCK_N)
33
+ mask = cols < N
34
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
35
+ xbar = tl.where(mask, x, 0.0)
36
+ var = tl.sum(xbar * xbar, axis=0)
37
+ rstd = 1 / tl.sqrt(var + eps)
38
+ # tl.store(Rstd + i_m, rstd)
39
+ # Normalize and apply linear transformation
40
+ y = x * rstd
41
+ # Write output
42
+ tl.store(Y + cols, y, mask=mask)
43
+
44
+
45
+ @triton.autotune(
46
+ configs=[
47
+ triton.Config({}, num_warps=num_warps)
48
+ for num_warps in [1, 2, 4, 8, 16, 32]
49
+ ],
50
+ key=['N']
51
+ )
52
+ @triton.jit
53
+ def l2norm_bwd_kernel(
54
+ X,
55
+ DY,
56
+ DX,
57
+ N,
58
+ eps,
59
+ BLOCK_N: tl.constexpr,
60
+ ):
61
+ i_m = tl.program_id(0)
62
+ X += i_m * N
63
+ DX += i_m * N
64
+ DY += i_m * N
65
+
66
+ # Y += i_m * stride_y_row
67
+ cols = tl.arange(0, BLOCK_N)
68
+ mask = cols < N
69
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
70
+ x = tl.where(mask, x, 0.0)
71
+ var = tl.sum(x * x)
72
+ rstd = 1 / tl.sqrt(var + eps)
73
+ # tl.store(Rstd + i_m, rstd)
74
+ # Normalize and apply linear transformation
75
+ # y = x * rstd
76
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
77
+ dy = tl.where(mask, dy, 0.0)
78
+ dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
79
+ tl.store(DX + cols, dx, mask=mask)
80
+
81
+
82
+ def l2norm_fwd(
83
+ x: torch.Tensor,
84
+ eps: float = 1e-6,
85
+ output_dtype: Optional[torch.dtype] = None
86
+ ):
87
+ x_shape_og = x.shape
88
+ x = x.reshape(-1, x.shape[-1])
89
+ # allocate output
90
+ if output_dtype is None:
91
+ y = torch.empty_like(x)
92
+ else:
93
+ y = torch.empty_like(x, dtype=output_dtype)
94
+ assert y.stride(-1) == 1
95
+ N = x.shape[-1]
96
+ M = x.shape[0]
97
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
98
+ # Less than 64KB per feature: enqueue fused kernel
99
+ MAX_FUSED_SIZE = 65536 // x.element_size()
100
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
101
+ if N > BLOCK_N:
102
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
103
+ # heuristics for number of warps
104
+ l2norm_fwd_kernel[(M,)](
105
+ x,
106
+ y,
107
+ N,
108
+ eps,
109
+ BLOCK_N,
110
+ )
111
+ return y.reshape(x_shape_og)
112
+
113
+
114
+ def l2norm_bwd(
115
+ x: torch.Tensor,
116
+ dy: torch.Tensor,
117
+ eps: float = 1e-5
118
+ ):
119
+ x_shape_og = x.shape
120
+ x = x.reshape(-1, dy.shape[-1])
121
+ dy = dy.reshape(-1, dy.shape[-1])
122
+ if dy.stride(-1) != 1:
123
+ dy = dy.contiguous()
124
+ assert dy.shape == x.shape
125
+ # allocate output
126
+ dx = torch.empty_like(x)
127
+ M = x.shape[0]
128
+ N = x.shape[-1]
129
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
130
+ # Less than 64KB per feature: enqueue fused kernel
131
+ MAX_FUSED_SIZE = 65536 // x.element_size()
132
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
133
+ if N > BLOCK_N:
134
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
135
+ # heuristics for number of warps
136
+ l2norm_bwd_kernel[(M,)](
137
+ x,
138
+ dy,
139
+ dx,
140
+ N,
141
+ eps,
142
+ BLOCK_N,
143
+ )
144
+ return dx.reshape(x_shape_og)
145
+
146
+
147
+ class L2NormFunction(torch.autograd.Function):
148
+
149
+ @staticmethod
150
+ @input_guard
151
+ def forward(
152
+ ctx,
153
+ x,
154
+ eps=1e-6,
155
+ output_dtype=None
156
+ ):
157
+ y = l2norm_fwd(x, eps, output_dtype)
158
+ ctx.eps = eps
159
+ ctx.x_dtype = x.dtype
160
+ ctx.save_for_backward(x)
161
+ return y
162
+
163
+ @staticmethod
164
+ @input_guard
165
+ def backward(ctx, dy):
166
+ x, = ctx.saved_tensors
167
+ dx = l2norm_bwd(x, dy, ctx.eps)
168
+ return dx, None, None
169
+
170
+
171
+ def l2_norm(
172
+ x: torch.Tensor,
173
+ eps: float = 1e-6,
174
+ output_dtype: Optional[torch.dtype] = None
175
+ ) -> torch.Tensor:
176
+ return L2NormFunction.apply(x, eps, output_dtype)
fla/modules/layernorm_gated.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import triton
14
+ import triton.language as tl
15
+ from einops import rearrange
16
+
17
+ from fla.utils import get_multiprocessor_count, input_guard
18
+
19
+
20
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
21
+ dtype = x.dtype
22
+ weight = weight.float()
23
+ bias = bias.float() if bias is not None else None
24
+ if upcast:
25
+ x = x.float()
26
+ z = z.float() if z is not None else z
27
+ if z is not None and not norm_before_gate:
28
+ x = x * F.silu(z)
29
+ if group_size is None:
30
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
31
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
32
+ else:
33
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
34
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
35
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
36
+ if bias is not None:
37
+ out = out + bias
38
+ if z is not None and norm_before_gate:
39
+ out *= F.silu(z)
40
+ return out.to(dtype)
41
+
42
+
43
+ @triton.heuristics({
44
+ "HAS_BIAS": lambda args: args["B"] is not None,
45
+ "HAS_Z": lambda args: args["Z"] is not None,
46
+ })
47
+ @triton.jit
48
+ def layer_norm_fwd_kernel(
49
+ X, # pointer to the input
50
+ Y, # pointer to the output
51
+ W, # pointer to the weights
52
+ B, # pointer to the biases
53
+ Z, # pointer to the other branch
54
+ Mean, # pointer to the mean
55
+ Rstd, # pointer to the 1/std
56
+ stride_x_row, # how much to increase the pointer when moving by 1 row
57
+ stride_y_row,
58
+ stride_z_row,
59
+ M, # number of rows in X
60
+ N, # number of columns in X
61
+ eps, # epsilon to avoid division by zero
62
+ BLOCK_N: tl.constexpr,
63
+ HAS_BIAS: tl.constexpr,
64
+ HAS_Z: tl.constexpr,
65
+ NORM_BEFORE_GATE: tl.constexpr,
66
+ IS_RMS_NORM: tl.constexpr,
67
+ ):
68
+ # Map the program id to the row of X and Y it should compute.
69
+ row = tl.program_id(0)
70
+ group = tl.program_id(1)
71
+ X += row * stride_x_row + group * N
72
+ Y += row * stride_y_row + group * N
73
+ if HAS_Z:
74
+ Z += row * stride_z_row + group * N
75
+ if not IS_RMS_NORM:
76
+ Mean += group * M
77
+ Rstd += group * M
78
+ W += group * N
79
+ if HAS_BIAS:
80
+ B += group * N
81
+ # Compute mean and variance
82
+ cols = tl.arange(0, BLOCK_N)
83
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
84
+ if HAS_Z and not NORM_BEFORE_GATE:
85
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
86
+ x *= z * tl.sigmoid(z)
87
+ if not IS_RMS_NORM:
88
+ mean = tl.sum(x, axis=0) / N
89
+ tl.store(Mean + row, mean)
90
+ xbar = tl.where(cols < N, x - mean, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ else:
93
+ xbar = tl.where(cols < N, x, 0.)
94
+ var = tl.sum(xbar * xbar, axis=0) / N
95
+ rstd = 1 / tl.sqrt(var + eps)
96
+ tl.store(Rstd + row, rstd)
97
+ # Normalize and apply linear transformation
98
+ mask = cols < N
99
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
100
+ if HAS_BIAS:
101
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
102
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
103
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
104
+ if HAS_Z and NORM_BEFORE_GATE:
105
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
106
+ y *= z * tl.sigmoid(z)
107
+ # Write output
108
+ tl.store(Y + cols, y, mask=mask)
109
+
110
+
111
+ def layer_norm_fwd(
112
+ x: torch.Tensor,
113
+ weight: torch.Tensor,
114
+ bias: torch.Tensor,
115
+ eps: float,
116
+ z: torch.Tensor = None,
117
+ out: torch.Tensor = None,
118
+ group_size: int = None,
119
+ norm_before_gate: bool = True,
120
+ is_rms_norm: bool = False,
121
+ ):
122
+ M, N = x.shape
123
+ if group_size is None:
124
+ group_size = N
125
+ assert N % group_size == 0
126
+ ngroups = N // group_size
127
+ assert x.stride(-1) == 1
128
+ if z is not None:
129
+ assert z.stride(-1) == 1
130
+ assert z.shape == (M, N)
131
+ assert weight.shape == (N,)
132
+ assert weight.stride(-1) == 1
133
+ if bias is not None:
134
+ assert bias.stride(-1) == 1
135
+ assert bias.shape == (N,)
136
+ # allocate output
137
+ if out is not None:
138
+ assert out.shape == x.shape
139
+ else:
140
+ out = torch.empty_like(x)
141
+ assert out.stride(-1) == 1
142
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
143
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
144
+ # Less than 64KB per feature: enqueue fused kernel
145
+ MAX_FUSED_SIZE = 65536 // x.element_size()
146
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
147
+ if group_size > BLOCK_N:
148
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
149
+ # heuristics for number of warps
150
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
151
+ grid = (M, ngroups)
152
+ layer_norm_fwd_kernel[grid](
153
+ x,
154
+ out,
155
+ weight,
156
+ bias,
157
+ z,
158
+ mean,
159
+ rstd,
160
+ x.stride(0),
161
+ out.stride(0),
162
+ z.stride(0) if z is not None else 0,
163
+ M,
164
+ group_size,
165
+ eps,
166
+ BLOCK_N=BLOCK_N,
167
+ NORM_BEFORE_GATE=norm_before_gate,
168
+ IS_RMS_NORM=is_rms_norm,
169
+ num_warps=num_warps
170
+ )
171
+ return out, mean, rstd
172
+
173
+
174
+ @triton.heuristics({
175
+ "HAS_BIAS": lambda args: args["B"] is not None,
176
+ "HAS_Z": lambda args: args["Z"] is not None,
177
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None,
178
+ })
179
+ @triton.jit
180
+ def layer_norm_bwd_kernel(
181
+ X, # pointer to the input
182
+ W, # pointer to the weights
183
+ B, # pointer to the biases
184
+ Z, # pointer to the other branch
185
+ Y, # pointer to the output to be recomputed
186
+ DY, # pointer to the output gradient
187
+ DX, # pointer to the input gradient
188
+ DW, # pointer to the partial sum of weights gradient
189
+ DB, # pointer to the partial sum of biases gradient
190
+ DZ, # pointer to the other branch
191
+ Mean, # pointer to the mean
192
+ Rstd, # pointer to the 1/std
193
+ stride_x_row, # how much to increase the pointer when moving by 1 row
194
+ stride_z_row,
195
+ stride_y_row,
196
+ stride_dy_row,
197
+ stride_dx_row,
198
+ stride_dz_row,
199
+ stride_dw_row,
200
+ stride_db_row,
201
+ M, # number of rows in X
202
+ N, # number of columns in X
203
+ eps, # epsilon to avoid division by zero
204
+ rows_per_program,
205
+ NORM_BEFORE_GATE: tl.constexpr,
206
+ IS_RMS_NORM: tl.constexpr,
207
+ HAS_BIAS: tl.constexpr,
208
+ HAS_Z: tl.constexpr,
209
+ RECOMPUTE_OUTPUT: tl.constexpr,
210
+ BLOCK_N: tl.constexpr,
211
+ ):
212
+ # Map the program id to the elements of X, DX, and DY it should compute.
213
+ row_block_id = tl.program_id(0)
214
+ group = tl.program_id(1)
215
+ row_start = row_block_id * rows_per_program
216
+ cols = tl.arange(0, BLOCK_N)
217
+ mask = cols < N
218
+ X += row_start * stride_x_row + group * N
219
+ if HAS_Z:
220
+ Z += row_start * stride_z_row + group * N
221
+ DZ += row_start * stride_dz_row + group * N
222
+ DY += row_start * stride_dy_row + group * N
223
+ DX += row_start * stride_dx_row + group * N
224
+ if RECOMPUTE_OUTPUT:
225
+ Y += row_start * stride_y_row + group * N
226
+ if not IS_RMS_NORM:
227
+ Mean += group * M
228
+ Rstd += group * M
229
+ W += group * N
230
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
231
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
232
+ B += group * N
233
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
234
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
235
+ if HAS_BIAS:
236
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
237
+ row_end = min((row_block_id + 1) * rows_per_program, M)
238
+ for row in range(row_start, row_end):
239
+ # Load data to SRAM
240
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
241
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
242
+ if not IS_RMS_NORM:
243
+ mean = tl.load(Mean + row)
244
+ if HAS_Z and not NORM_BEFORE_GATE:
245
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
246
+ x_og = x
247
+ x = x_og * z * tl.sigmoid(z)
248
+ rstd = tl.load(Rstd + row)
249
+ # Compute dx
250
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
251
+ xhat = tl.where(mask, xhat, 0.)
252
+ if HAS_Z and NORM_BEFORE_GATE:
253
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
254
+ z_sigmoid = tl.sigmoid(z)
255
+ y = xhat * w + b if HAS_BIAS else xhat * w
256
+ if RECOMPUTE_OUTPUT:
257
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
258
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
259
+ tl.store(DZ + cols, dz, mask=mask)
260
+ dy *= z * z_sigmoid
261
+ else:
262
+ if RECOMPUTE_OUTPUT:
263
+ y = xhat * w + b if HAS_BIAS else xhat * w
264
+ tl.store(Y + cols, y, mask=mask)
265
+ wdy = w * dy
266
+ c1 = tl.sum(xhat * wdy, axis=0) / N
267
+ if not IS_RMS_NORM:
268
+ c2 = tl.sum(wdy, axis=0) / N
269
+ dx = (wdy - (xhat * c1 + c2)) * rstd
270
+ else:
271
+ dx = (wdy - xhat * c1) * rstd
272
+ dw += dy * xhat
273
+ if HAS_BIAS:
274
+ db += dy
275
+ if HAS_Z and not NORM_BEFORE_GATE:
276
+ z_sigmoid = tl.sigmoid(z)
277
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
278
+ tl.store(DZ + cols, dz, mask=mask)
279
+ dx *= z * z_sigmoid
280
+ # Write dx
281
+ tl.store(DX + cols, dx, mask=mask)
282
+
283
+ X += stride_x_row
284
+ if HAS_Z:
285
+ Z += stride_z_row
286
+ DZ += stride_dz_row
287
+ if RECOMPUTE_OUTPUT:
288
+ Y += stride_y_row
289
+ DY += stride_dy_row
290
+ DX += stride_dx_row
291
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
292
+ if HAS_BIAS:
293
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
294
+
295
+
296
+ def layer_norm_bwd(
297
+ dy: torch.Tensor,
298
+ x: torch.Tensor,
299
+ weight: torch.Tensor,
300
+ bias: torch.Tensor,
301
+ eps: float,
302
+ mean: torch.Tensor,
303
+ rstd: torch.Tensor,
304
+ z: torch.Tensor = None,
305
+ group_size: int = None,
306
+ norm_before_gate: bool = True,
307
+ is_rms_norm: bool = False,
308
+ recompute_output: bool = False,
309
+ dz: torch.Tensor = None,
310
+ out: torch.Tensor = None,
311
+ ):
312
+ M, N = x.shape
313
+ if group_size is None:
314
+ group_size = N
315
+ assert N % group_size == 0
316
+ ngroups = N // group_size
317
+ assert x.stride(-1) == 1
318
+ assert dy.stride(-1) == 1
319
+ assert dy.shape == (M, N)
320
+ if z is not None:
321
+ assert z.stride(-1) == 1
322
+ assert z.shape == (M, N)
323
+ assert weight.shape == (N,)
324
+ assert weight.stride(-1) == 1
325
+ if bias is not None:
326
+ assert bias.stride(-1) == 1
327
+ assert bias.shape == (N,)
328
+ # allocate output
329
+ dx = torch.empty_like(x)
330
+ if dz is not None:
331
+ assert z is not None
332
+ assert dz.shape == z.shape
333
+ assert dz.stride(-1) == 1
334
+ else:
335
+ dz = torch.empty_like(z) if z is not None else None
336
+ if recompute_output:
337
+ if out is None:
338
+ out = torch.empty_like(x)
339
+ assert out.shape == x.shape
340
+
341
+ # Less than 64KB per feature: enqueue fused kernel
342
+ MAX_FUSED_SIZE = 65536 // x.element_size()
343
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
344
+ if group_size > BLOCK_N:
345
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
346
+ # heuristics for number of warps
347
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
348
+ sm_count = get_multiprocessor_count(x.device.index)
349
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
350
+ # would limit the occupancy.
351
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
352
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
353
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
354
+ rows_per_program = math.ceil(M / nrow_groups)
355
+ grid = (nrow_groups, ngroups)
356
+ layer_norm_bwd_kernel[grid](
357
+ x,
358
+ weight,
359
+ bias,
360
+ z,
361
+ out if recompute_output else None,
362
+ dy,
363
+ dx,
364
+ _dw,
365
+ _db,
366
+ dz,
367
+ mean,
368
+ rstd,
369
+ x.stride(0),
370
+ z.stride(0) if z is not None else 0,
371
+ 0 if not recompute_output else out.stride(0),
372
+ dy.stride(0),
373
+ dx.stride(0),
374
+ dz.stride(0) if dz is not None else 0,
375
+ _dw.stride(0),
376
+ _db.stride(0) if _db is not None else 0,
377
+ M, group_size, eps,
378
+ rows_per_program,
379
+ BLOCK_N=BLOCK_N,
380
+ NORM_BEFORE_GATE=norm_before_gate,
381
+ IS_RMS_NORM=is_rms_norm,
382
+ num_warps=num_warps
383
+ )
384
+ dw = _dw.sum(0).to(weight.dtype)
385
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
386
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
387
+
388
+
389
+ class LayerNormFn(torch.autograd.Function):
390
+
391
+ @input_guard
392
+ @staticmethod
393
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
394
+ is_rms_norm=False):
395
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
396
+ """
397
+
398
+ x_shape_og = x.shape
399
+ # reshape input data into 2D tensor
400
+ x = x.reshape(-1, x.shape[-1])
401
+ if x.stride(-1) != 1:
402
+ x = x.contiguous()
403
+ if z is not None:
404
+ assert z.shape == x_shape_og
405
+ z = z.reshape(-1, z.shape[-1])
406
+ if z.stride(-1) != 1:
407
+ z = z.contiguous()
408
+ weight = weight.contiguous()
409
+ if bias is not None:
410
+ bias = bias.contiguous()
411
+ y, mean, rstd = layer_norm_fwd(
412
+ x,
413
+ weight,
414
+ bias,
415
+ eps,
416
+ z=z,
417
+ group_size=group_size,
418
+ norm_before_gate=norm_before_gate,
419
+ is_rms_norm=is_rms_norm,
420
+ )
421
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
422
+ ctx.x_shape_og = x_shape_og
423
+ ctx.eps = eps
424
+ ctx.group_size = group_size
425
+ ctx.norm_before_gate = norm_before_gate
426
+ ctx.is_rms_norm = is_rms_norm
427
+ return y.reshape(x_shape_og)
428
+
429
+ @input_guard
430
+ @staticmethod
431
+ def backward(ctx, dy):
432
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
433
+ dy = dy.reshape(-1, dy.shape[-1])
434
+ if dy.stride(-1) != 1:
435
+ dy = dy.contiguous()
436
+ assert dy.shape == x.shape
437
+ dx, dw, db, dz = layer_norm_bwd(
438
+ dy,
439
+ x,
440
+ weight,
441
+ bias,
442
+ ctx.eps,
443
+ mean,
444
+ rstd,
445
+ z,
446
+ ctx.group_size,
447
+ ctx.norm_before_gate,
448
+ ctx.is_rms_norm
449
+ )
450
+ dx = dx.reshape(ctx.x_shape_og)
451
+ dz = dz.reshape(ctx.x_shape_og) if dz is not None else None
452
+ return dx, dw, db, dz, None, None, None, None
453
+
454
+
455
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
456
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
457
+
458
+
459
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
460
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
461
+
462
+
463
+ class LayerNormGated(nn.Module):
464
+
465
+ def __init__(
466
+ self,
467
+ hidden_size,
468
+ eps: float = 1e-5,
469
+ group_size: Optional[int] = None,
470
+ norm_before_gate: bool = True,
471
+ device: Optional[torch.device] = None,
472
+ dtype: Optional[torch.dtype] = None,
473
+ ):
474
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
475
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
476
+ """
477
+
478
+ factory_kwargs = {"device": device, "dtype": dtype}
479
+ super().__init__()
480
+ self.eps = eps
481
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
482
+ self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
483
+ self.group_size = group_size
484
+ self.norm_before_gate = norm_before_gate
485
+ self.reset_parameters()
486
+
487
+ def reset_parameters(self):
488
+ torch.nn.init.ones_(self.weight)
489
+ torch.nn.init.zeros_(self.bias)
490
+
491
+ def forward(self, x, z=None):
492
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
493
+ """
494
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
495
+ norm_before_gate=self.norm_before_gate)
496
+
497
+
498
+ class RMSNormGated(nn.Module):
499
+
500
+ def __init__(
501
+ self,
502
+ hidden_size,
503
+ eps: float = 1e-5,
504
+ group_size: Optional[int] = None,
505
+ norm_before_gate: bool = False,
506
+ device: Optional[torch.device] = None,
507
+ dtype: Optional[torch.dtype] = None,
508
+ ):
509
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
510
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
511
+ """
512
+ factory_kwargs = {"device": device, "dtype": dtype}
513
+ super().__init__()
514
+ self.eps = eps
515
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
516
+ self.register_parameter("bias", None)
517
+ self.group_size = group_size
518
+ self.norm_before_gate = norm_before_gate
519
+ self.reset_parameters()
520
+
521
+ def reset_parameters(self):
522
+ torch.nn.init.ones_(self.weight)
523
+
524
+ def forward(self, x, z=None):
525
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
526
+ """
527
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
528
+ norm_before_gate=self.norm_before_gate)
fla/modules/mlp.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from functools import partial
7
+ from typing import TYPE_CHECKING, Any, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.distributed import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_module
13
+ from torch.distributed.tensor.parallel import ParallelStyle
14
+
15
+ from fla.modules.activations import swiglu, swiglu_linear
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+
21
+ class GatedMLP(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ hidden_ratio: Optional[int] = None,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = 'swish',
29
+ fuse_swiglu: bool = True
30
+ ) -> GatedMLP:
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ # the final number of params is `hidden_ratio * hidden_size^2`
35
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
36
+ if hidden_ratio is None:
37
+ hidden_ratio = 4
38
+ if intermediate_size is None:
39
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
40
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
41
+ self.hidden_ratio = hidden_ratio
42
+ self.intermediate_size = intermediate_size
43
+ self.hidden_act = hidden_act
44
+ self.fuse_swiglu = fuse_swiglu
45
+
46
+ if hidden_act != 'swish':
47
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
48
+
49
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
50
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
51
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
52
+ if self.fuse_swiglu:
53
+ self.swiglu_linear = SwiGLULinear()
54
+
55
+ def forward(
56
+ self,
57
+ x: torch.Tensor,
58
+ **kwargs: Unpack[Any]
59
+ ) -> torch.Tensor:
60
+ gate, y = self.gate_proj(x), self.up_proj(x)
61
+ if self.fuse_swiglu:
62
+ return self.swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
63
+ else:
64
+ return self.down_proj(swiglu(gate, y))
65
+
66
+
67
+ class SwiGLULinear(nn.Module):
68
+
69
+ def forward(self, x, y, weight, bias):
70
+ return swiglu_linear(x, y, weight, bias)
71
+
72
+
73
+ class SwiGLULinearParallel(ParallelStyle):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ input_layouts: Optional[Placement] = None,
78
+ output_layouts: Optional[Placement] = None,
79
+ use_local_output: bool = True,
80
+ ):
81
+ super().__init__()
82
+ self.input_layouts = (input_layouts or Shard(-1),)
83
+ self.output_layouts = (output_layouts or Replicate(),)
84
+ self.desired_input_layouts = (Shard(-1),)
85
+ self.use_local_output = use_local_output
86
+
87
+ @staticmethod
88
+ def _prepare_input_fn(
89
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
90
+ ):
91
+ x, y, weight, bias = inputs
92
+ if not isinstance(x, DTensor):
93
+ x = DTensor.from_local(x, device_mesh, input_layouts, run_check=False)
94
+ if x.placements != desired_input_layouts:
95
+ x = x.redistribute(placements=desired_input_layouts, async_op=True)
96
+
97
+ if not isinstance(y, DTensor):
98
+ y = DTensor.from_local(y, device_mesh, input_layouts, run_check=False)
99
+ if y.placements != desired_input_layouts:
100
+ y = y.redistribute(placements=desired_input_layouts, async_op=True)
101
+
102
+ if not isinstance(weight, DTensor):
103
+ weight = DTensor.from_local(weight, device_mesh, (Shard(1),))
104
+
105
+ if bias is not None and not isinstance(bias, DTensor):
106
+ bias = DTensor.from_local(bias, device_mesh, (Replicate(),))
107
+
108
+ return x, y, weight, bias
109
+
110
+ @staticmethod
111
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
112
+ # Rowwise sharding produces partial output, depending on output layouts:
113
+ # 1. to replicate -> allreduce
114
+ # 2. to shard -> reduce_scatter
115
+ if outputs.placements != output_layouts:
116
+ outputs = outputs.redistribute(placements=output_layouts, async_op=True)
117
+ # back to local tensor if use_local_output is True
118
+ return outputs.to_local() if use_local_output else outputs
119
+
120
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
121
+ return distribute_module(
122
+ module,
123
+ device_mesh,
124
+ partition_fn=None,
125
+ input_fn=partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
126
+ output_fn=partial(self._prepare_output_fn, self.output_layouts, self.use_local_output)
127
+ )
fla/ops/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .abc import chunk_abc
4
+ from .attn import parallel_attn
5
+ from .based import fused_chunk_based, parallel_based
6
+ from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule
7
+ from .forgetting_attn import parallel_forgetting_attn
8
+ from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
9
+ from .generalized_delta_rule import (
10
+ chunk_dplr_delta_rule,
11
+ chunk_iplr_delta_rule,
12
+ fused_recurrent_dplr_delta_rule,
13
+ fused_recurrent_iplr_delta_rule
14
+ )
15
+ from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
16
+ from .gsa import chunk_gsa, fused_recurrent_gsa
17
+ from .hgrn import fused_recurrent_hgrn
18
+ from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn
19
+ from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
20
+ from .nsa import parallel_nsa
21
+ from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
22
+ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
23
+ from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
24
+ from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla
25
+
26
+ __all__ = [
27
+ 'chunk_abc',
28
+ 'parallel_attn',
29
+ 'fused_chunk_based', 'parallel_based',
30
+ 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule',
31
+ 'parallel_forgetting_attn',
32
+ 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule',
33
+ 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule',
34
+ 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule',
35
+ 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla',
36
+ 'chunk_gsa', 'fused_recurrent_gsa',
37
+ 'fused_recurrent_hgrn',
38
+ 'chunk_lightning_attn', 'fused_recurrent_lightning_attn',
39
+ 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn',
40
+ 'parallel_nsa',
41
+ 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention',
42
+ 'chunk_rwkv6', 'fused_recurrent_rwkv6',
43
+ 'chunk_rwkv7', 'fused_recurrent_rwkv7',
44
+ 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla',
45
+ ]
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (137 Bytes). View file