zaydzuhri commited on
Commit
92bc85b
·
verified ·
1 Parent(s): c93077d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fla/__pycache__/utils.cpython-312.pyc +0 -0
  3. fla/layers/__init__.py +44 -0
  4. fla/layers/__pycache__/__init__.cpython-312.pyc +0 -0
  5. fla/layers/__pycache__/abc.cpython-312.pyc +0 -0
  6. fla/layers/__pycache__/attn.cpython-312.pyc +0 -0
  7. fla/layers/__pycache__/based.cpython-312.pyc +0 -0
  8. fla/layers/__pycache__/bitattn.cpython-312.pyc +0 -0
  9. fla/layers/__pycache__/delta_net.cpython-312.pyc +0 -0
  10. fla/layers/__pycache__/gated_deltanet.cpython-312.pyc +0 -0
  11. fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc +0 -0
  12. fla/layers/__pycache__/gla.cpython-312.pyc +0 -0
  13. fla/layers/__pycache__/gsa.cpython-312.pyc +0 -0
  14. fla/layers/__pycache__/hgrn.cpython-312.pyc +0 -0
  15. fla/layers/__pycache__/hgrn2.cpython-312.pyc +0 -0
  16. fla/layers/__pycache__/lightnet.cpython-312.pyc +0 -0
  17. fla/layers/__pycache__/linear_attn.cpython-312.pyc +0 -0
  18. fla/layers/__pycache__/multiscale_retention.cpython-312.pyc +0 -0
  19. fla/layers/__pycache__/nsa.cpython-312.pyc +0 -0
  20. fla/layers/__pycache__/rebased.cpython-312.pyc +0 -0
  21. fla/layers/__pycache__/rwkv6.cpython-312.pyc +0 -0
  22. fla/layers/__pycache__/rwkv7.cpython-312.pyc +0 -0
  23. fla/layers/abc.py +218 -0
  24. fla/layers/attn.py +203 -0
  25. fla/layers/based.py +96 -0
  26. fla/layers/bitattn.py +192 -0
  27. fla/layers/gated_deltaproduct.py +351 -0
  28. fla/layers/gla.py +294 -0
  29. fla/layers/gsa.py +227 -0
  30. fla/layers/hgrn.py +168 -0
  31. fla/layers/hgrn2.py +211 -0
  32. fla/layers/lightnet.py +210 -0
  33. fla/layers/linear_attn.py +166 -0
  34. fla/layers/multiscale_retention.py +298 -0
  35. fla/layers/rebased.py +133 -0
  36. fla/layers/rwkv6.py +307 -0
  37. fla/layers/simple_gla.py +261 -0
  38. fla/models/__init__.py +53 -0
  39. fla/models/utils.py +147 -0
  40. fla/modules/__init__.py +30 -0
  41. fla/modules/activations.py +471 -0
  42. fla/modules/convolution.py +434 -0
  43. fla/modules/fused_bitlinear.py +638 -0
  44. fla/modules/fused_cross_entropy.py +419 -0
  45. fla/modules/fused_kl_div.py +323 -0
  46. fla/modules/fused_linear_cross_entropy.py +570 -0
  47. fla/modules/fused_linear_listnet_loss.py +427 -0
  48. fla/modules/fused_norm_gate.py +995 -0
  49. fla/modules/grpo.py +396 -0
  50. fla/modules/l2norm.py +176 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tb/20250712-1101/wandb/run-20250712_110147-top_transformer-top.code.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine-202507121056/run-top_transformer-top.code.340M.batch16.seqlen4096.context4096.warmup1000.update1.steps100000.lr3e-4.cosine-202507121056.wandb filter=lfs diff=lfs merge=lfs -text
fla/__pycache__/utils.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.2 kB). View file
 
fla/layers/__pycache__/abc.cpython-312.pyc ADDED
Binary file (9.56 kB). View file
 
fla/layers/__pycache__/attn.cpython-312.pyc ADDED
Binary file (9.5 kB). View file
 
fla/layers/__pycache__/based.cpython-312.pyc ADDED
Binary file (6.46 kB). View file
 
fla/layers/__pycache__/bitattn.cpython-312.pyc ADDED
Binary file (9.06 kB). View file
 
fla/layers/__pycache__/delta_net.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/layers/__pycache__/gated_deltanet.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/layers/__pycache__/gla.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/layers/__pycache__/gsa.cpython-312.pyc ADDED
Binary file (10.1 kB). View file
 
fla/layers/__pycache__/hgrn.cpython-312.pyc ADDED
Binary file (6.7 kB). View file
 
fla/layers/__pycache__/hgrn2.cpython-312.pyc ADDED
Binary file (8.6 kB). View file
 
fla/layers/__pycache__/lightnet.cpython-312.pyc ADDED
Binary file (8.85 kB). View file
 
fla/layers/__pycache__/linear_attn.cpython-312.pyc ADDED
Binary file (7.49 kB). View file
 
fla/layers/__pycache__/multiscale_retention.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/layers/__pycache__/nsa.cpython-312.pyc ADDED
Binary file (6.55 kB). View file
 
fla/layers/__pycache__/rebased.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
fla/layers/__pycache__/rwkv6.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
fla/layers/__pycache__/rwkv7.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/attn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+ try:
22
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
23
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
24
+ except ImportError:
25
+ warnings.warn(
26
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
27
+ category=ImportWarning
28
+ )
29
+ flash_attn_func = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class Attention(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ hidden_size: int = 2048,
39
+ num_heads: int = 32,
40
+ num_kv_heads: Optional[int] = None,
41
+ qkv_bias: bool = False,
42
+ qk_norm: bool = False,
43
+ window_size: Optional[int] = None,
44
+ rope_theta: Optional[float] = 10000.,
45
+ max_position_embeddings: Optional[int] = None,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.hidden_size = hidden_size
51
+ self.num_heads = num_heads
52
+ if num_kv_heads is None:
53
+ self.num_kv_heads = self.num_heads
54
+ else:
55
+ self.num_kv_heads = num_kv_heads
56
+ self.num_kv_groups = num_heads // self.num_kv_heads
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.qkv_bias = qkv_bias
60
+ self.qk_norm = qk_norm
61
+
62
+ self.window_size = window_size
63
+ self.rope_theta = rope_theta
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.layer_idx = layer_idx
66
+
67
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
68
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
69
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
70
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
71
+
72
+ if qk_norm:
73
+ self.q_norm = RMSNorm(self.head_dim)
74
+ self.k_norm = RMSNorm(self.head_dim)
75
+
76
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[Cache] = None,
83
+ output_attentions: bool = False,
84
+ use_cache: bool = False,
85
+ **kwargs,
86
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87
+ if attention_mask is not None:
88
+ assert len(attention_mask.shape) == 2, (
89
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
90
+ "for padding purposes (0 indicating padding). "
91
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
92
+ )
93
+
94
+ batch_size, q_len, _ = hidden_states.size()
95
+
96
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
97
+
98
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
99
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
100
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
101
+
102
+ if self.qk_norm:
103
+ q, k = self.q_norm(q), self.k_norm(k)
104
+
105
+ # equivalent to cu_seqlens in `flash_attn`
106
+ cu_seqlens = kwargs.get('cu_seqlens', None)
107
+
108
+ seqlen_offset, max_seqlen = 0, q_len
109
+ if past_key_values is not None:
110
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
111
+ max_seqlen = q.shape[1] + seqlen_offset
112
+
113
+ if attention_mask is not None:
114
+ # to deliminate the offsets of padding tokens
115
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
116
+ max_seqlen = q.shape[1] + max(seqlen_offset)
117
+
118
+ if self.max_position_embeddings is not None:
119
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
120
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
121
+
122
+ if past_key_values is not None:
123
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
124
+ k_cached, v_cached = past_key_values.update(
125
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
126
+ layer_idx=self.layer_idx,
127
+ offset=q_len,
128
+ cache_kwargs=dict(window_size=self.window_size)
129
+ )['attn_state']
130
+ if cache_has_content:
131
+ k, v = k_cached, v_cached
132
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
133
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
134
+
135
+ if flash_attn_func is None:
136
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
137
+
138
+ # Contains at least one padding token in the sequence
139
+ if attention_mask is not None:
140
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
141
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
142
+ max_seqlen_q, max_seqlen_k = max_seq_lens
143
+ o = flash_attn_varlen_func(
144
+ q, k, v,
145
+ cu_seqlens_q=cu_seqlens_q,
146
+ cu_seqlens_k=cu_seqlens_k,
147
+ max_seqlen_q=max_seqlen_q,
148
+ max_seqlen_k=max_seqlen_k,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ )
152
+ o = pad_input(o, indices_q, batch_size, q_len)
153
+ elif cu_seqlens is not None:
154
+ o = flash_attn_varlen_func(
155
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
156
+ cu_seqlens_q=cu_seqlens,
157
+ cu_seqlens_k=cu_seqlens,
158
+ max_seqlen_q=max_seqlen,
159
+ max_seqlen_k=max_seqlen,
160
+ causal=True,
161
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
162
+ ).unsqueeze(0)
163
+ else:
164
+ o = flash_attn_func(
165
+ q, k, v,
166
+ causal=True,
167
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
168
+ )
169
+ o = o.reshape(batch_size, q_len, -1)
170
+ o = self.o_proj(o)
171
+
172
+ if not output_attentions:
173
+ attentions = None
174
+
175
+ return o, attentions, past_key_values
176
+
177
+ def _upad_input(self, q, k, v, attention_mask, q_len):
178
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
179
+ cache_mask = attention_mask[:, -seq_len:]
180
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
181
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
182
+ max_seqlen_k = seqlens.max().item()
183
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
184
+
185
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
186
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
187
+ if q_len == seq_len:
188
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
189
+ cu_seqlens_q = cu_seqlens_k
190
+ max_seqlen_q = max_seqlen_k
191
+ indices_q = indices_k
192
+ elif q_len == 1:
193
+ max_seqlen_q = 1
194
+ # There is a memcpy here, that is very bad.
195
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
196
+ indices_q = cu_seqlens_q[:-1]
197
+ q = q.squeeze(1)
198
+ else:
199
+ # The -q_len: slice assumes left padding.
200
+ attention_mask = attention_mask[:, -q_len:]
201
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
202
+
203
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_eps: float = 1e-5,
41
+ gate_logit_normalizer: int = 8,
42
+ feature_map: str = 'swish',
43
+ use_output_gate: bool = False,
44
+ use_norm: bool = True,
45
+ layer_idx: Optional[int] = None,
46
+ scale: Optional[float] = 1.,
47
+ **kwargs
48
+ ) -> GatedSlotAttention:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+ self.hidden_size = hidden_size
53
+ self.expand_k = expand_k
54
+ self.expand_v = expand_v
55
+ self.num_heads = num_heads
56
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
57
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
58
+ self.key_dim = int(hidden_size * expand_k)
59
+ self.value_dim = int(hidden_size * expand_v)
60
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
61
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
62
+ self.head_k_dim = self.key_dim // self.num_heads
63
+ self.head_v_dim = self.value_dim // self.num_heads
64
+
65
+ self.use_short_conv = use_short_conv
66
+ self.conv_size = conv_size
67
+ self.conv_bias = conv_bias
68
+
69
+ self.gate_logit_normalizer = gate_logit_normalizer
70
+
71
+ self.use_output_gate = use_output_gate
72
+ self.use_norm = use_norm
73
+ self.scale = scale
74
+
75
+ if num_slots is None:
76
+ num_slots = self.head_k_dim
77
+ self.num_slots = num_slots
78
+
79
+ self.layer_idx = layer_idx
80
+
81
+ if layer_idx is None:
82
+ warnings.warn(
83
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
84
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
85
+ "when creating this class."
86
+ )
87
+
88
+ self.register_module('feature_map', None)
89
+ if feature_map == 'swish':
90
+ self.feature_map = SwishFeatureMap()
91
+ elif feature_map == 'relu':
92
+ self.feature_map = ReLUFeatureMap()
93
+ elif feature_map == 't2r':
94
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
95
+ else:
96
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
97
+
98
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
101
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
102
+
103
+ if use_short_conv:
104
+ self.conv_size = conv_size
105
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
106
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
107
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
108
+
109
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
110
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ past_key_values: Optional[Cache] = None,
117
+ use_cache: Optional[bool] = False,
118
+ output_attentions: Optional[bool] = False,
119
+ **kwargs: Unpack[Dict]
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
121
+ if attention_mask is not None:
122
+ assert len(attention_mask.shape) == 2, (
123
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
124
+ "for padding purposes (0 indicating padding). "
125
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
126
+ )
127
+
128
+ # launching the triton kernel for just one token will actually be slower
129
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
130
+
131
+ last_state = None
132
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
133
+ last_state = past_key_values[self.layer_idx]
134
+
135
+ cu_seqlens = kwargs.get('cu_seqlens', None)
136
+ if self.use_short_conv:
137
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
138
+ if last_state is not None:
139
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
140
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
141
+ q, conv_state_q = self.q_conv1d(
142
+ x=self.q_proj(hidden_states),
143
+ mask=conv_mask,
144
+ cache=conv_state_q,
145
+ output_final_state=use_cache,
146
+ cu_seqlens=cu_seqlens
147
+ )
148
+ k, conv_state_k = self.k_conv1d(
149
+ x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache,
153
+ cu_seqlens=cu_seqlens
154
+ )
155
+ v, conv_state_v = self.v_conv1d(
156
+ x=self.v_proj(hidden_states),
157
+ mask=conv_mask,
158
+ cache=conv_state_v,
159
+ output_final_state=use_cache,
160
+ cu_seqlens=cu_seqlens
161
+ )
162
+ else:
163
+ q = self.q_proj(hidden_states)
164
+ k = self.k_proj(hidden_states)
165
+ v = self.v_proj(hidden_states)
166
+ f = self.f_proj(hidden_states)
167
+
168
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
169
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
170
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
171
+ f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
172
+
173
+ if self.feature_map is not None:
174
+ q, k = map(lambda x: self.feature_map(x), (q, k))
175
+ v = F.silu(v)
176
+
177
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
178
+ s = (1 - f.exp()).to(f.dtype)
179
+ # dealing with left-padding
180
+ if attention_mask is not None:
181
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
182
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
183
+
184
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
185
+ if mode == 'fused_recurrent':
186
+ o, recurrent_state = fused_recurrent_gsa(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ s=s,
191
+ g=f,
192
+ initial_state=recurrent_state,
193
+ output_final_state=use_cache,
194
+ scale=self.scale,
195
+ cu_seqlens=cu_seqlens,
196
+ head_first=False
197
+ )
198
+ elif mode == 'chunk':
199
+ o, recurrent_state = chunk_gsa(
200
+ q=q,
201
+ k=k,
202
+ v=v,
203
+ s=s,
204
+ g=f,
205
+ initial_state=recurrent_state,
206
+ output_final_state=use_cache,
207
+ scale=self.scale,
208
+ cu_seqlens=cu_seqlens,
209
+ head_first=False
210
+ )
211
+ else:
212
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
213
+
214
+ if past_key_values is not None:
215
+ past_key_values.update(
216
+ recurrent_state=recurrent_state,
217
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
218
+ layer_idx=self.layer_idx,
219
+ offset=q.shape[1]
220
+ )
221
+
222
+ o = rearrange(o, 'b t h d -> b t (h d)')
223
+ o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
224
+ return o, None, past_key_values
225
+
226
+ def state_size(self, *args, **kwargs) -> int:
227
+ return 2 * self.num_slots * self.hidden_size
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/lightnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022)
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import FusedRMSNormGated, ShortConvolution
16
+ from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear
17
+ from fla.ops.gla import chunk_gla, fused_recurrent_gla
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class LightNetAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ num_heads: Optional[int] = None,
32
+ expand_ratio: Optional[int] = 128,
33
+ use_short_conv: bool = False,
34
+ conv_size: int = 4,
35
+ conv_bias: bool = False,
36
+ gate_low_rank_dim: int = 128,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> LightNetAttention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.key_dim = int(self.num_heads * self.expand_ratio)
60
+ self.value_dim = hidden_size
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.layer_idx = layer_idx
63
+
64
+ assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
65
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
66
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
67
+
68
+ self.head_f_dim = self.expand_ratio
69
+ self.head_i_dim = self.hidden_size // num_heads
70
+
71
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
72
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
73
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
74
+
75
+ if use_short_conv:
76
+ self.conv_size = conv_size
77
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
78
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
79
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None)
80
+
81
+ self.g_proj = nn.Sequential(
82
+ nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
83
+ nn.Linear(gate_low_rank_dim, hidden_size, bias=False)
84
+ )
85
+ self.g_norm = FusedRMSNormGated(
86
+ hidden_size=hidden_size,
87
+ elementwise_affine=elementwise_affine,
88
+ eps=norm_eps
89
+ )
90
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict]
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ if attention_mask is not None:
102
+ assert len(attention_mask.shape) == 2, (
103
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
104
+ "for padding purposes (0 indicating padding). "
105
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
106
+ )
107
+
108
+ # launching the triton kernel for just one token will actually be slower
109
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
110
+
111
+ last_state = None
112
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
113
+ last_state = past_key_values[self.layer_idx]
114
+
115
+ cu_seqlens = kwargs.get('cu_seqlens', None)
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(
122
+ x=self.q_proj(hidden_states),
123
+ mask=conv_mask,
124
+ cache=conv_state_q,
125
+ output_final_state=use_cache,
126
+ cu_seqlens=cu_seqlens
127
+ )
128
+ k, conv_state_k = self.k_conv1d(
129
+ x=self.k_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_k,
132
+ output_final_state=use_cache,
133
+ cu_seqlens=cu_seqlens
134
+ )
135
+ v, conv_state_v = self.v_conv1d(
136
+ x=self.v_proj(hidden_states),
137
+ mask=conv_mask,
138
+ cache=conv_state_v,
139
+ output_final_state=use_cache,
140
+ cu_seqlens=cu_seqlens
141
+ )
142
+ else:
143
+ q = self.q_proj(hidden_states)
144
+ k = self.k_proj(hidden_states)
145
+ v = self.v_proj(hidden_states)
146
+
147
+ # dealing with left-padding
148
+ if attention_mask is not None:
149
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
150
+
151
+ q = F.silu(q)
152
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k))
153
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim)
154
+ # TODO: this 2 steps took huge amount of time, which should be optimized
155
+ z = k.float().logcumsumexp(1)
156
+
157
+ if cu_seqlens is not None:
158
+ raise NotImplementedError("LightNet does not support variable-length sequences for now.")
159
+ k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype)
160
+
161
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
162
+ if mode == 'fused_recurrent':
163
+ o, recurrent_state = fused_recurrent_gla(
164
+ q=q,
165
+ k=k,
166
+ v=v,
167
+ gk=g,
168
+ initial_state=recurrent_state,
169
+ output_final_state=use_cache,
170
+ cu_seqlens=cu_seqlens,
171
+ head_first=False
172
+ )
173
+ elif mode == 'chunk':
174
+ o, recurrent_state = chunk_gla(
175
+ q=q,
176
+ k=k,
177
+ v=v,
178
+ g=g,
179
+ initial_state=recurrent_state,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens,
182
+ head_first=False
183
+ )
184
+ else:
185
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
186
+
187
+ if past_key_values is not None:
188
+ past_key_values.update(
189
+ recurrent_state=recurrent_state,
190
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
191
+ layer_idx=self.layer_idx,
192
+ offset=q.shape[1]
193
+ )
194
+
195
+ o = rms_norm_swish_gate_linear(
196
+ rearrange(o, 'b t h d -> b t (h d)'),
197
+ self.g_proj(hidden_states),
198
+ self.g_norm.weight,
199
+ self.g_norm.bias,
200
+ self.o_proj.weight,
201
+ self.o_proj.bias
202
+ )
203
+ return o, None, past_key_values
204
+
205
+ def state_size(self, **kwargs) -> int:
206
+ state_size = self.key_dim * self.head_i_dim
207
+ for module in self.children():
208
+ if isinstance(module, ShortConvolution):
209
+ state_size += module.state_size
210
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class MultiScaleRetention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
24
+
25
+ Args:
26
+ mode (str, Optional):
27
+ Which Retention kernel to use.
28
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
29
+ Default: `chunk`.
30
+ hidden_size (int, Optional):
31
+ The hidden size of the input. Default: 1024.
32
+ expand_k (float, Optional):
33
+ The expansion ratio for the key dim. Default: 1.0.
34
+ expand_v (float, Optional):
35
+ The expansion ratio for the value dim. Default: 2.0.
36
+ num_heads (int, Optional):
37
+ The number of heads. Default: 8.
38
+ num_kv_heads (int, Optional):
39
+ The number of key/value heads, used for MQA. Default: None.
40
+ feature_map (str, Optional):
41
+ Feature map function applied to queries/keys. Default: None.
42
+ use_short_conv (bool, Optional):
43
+ Whether to use short convolutions. Default: `False`.
44
+ conv_size (int, Optional):
45
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
46
+ conv_bias (bool, Optional):
47
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
48
+ use_output_gate (bool, Optional):
49
+ Whether to use output gate. Default: `True`.
50
+ gate_fn (str, Optional):
51
+ The activation function for the output gate. Default: `swish`.
52
+ elementwise_affine (bool, Optional):
53
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
54
+ norm_eps (float, Optional):
55
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
56
+ fuse_norm (bool, Optional):
57
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
58
+ layer_idx (int, Optional):
59
+ The index of the layer. Default: None.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ mode: str = 'chunk',
65
+ hidden_size: int = 1024,
66
+ expand_k: float = 1.0,
67
+ expand_v: float = 2.0,
68
+ num_heads: int = 8,
69
+ num_kv_heads: Optional[int] = None,
70
+ feature_map: Optional[str] = None,
71
+ use_short_conv: bool = False,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ use_output_gate: bool = True,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ fuse_norm: bool = True,
79
+ layer_idx: int = None,
80
+ **kwargs
81
+ ) -> MultiScaleRetention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+ self.use_output_gate = use_output_gate
97
+
98
+ self.key_dim = int(hidden_size * expand_k)
99
+ self.value_dim = int(hidden_size * expand_v)
100
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
101
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
102
+ self.layer_idx = layer_idx
103
+
104
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
105
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
106
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
107
+
108
+ self.head_k_dim = self.key_dim // num_heads
109
+ self.head_v_dim = self.value_dim // num_heads
110
+
111
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
112
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
113
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
114
+ if self.use_output_gate:
115
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
116
+
117
+ if use_short_conv:
118
+ self.conv_size = conv_size
119
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
120
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
121
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
122
+
123
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
124
+
125
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
126
+ self.g_norm_swish_gate = FusedRMSNormGated(
127
+ hidden_size=self.head_v_dim,
128
+ elementwise_affine=elementwise_affine,
129
+ eps=norm_eps
130
+ )
131
+ self.fuse_norm_and_gate = True
132
+ else:
133
+ self.fuse_norm_and_gate = False
134
+ self.g_norm = RMSNorm(
135
+ hidden_size=self.head_v_dim,
136
+ elementwise_affine=elementwise_affine,
137
+ eps=norm_eps
138
+ )
139
+ self.gate_fn = ACT2FN[gate_fn]
140
+
141
+ # TODO: fix this issue
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
143
+ # Ideally, we would want to support arbitrary d_head_qk
144
+ assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256"
145
+ self.rotary = RotaryEmbedding(dim=self.head_k_dim)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_values: Optional[Cache] = None,
152
+ use_cache: Optional[bool] = False,
153
+ output_attentions: Optional[bool] = False,
154
+ **kwargs
155
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
156
+ if attention_mask is not None:
157
+ assert len(attention_mask.shape) == 2, (
158
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
159
+ "for padding purposes (0 indicating padding). "
160
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
161
+ )
162
+
163
+ # launching the triton kernel for just one token will actually be slower
164
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
165
+
166
+ last_state = None
167
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
168
+ last_state = past_key_values[self.layer_idx]
169
+
170
+ cu_seqlens = kwargs.get('cu_seqlens', None)
171
+ if self.use_short_conv:
172
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
173
+ if last_state is not None:
174
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
175
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
176
+ q, conv_state_q = self.q_conv1d(
177
+ x=self.q_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_q,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens
182
+ )
183
+ k, conv_state_k = self.k_conv1d(
184
+ x=self.k_proj(hidden_states),
185
+ mask=conv_mask,
186
+ cache=conv_state_k,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens
189
+ )
190
+ v, conv_state_v = self.v_conv1d(
191
+ x=self.v_proj(hidden_states),
192
+ mask=conv_mask,
193
+ cache=conv_state_v,
194
+ output_final_state=use_cache,
195
+ cu_seqlens=cu_seqlens
196
+ )
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+
202
+ # dealing with left-padding
203
+ if attention_mask is not None:
204
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
205
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
206
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
207
+ if self.feature_map_fn is not None:
208
+ q, k = map(self.feature_map_fn, (q, k))
209
+
210
+ seqlen_offset, max_seqlen = 0, q.shape[1]
211
+ if past_key_values is not None:
212
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
213
+ max_seqlen = q.shape[1] + seqlen_offset
214
+
215
+ if attention_mask is not None:
216
+ # to deliminate the offsets of padding tokens
217
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
218
+ max_seqlen = q.shape[1] + max(seqlen_offset)
219
+
220
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
221
+
222
+ if self.num_kv_groups > 1:
223
+ k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
224
+ v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
225
+ else:
226
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
227
+
228
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
229
+ if mode == 'chunk':
230
+ o, recurrent_state = chunk_retention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ cu_seqlens=cu_seqlens,
237
+ head_first=False
238
+ )
239
+ elif mode == 'fused_chunk':
240
+ o, recurrent_state = fused_chunk_retention(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ cu_seqlens=cu_seqlens,
247
+ head_first=False
248
+ )
249
+ elif mode == 'parallel':
250
+ o, recurrent_state = parallel_retention(
251
+ q=q,
252
+ k=k,
253
+ v=v,
254
+ cu_seqlens=cu_seqlens,
255
+ head_first=False
256
+ )
257
+ elif mode == 'fused_recurrent':
258
+ o, recurrent_state = fused_recurrent_retention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ initial_state=recurrent_state,
263
+ output_final_state=use_cache,
264
+ cu_seqlens=cu_seqlens,
265
+ head_first=False
266
+ )
267
+ else:
268
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
269
+
270
+ if past_key_values is not None:
271
+ past_key_values.update(
272
+ recurrent_state=recurrent_state,
273
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274
+ layer_idx=self.layer_idx,
275
+ offset=q.shape[1]
276
+ )
277
+
278
+ if self.use_output_gate:
279
+ g = self.g_proj(hidden_states)
280
+ if self.fuse_norm_and_gate:
281
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
282
+ o = self.g_norm_swish_gate(o, g)
283
+ o = rearrange(o, 'b t h d -> b t (h d)')
284
+ else:
285
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
286
+ o = o * self.gate_fn(g)
287
+ else:
288
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
292
+
293
+ def state_size(self, **kwargs) -> int:
294
+ state_size = self.key_dim * self.head_v_dim
295
+ for module in self.children():
296
+ if isinstance(module, ShortConvolution):
297
+ state_size += module.state_size
298
+ return state_size
fla/layers/rebased.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ l_max: int = 2048,
27
+ feature_dim: int = 16,
28
+ num_key_value_heads: int = 16,
29
+ num_heads: int = 16,
30
+ use_gamma: Optional[bool] = True,
31
+ use_beta: Optional[bool] = True,
32
+ normalize: Optional[bool] = True,
33
+ causal: bool = True,
34
+ eps: float = 1e-5,
35
+ mode: str = "parallel",
36
+ layer_idx: Optional[int] = None,
37
+ **kwargs
38
+ ) -> ReBasedLinearAttention:
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.l_max = l_max
42
+ self.mode = mode
43
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
44
+
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+ self.eps = eps
54
+ self.mode = mode
55
+ self.layer_idx = layer_idx
56
+
57
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
58
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
61
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
62
+ self.dropout = nn.Identity()
63
+
64
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
65
+ mode = self.mode
66
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
67
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
68
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
69
+ if mode == "fused_chunk":
70
+ o = fused_chunk_linear_attn(
71
+ q=q,
72
+ k=k,
73
+ v=v,
74
+ normalize=True,
75
+ scale=1,
76
+ head_first=False
77
+ )
78
+ elif mode == 'chunk':
79
+ o = chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'parallel':
88
+ assert q.shape[-1] <= 128
89
+ o = parallel_rebased(
90
+ q=q,
91
+ k=k,
92
+ v=v,
93
+ eps=self.eps,
94
+ use_scale=True,
95
+ use_normalize=True,
96
+ head_first=False
97
+ )
98
+ o = self.o_proj(o)
99
+ o = self.dropout(o)
100
+ return o
101
+
102
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
103
+ def forward_reference(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ filters: torch.Tensor = None,
107
+ *args,
108
+ **kwargs
109
+ ):
110
+ """
111
+ x (torch.Tensor): tensor of shape (b, d, t)
112
+ y (torch.Tensor): tensor of shape (b, d, t)
113
+ """
114
+ b, t, _ = hidden_states.size()
115
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
116
+
117
+ q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
118
+ k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
119
+ v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
120
+
121
+ # Linear attention
122
+ q, k = self.feature_map(q), self.feature_map(k)
123
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
124
+
125
+ # Compute attention
126
+ if self.causal:
127
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
128
+ else:
129
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
130
+ y = rearrange(y, 'b h t d -> b t (h d)')
131
+ y = self.o_proj(y.to(hidden_states.dtype))
132
+ y = self.dropout(y)
133
+ return y.to(hidden_states.dtype)
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_k_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormGated(
125
+ hidden_size=self.head_v_dim,
126
+ elementwise_affine=elementwise_affine,
127
+ eps=norm_eps
128
+ )
129
+ self.fuse_norm_and_gate = True
130
+ else:
131
+ self.fuse_norm_and_gate = False
132
+ self.g_norm = RMSNorm(
133
+ hidden_size=self.head_v_dim,
134
+ elementwise_affine=elementwise_affine,
135
+ eps=norm_eps
136
+ )
137
+ self.gate_fn = ACT2FN[gate_fn]
138
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
139
+
140
+ self.gate_logit_normalizer = gate_logit_normalizer
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # launching the triton kernel for just one token will actually be slower
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
160
+
161
+ last_state = None
162
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
163
+ last_state = past_key_values[self.layer_idx]
164
+
165
+ cu_seqlens = kwargs.get('cu_seqlens', None)
166
+ if self.use_short_conv:
167
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
168
+ if last_state is not None:
169
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
170
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
171
+ q, conv_state_q = self.q_conv1d(
172
+ x=self.q_proj(hidden_states),
173
+ mask=conv_mask,
174
+ cache=conv_state_q,
175
+ output_final_state=use_cache,
176
+ cu_seqlens=cu_seqlens
177
+ )
178
+ k, conv_state_k = self.k_conv1d(
179
+ x=self.k_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_k,
182
+ output_final_state=use_cache,
183
+ cu_seqlens=cu_seqlens
184
+ )
185
+ v, conv_state_v = self.v_conv1d(
186
+ x=self.v_proj(hidden_states),
187
+ mask=conv_mask,
188
+ cache=conv_state_v,
189
+ output_final_state=use_cache,
190
+ cu_seqlens=cu_seqlens
191
+ )
192
+ else:
193
+ q = self.q_proj(hidden_states)
194
+ k = self.k_proj(hidden_states)
195
+ v = self.v_proj(hidden_states)
196
+ gk = self.gk_proj(hidden_states)
197
+
198
+ if self.feature_map_fn is not None:
199
+ q, k = map(self.feature_map_fn, (q, k))
200
+ # dealing with left-padding
201
+ if attention_mask is not None:
202
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
203
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
204
+ if self.num_kv_groups > 1:
205
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
206
+ else:
207
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
208
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
209
+
210
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
211
+ if mode == 'chunk':
212
+ o, recurrent_state = chunk_simple_gla(
213
+ q=q,
214
+ k=k,
215
+ v=v,
216
+ gk=gk,
217
+ initial_state=recurrent_state,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens,
220
+ head_first=False
221
+ )
222
+ elif mode == 'fused_recurrent':
223
+ o, recurrent_state = fused_recurrent_simple_gla(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ gk=gk,
228
+ initial_state=recurrent_state,
229
+ output_final_state=use_cache,
230
+ cu_seqlens=cu_seqlens,
231
+ head_first=False
232
+ )
233
+ else:
234
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
235
+
236
+ if past_key_values is not None:
237
+ past_key_values.update(
238
+ recurrent_state=recurrent_state,
239
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
240
+ layer_idx=self.layer_idx,
241
+ offset=q.shape[1]
242
+ )
243
+
244
+ g = self.g_proj(hidden_states)
245
+ if self.fuse_norm_and_gate:
246
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
247
+ o = self.g_norm_swish_gate(o, g)
248
+ o = rearrange(o, 'b t h d -> b t (h d)')
249
+ else:
250
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
251
+ o = o * self.gate_fn(g)
252
+ o = self.o_proj(o)
253
+
254
+ return o, None, past_key_values
255
+
256
+ def state_size(self, **kwargs) -> int:
257
+ state_size = self.key_dim * self.head_v_dim
258
+ for module in self.children():
259
+ if isinstance(module, ShortConvolution):
260
+ state_size += module.state_size
261
+ return state_size
fla/models/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
6
+ from fla.models.forgetting_transformer import (
7
+ ForgettingTransformerConfig,
8
+ ForgettingTransformerForCausalLM,
9
+ ForgettingTransformerModel
10
+ )
11
+ from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
12
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel
13
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
14
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
15
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
16
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
17
+ from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel
18
+ from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel
19
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
20
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
21
+ from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
22
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
23
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
24
+ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
25
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
26
+ from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel
27
+ from fla.models.transformer_top import TOPTransformerConfig, TOPTransformerForCausalLM, TOPTransformerModel
28
+ from fla.models.transformer_mtp import MTPTransformerConfig, MTPTransformerForCausalLM, MTPTransformerModel
29
+
30
+ __all__ = [
31
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
32
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
33
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
34
+ 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
35
+ 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
36
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
37
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
38
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
39
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
40
+ 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel',
41
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
42
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
43
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
44
+ 'NSAConfig', 'NSAForCausalLM', 'NSAModel',
45
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
46
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
47
+ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
48
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
49
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
50
+ 'TOPTransformerConfig', 'TOPTransformerForCausalLM', 'TOPTransformerModel',
51
+ 'MTPTransformerConfig', 'MTPTransformerForCausalLM', 'MTPTransformerModel',
52
+ 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
53
+ ]
fla/models/utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import transformers
9
+
10
+
11
+ class Cache(transformers.cache_utils.Cache):
12
+ """
13
+ A cache used for storing hidden states produced by flash linear attention models.
14
+
15
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
16
+ """
17
+
18
+ is_compileable = True
19
+
20
+ def __init__(
21
+ self,
22
+ seen_tokens: int = 0
23
+ ) -> Cache:
24
+ super().__init__()
25
+
26
+ self.states: List[Dict[str, Any]] = []
27
+
28
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
29
+
30
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
31
+ if layer_idx < len(self):
32
+ return self.states[layer_idx]
33
+ else:
34
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
35
+
36
+ def __iter__(self):
37
+ for state in self.states:
38
+ yield state
39
+
40
+ def __len__(self):
41
+ return len(self.states)
42
+
43
+ def update(
44
+ self,
45
+ recurrent_state: torch.Tensor = None,
46
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
47
+ conv_state: Tuple[torch.Tensor] = None,
48
+ ffn_state: torch.Tensor = None,
49
+ layer_idx: int = 0,
50
+ offset: Optional[int] = 1,
51
+ cache_kwargs: Optional[Dict[str, Any]] = None,
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
55
+
56
+ Args:
57
+ recurrent_state (`torch.Tensor`, `optional`):
58
+ The new recurrent state to cache.
59
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
60
+ The new attention key/value states to cache.
61
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
62
+ The new convolution state to cache.
63
+ layer_idx (`int`, defaults to 0):
64
+ The index of the layer to cache the states for.
65
+ offset (`int`, `optional`, defaults to 1):
66
+ The number of new tokens being processed.
67
+ cache_kwargs (`Dict[str, Any]`, `optional`):
68
+ Additional arguments for the cache subclass.
69
+
70
+ Return:
71
+ Dictionary of the updated state.
72
+ """
73
+
74
+ # Update the number of seen tokens
75
+ if layer_idx == 0:
76
+ self._seen_tokens += offset
77
+
78
+ if attn_state is not None:
79
+ input_size = attn_state[0].shape[-2]
80
+ window_size = cache_kwargs.get('window_size', None)
81
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
82
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
83
+ if len(self.states) <= layer_idx:
84
+ if attn_state is not None:
85
+ if window_size is not None and input_size > window_size:
86
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
87
+ attn_state[1][..., -window_size:, :].contiguous())
88
+ state = dict(
89
+ recurrent_state=recurrent_state,
90
+ attn_state=attn_state,
91
+ conv_state=conv_state,
92
+ ffn_state=ffn_state
93
+ )
94
+ self.states.append(state)
95
+ else:
96
+ state = self.states[layer_idx]
97
+ if recurrent_state is not None:
98
+ state['recurrent_state'] = recurrent_state
99
+ if attn_state is not None:
100
+ key_state, value_state = state['attn_state']
101
+ if window_size is not None and key_state.shape[-2] == window_size:
102
+ # DO NOT allocate new memory if the cache is full
103
+ # roll the key/value states to the left by `input_size`
104
+ key_state = key_state.roll(-input_size, -2)
105
+ value_state = value_state.roll(-input_size, -2)
106
+ # replace the last `input_size` tokens with the new key/value states
107
+ key_state[..., -input_size:, :] = attn_state[0]
108
+ value_state[..., -input_size:, :] = attn_state[1]
109
+ attn_state = (key_state, value_state)
110
+ else:
111
+ attn_state = (torch.cat([key_state, attn_state[0]], -2),
112
+ torch.cat([value_state, attn_state[1]], -2),)
113
+ state['attn_state'] = attn_state
114
+ if conv_state is not None:
115
+ state['conv_state'] = conv_state
116
+ if ffn_state is not None:
117
+ state['ffn_state'] = ffn_state
118
+
119
+ return state
120
+
121
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
122
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
123
+ if len(self.states) <= layer_idx:
124
+ return 0
125
+ return self._seen_tokens
126
+
127
+ def get_max_length(self) -> Optional[int]:
128
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
129
+ return None
130
+
131
+ def to_legacy_cache(self) -> Tuple:
132
+ return tuple(self.states)
133
+
134
+ @classmethod
135
+ @torch.compiler.disable
136
+ def from_legacy_cache(
137
+ cls,
138
+ past_key_values: Optional[Tuple] = None,
139
+ seen_tokens: int = 0
140
+ ) -> Cache:
141
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
142
+
143
+ cache = cls(seen_tokens)
144
+ if isinstance(past_key_values, list):
145
+ for layer_idx in range(len(past_key_values)):
146
+ cache.states.append(past_key_values[layer_idx])
147
+ return cache
fla/modules/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution
4
+ from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear
5
+ from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
6
+ from fla.modules.fused_kl_div import FusedKLDivLoss
7
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
8
+ from fla.modules.fused_linear_listnet_loss import FusedLinearListNetLoss
9
+ from fla.modules.fused_norm_gate import (
10
+ FusedLayerNormGated,
11
+ FusedLayerNormSwishGate,
12
+ FusedLayerNormSwishGateLinear,
13
+ FusedRMSNormGated,
14
+ FusedRMSNormSwishGate,
15
+ FusedRMSNormSwishGateLinear
16
+ )
17
+ from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear
18
+ from fla.modules.mlp import GatedMLP
19
+ from fla.modules.rotary import RotaryEmbedding
20
+
21
+ __all__ = [
22
+ 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
23
+ 'BitLinear', 'FusedBitLinear',
24
+ 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss',
25
+ 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
26
+ 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear',
27
+ 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
28
+ 'GatedMLP',
29
+ 'RotaryEmbedding'
30
+ ]
fla/modules/activations.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Tri Dao, Yu Zhang, Songlin Yang.
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.ops.utils.op import exp, log
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, get_multiprocessor_count, input_guard
11
+
12
+ sigmoid_fwd_codestring = """
13
+ template <typename T> T sigmoid_fwd(T x) {
14
+ return 1.0f / (1.0f + ::exp(-float(x)));
15
+ }
16
+ """
17
+ sigmoid_bwd_codestring = """
18
+ template <typename T> T sigmoid_bwd(T x, T g) {
19
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
20
+ return float(g) * x_sigmoid * (1.0f - x_sigmoid);
21
+ }
22
+ """
23
+
24
+ sigmoid_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
25
+ sigmoid_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
26
+
27
+
28
+ @torch.compiler.disable
29
+ def sigmoid_fwd(x):
30
+ return sigmoid_fwd_jit_fn(x)
31
+
32
+
33
+ @torch.compiler.disable
34
+ def sigmoid_bwd(x, g):
35
+ return sigmoid_bwd_jit_fn(x, g)
36
+
37
+
38
+ class SigmoidFunction(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx, x):
42
+ ctx.save_for_backward(x)
43
+ return sigmoid_fwd(x)
44
+
45
+ @staticmethod
46
+ def backward(ctx, dout):
47
+ x, = ctx.saved_tensors
48
+ return sigmoid_bwd(x, dout)
49
+
50
+
51
+ sigmoid = SigmoidFunction.apply
52
+
53
+
54
+ @triton.autotune(
55
+ configs=[
56
+ triton.Config({}, num_warps=num_warps)
57
+ for num_warps in [1, 2, 4, 8, 16, 32]
58
+ ],
59
+ key=['D']
60
+ )
61
+ @triton.jit
62
+ def logsigmoid_fwd_kernel(
63
+ x,
64
+ y,
65
+ temperature,
66
+ T: tl.constexpr,
67
+ D: tl.constexpr,
68
+ B: tl.constexpr
69
+ ):
70
+ i = tl.program_id(0)
71
+ o_i = i * B + tl.arange(0, B)
72
+ m_i = o_i < T
73
+
74
+ b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32)
75
+ b_m = tl.minimum(0., b_x)
76
+ b_z = 1. + exp(-tl.abs(b_x))
77
+ b_y = (b_m - log(b_z)) / temperature
78
+ tl.store(y + o_i, b_y.to(y.dtype.element_ty), mask=m_i)
79
+
80
+
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16, 32]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit
89
+ def logsigmoid_bwd_kernel(
90
+ x,
91
+ dx,
92
+ dy,
93
+ temperature,
94
+ T: tl.constexpr,
95
+ D: tl.constexpr,
96
+ B: tl.constexpr
97
+ ):
98
+ i = tl.program_id(0)
99
+ o_i = i * B + tl.arange(0, B)
100
+ m_i = o_i < T
101
+
102
+ b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32)
103
+ b_dy = tl.load(dy + o_i, mask=m_i, other=0.).to(tl.float32)
104
+ b_dx = b_dy * (1. - tl.sigmoid(b_x)) / temperature
105
+ tl.store(dx + o_i, b_dx.to(dx.dtype.element_ty), mask=m_i)
106
+
107
+
108
+ def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
109
+ T, D = x.numel(), x.shape[-1]
110
+ B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))
111
+ y = torch.empty_like(x)
112
+ logsigmoid_fwd_kernel[(triton.cdiv(T, B),)](
113
+ x=x,
114
+ y=y,
115
+ temperature=temperature,
116
+ T=T,
117
+ D=D,
118
+ B=B
119
+ )
120
+ return y
121
+
122
+
123
+ def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
124
+ T, D = x.numel(), x.shape[-1]
125
+ B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))
126
+ dx = torch.empty_like(x)
127
+ logsigmoid_bwd_kernel[(triton.cdiv(T, B),)](
128
+ x=x,
129
+ dx=dx,
130
+ dy=dy,
131
+ temperature=temperature,
132
+ T=T,
133
+ D=D,
134
+ B=B
135
+ )
136
+ return dx
137
+
138
+
139
+ class LogSigmoidFunction(torch.autograd.Function):
140
+
141
+ @staticmethod
142
+ @input_guard
143
+ def forward(ctx, x, temperature):
144
+ ctx.save_for_backward(x,)
145
+ ctx.temperature = temperature
146
+ return logsigmoid_fwd(x, temperature)
147
+
148
+ @staticmethod
149
+ @input_guard
150
+ def backward(ctx, dy):
151
+ x, = ctx.saved_tensors
152
+ return logsigmoid_bwd(x, dy, ctx.temperature), None
153
+
154
+
155
+ def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
156
+ return LogSigmoidFunction.apply(x, temperature)
157
+
158
+
159
+ swish_fwd_codestring = """
160
+ template <typename T> T swish_fwd(T x) {
161
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
162
+ return float(x) * x_sigmoid;
163
+ }
164
+ """
165
+ swish_bwd_codestring = """
166
+ template <typename T> T swish_bwd(T x, T g) {
167
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
168
+ return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
169
+ }
170
+ """
171
+
172
+ swish_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
173
+ swish_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
174
+
175
+
176
+ @torch.compiler.disable
177
+ def swish_fwd(x):
178
+ return swish_fwd_jit_fn(x)
179
+
180
+
181
+ @torch.compiler.disable
182
+ def swish_bwd(x, g):
183
+ return swish_bwd_jit_fn(x, g)
184
+
185
+
186
+ class SwishFunction(torch.autograd.Function):
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ ctx.save_for_backward(x)
191
+ return swish_fwd(x)
192
+
193
+ @staticmethod
194
+ def backward(ctx, dout):
195
+ x, = ctx.saved_tensors
196
+ return swish_bwd(x, dout)
197
+
198
+
199
+ swish = SwishFunction.apply
200
+
201
+ # 1/sqrt(2*pi)-> 0.3989423
202
+ # 1/sqrt(2) -> 0.70710678
203
+ # sqrt(2/pi) -> 0.79788456
204
+
205
+
206
+ # this function is tanh approximation of gelu
207
+ # actual gelu is:
208
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
209
+ @torch.compile
210
+ def bias_gelu(y, bias):
211
+ x = bias + y
212
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
213
+
214
+
215
+ # gradient of tanh approximation of gelu
216
+ # gradient of actual gelu is:
217
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
218
+ @torch.compile
219
+ def bias_gelu_bwd(g, y, bias):
220
+ """Assume that y has shape (B, D) and bias has shape (D)"""
221
+ x = bias + y
222
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
223
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
224
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
225
+ 1 + tanh_out
226
+ )
227
+ grad_y = ff * g
228
+ return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
229
+
230
+
231
+ class GeLUFunction(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ # bias is an optional argument
235
+ def forward(ctx, input, bias):
236
+ ctx.save_for_backward(input, bias)
237
+ return bias_gelu(input, bias)
238
+
239
+ @staticmethod
240
+ def backward(ctx, grad_output):
241
+ input, bias = ctx.saved_tensors
242
+ tmp = bias_gelu_bwd(grad_output, input, bias)
243
+ return tmp, tmp
244
+
245
+
246
+ bias_gelu_impl = GeLUFunction.apply
247
+
248
+
249
+ # this function is tanh approximation of gelu
250
+ # actual gelu is:
251
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
252
+ @torch.compile
253
+ def gelu_fwd(x):
254
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
255
+
256
+
257
+ # gradient of tanh approximation of gelu
258
+ # gradient of actual gelu is:
259
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
260
+ @torch.compile
261
+ def gelu_bwd(g, x):
262
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
263
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
264
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
265
+ 1 + tanh_out
266
+ )
267
+ return (ff * g).to(dtype=x.dtype)
268
+
269
+
270
+ class FastGeLUFunction(torch.autograd.Function):
271
+ @staticmethod
272
+ # bias is an optional argument
273
+ def forward(ctx, input):
274
+ ctx.save_for_backward(input)
275
+ return gelu_fwd(input)
276
+
277
+ @staticmethod
278
+ def backward(ctx, grad_output):
279
+ (input,) = ctx.saved_tensors
280
+ tmp = gelu_bwd(grad_output, input)
281
+ return tmp
282
+
283
+
284
+ fast_gelu_impl = FastGeLUFunction.apply
285
+
286
+
287
+ @torch.compile
288
+ def relu_bwd(g, x):
289
+ return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
290
+
291
+
292
+ @torch.compile
293
+ def sqrelu_fwd(x):
294
+ r = F.relu(x.float())
295
+ return (r * r).to(dtype=x.dtype)
296
+
297
+
298
+ @torch.compile
299
+ def sqrelu_bwd(g, x):
300
+ return (2.0 * g * F.relu(x.float())).to(dtype=x.dtype)
301
+
302
+
303
+ class SquaredReLUFunction(torch.autograd.Function):
304
+
305
+ @staticmethod
306
+ def forward(ctx, input):
307
+ ctx.save_for_backward(input)
308
+ return sqrelu_fwd(input)
309
+
310
+ @staticmethod
311
+ def backward(ctx, grad_output):
312
+ input, = ctx.saved_tensors
313
+ return sqrelu_bwd(grad_output, input)
314
+
315
+
316
+ sqrelu = SquaredReLUFunction.apply
317
+
318
+
319
+ swiglu_fwd_codestring = """
320
+ template <typename T> T swiglu_fwd(T x, T y) {
321
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
322
+ }
323
+ """
324
+ swiglu_bwd_codestring = """
325
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
326
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
327
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
328
+ dy = float(x) * x_sigmoid * float(g);
329
+ }
330
+ """
331
+
332
+ swiglu_fwdbwd_codestring = """
333
+ template <typename T> T swiglu_fwdbwd(T x, T y, T g, T& dx, T& dy, T& z) {
334
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
335
+ float x_swish = float(x) * x_sigmoid;
336
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
337
+ dy = x_swish * float(g);
338
+ z = x_swish * float(y);
339
+ }
340
+ """
341
+
342
+
343
+ swiglu_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
344
+ swiglu_bwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
345
+ swiglu_fwdbwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_fwdbwd_codestring, num_outputs=3)
346
+
347
+
348
+ @torch.compiler.disable
349
+ def swiglu_fwd(x, y):
350
+ return swiglu_fwd_jit_fn(x, y)
351
+
352
+
353
+ @torch.compiler.disable
354
+ def swiglu_bwd(x, y, g):
355
+ return swiglu_bwd_jit_fn(x, y, g)
356
+
357
+
358
+ @torch.compiler.disable
359
+ def swiglu_fwdbwd(x, y, g):
360
+ return swiglu_fwdbwd_jit_fn(x, y, g)
361
+
362
+
363
+ @torch.compile
364
+ def swiglu_fwd_torch(x, y):
365
+ return (F.silu(x.float()) * y).to(x.dtype)
366
+
367
+
368
+ @torch.compile
369
+ def swiglu_bwd_torch(x, y, g):
370
+ dtype = x.dtype
371
+ x, y, g = x.float(), y.float(), g.float()
372
+ x_sigmoid = x.sigmoid()
373
+ x_swish = x * x_sigmoid
374
+ dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y
375
+ dy = x_swish * g
376
+ return dx.to(dtype), dy.to(dtype)
377
+
378
+
379
+ @torch.compile
380
+ def swiglu_fwdbwd_torch(x, y, g):
381
+ dtype = x.dtype
382
+ x, y, g = x.float(), y.float(), g.float()
383
+ x_sigmoid = x.sigmoid()
384
+ x_swish = x * x_sigmoid
385
+ dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y
386
+ dy = x_swish * g
387
+ z = x_swish * y
388
+ return dx.to(dtype), dy.to(dtype), z.to(dtype)
389
+
390
+
391
+ class SwiGLUFunction(torch.autograd.Function):
392
+ r"""
393
+ Swish-Gated Linear Unit (SwiGLU) function.
394
+
395
+ .. math::
396
+ \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
397
+ """
398
+
399
+ @staticmethod
400
+ def forward(ctx, x, y):
401
+ ctx.save_for_backward(x, y)
402
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
403
+ return swiglu_fwd_torch(x, y)
404
+ else:
405
+ return swiglu_fwd(x, y)
406
+
407
+ @staticmethod
408
+ def backward(ctx, dout):
409
+ x, y = ctx.saved_tensors
410
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
411
+ return swiglu_bwd_torch(x, y, dout)
412
+ else:
413
+ return swiglu_bwd(x, y, dout)
414
+
415
+
416
+ class SwiGLULinearFunction(torch.autograd.Function):
417
+ r"""
418
+ Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
419
+
420
+ .. math::
421
+ \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
422
+
423
+ This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
424
+ """
425
+
426
+ @staticmethod
427
+ @autocast_custom_fwd
428
+ def forward(ctx, x, y, weight, bias):
429
+ with torch.no_grad():
430
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
431
+ z = swiglu_fwd_torch(x, y)
432
+ else:
433
+ z = swiglu_fwd(x, y)
434
+ out = F.linear(z, weight, bias)
435
+ # We don't store z, will be recomputed in the backward pass to save memory
436
+ ctx.save_for_backward(x, y, weight)
437
+ ctx.linear_bias_is_none = bias is None
438
+ return out
439
+
440
+ @staticmethod
441
+ @autocast_custom_bwd
442
+ def backward(ctx, dout, *args):
443
+ x, y, weight = ctx.saved_tensors
444
+ dout = dout.reshape(-1, dout.shape[-1])
445
+ dz = F.linear(dout, weight.t()).view_as(x)
446
+ with torch.no_grad():
447
+ if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
448
+ dx, dy, z = swiglu_fwdbwd_torch(x, y, dz)
449
+ else:
450
+ dx, dy, z = swiglu_fwdbwd(x, y, dz)
451
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
452
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
453
+ return dx, dy, dlinear_weight, dlinear_bias
454
+
455
+
456
+ swiglu = SwiGLUFunction.apply
457
+
458
+
459
+ swiglu_linear = SwiGLULinearFunction.apply
460
+
461
+
462
+ ACT2FN = {
463
+ 'relu': F.relu,
464
+ 'sigmoid': sigmoid,
465
+ 'logsigmoid': logsigmoid,
466
+ 'silu': swish,
467
+ 'swish': swish,
468
+ 'sqrelu': sqrelu,
469
+ 'gelu': fast_gelu_impl,
470
+ 'bias_gelu': bias_gelu_impl,
471
+ }
fla/modules/convolution.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+ from fla.modules.activations import ACT2FN
17
+ from fla.ops.common.utils import prepare_position_ids, prepare_sequence_ids
18
+ from fla.utils import checkpoint, input_guard
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn = None
24
+ causal_conv1d_update = None
25
+
26
+
27
+ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
28
+ seqlen = u.shape[-1]
29
+ fft_size = 2 * seqlen
30
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
31
+ if k_rev is not None:
32
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
33
+ k_f = k_f + k_rev_f.conj()
34
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
35
+
36
+ if len(u.shape) > 3:
37
+ k_f = k_f.unsqueeze(1)
38
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
39
+
40
+ out = y + u
41
+ if gelu:
42
+ out = F.gelu(out)
43
+ if dropout_mask is not None:
44
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
45
+ else:
46
+ return out.to(dtype=u.dtype)
47
+
48
+
49
+ @checkpoint
50
+ def proj_then_conv1d(
51
+ x: torch.Tensor,
52
+ proj_weight: torch.Tensor,
53
+ conv1d_weight: torch.Tensor,
54
+ conv1d_bias: Optional[torch.Tensor] = None,
55
+ cache: Optional[torch.Tensor] = None
56
+ ) -> torch.Tensor:
57
+ # We do matmul and transpose BLH -> HBL at the same time
58
+ x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2])
59
+
60
+ if causal_conv1d_fn is None:
61
+ raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
62
+ if cache is None:
63
+ x = causal_conv1d_fn(
64
+ x=x,
65
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
66
+ bias=conv1d_bias,
67
+ activation="silu",
68
+ ).transpose(1, 2)
69
+ else:
70
+ assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
71
+ x = x.squeeze(-1)
72
+ x = causal_conv1d_update(
73
+ x=x,
74
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
75
+ bias=conv1d_bias,
76
+ cache=cache,
77
+ activation="silu",
78
+ )
79
+ return x
80
+
81
+
82
+ @triton.jit
83
+ def causal_conv1d_varlen_states_fwd_kernel(
84
+ x,
85
+ cache,
86
+ offsets,
87
+ D,
88
+ W,
89
+ BD: tl.constexpr,
90
+ BW: tl.constexpr
91
+ ):
92
+ i_d, i_w, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ eos = tl.load(offsets + i_n + 1)
94
+ bos = tl.maximum(tl.load(offsets + i_n), eos - W)
95
+ o_t = eos - (i_w + 1) * BW + tl.arange(0, BW)
96
+ o_d = i_d * BD + tl.arange(0, BD)
97
+ o_w = W - (i_w + 1) * BW + tl.arange(0, BW)
98
+
99
+ b_x = tl.load(x + o_t * D + o_d[:, None], mask=(o_t >= bos) & (o_d[:, None] < D), other=0)
100
+ tl.store(cache + i_n * D*W + o_d[:, None] * W + o_w, b_x, mask=(o_d[:, None] < D) & (o_w >= 0))
101
+
102
+
103
+ @input_guard
104
+ def causal_conv1d_varlen_states_fwd(
105
+ x: torch.Tensor,
106
+ cache: torch.Tensor,
107
+ cu_seqlens: torch.Tensor,
108
+ state_len: int
109
+ ) -> torch.Tensor:
110
+ N, D, W = len(cu_seqlens) - 1, x.shape[-1], state_len
111
+ cache = torch.empty(N, D, W, dtype=x.dtype, device=x.device) if cache is None else cache
112
+ BD = min(triton.next_power_of_2(D), 256)
113
+ BW = min(triton.next_power_of_2(state_len), 16)
114
+ grid = (triton.cdiv(D, BD), triton.cdiv(W, BW), N)
115
+ with torch.cuda.device(x.device.index):
116
+ causal_conv1d_varlen_states_fwd_kernel[grid](
117
+ x=x,
118
+ cache=cache,
119
+ offsets=cu_seqlens,
120
+ D=D,
121
+ W=W,
122
+ BW=BW,
123
+ BD=BD
124
+ )
125
+ return cache
126
+
127
+
128
+ class ShortConvolution(nn.Conv1d):
129
+ """
130
+ Simple wrapper around `nn.Conv1d` that accepts dimension last.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size: int,
136
+ kernel_size: int,
137
+ bias: bool = False,
138
+ activation: Optional[str] = 'silu',
139
+ use_fast_conv1d: Optional[bool] = True,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ super().__init__(
144
+ in_channels=hidden_size,
145
+ out_channels=hidden_size,
146
+ kernel_size=kernel_size,
147
+ groups=hidden_size,
148
+ bias=bias,
149
+ padding=kernel_size - 1,
150
+ device=device,
151
+ dtype=dtype,
152
+ )
153
+
154
+ self.hidden_size = hidden_size
155
+ self.activation = None
156
+ if activation is not None:
157
+ assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
158
+ self.activation = activation
159
+
160
+ if causal_conv1d_fn is None:
161
+ if use_fast_conv1d:
162
+ raise RuntimeError(
163
+ "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel "
164
+ "or set `use_fast_conv1d` to False"
165
+ )
166
+ else:
167
+ warnings.warn(
168
+ "The naive Pytorch verison is very slow in practice, "
169
+ "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel",
170
+ category=ImportWarning
171
+ )
172
+ self.use_fast_conv1d = use_fast_conv1d
173
+
174
+ def extra_repr(self):
175
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
176
+ ', stride={stride}')
177
+ if self.padding != (0,) * len(self.padding):
178
+ s += ', padding={padding}'
179
+ if self.dilation != (1,) * len(self.dilation):
180
+ s += ', dilation={dilation}'
181
+ if self.output_padding != (0,) * len(self.output_padding):
182
+ s += ', output_padding={output_padding}'
183
+ if self.groups != 1:
184
+ s += ', groups={groups}'
185
+ if self.bias is None:
186
+ s += ', bias=False'
187
+ if self.padding_mode != 'zeros':
188
+ s += ', padding_mode={padding_mode}'
189
+ if self.activation is not None:
190
+ s += ', activation={activation}'
191
+ if not self.use_fast_conv1d:
192
+ s += ', use_fast_conv1d={use_fast_conv1d}'
193
+ return s.format(**self.__dict__)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ mask: Optional[torch.Tensor] = None,
199
+ cache: Optional[torch.Tensor] = None,
200
+ output_final_state: bool = False,
201
+ cu_seqlens: Optional[torch.LongTensor] = None,
202
+ **kwargs,
203
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """
205
+ Args:
206
+ x (`torch.Tensor`):
207
+ Tensor of shape `[B, T, D]`.
208
+ If `seq_idx` is provided, `B` must be 1.
209
+ mask (`Optional[torch.Tensor]`):
210
+ Attention mask dealing with padded positions.
211
+ cache (`Optional[torch.Tensor]`):
212
+ Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size.
213
+ If provided, the cache is updated **inplace**.
214
+ output_final_state (Optional[bool]):
215
+ Whether to output the final state of shape `[N, D, W]`. Default: `False`.
216
+ cu_seqlens (Optional[torch.LongTensor]):
217
+ Cumulative sequence lengths for each batch. Used for varlen. Default: `None`.
218
+ Shape: [B+1]
219
+
220
+ Returns:
221
+ Tensor of shape `[B, T, D]`.
222
+ """
223
+
224
+ B, T, D, W = *x.shape, self.kernel_size[0]
225
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
226
+ if mask is not None:
227
+ if cu_seqlens is not None:
228
+ raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time")
229
+ x = x.mul_(mask.unsqueeze(-1))
230
+ if output_final_state and cache is None:
231
+ cache = x.new_zeros(N, D, W)
232
+ # during the decoding phase, we assume the batch is composed of sequences of length 1
233
+ if cache is not None and B * T == N:
234
+ return self.step(x, cache, cu_seqlens)
235
+
236
+ if cache is not None:
237
+ if cu_seqlens is not None:
238
+ cache = causal_conv1d_varlen_states_fwd(x, cache, cu_seqlens, W)
239
+ else:
240
+ cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w'))
241
+
242
+ x = rearrange(x, 'b t d -> b d t')
243
+ if self.use_fast_conv1d:
244
+ # Sequence index for each token. Used for varlen.
245
+ # Suppose a batch consists of two sequences with lengths 3 and 4,
246
+ # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
247
+ # NOTE: No need to provide this arg if `cu_seqlens` is passed.
248
+ # This arg is just for BC, and will be removed in the future.
249
+ # [B, T]
250
+ seq_idx = kwargs.get('seq_idx', None)
251
+ if cu_seqlens is not None and seq_idx is None:
252
+ seq_idx = prepare_sequence_ids(prepare_position_ids(cu_seqlens)).to(torch.int32).unsqueeze(0)
253
+ x = causal_conv1d_fn(
254
+ x=x,
255
+ weight=rearrange(self.weight, "d 1 w -> d w"),
256
+ bias=self.bias,
257
+ activation=self.activation,
258
+ seq_idx=seq_idx,
259
+ )
260
+ else:
261
+ if cu_seqlens is not None:
262
+ raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version")
263
+ x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
264
+ if self.activation is not None:
265
+ x = ACT2FN[self.activation](x)
266
+ return rearrange(x, "b d t -> b t d"), cache
267
+
268
+ def step(
269
+ self,
270
+ x: torch.Tensor,
271
+ cache: torch.Tensor,
272
+ cu_seqlens: Optional[torch.LongTensor] = None
273
+ ):
274
+ shape = x.shape
275
+ x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
276
+ if self.use_fast_conv1d:
277
+ x = causal_conv1d_update(
278
+ x=x,
279
+ conv_state=cache,
280
+ weight=rearrange(self.weight, "d 1 w -> d w"),
281
+ bias=self.bias,
282
+ activation=self.activation,
283
+ )
284
+ else:
285
+ dtype = x.dtype
286
+ # we follow the fast mode that updates the cache in-place
287
+ cache.copy_(cache.roll(shifts=-1, dims=-1))
288
+ cache[:, :, -1] = x
289
+ x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
290
+ if self.bias is not None:
291
+ x = x + self.bias
292
+ if self.activation is not None:
293
+ x = ACT2FN[self.activation](x).to(dtype=dtype)
294
+ return x.view(shape), cache
295
+
296
+ @property
297
+ def state_size(self) -> int:
298
+ return self.hidden_size * self.kernel_size
299
+
300
+
301
+ class LongConvolution(nn.Module):
302
+ """
303
+ LongConvolution applies a convolution operation on the input tensor using a fixed
304
+ filter of length max_len.
305
+ The filter is learned during training and is applied using FFT convolution.
306
+ Args:
307
+ hidden_size (int): The number of expected features in the input and output.
308
+ max_len (int): The maximum sequence length.
309
+ Returns:
310
+ y: [batch_size, seq_len, hidden_size] tensor
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ hidden_size: int,
316
+ max_len: int,
317
+ **kwargs,
318
+ ):
319
+ """
320
+ Initializes the LongConvolution module.
321
+ Args:
322
+ hidden_size (int): The number of expected features in the input and output.
323
+ max_len (int): The maximum sequence length.
324
+ """
325
+ super().__init__()
326
+ self.hidden_size = hidden_size
327
+ self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True)
328
+
329
+ def forward(self, x: torch.Tensor, *args, **kwargs):
330
+ """
331
+ Applies the LongConvolution operation on the input tensor.
332
+ Args:
333
+ x: [batch_size, seq_len, hidden_size] tensor
334
+ Returns:
335
+ y: [batch_size, seq_len, hidden_size] tensor
336
+ """
337
+ x = x.transpose(1, 2)
338
+ y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
339
+ y = y.transpose(1, 2)
340
+ return y.to(dtype=x.dtype)
341
+
342
+
343
+ class PositionalEmbedding(nn.Module):
344
+ def __init__(self, emb_dim: int, seq_len: int, **kwargs):
345
+ """Complex exponential positional embeddings for implicit long convolution filters."""
346
+ super().__init__()
347
+
348
+ self.seq_len = seq_len
349
+ # The time embedding fed to the filteres is normalized so that t_f = 1
350
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
351
+
352
+ if emb_dim > 1:
353
+ bands = (emb_dim - 1) // 2
354
+ # To compute the right embeddings we use the "proper" linspace
355
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
356
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
357
+
358
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
359
+ z = torch.exp(-1j * f * w)
360
+ z = torch.cat([t, z.real, z.imag], dim=-1)
361
+ self.z = nn.Parameter(z, requires_grad=False)
362
+
363
+ def forward(self, L):
364
+ return self.z[:, :L]
365
+
366
+
367
+ class ImplicitLongConvolution(nn.Module):
368
+ """
369
+ Long convolution with implicit filter parameterized by an MLP.
370
+
371
+ Args:
372
+ hidden_size (int):
373
+ The number of expected features in the input and output.
374
+ max_len (int):
375
+ The maximum sequence length.
376
+ d_emb (Optional[int]):
377
+ The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
378
+ Defaults to 3.
379
+ d_hidden (Optional[int]):
380
+ The number of features in the hidden layer of the MLP. Defaults to 16.
381
+
382
+ Attributes:
383
+ pos_emb (`PositionalEmbedding`): The positional embedding layer.
384
+ mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
385
+
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ hidden_size: int,
391
+ max_len: int,
392
+ d_emb: int = 3,
393
+ d_hidden: int = 16,
394
+ **kwargs,
395
+ ):
396
+ """
397
+ Long convolution with implicit filter parameterized by an MLP.
398
+
399
+
400
+ """
401
+ super().__init__()
402
+ self.hidden_size = hidden_size
403
+ self.d_emb = d_emb
404
+
405
+ assert (
406
+ d_emb % 2 != 0 and d_emb >= 3
407
+ ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
408
+ self.pos_emb = PositionalEmbedding(d_emb, max_len)
409
+
410
+ # final linear layer
411
+ self.mlp = nn.Sequential(
412
+ nn.Linear(d_emb, d_hidden),
413
+ torch.nn.ReLU(),
414
+ nn.Linear(d_hidden, hidden_size),
415
+ )
416
+
417
+ def filter(self, seq_len: int, *args, **kwargs):
418
+ k = self.mlp(self.pos_emb(seq_len))
419
+
420
+ return k.transpose(1, 2)
421
+
422
+ def forward(self, x: torch.Tensor, *args, **kwargs):
423
+ """
424
+ Args:
425
+ x: [batch_size, seq_len, hidden_size] tensor
426
+ Returns:
427
+ y: [batch_size, seq_len, hidden_size] tensor
428
+ """
429
+ x = x.transpose(1, 2)
430
+ k = self.filter(x.shape[-1])
431
+ y = fft_conv(x, k, dropout_mask=None, gelu=False)
432
+
433
+ y = y.transpose(1, 2)
434
+ return y.to(dtype=x.dtype)
fla/modules/fused_bitlinear.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Implementations of BitLinear layer with fused LayerNorm and quantized Linear layer.
5
+ # [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
6
+ # [Scalable MatMul-free Language Modeling](https://arxiv.org/abs/2406.02528)
7
+
8
+ # Code adapted from https://github.com/ridgerchu/matmulfreellm/
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from fla.modules.layernorm import RMSNorm
21
+ from fla.utils import get_multiprocessor_count, input_guard, require_version
22
+
23
+
24
+ def activation_quant(x):
25
+ """
26
+ Per-token quantization to 8 bits. No grouping is needed for quantization.
27
+
28
+ Args:
29
+ x: An activation tensor with shape [n, d].
30
+
31
+ Returns:
32
+ A quantized activation tensor with shape [n, d].
33
+ """
34
+ # Compute the scale factor
35
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
36
+ # Quantize and then de-quantize the tensor
37
+ y = (x * scale).round().clamp_(-128, 127) / scale
38
+ return y
39
+
40
+
41
+ def weight_quant(w):
42
+ """
43
+ Per-tensor quantization to 1.58 bits. No grouping is needed for quantization.
44
+
45
+ Args:
46
+ w: A weight tensor with shape [d, k].
47
+
48
+ Returns:
49
+ A quantized weight tensor with shape [d, k].
50
+ """
51
+ # Compute the scale factor
52
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
53
+ # Quantize and then de-quantize the tensor
54
+ u = (w * scale).round().clamp_(-1, 1) / scale
55
+ return u
56
+
57
+
58
+ @triton.autotune(
59
+ configs=[
60
+ triton.Config({}, num_warps=1),
61
+ triton.Config({}, num_warps=2),
62
+ triton.Config({}, num_warps=4),
63
+ triton.Config({}, num_warps=8),
64
+ triton.Config({}, num_warps=16),
65
+ triton.Config({}, num_warps=32),
66
+ ],
67
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
68
+ )
69
+ @triton.jit
70
+ def layer_norm_fwd_kernel_quant(
71
+ X, # pointer to the input
72
+ Y, # pointer to the output
73
+ W, # pointer to the weights
74
+ B, # pointer to the biases
75
+ RESIDUAL, # pointer to the residual
76
+ RESIDUAL_OUT, # pointer to the residual
77
+ Mean, # pointer to the mean
78
+ Rstd, # pointer to the 1/std
79
+ stride_x_row, # how much to increase the pointer when moving by 1 row
80
+ stride_y_row,
81
+ stride_res_row,
82
+ stride_res_out_row,
83
+ N, # number of columns in X
84
+ eps, # epsilon to avoid division by zero
85
+ IS_RMS_NORM: tl.constexpr,
86
+ BLOCK_N: tl.constexpr,
87
+ HAS_RESIDUAL: tl.constexpr,
88
+ STORE_RESIDUAL_OUT: tl.constexpr,
89
+ HAS_WEIGHT: tl.constexpr,
90
+ HAS_BIAS: tl.constexpr
91
+ ):
92
+ # Map the program id to the row of X and Y it should compute.
93
+ row = tl.program_id(0)
94
+ X += row * stride_x_row
95
+ Y += row * stride_y_row
96
+ if HAS_RESIDUAL:
97
+ RESIDUAL += row * stride_res_row
98
+ if STORE_RESIDUAL_OUT:
99
+ RESIDUAL_OUT += row * stride_res_out_row
100
+ # Compute mean and variance
101
+ cols = tl.arange(0, BLOCK_N)
102
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
103
+ if HAS_RESIDUAL:
104
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
105
+ x += residual
106
+ if STORE_RESIDUAL_OUT:
107
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
108
+ if not IS_RMS_NORM:
109
+ mean = tl.sum(x, axis=0) / N
110
+ tl.store(Mean + row, mean)
111
+ xbar = tl.where(cols < N, x - mean, 0.0)
112
+ var = tl.sum(xbar * xbar, axis=0) / N
113
+ else:
114
+ xbar = tl.where(cols < N, x, 0.0)
115
+ var = tl.sum(xbar * xbar, axis=0) / N
116
+ rstd = 1 / tl.sqrt(var + eps)
117
+ tl.store(Rstd + row, rstd)
118
+ # Normalize and apply linear transformation
119
+ mask = cols < N
120
+ if HAS_WEIGHT:
121
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
122
+ if HAS_BIAS:
123
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
124
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
125
+
126
+ y = x_hat * w if HAS_WEIGHT else x_hat
127
+ if HAS_BIAS:
128
+ y = y + b
129
+
130
+ # Aply quantization to the output
131
+ scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)
132
+ # Quantize and then de-quantize the tensor
133
+ y = tl.extra.cuda.libdevice.round(y * scale)
134
+ y = tl.maximum(tl.minimum(y, 127), -128) / scale
135
+
136
+ # Write output
137
+ tl.store(Y + cols, y, mask=mask)
138
+
139
+
140
+ def layer_norm_fwd_quant(
141
+ x: torch.Tensor,
142
+ weight: torch.Tensor,
143
+ bias: torch.Tensor,
144
+ eps: float,
145
+ residual: torch.Tensor = None,
146
+ out_dtype: torch.dtype = None,
147
+ residual_dtype: torch.dtype = None,
148
+ is_rms_norm: bool = False
149
+ ):
150
+ if residual is not None:
151
+ residual_dtype = residual.dtype
152
+ M, N = x.shape
153
+ # allocate output
154
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
155
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
156
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
157
+ else:
158
+ residual_out = None
159
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
160
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
161
+ # Less than 64KB per feature: enqueue fused kernel
162
+ MAX_FUSED_SIZE = 65536 // x.element_size()
163
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
164
+ if N > BLOCK_N:
165
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
166
+ # heuristics for number of warps
167
+ layer_norm_fwd_kernel_quant[(M,)](
168
+ x,
169
+ y,
170
+ weight,
171
+ bias,
172
+ residual,
173
+ residual_out,
174
+ mean,
175
+ rstd,
176
+ x.stride(0),
177
+ y.stride(0),
178
+ residual.stride(0) if residual is not None else 0,
179
+ residual_out.stride(0) if residual_out is not None else 0,
180
+ N,
181
+ eps,
182
+ is_rms_norm,
183
+ BLOCK_N,
184
+ residual is not None,
185
+ residual_out is not None,
186
+ weight is not None,
187
+ bias is not None,
188
+ )
189
+ # residual_out is None if residual is None and residual_dtype == input_dtype
190
+ return y, mean, rstd, residual_out if residual_out is not None else x
191
+
192
+
193
+ @triton.heuristics({
194
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None
195
+ })
196
+ @triton.autotune(
197
+ configs=[
198
+ triton.Config({}, num_warps=1),
199
+ triton.Config({}, num_warps=2),
200
+ triton.Config({}, num_warps=4),
201
+ triton.Config({}, num_warps=8),
202
+ triton.Config({}, num_warps=16),
203
+ triton.Config({}, num_warps=32),
204
+ ],
205
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
206
+ )
207
+ @triton.jit
208
+ def layer_norm_bwd_kernel(
209
+ X, # pointer to the input
210
+ W, # pointer to the weights
211
+ B, # pointer to the biases
212
+ Y, # pointer to the output to be recomputed
213
+ DY, # pointer to the output gradient
214
+ DX, # pointer to the input gradient
215
+ DW, # pointer to the partial sum of weights gradient
216
+ DB, # pointer to the partial sum of biases gradient
217
+ DRESIDUAL,
218
+ DRESIDUAL_IN,
219
+ Mean, # pointer to the mean
220
+ Rstd, # pointer to the 1/std
221
+ stride_x_row, # how much to increase the pointer when moving by 1 row
222
+ stride_y_row,
223
+ stride_dy_row,
224
+ stride_dx_row,
225
+ stride_dres_row,
226
+ stride_dres_in_row,
227
+ M, # number of rows in X
228
+ N, # number of columns in X
229
+ eps, # epsilon to avoid division by zero
230
+ rows_per_program,
231
+ IS_RMS_NORM: tl.constexpr,
232
+ BLOCK_N: tl.constexpr,
233
+ HAS_DRESIDUAL: tl.constexpr,
234
+ STORE_DRESIDUAL: tl.constexpr,
235
+ HAS_WEIGHT: tl.constexpr,
236
+ HAS_BIAS: tl.constexpr,
237
+ RECOMPUTE_OUTPUT: tl.constexpr,
238
+ ):
239
+ # Map the program id to the elements of X, DX, and DY it should compute.
240
+ row_block_id = tl.program_id(0)
241
+ row_start = row_block_id * rows_per_program
242
+ cols = tl.arange(0, BLOCK_N)
243
+ mask = cols < N
244
+ X += row_start * stride_x_row
245
+ if HAS_DRESIDUAL:
246
+ DRESIDUAL += row_start * stride_dres_row
247
+ if STORE_DRESIDUAL:
248
+ DRESIDUAL_IN += row_start * stride_dres_in_row
249
+ DY += row_start * stride_dy_row
250
+ DX += row_start * stride_dx_row
251
+ if RECOMPUTE_OUTPUT:
252
+ Y += row_start * stride_y_row
253
+ if HAS_WEIGHT:
254
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
255
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
256
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
257
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
258
+ if HAS_BIAS:
259
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
260
+ row_end = min((row_block_id + 1) * rows_per_program, M)
261
+ for row in range(row_start, row_end):
262
+ # Load data to SRAM
263
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
264
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
265
+ if not IS_RMS_NORM:
266
+ mean = tl.load(Mean + row)
267
+ rstd = tl.load(Rstd + row)
268
+ # Compute dx
269
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
270
+ xhat = tl.where(mask, xhat, 0.0)
271
+ if RECOMPUTE_OUTPUT:
272
+ y = xhat * w if HAS_WEIGHT else xhat
273
+ if HAS_BIAS:
274
+ y = y + b
275
+
276
+ # Aply quantization to the output
277
+ scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)
278
+ # Quantize and then de-quantize the tensor
279
+ y = tl.extra.cuda.libdevice.round(y * scale)
280
+ y = tl.maximum(tl.minimum(y, 127), -128) / scale
281
+
282
+ tl.store(Y + cols, y, mask=mask)
283
+ wdy = dy
284
+ if HAS_WEIGHT:
285
+ wdy = dy * w
286
+ dw += dy * xhat
287
+ if HAS_BIAS:
288
+ db += dy
289
+ if not IS_RMS_NORM:
290
+ c1 = tl.sum(xhat * wdy, axis=0) / N
291
+ c2 = tl.sum(wdy, axis=0) / N
292
+ dx = (wdy - (xhat * c1 + c2)) * rstd
293
+ else:
294
+ c1 = tl.sum(xhat * wdy, axis=0) / N
295
+ dx = (wdy - xhat * c1) * rstd
296
+ if HAS_DRESIDUAL:
297
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
298
+ dx += dres
299
+ # Write dx
300
+ if STORE_DRESIDUAL:
301
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
302
+ tl.store(DX + cols, dx, mask=mask)
303
+
304
+ X += stride_x_row
305
+ if HAS_DRESIDUAL:
306
+ DRESIDUAL += stride_dres_row
307
+ if STORE_DRESIDUAL:
308
+ DRESIDUAL_IN += stride_dres_in_row
309
+ if RECOMPUTE_OUTPUT:
310
+ Y += stride_y_row
311
+ DY += stride_dy_row
312
+ DX += stride_dx_row
313
+ if HAS_WEIGHT:
314
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
315
+ if HAS_BIAS:
316
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
317
+
318
+
319
+ def layer_norm_bwd(
320
+ dy: torch.Tensor,
321
+ x: torch.Tensor,
322
+ weight: torch.Tensor,
323
+ bias: torch.Tensor,
324
+ eps: float,
325
+ mean: torch.Tensor,
326
+ rstd: torch.Tensor,
327
+ dresidual: torch.Tensor = None,
328
+ has_residual: bool = False,
329
+ is_rms_norm: bool = False,
330
+ x_dtype: torch.dtype = None,
331
+ recompute_output: bool = False,
332
+ ):
333
+ M, N = x.shape
334
+ # allocate output
335
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
336
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
337
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
338
+
339
+ # Less than 64KB per feature: enqueue fused kernel
340
+ MAX_FUSED_SIZE = 65536 // x.element_size()
341
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
342
+ if N > BLOCK_N:
343
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
344
+ sm_count = get_multiprocessor_count(x.device.index)
345
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None
346
+ _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None
347
+ rows_per_program = math.ceil(M / sm_count)
348
+ grid = (sm_count,)
349
+ layer_norm_bwd_kernel[grid](
350
+ x,
351
+ weight,
352
+ bias,
353
+ y,
354
+ dy,
355
+ dx,
356
+ _dw,
357
+ _db,
358
+ dresidual,
359
+ dresidual_in,
360
+ mean,
361
+ rstd,
362
+ x.stride(0),
363
+ 0 if not recompute_output else y.stride(0),
364
+ dy.stride(0),
365
+ dx.stride(0),
366
+ dresidual.stride(0) if dresidual is not None else 0,
367
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
368
+ M,
369
+ N,
370
+ eps,
371
+ rows_per_program,
372
+ is_rms_norm,
373
+ BLOCK_N,
374
+ dresidual is not None,
375
+ dresidual_in is not None,
376
+ weight is not None,
377
+ bias is not None,
378
+ )
379
+ dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
380
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
381
+ # Don't need to compute dresidual_in separately in this case
382
+ if has_residual and dx.dtype == x.dtype:
383
+ dresidual_in = dx
384
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
385
+
386
+
387
+ class LayerNormLinearQuantFn(torch.autograd.Function):
388
+
389
+ @staticmethod
390
+ @input_guard
391
+ def forward(
392
+ ctx,
393
+ x,
394
+ norm_weight,
395
+ norm_bias,
396
+ linear_weight,
397
+ linear_bias,
398
+ residual=None,
399
+ eps=1e-6,
400
+ prenorm=False,
401
+ residual_in_fp32=False,
402
+ is_rms_norm=False,
403
+ ):
404
+ x_shape_og = x.shape
405
+ # reshape input data into 2D tensor
406
+ x = x.reshape(-1, x.shape[-1])
407
+ if residual is not None:
408
+ assert residual.shape == x_shape_og
409
+ residual = residual.reshape(-1, residual.shape[-1])
410
+ residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
411
+ y, mean, rstd, residual_out = layer_norm_fwd_quant(
412
+ x,
413
+ norm_weight,
414
+ norm_bias,
415
+ eps,
416
+ residual,
417
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
418
+ residual_dtype=residual_dtype,
419
+ is_rms_norm=is_rms_norm,
420
+ )
421
+ y = y.reshape(x_shape_og)
422
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
423
+ linear_weight = weight_quant(linear_weight).to(dtype)
424
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
425
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
426
+ # We don't store y, will be recomputed in the backward pass to save memory
427
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
428
+ ctx.x_shape_og = x_shape_og
429
+ ctx.eps = eps
430
+ ctx.is_rms_norm = is_rms_norm
431
+ ctx.has_residual = residual is not None
432
+ ctx.prenorm = prenorm
433
+ ctx.x_dtype = x.dtype
434
+ ctx.linear_bias_is_none = linear_bias is None
435
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
436
+
437
+ @staticmethod
438
+ @input_guard
439
+ def backward(ctx, dout, *args):
440
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
441
+ dout = dout.reshape(-1, dout.shape[-1])
442
+ dy = F.linear(dout, linear_weight.t())
443
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
444
+ assert dy.shape == x.shape
445
+ if ctx.prenorm:
446
+ dresidual = args[0]
447
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
448
+ assert dresidual.shape == x.shape
449
+ else:
450
+ dresidual = None
451
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd(
452
+ dy,
453
+ x,
454
+ norm_weight,
455
+ norm_bias,
456
+ ctx.eps,
457
+ mean,
458
+ rstd,
459
+ dresidual,
460
+ ctx.has_residual,
461
+ ctx.is_rms_norm,
462
+ x_dtype=ctx.x_dtype,
463
+ recompute_output=True
464
+ )
465
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
466
+ return (
467
+ dx.reshape(ctx.x_shape_og),
468
+ dnorm_weight,
469
+ dnorm_bias,
470
+ dlinear_weight,
471
+ dlinear_bias,
472
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
473
+ None,
474
+ None,
475
+ None,
476
+ None,
477
+ )
478
+
479
+
480
+ def layer_norm_linear_quant_fn(
481
+ x,
482
+ norm_weight,
483
+ norm_bias,
484
+ linear_weight,
485
+ linear_bias,
486
+ residual=None,
487
+ eps=1e-6,
488
+ prenorm=False,
489
+ residual_in_fp32=False,
490
+ is_rms_norm=False,
491
+ ):
492
+ return LayerNormLinearQuantFn.apply(
493
+ x,
494
+ norm_weight,
495
+ norm_bias,
496
+ linear_weight,
497
+ linear_bias,
498
+ residual,
499
+ eps,
500
+ prenorm,
501
+ residual_in_fp32,
502
+ is_rms_norm,
503
+ )
504
+
505
+
506
+ def rms_norm_linear_quant(
507
+ x: torch.Tensor,
508
+ norm_weight: torch.Tensor,
509
+ norm_bias: torch.Tensor,
510
+ linear_weight: torch.Tensor,
511
+ linear_bias: torch.Tensor,
512
+ residual: torch.Tensor = None,
513
+ eps: float = 1e-5,
514
+ prenorm: bool = False,
515
+ residual_in_fp32: bool = False
516
+ ):
517
+ return layer_norm_linear_quant_fn(
518
+ x=x,
519
+ norm_weight=norm_weight,
520
+ norm_bias=norm_bias,
521
+ linear_weight=linear_weight,
522
+ linear_bias=linear_bias,
523
+ residual=residual,
524
+ eps=eps,
525
+ prenorm=prenorm,
526
+ residual_in_fp32=residual_in_fp32,
527
+ is_rms_norm=True
528
+ )
529
+
530
+
531
+ @require_version("triton>=3.0", "Triton >= 3.0 is required to do online quantization.")
532
+ def bit_linear(x, weight, bias=None, norm_weight=None, norm_bias=None, eps=1e-8):
533
+ """
534
+ A functional version of BitLinear that applies quantization to activations and weights.
535
+
536
+ Args:
537
+ x: Input tensor with shape [n, d].
538
+ weight: Weight tensor with shape [out_features, in_features].
539
+ bias: Bias tensor with shape [out_features] (optional).
540
+ norm_weight: Weight tensor for RMS normalization with shape [in_features].
541
+ norm_bias: Bias tensor for RMS normalization with shape [in_features].
542
+ eps: A small constant for numerical stability in normalization.
543
+
544
+ Returns:
545
+ Output tensor with shape [n, out_features].
546
+ """
547
+ return layer_norm_linear_quant_fn(
548
+ x,
549
+ norm_weight,
550
+ norm_bias,
551
+ weight,
552
+ bias,
553
+ is_rms_norm=True
554
+ )
555
+
556
+
557
+ class BitLinear(nn.Linear):
558
+ """
559
+ A custom linear layer that applies quantization on both activations and weights.
560
+ This is primarily for training; kernel optimization is needed for efficiency in deployment.
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ in_features: int,
566
+ out_features: int,
567
+ bias: bool = False,
568
+ norm_eps: float = 1e-8
569
+ ):
570
+ """
571
+ Initializes the BitLinear layer.
572
+
573
+ Args:
574
+ in_features: Size of each input sample.
575
+ out_features: Size of each output sample.
576
+ bias: If set to False, the layer will not learn an additive bias. Default: True.
577
+ """
578
+ # Initialize the superclass nn.Linear with the given parameters
579
+ super(BitLinear, self).__init__(in_features, out_features, bias=bias)
580
+
581
+ self.norm = RMSNorm(in_features, eps=norm_eps)
582
+
583
+ def __repr__(self) -> str:
584
+ return f"{self.__class__.__name__}({super().extra_repr()}, norm_eps={self.norm.eps})"
585
+
586
+ def forward(self, x):
587
+ """
588
+ Overrides the forward pass to include quantization.
589
+
590
+ Args:
591
+ x: An input tensor with shape [n, d].
592
+
593
+ Returns:
594
+ An output tensor with shape [n, d].
595
+ """
596
+ # Weight tensor
597
+ w = self.weight
598
+
599
+ # Apply RMS normalization to the input
600
+ x_norm = self.norm(x)
601
+
602
+ # Apply quantization to both activations and weights
603
+ # Uses Straight-Through Estimator (STE) trick with .detach() for gradient flow
604
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
605
+ w_quant = w + (weight_quant(w) - w).detach()
606
+ # Perform linear operation with quantized values
607
+ y = F.linear(x_quant, w_quant)
608
+
609
+ return y
610
+
611
+
612
+ class FusedBitLinear(BitLinear):
613
+ """
614
+ A custom linear layer that applies quantization on both activations and weights.
615
+ This is primarily for training; kernel optimization is needed for efficiency in deployment.
616
+ """
617
+
618
+ def __init__(self, in_features, out_features, bias=False):
619
+ """
620
+ Initializes the BitLinear layer.
621
+
622
+ Args:
623
+ in_features: Size of each input sample.
624
+ out_features: Size of each output sample.
625
+ bias: If set to False, the layer will not learn an additive bias. Default: True.
626
+ """
627
+ # Initialize the superclass nn.Linear with the given parameters
628
+ super(FusedBitLinear, self).__init__(in_features, out_features, bias=bias)
629
+
630
+ def forward(self, x):
631
+ return layer_norm_linear_quant_fn(
632
+ x,
633
+ self.norm.weight,
634
+ self.norm.bias,
635
+ self.weight,
636
+ self.bias,
637
+ is_rms_norm=True
638
+ )
fla/modules/fused_cross_entropy.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+
5
+ from typing import Any, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import input_guard
14
+
15
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
16
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
17
+ # version of PyTorch. The following 2 lines are for backward compatibility with
18
+ # older PyTorch.
19
+ if "all_gather_into_tensor" not in dir(torch.distributed):
20
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
21
+
22
+
23
+ @triton.heuristics({
24
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
25
+ })
26
+ @triton.jit
27
+ def cross_entropy_fwd_kernel(
28
+ loss_ptr, # data ptrs
29
+ lse_ptr,
30
+ z_loss_ptr,
31
+ logits_ptr,
32
+ labels_ptr,
33
+ label_smoothing,
34
+ logit_scale,
35
+ lse_square_scale,
36
+ ignore_index,
37
+ total_classes,
38
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
39
+ n_cols, # shapes
40
+ n_rows,
41
+ logits_row_stride, # strides
42
+ BLOCK_SIZE: tl.constexpr,
43
+ HAS_SMOOTHING: tl.constexpr,
44
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
45
+ SPLIT: tl.constexpr,
46
+ ):
47
+ row_idx = tl.program_id(0)
48
+ col_block_idx = tl.program_id(1)
49
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
50
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
51
+ label_idx = tl.load(labels_ptr + row_idx)
52
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf"))
53
+ logits = logits.to(tl.float32) * logit_scale
54
+ max_logits = tl.max(logits, 0)
55
+ if HAS_SMOOTHING:
56
+ sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
57
+ lse = log(tl.sum(exp(logits - max_logits), 0)) + max_logits
58
+ tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
59
+ if label_idx == ignore_index:
60
+ loss = 0.0
61
+ z_loss = 0.0
62
+ else:
63
+ label_idx -= class_start_idx
64
+ if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
65
+ n_cols, (col_block_idx + 1) * BLOCK_SIZE
66
+ ):
67
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
68
+ if HAS_SMOOTHING:
69
+ loss = (
70
+ (lse if not SPLIT else 0.0)
71
+ - label_smoothing * sum_logits / total_classes
72
+ - (1 - label_smoothing) * logits_label
73
+ )
74
+ else:
75
+ loss = (lse if not SPLIT else 0.0) - logits_label
76
+ else:
77
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss
78
+ if HAS_SMOOTHING:
79
+ loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
80
+ else:
81
+ loss = 0.0
82
+ if not SPLIT:
83
+ z_loss = lse_square_scale * lse * lse
84
+ loss += z_loss
85
+ else:
86
+ z_loss = 0.0
87
+ tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
88
+ if not SPLIT:
89
+ tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
90
+
91
+
92
+ @triton.heuristics({
93
+ "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
94
+ })
95
+ @triton.jit
96
+ def cross_entropy_bwd_kernel(
97
+ dlogits_ptr, # data ptrs
98
+ dloss_ptr,
99
+ logits_ptr,
100
+ lse_ptr,
101
+ labels_ptr,
102
+ label_smoothing,
103
+ logit_scale,
104
+ lse_square_scale,
105
+ ignore_index,
106
+ total_classes,
107
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
108
+ n_cols, # shapes
109
+ logits_row_stride, # strides
110
+ dlogits_row_stride,
111
+ dloss_row_stride,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ HAS_SMOOTHING: tl.constexpr,
114
+ ):
115
+ row_idx = tl.program_id(0)
116
+ col_block_idx = tl.program_id(1)
117
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
118
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
119
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
120
+ label_idx = tl.load(labels_ptr + row_idx)
121
+ if label_idx != ignore_index:
122
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
123
+ else:
124
+ dloss = 0.0
125
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
126
+ tl.float32
127
+ ) * logit_scale
128
+ lse = tl.load(lse_ptr + row_idx)
129
+ probs = exp(logits - lse)
130
+ probs += 2.0 * lse_square_scale * lse * probs
131
+ label_idx -= class_start_idx
132
+ if HAS_SMOOTHING:
133
+ smooth_negative = label_smoothing / total_classes
134
+ probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative
135
+ else:
136
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
137
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
138
+
139
+
140
+ def fused_cross_entropy_forward(
141
+ logits: torch.Tensor,
142
+ target: torch.Tensor,
143
+ label_smoothing: float = 0.0,
144
+ logit_scale: float = 1.0,
145
+ lse_square_scale: float = 0.0,
146
+ ignore_index: int = -100,
147
+ process_group=None,
148
+ ):
149
+ n_rows, n_cols = logits.shape
150
+ assert target.shape == (n_rows,)
151
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
152
+ total_classes = world_size * n_cols
153
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
154
+ class_start_idx = rank * n_cols
155
+
156
+ if logits.stride(-1) != 1:
157
+ logits = logits.contiguous()
158
+ # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
159
+ MAX_BLOCK_SIZE = 64 * 1024
160
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
161
+ num_warps = (
162
+ 4
163
+ if BLOCK_SIZE < 2048
164
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
165
+ )
166
+ # We may split the lse computation across multiple blocks, then do a reduction
167
+ # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
168
+ # where having just one thread block processing more than 64k elements is slow.
169
+ split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
170
+ n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
171
+ loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
172
+ losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
173
+ lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
174
+ z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
175
+
176
+ cross_entropy_fwd_kernel[(n_rows, n_splits)](
177
+ losses, # data ptrs
178
+ lse,
179
+ z_losses,
180
+ logits,
181
+ target,
182
+ label_smoothing,
183
+ logit_scale,
184
+ lse_square_scale,
185
+ ignore_index,
186
+ total_classes,
187
+ class_start_idx,
188
+ n_cols, # shapes
189
+ n_rows,
190
+ logits.stride(0), # strides
191
+ BLOCK_SIZE=BLOCK_SIZE, # constants
192
+ num_warps=num_warps,
193
+ SPLIT=split
194
+ )
195
+
196
+ if split:
197
+ # If there's no label_smoothing, if target are in the vocab of this partition, losses contains
198
+ # - predicted logit, and 0 otherwise.
199
+ # If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains
200
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
201
+ # For target not in the vocab of this partition, losses contains
202
+ # -0.1 * sum logit / total_classes.
203
+ if n_splits > 1:
204
+ lse = torch.logsumexp(lse, dim=0)
205
+ losses = losses.sum(dim=0)
206
+ if world_size > 1:
207
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
208
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
209
+ handle_losses = torch.distributed.all_reduce(
210
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
211
+ )
212
+ lse = torch.logsumexp(lse_allgather, dim=0)
213
+ handle_losses.wait()
214
+ # After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit,
215
+ # we just have to add the (global) lse.
216
+ # If there's label_smoothing=0.1, the total losses are
217
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
218
+ # Again, we just have to add the (global) lse.
219
+ losses += lse
220
+ if lse_square_scale != 0.0:
221
+ z_losses = lse_square_scale * lse.square()
222
+ z_losses.masked_fill_(target == ignore_index, 0.0)
223
+ losses += z_losses
224
+ else:
225
+ z_losses = torch.zeros_like(losses)
226
+ losses.masked_fill_(target == ignore_index, 0.0)
227
+
228
+ return losses, z_losses, lse, total_classes, class_start_idx
229
+
230
+
231
+ class CrossEntropyLossFunction(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ @input_guard
235
+ def forward(
236
+ ctx,
237
+ logits,
238
+ target,
239
+ label_smoothing=0.0,
240
+ logit_scale=1.0,
241
+ lse_square_scale=0.0,
242
+ ignore_index=-100,
243
+ inplace_backward=False,
244
+ process_group=None,
245
+ ):
246
+ losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward(
247
+ logits,
248
+ target,
249
+ label_smoothing,
250
+ logit_scale,
251
+ lse_square_scale,
252
+ ignore_index,
253
+ process_group,
254
+ )
255
+ ctx.save_for_backward(logits, lse, target)
256
+ ctx.mark_non_differentiable(z_losses)
257
+ ctx.label_smoothing = label_smoothing
258
+ ctx.logit_scale = logit_scale
259
+ ctx.lse_square_scale = lse_square_scale
260
+ ctx.ignore_index = ignore_index
261
+ ctx.total_classes = total_classes
262
+ ctx.class_start_idx = class_start_idx
263
+ ctx.inplace_backward = inplace_backward
264
+
265
+ return losses, z_losses
266
+
267
+ @staticmethod
268
+ @input_guard
269
+ def backward(ctx, grad_losses, grad_z_losses):
270
+ del grad_z_losses # z_losses are only for logging.
271
+
272
+ logits, lse, target = ctx.saved_tensors
273
+ dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
274
+ n_rows, n_cols = logits.shape
275
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
276
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
277
+ def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
278
+ cross_entropy_bwd_kernel[grid](
279
+ dlogits, # data ptrs
280
+ grad_losses,
281
+ logits,
282
+ lse,
283
+ target,
284
+ ctx.label_smoothing,
285
+ ctx.logit_scale,
286
+ ctx.lse_square_scale,
287
+ ctx.ignore_index,
288
+ ctx.total_classes,
289
+ ctx.class_start_idx,
290
+ n_cols, # shapes
291
+ logits.stride(0), # strides
292
+ dlogits.stride(0),
293
+ grad_losses.stride(0),
294
+ BLOCK_SIZE=BLOCK_SIZE, # constants
295
+ num_warps=num_warps,
296
+ )
297
+ return dlogits, None, None, None, None, None, None, None, None
298
+
299
+
300
+ def cross_entropy_loss(
301
+ logits: torch.Tensor,
302
+ target: torch.Tensor,
303
+ label_smoothing: float = 0.0,
304
+ logit_scale: float = 1.0,
305
+ lse_square_scale: float = 0.0,
306
+ ignore_index=-100,
307
+ inplace_backward: bool = False,
308
+ process_group=None,
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ """
311
+ Arguments:
312
+ logits: [batch, vocab_size]
313
+ target: [batch,]
314
+ label_smoothing: float
315
+ logit_scale: float.
316
+ Multiply logits by this scale before calculating the loss.
317
+ lse_square_scale: float.
318
+ If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
319
+ This is also referred to as "z-loss".
320
+ ignore_index: int.
321
+ If target == ignore_index, the loss is set to 0.0.
322
+ inplace_backward: bool.
323
+ If True, we do the backward pass in-place by modifying the logits.
324
+ This saves memory.
325
+ process_group:
326
+ if not None, we're doing Tensor Parallel: each process is responsible for
327
+ one part of the vocab. The loss will be aggregated across processes.
328
+ Returns:
329
+ losses: [batch,], float
330
+ z_losses: [batch,], float
331
+ """
332
+ return CrossEntropyLossFunction.apply(
333
+ logits,
334
+ target,
335
+ label_smoothing,
336
+ logit_scale,
337
+ lse_square_scale,
338
+ ignore_index,
339
+ inplace_backward,
340
+ process_group,
341
+ )
342
+
343
+
344
+ class FusedCrossEntropyLoss(nn.Module):
345
+ def __init__(
346
+ self,
347
+ ignore_index: int = -100,
348
+ reduction: str = "mean",
349
+ label_smoothing: float = 0.0,
350
+ logit_scale: float = 1.0,
351
+ lse_square_scale: float = 0.0,
352
+ inplace_backward: bool = False,
353
+ process_group: Any = None,
354
+ return_z_loss: bool = False,
355
+ ):
356
+ """
357
+ Arguments:
358
+ ignore_index: int. If target == ignore_index, the loss is set to 0.0.
359
+ label_smoothing: float
360
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
361
+ This is also referred to as "z-loss".
362
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
363
+ This saves memory.
364
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
365
+ one part of the vocab. The loss will be aggregated across processes.
366
+ return_z_loss: bool. If True, we return the component of the loss contributed by
367
+ the lse_square_scale value. This value is only for logging and does not support
368
+ backprop.
369
+ """
370
+ super().__init__()
371
+ if reduction not in ["mean", "none", "sum"]:
372
+ raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
373
+ self.ignore_index = ignore_index
374
+ self.reduction = reduction
375
+ self.label_smoothing = label_smoothing
376
+ self.logit_scale = logit_scale
377
+ self.lse_square_scale = lse_square_scale
378
+ self.inplace_backward = inplace_backward
379
+ self.process_group = process_group
380
+ self.return_z_loss = return_z_loss
381
+
382
+ def forward(self, input, target):
383
+ """
384
+ Arguments:
385
+ input: (batch, vocab_size)
386
+ target: (batch,)
387
+ Returns:
388
+ losses: (batch,) if reduction is 'none', else (1,), dtype float
389
+ z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
390
+ """
391
+ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
392
+ loss, z_loss = cross_entropy_loss(
393
+ input,
394
+ target,
395
+ label_smoothing=self.label_smoothing,
396
+ logit_scale=self.logit_scale,
397
+ lse_square_scale=self.lse_square_scale,
398
+ ignore_index=self.ignore_index,
399
+ inplace_backward=self.inplace_backward,
400
+ process_group=self.process_group,
401
+ )
402
+ if self.reduction == "mean":
403
+ loss = loss.sum() / (target != self.ignore_index).sum()
404
+ elif self.reduction == "sum":
405
+ loss = loss.sum()
406
+ else:
407
+ loss = loss
408
+
409
+ if not self.return_z_loss:
410
+ return loss
411
+
412
+ if self.reduction == "mean":
413
+ z_loss = z_loss.sum() / (target != self.ignore_index).sum()
414
+ elif self.reduction == "sum":
415
+ z_loss = z_loss.sum()
416
+ else:
417
+ z_loss = z_loss
418
+
419
+ return loss, z_loss
fla/modules/fused_kl_div.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.ops.utils.op import exp, log
12
+ from fla.utils import input_guard
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
15
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
16
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
17
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
18
+ MAX_FUSED_SIZE = 65536 // 2
19
+
20
+
21
+ @triton.jit
22
+ def kl_div_kernel(
23
+ logits,
24
+ target_logits,
25
+ loss,
26
+ s_logits,
27
+ s_loss,
28
+ reduction: tl.constexpr,
29
+ N: tl.constexpr,
30
+ V: tl.constexpr,
31
+ BV: tl.constexpr
32
+ ):
33
+ # https://github.com/triton-lang/triton/issues/1058
34
+ # If N*V is too large, i_n * stride will overflow out of int32, so we convert to int64
35
+ i_n = tl.program_id(0).to(tl.int64)
36
+
37
+ logits += i_n * s_logits
38
+ target_logits += i_n * s_logits
39
+
40
+ # m is the max value. use the notation from the paper
41
+ sm = float('-inf')
42
+ tm = float('-inf')
43
+ # d is the sum. use the notation from the paper
44
+ sd, td = 0.0, 0.0
45
+
46
+ NV = tl.cdiv(V, BV)
47
+ for iv in range(0, NV):
48
+ o_x = iv * BV + tl.arange(0, BV)
49
+ # for student
50
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
51
+ b_sm = tl.max(b_sl)
52
+ m_new = tl.maximum(sm, b_sm)
53
+ sd = sd * exp(sm - m_new) + tl.sum(exp(b_sl - m_new))
54
+ sm = m_new
55
+ # for teacher
56
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
57
+ b_tm = tl.max(b_tl)
58
+ m_new = tl.maximum(tm, b_tm)
59
+ td = td * exp(tm - m_new) + tl.sum(exp(b_tl - m_new))
60
+ tm = m_new
61
+
62
+ b_loss = 0.
63
+ # KL(y_true || y) = exp(y_true) * (log(y_true) - log(y))
64
+ for iv in range(0, NV):
65
+ o_x = iv * BV + tl.arange(0, BV)
66
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
67
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
68
+ b_sp_log = b_sl - sm - log(sd)
69
+ b_tp_log = b_tl - tm - log(td)
70
+ b_sp = exp(b_sp_log)
71
+ b_tp = exp(b_tp_log)
72
+ b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0)
73
+ b_dl = -b_tp + b_sp
74
+ b_loss += tl.sum(b_kl)
75
+ if reduction == 'batchmean':
76
+ b_dl = b_dl / N
77
+ tl.store(logits + o_x, b_dl, mask=o_x < V)
78
+
79
+ # Normalize the loss by the number of elements if reduction is 'batchmean'
80
+ if reduction == 'batchmean':
81
+ b_loss = b_loss / N
82
+
83
+ tl.store(loss + i_n * s_loss, b_loss)
84
+
85
+
86
+ @triton.jit
87
+ def elementwise_mul_kernel(
88
+ x,
89
+ g,
90
+ N: tl.constexpr,
91
+ B: tl.constexpr
92
+ ):
93
+ """
94
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
95
+ The multiplication is performed in-place on the tensor pointed by x.
96
+
97
+ Parameters:
98
+ x:
99
+ Pointer to the input tensor.
100
+ g:
101
+ Pointer to the gradient output value.
102
+ N (int):
103
+ The number of columns in the input tensor.
104
+ B (int):
105
+ The block size for Triton operations.
106
+ """
107
+
108
+ # Get the program ID and convert it to int64 to avoid overflow
109
+ i_x = tl.program_id(0).to(tl.int64)
110
+ o_x = i_x * B + tl.arange(0, B)
111
+
112
+ # Load the gradient output value
113
+ b_g = tl.load(g)
114
+ b_x = tl.load(x + o_x, mask=o_x < N)
115
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
116
+
117
+
118
+ def fused_kl_div_forward(
119
+ x: torch.Tensor,
120
+ target_x: torch.Tensor,
121
+ weight: torch.Tensor,
122
+ target_weight: torch.Tensor,
123
+ reduction: str = 'batchmean'
124
+ ):
125
+ device = x.device
126
+
127
+ # ideally, we would like to achieve the same memory consumption as [N, H],
128
+ # so the expected chunk size should be:
129
+ # NC = ceil(V / H)
130
+ # C = ceil(N / NC)
131
+ # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
132
+ N, H, V = *x.shape, weight.shape[0]
133
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
134
+ # TODO: in real cases, we may need to limit the number of chunks NC to
135
+ # ensure the precisions of accumulated gradients
136
+ NC = min(8, triton.cdiv(V, H))
137
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
138
+ NC = triton.cdiv(N, C)
139
+
140
+ dx = torch.zeros_like(x, device=device)
141
+ dw = torch.zeros_like(weight, device=device) if weight is not None else None
142
+ # we use fp32 for loss accumulator
143
+ loss = torch.zeros(N, dtype=torch.float32, device=device)
144
+
145
+ for ic in range(NC):
146
+ start, end = ic * C, min((ic + 1) * C, N)
147
+ # [C, N]
148
+ c_sx = x[start:end]
149
+ c_tx = target_x[start:end]
150
+ # when doing matmul, use the original precision
151
+ # [C, V]
152
+ c_sl = F.linear(c_sx, weight)
153
+ c_tl = F.linear(c_tx, target_weight)
154
+
155
+ # unreduced loss
156
+ c_loss = loss[start:end]
157
+
158
+ # Here we calculate the gradient of c_sx in place so we can save memory.
159
+ kl_div_kernel[(c_sx.shape[0],)](
160
+ logits=c_sl,
161
+ target_logits=c_tl,
162
+ loss=c_loss,
163
+ s_logits=c_sl.stride(-2),
164
+ s_loss=c_loss.stride(-1),
165
+ reduction=reduction,
166
+ N=N,
167
+ V=V,
168
+ BV=BV,
169
+ num_warps=32
170
+ )
171
+
172
+ # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
173
+ # thus dx[start: end] should be of shape: C x H
174
+ # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
175
+ # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
176
+ # Thus, we need an additional scaling factor of (n_non_ignore/total) to scale the gradients.
177
+ # [C, H]
178
+
179
+ dx[start:end] = torch.mm(c_sl, weight)
180
+
181
+ if weight is not None:
182
+ torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw)
183
+
184
+ loss = loss.sum()
185
+ return loss, dx, dw
186
+
187
+
188
+ def fused_kl_div_backward(
189
+ do: torch.Tensor,
190
+ dx: torch.Tensor,
191
+ dw: torch.Tensor
192
+ ):
193
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
194
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
195
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
196
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
197
+ N, H = dx.shape
198
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
199
+
200
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
201
+ x=dx,
202
+ g=do,
203
+ N=N*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ # handle dw
209
+ if dw is not None:
210
+ V, H = dw.shape
211
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
212
+ x=dw,
213
+ g=do,
214
+ N=V*H,
215
+ B=B,
216
+ num_warps=32,
217
+ )
218
+
219
+ return dx, dw
220
+
221
+
222
+ class FusedKLDivLossFunction(torch.autograd.Function):
223
+
224
+ @staticmethod
225
+ @input_guard
226
+ def forward(
227
+ ctx,
228
+ x: torch.Tensor,
229
+ target_x: torch.Tensor,
230
+ weight: torch.Tensor,
231
+ target_weight: torch.Tensor,
232
+ reduction: str
233
+ ):
234
+ loss, dx, dw = fused_kl_div_forward(
235
+ x=x,
236
+ target_x=target_x,
237
+ weight=weight,
238
+ target_weight=target_weight,
239
+ reduction=reduction
240
+ )
241
+ ctx.save_for_backward(dx, dw)
242
+ return loss
243
+
244
+ @staticmethod
245
+ @input_guard
246
+ def backward(ctx, do):
247
+ dx, dw = ctx.saved_tensors
248
+ dx, dw = fused_kl_div_backward(do, dx, dw)
249
+ return dx, None, dw, None, None
250
+
251
+
252
+ def fused_kl_div_loss(
253
+ x: torch.Tensor,
254
+ target_x: torch.Tensor,
255
+ weight: torch.Tensor,
256
+ target_weight: torch.Tensor,
257
+ reduction: str = 'batchmean'
258
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ """
260
+ Args:
261
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
262
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
266
+ where `vocab_size` is the number of classes.
267
+ reduction:
268
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
269
+ Returns:
270
+ loss
271
+ """
272
+ return FusedKLDivLossFunction.apply(
273
+ x,
274
+ target_x,
275
+ weight,
276
+ target_weight,
277
+ reduction
278
+ )
279
+
280
+
281
+ class FusedKLDivLoss(nn.Module):
282
+
283
+ def __init__(
284
+ self,
285
+ reduction: str = 'batchmean'
286
+ ):
287
+ """
288
+ Args:
289
+ reduction:
290
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
291
+ """
292
+ super().__init__()
293
+
294
+ assert reduction in ['batchmean'], f"reduction: {reduction} is not supported"
295
+
296
+ self.reduction = reduction
297
+
298
+ def forward(
299
+ self,
300
+ x: torch.Tensor,
301
+ target_x: torch.Tensor,
302
+ weight: torch.Tensor,
303
+ target_weight: torch.Tensor
304
+ ):
305
+ """
306
+ Args:
307
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
308
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
309
+ weight (torch.Tensor): [vocab_size, hidden_size]
310
+ where `vocab_size` is the number of classes.
311
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
312
+ where `vocab_size` is the number of classes.
313
+ Returns:
314
+ loss
315
+ """
316
+ loss = fused_kl_div_loss(
317
+ x=x,
318
+ target_x=target_x,
319
+ weight=weight,
320
+ target_weight=target_weight,
321
+ reduction=self.reduction
322
+ )
323
+ return loss
fla/modules/fused_linear_cross_entropy.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Code adapted from
4
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
16
+ from torch.distributed.tensor.parallel import ParallelStyle
17
+
18
+ from fla.ops.utils import logsumexp_fwd
19
+ from fla.ops.utils.op import exp
20
+ from fla.utils import input_guard
21
+
22
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
23
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
24
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
25
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
26
+ MAX_FUSED_SIZE = 65536 // 2
27
+
28
+
29
+ @triton.jit
30
+ def cross_entropy_kernel(
31
+ logits,
32
+ lse,
33
+ target,
34
+ loss,
35
+ total,
36
+ ignore_index,
37
+ label_smoothing: tl.constexpr,
38
+ logit_scale: tl.constexpr,
39
+ reduction: tl.constexpr,
40
+ V: tl.constexpr,
41
+ BV: tl.constexpr
42
+ ):
43
+ """
44
+ This kernel computes both cross entropy loss and the gradient of the input.
45
+ We only consider hard label + mean reduction for now.
46
+ Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
47
+
48
+ Args:
49
+ logits:
50
+ Pointer to logits tensor.
51
+ lse:
52
+ Pointer to logsumexp tensor.
53
+ target: Pointer to target tensor.
54
+ loss:
55
+ Pointer to tensor to store the loss.
56
+ V (int):
57
+ The number of columns in the input tensor.
58
+ total (int):
59
+ The number of non-ignored classes.
60
+ ignore_index (int):
61
+ The index to ignore in the target.
62
+ label_smoothing (float):
63
+ The amount of smoothing when computing the loss, where 0.0 means no smoothing.
64
+ reduction (str):
65
+ The string for the reduction to apply
66
+ BV (int):
67
+ The block size for vocab.
68
+ """
69
+
70
+ # https://github.com/triton-lang/triton/issues/1058
71
+ # If B*T*V is too large, i_n * stride will overflow out of int32, so we convert to int64
72
+ i_n = tl.program_id(0).to(tl.int64)
73
+ NV = tl.cdiv(V, BV)
74
+
75
+ # 1. Load target first because if the target is ignore_index, we can return right away
76
+ b_y = tl.load(target + i_n)
77
+
78
+ # 2. locate the start index
79
+ logits += i_n * V
80
+
81
+ if b_y == ignore_index:
82
+ # set all x as 0
83
+ for i in range(0, V, BV):
84
+ o_v = i + tl.arange(0, BV)
85
+ tl.store(logits + o_v, 0.0, mask=o_v < V)
86
+ return
87
+
88
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
89
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
90
+
91
+ # 3. [Online softmax] first pass: compute logsumexp
92
+ # we did this in anouter kernel
93
+ b_l = tl.load(logits + b_y) * logit_scale
94
+ b_lse = tl.load(lse + i_n)
95
+
96
+ # 4. Calculate the loss
97
+ # loss = lse - logits_l
98
+ b_loss = b_lse - b_l
99
+
100
+ # Label smoothing is a general case of normal cross entropy
101
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
102
+ b_z = 0.0
103
+ eps = label_smoothing / V
104
+
105
+ # We need tl.debug_barrier() as mentioned in
106
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
107
+ tl.debug_barrier()
108
+
109
+ # 5. [Online Softmax] Second pass: compute gradients
110
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements
111
+ # dx_y = (softmax(x_y) - 1) / N
112
+ # dx_i = softmax(x_i) / N, i != y
113
+ # For label smoothing:
114
+ # dx_i = (softmax(x_y) - label_smoothing / V) / N, i != y
115
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
116
+ # = dx_i - (1 - label_smoothing) / N
117
+ for iv in range(0, NV):
118
+ o_v = iv * BV + tl.arange(0, BV)
119
+ b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale
120
+ if label_smoothing > 0:
121
+ # scale X beforehand to avoid overflow
122
+ b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0))
123
+ b_p = (exp(b_logits - b_lse) - eps) * logit_scale
124
+ if reduction == "mean":
125
+ b_p = b_p / total
126
+ tl.store(logits + o_v, b_p, mask=o_v < V)
127
+
128
+ tl.debug_barrier()
129
+
130
+ # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
131
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
132
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
133
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
134
+ # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
135
+ # Refer to H(q', p) in section 7 of the paper:
136
+ # https://arxiv.org/pdf/1512.00567
137
+ # pytorch:
138
+ # https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
139
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
140
+ if label_smoothing > 0:
141
+ b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse)
142
+
143
+ # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
144
+ b_l = tl.load(logits + b_y)
145
+
146
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
147
+ if reduction == 'mean':
148
+ b_loss = b_loss / total
149
+ b_l += (label_smoothing - 1) / total * logit_scale
150
+ else:
151
+ b_l += (label_smoothing - 1) * logit_scale
152
+
153
+ tl.store(loss + i_n, b_loss)
154
+ tl.store(logits + b_y, b_l)
155
+
156
+
157
+ @triton.jit
158
+ def elementwise_mul_kernel(
159
+ x,
160
+ g,
161
+ N: tl.constexpr,
162
+ B: tl.constexpr
163
+ ):
164
+ """
165
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
166
+ The multiplication is performed in-place on the tensor pointed by x.
167
+
168
+ Parameters:
169
+ x:
170
+ Pointer to the input tensor.
171
+ g:
172
+ Pointer to the gradient output value.
173
+ N (int):
174
+ The number of columns in the input tensor.
175
+ B (int):
176
+ The block size for Triton operations.
177
+ """
178
+
179
+ # Get the program ID and convert it to int64 to avoid overflow
180
+ i_x = tl.program_id(0).to(tl.int64)
181
+ o_x = i_x * B + tl.arange(0, B)
182
+
183
+ # Load the gradient output value
184
+ b_g = tl.load(g)
185
+ b_x = tl.load(x + o_x, mask=o_x < N)
186
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
187
+
188
+
189
+ def fused_linear_cross_entropy_forward(
190
+ x: torch.Tensor,
191
+ target: torch.LongTensor,
192
+ weight: torch.Tensor,
193
+ bias: torch.Tensor = None,
194
+ ignore_index: int = -100,
195
+ label_smoothing: float = 0.0,
196
+ logit_scale: float = 1.0,
197
+ num_chunks: int = 8,
198
+ reduction: str = "mean"
199
+ ):
200
+ device = x.device
201
+ # inputs have shape: [N, H]
202
+ # materialized activations will have shape: [N, V]
203
+ # the increase in memory = [N, V]
204
+ # reduction can be achieved by partitioning the number of tokens N into smaller chunks.
205
+
206
+ # ideally, we would like to achieve the same memory consumption as [N, H],
207
+ # so the expected chunk size should be:
208
+ # NC = ceil(V / H)
209
+ # C = ceil(N / NC)
210
+ # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
211
+ N, H, V = *x.shape, weight.shape[0]
212
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
213
+ # TODO: in real cases, we may need to limit the number of chunks NC to
214
+ # ensure the precisions of accumulated gradients
215
+ NC = min(num_chunks, triton.cdiv(V, H))
216
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
217
+ NC = triton.cdiv(N, C)
218
+
219
+ # [N, H]
220
+ dx = torch.zeros_like(x, device=device)
221
+ # [V, H]
222
+ dw = torch.zeros_like(weight, device=device, dtype=torch.float) if weight is not None else None
223
+ # [V]
224
+ db = torch.zeros_like(bias, device=device, dtype=torch.float) if bias is not None else None
225
+ # [N]
226
+ loss = torch.zeros(N, device=device, dtype=torch.float)
227
+
228
+ total = target.ne(ignore_index).sum().item()
229
+
230
+ for ic in range(NC):
231
+ start, end = ic * C, min((ic + 1) * C, N)
232
+ # [C, N]
233
+ c_x = x[start:end]
234
+ # when doing matmul, use the original precision
235
+ # [C, V]
236
+ c_logits = F.linear(c_x, weight, bias)
237
+ c_target = target[start:end]
238
+ # [C]
239
+ # keep lse in fp32 to maintain precision
240
+ c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float)
241
+
242
+ # unreduced loss
243
+ c_loss = loss[start:end]
244
+
245
+ # Here we calculate the gradient of c_logits in place so we can save memory.
246
+ cross_entropy_kernel[(c_logits.shape[0],)](
247
+ logits=c_logits,
248
+ lse=c_lse,
249
+ target=c_target,
250
+ loss=c_loss,
251
+ total=total,
252
+ ignore_index=ignore_index,
253
+ label_smoothing=label_smoothing,
254
+ logit_scale=logit_scale,
255
+ reduction=reduction,
256
+ V=V,
257
+ BV=BV,
258
+ num_warps=32
259
+ )
260
+
261
+ # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
262
+ # thus dx should be of shape: C x H
263
+ dx[start:end] = torch.mm(c_logits, weight)
264
+
265
+ # keep dw in fp32 to maintain precision
266
+ if weight is not None:
267
+ dw += c_logits.t() @ c_x
268
+
269
+ if bias is not None:
270
+ torch.add(input=db, other=c_logits.sum(0), out=db)
271
+
272
+ loss = loss.sum()
273
+ if dw is not None:
274
+ dw = dw.to(weight)
275
+ if db is not None:
276
+ db = db.to(bias)
277
+ return loss, dx, dw, db
278
+
279
+
280
+ def fused_linear_cross_entropy_backward(
281
+ do: torch.Tensor,
282
+ dx: torch.Tensor,
283
+ dw: torch.Tensor,
284
+ db: torch.Tensor
285
+ ):
286
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
287
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
288
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
289
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
290
+ N, H = dx.shape
291
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
292
+
293
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
294
+ x=dx,
295
+ g=do,
296
+ N=N*H,
297
+ B=B,
298
+ num_warps=32,
299
+ )
300
+
301
+ # handle dw
302
+ if dw is not None:
303
+ V, H = dw.shape
304
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
305
+ x=dw,
306
+ g=do,
307
+ N=V*H,
308
+ B=B,
309
+ num_warps=32,
310
+ )
311
+
312
+ if db is not None:
313
+ V = db.shape[0]
314
+ elementwise_mul_kernel[(triton.cdiv(V, B),)](
315
+ x=db,
316
+ g=do,
317
+ N=V,
318
+ B=B,
319
+ num_warps=32,
320
+ )
321
+ return dx, dw, db
322
+
323
+
324
+ class FusedLinearCrossEntropyFunction(torch.autograd.Function):
325
+
326
+ @staticmethod
327
+ @input_guard
328
+ def forward(
329
+ ctx,
330
+ x: torch.Tensor,
331
+ target: torch.LongTensor,
332
+ weight: torch.Tensor,
333
+ bias: torch.Tensor = None,
334
+ ignore_index: int = -100,
335
+ label_smoothing: float = 0.0,
336
+ logit_scale: float = 1.0,
337
+ num_chunks: int = 8,
338
+ reduction: str = "mean"
339
+ ):
340
+ """
341
+ Fusing the last linear layer with cross-entropy loss
342
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
343
+
344
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
345
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
346
+ compute the gradient at the forward pass. By doing so, we don't have to store the x and target
347
+ for the backward pass.
348
+
349
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
350
+ target (torch.LongTensor): [batch_size * seq_len]
351
+ where each value is in [0, vocab_size).
352
+ weight (torch.Tensor): [vocab_size, hidden_size]
353
+ where `vocab_size` is the number of classes.
354
+ bias (Optional[torch.Tensor]): [vocab_size]
355
+ where `vocab_size` is the number of classes.
356
+ ignore_index:
357
+ the index to ignore in the target.
358
+ label_smoothing:
359
+ the amount of smoothing when computing the loss, where 0.0 means no smoothing.
360
+ logit_scale: float = 1.0,
361
+ A scaling factor applied to the logits. Default: 1.0
362
+ num_chunks: int
363
+ The number of chunks to split the input tensor into for processing.
364
+ This can help optimize memory usage and computation speed.
365
+ Default: 8
366
+ reduction:
367
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
368
+ 'mean': the weighted mean of the output is taken,
369
+ 'sum': the output will be summed.
370
+ Default: 'mean'.
371
+ """
372
+ loss, dx, dw, db = fused_linear_cross_entropy_forward(
373
+ x,
374
+ target,
375
+ weight,
376
+ bias,
377
+ ignore_index,
378
+ label_smoothing,
379
+ logit_scale,
380
+ num_chunks,
381
+ reduction
382
+ )
383
+ # downcast to dtype and store for backward
384
+ ctx.save_for_backward(
385
+ dx.detach(),
386
+ dw.detach() if weight is not None else None,
387
+ db.detach() if bias is not None else None,
388
+ )
389
+ return loss
390
+
391
+ @staticmethod
392
+ @input_guard
393
+ def backward(ctx, do):
394
+ dx, dw, db = ctx.saved_tensors
395
+ dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db)
396
+ return dx, None, dw, db, None, None, None, None, None
397
+
398
+
399
+ def fused_linear_cross_entropy_loss(
400
+ x: torch.Tensor,
401
+ target: torch.LongTensor,
402
+ weight: torch.Tensor,
403
+ bias: torch.Tensor = None,
404
+ ignore_index: int = -100,
405
+ label_smoothing: float = 0.0,
406
+ logit_scale: float = 1.0,
407
+ num_chunks: int = 8,
408
+ reduction: str = "mean"
409
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
410
+ """
411
+ Args:
412
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
413
+ target (torch.LongTensor): [batch_size * seq_len]
414
+ where each value is in [0, vocab_size).
415
+ weight (torch.Tensor): [vocab_size, hidden_size]
416
+ where `vocab_size` is the number of classes.
417
+ bias (Optional[torch.Tensor]): [vocab_size]
418
+ where `vocab_size` is the number of classes.
419
+ ignore_index: int.
420
+ If target == ignore_index, the loss is set to 0.0.
421
+ label_smoothing: float
422
+ logit_scale: float
423
+ A scaling factor applied to the logits. Default: 1.0
424
+ num_chunks: int
425
+ The number of chunks to split the input tensor into for processing.
426
+ This can help optimize memory usage and computation speed.
427
+ Default: 8
428
+ reduction:
429
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
430
+ 'mean': the weighted mean of the output is taken,
431
+ 'sum': the output will be summed.
432
+ Default: 'mean'.
433
+ Returns:
434
+ losses: [batch,], float
435
+ """
436
+ return FusedLinearCrossEntropyFunction.apply(
437
+ x,
438
+ target,
439
+ weight,
440
+ bias,
441
+ ignore_index,
442
+ label_smoothing,
443
+ logit_scale,
444
+ num_chunks,
445
+ reduction
446
+ )
447
+
448
+
449
+ class FusedLinearCrossEntropyLoss(nn.Module):
450
+
451
+ def __init__(
452
+ self,
453
+ ignore_index: int = -100,
454
+ label_smoothing: float = 0.0,
455
+ logit_scale: float = 1.0,
456
+ num_chunks: int = 8,
457
+ reduction: str = "mean"
458
+ ):
459
+ """
460
+ Args:
461
+ ignore_index: int.
462
+ If target == ignore_index, the loss is set to 0.0.
463
+ label_smoothing: float
464
+ logit_scale: float
465
+ A scaling factor applied to the logits. Default: 1.0
466
+ num_chunks: int
467
+ The number of chunks to split the input tensor into for processing.
468
+ This can help optimize memory usage and computation speed.
469
+ Default: 8
470
+ reduction:
471
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
472
+ 'mean': the weighted mean of the output is taken,
473
+ 'sum': the output will be summed.
474
+ Default: 'mean'.
475
+ """
476
+ super().__init__()
477
+
478
+ assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported"
479
+
480
+ self.ignore_index = ignore_index
481
+ self.label_smoothing = label_smoothing
482
+ self.logit_scale = logit_scale
483
+ self.num_chunks = num_chunks
484
+ self.reduction = reduction
485
+
486
+ @torch.compiler.disable
487
+ def forward(
488
+ self,
489
+ x: torch.Tensor,
490
+ target: torch.LongTensor,
491
+ weight: torch.Tensor,
492
+ bias: Optional[torch.Tensor] = None
493
+ ):
494
+ """
495
+ Args:
496
+ x (torch.Tensor): [batch_size, seq_len, hidden_size]
497
+ target (torch.LongTensor): [batch_size, seq_len]
498
+ where each value is in [0, V).
499
+ weight (torch.Tensor): [vocab_size, hidden_size]
500
+ where `vocab_size` is the number of classes.
501
+ bias (Optional[torch.Tensor]): [vocab_size]
502
+ where `vocab_size` is the number of classes.
503
+ Returns:
504
+ loss
505
+ """
506
+ loss = fused_linear_cross_entropy_loss(
507
+ x.view(-1, x.shape[-1]),
508
+ target.view(-1),
509
+ weight=weight,
510
+ bias=bias,
511
+ ignore_index=self.ignore_index,
512
+ label_smoothing=self.label_smoothing,
513
+ logit_scale=self.logit_scale,
514
+ num_chunks=self.num_chunks,
515
+ reduction=self.reduction
516
+ )
517
+ return loss
518
+
519
+
520
+ class LinearLossParallel(ParallelStyle):
521
+ def __init__(
522
+ self,
523
+ *,
524
+ sequence_dim: int = 1,
525
+ use_local_output: bool = False,
526
+ ):
527
+ super().__init__()
528
+
529
+ self.sequence_sharding = (Shard(sequence_dim),)
530
+ self.use_local_output = use_local_output
531
+
532
+ @staticmethod
533
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
534
+ x, target, weight, bias = inputs
535
+
536
+ if not isinstance(x, DTensor):
537
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
538
+ x = DTensor.from_local(x, device_mesh, sequence_sharding)
539
+ if x.placements != sequence_sharding:
540
+ x = x.redistribute(placements=sequence_sharding, async_op=True)
541
+ if not isinstance(target, DTensor):
542
+ target = DTensor.from_local(target, device_mesh, [Replicate()])
543
+ if target.placements != sequence_sharding:
544
+ target = target.redistribute(placements=sequence_sharding, async_op=True)
545
+
546
+ if not isinstance(weight, DTensor):
547
+ weight = DTensor.from_local(weight, device_mesh, [Replicate()])
548
+ if weight.placements != [Replicate()]:
549
+ # we replicate the weight/bias in FLCE
550
+ weight = weight.redistribute(placements=[Replicate()], async_op=True)
551
+
552
+ if bias is not None and not isinstance(bias, DTensor):
553
+ bias = DTensor.from_local(bias, device_mesh, [Replicate()])
554
+ if bias is not None and bias.placements != [Replicate()]:
555
+ bias = bias.redistribute(placements=[Replicate()], async_op=True)
556
+
557
+ return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias
558
+
559
+ @staticmethod
560
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
561
+ return outputs.to_local() if use_local_output else outputs
562
+
563
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
564
+ return distribute_module(
565
+ module,
566
+ device_mesh,
567
+ partition_fn=None,
568
+ input_fn=partial(self._prepare_input_fn, self.sequence_sharding),
569
+ output_fn=partial(self._prepare_output_fn, self.use_local_output)
570
+ )
fla/modules/fused_linear_listnet_loss.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Code adapted from
4
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
16
+ from torch.distributed.tensor.parallel import ParallelStyle
17
+
18
+ from fla.ops.utils import logsumexp_fwd
19
+ from fla.ops.utils.op import exp
20
+ from fla.utils import input_guard
21
+
22
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
23
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
24
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
25
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
26
+ MAX_FUSED_SIZE = 65536 // 2
27
+
28
+ @triton.jit
29
+ def listnet_kernel(
30
+ logits,
31
+ targets, # Now full target distributions
32
+ lse_logits,
33
+ lse_targets,
34
+ loss,
35
+ total,
36
+ ignore_index,
37
+ logit_scale: tl.constexpr,
38
+ reduction: tl.constexpr,
39
+ V: tl.constexpr,
40
+ BV: tl.constexpr
41
+ ):
42
+ i_n = tl.program_id(0).to(tl.int64)
43
+ NV = tl.cdiv(V, BV)
44
+
45
+ # Pointers to current token's data
46
+ logits_ptr = logits + i_n * V
47
+ targets_ptr = targets + i_n * V
48
+ loss_ptr = loss + i_n
49
+
50
+ # Compute prediction softmax
51
+ b_lse_logits = tl.load(lse_logits + i_n)
52
+ b_lse_targets = tl.load(lse_targets + i_n)
53
+ b_loss = 0.0
54
+
55
+ # Compute gradient: softmax(pred) - softmax(target)
56
+ for iv in range(0, NV):
57
+ o_v = iv * BV + tl.arange(0, BV)
58
+ mask = o_v < V
59
+
60
+ # Load target and compute softmax
61
+ t_val = tl.load(targets_ptr + o_v, mask=mask, other=0.0)
62
+ p_target = tl.exp(t_val - b_lse_targets)
63
+
64
+ # Load logits and compute softmax
65
+ l_val = tl.load(logits_ptr + o_v, mask=mask, other=0.0) * logit_scale
66
+ l_val_minus_lse = l_val - b_lse_logits
67
+ p_pred = tl.exp(l_val_minus_lse)
68
+
69
+ # Gradient calculation
70
+ grad_val = p_pred - p_target
71
+ if reduction == "mean":
72
+ grad_val = grad_val / total
73
+ grad_val = tl.where(b_lse_targets == float('inf'), 0.0, grad_val)
74
+ tl.store(logits_ptr + o_v, grad_val, mask=mask)
75
+
76
+ # Cross-entropy loss
77
+ # instead of: b_loss -= tl.sum(p_target * tl.log(p_pred), axis=0)
78
+ b_loss -= tl.sum(p_target * l_val_minus_lse, axis=0)
79
+
80
+ tl.store(loss_ptr, b_loss)
81
+
82
+ @triton.jit
83
+ def elementwise_mul_kernel(
84
+ x,
85
+ g,
86
+ N: tl.constexpr,
87
+ B: tl.constexpr
88
+ ):
89
+ """
90
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
91
+ The multiplication is performed in-place on the tensor pointed by x.
92
+
93
+ Parameters:
94
+ x:
95
+ Pointer to the input tensor.
96
+ g:
97
+ Pointer to the gradient output value.
98
+ N (int):
99
+ The number of columns in the input tensor.
100
+ B (int):
101
+ The block size for Triton operations.
102
+ """
103
+
104
+ # Get the program ID and convert it to int64 to avoid overflow
105
+ i_x = tl.program_id(0).to(tl.int64)
106
+ o_x = i_x * B + tl.arange(0, B)
107
+
108
+ # Load the gradient output value
109
+ b_g = tl.load(g)
110
+ b_x = tl.load(x + o_x, mask=o_x < N)
111
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
112
+
113
+
114
+ def fused_linear_listnet_forward(
115
+ x: torch.Tensor,
116
+ targets: torch.Tensor, # Float tensor [N, V]
117
+ weight: torch.Tensor,
118
+ bias: torch.Tensor = None,
119
+ ignore_index: int = -100,
120
+ logit_scale: float = 1.0,
121
+ num_chunks: int = 8,
122
+ reduction: str = "mean"
123
+ ):
124
+ N, H, V = *x.shape, weight.shape[0]
125
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
126
+ NC = min(num_chunks, triton.cdiv(V, H))
127
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
128
+ NC = triton.cdiv(N, C)
129
+
130
+ # Initialize outputs
131
+ dx = torch.zeros_like(x)
132
+ dw = torch.zeros_like(weight, dtype=torch.float) if weight is not None else None
133
+ db = torch.zeros_like(bias, dtype=torch.float) if bias is not None else None
134
+ loss = torch.zeros(N, device=x.device, dtype=torch.float)
135
+ total = N # All tokens considered
136
+
137
+ for ic in range(NC):
138
+ start, end = ic * C, min((ic + 1) * C, N)
139
+ c_x = x[start:end]
140
+ c_logits = F.linear(c_x, weight, bias)
141
+ c_targets = targets[start:end]
142
+ c_lse_logits = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float)
143
+ c_lse_targets = logsumexp_fwd(c_targets, dtype=torch.float).nan_to_num(nan=float("inf"))
144
+ c_loss = loss[start:end]
145
+
146
+ # Call ListNet kernel
147
+ listnet_kernel[(c_logits.shape[0],)](
148
+ logits=c_logits,
149
+ targets=c_targets, # Full target distributions
150
+ lse_logits=c_lse_logits,
151
+ lse_targets=c_lse_targets,
152
+ loss=c_loss,
153
+ total=total,
154
+ ignore_index=ignore_index,
155
+ logit_scale=logit_scale,
156
+ reduction=reduction,
157
+ V=V,
158
+ BV=BV,
159
+ num_warps=32
160
+ )
161
+
162
+ # Backward through linear layer
163
+ dx[start:end] = torch.mm(c_logits, weight)
164
+ if weight is not None:
165
+ dw += c_logits.t() @ c_x
166
+ if bias is not None:
167
+ db += c_logits.sum(0)
168
+
169
+ loss = loss.sum()
170
+ if reduction == "mean":
171
+ loss = loss / total
172
+
173
+ return loss, dx, dw, db
174
+
175
+
176
+ def fused_linear_listnet_backward(
177
+ do: torch.Tensor,
178
+ dx: torch.Tensor,
179
+ dw: torch.Tensor,
180
+ db: torch.Tensor
181
+ ):
182
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
183
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
184
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
185
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
186
+ N, H = dx.shape
187
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
188
+
189
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
190
+ x=dx,
191
+ g=do,
192
+ N=N*H,
193
+ B=B,
194
+ num_warps=32,
195
+ )
196
+
197
+ # handle dw
198
+ if dw is not None:
199
+ V, H = dw.shape
200
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
201
+ x=dw,
202
+ g=do,
203
+ N=V*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ if db is not None:
209
+ V = db.shape[0]
210
+ elementwise_mul_kernel[(triton.cdiv(V, B),)](
211
+ x=db,
212
+ g=do,
213
+ N=V,
214
+ B=B,
215
+ num_warps=32,
216
+ )
217
+ return dx, dw, db
218
+
219
+
220
+ class FusedLinearListNetFunction(torch.autograd.Function):
221
+ @staticmethod
222
+ def forward(
223
+ ctx,
224
+ x: torch.Tensor,
225
+ targets: torch.Tensor, # Float targets
226
+ weight: torch.Tensor,
227
+ bias: torch.Tensor = None,
228
+ ignore_index: int = -100,
229
+ logit_scale: float = 1.0,
230
+ num_chunks: int = 8,
231
+ reduction: str = "mean"
232
+ ):
233
+ loss, dx, dw, db = fused_linear_listnet_forward(
234
+ x, targets, weight, bias, ignore_index,
235
+ logit_scale, num_chunks, reduction
236
+ )
237
+ ctx.save_for_backward(dx, dw, db)
238
+ return loss
239
+
240
+ @staticmethod
241
+ def backward(ctx, do):
242
+ dx, dw, db = ctx.saved_tensors
243
+ dx, dw, db = fused_linear_listnet_backward(do, dx, dw, db)
244
+ return dx, None, dw, db, None, None, None, None
245
+
246
+
247
+ def fused_linear_listnet_loss(
248
+ x: torch.Tensor,
249
+ target: torch.LongTensor,
250
+ weight: torch.Tensor,
251
+ bias: torch.Tensor = None,
252
+ ignore_index: int = -100,
253
+ label_smoothing: float = 0.0,
254
+ logit_scale: float = 1.0,
255
+ num_chunks: int = 8,
256
+ reduction: str = "mean"
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """
259
+ Args:
260
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
261
+ target (torch.LongTensor): [batch_size * seq_len]
262
+ where each value is in [0, vocab_size).
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ bias (Optional[torch.Tensor]): [vocab_size]
266
+ where `vocab_size` is the number of classes.
267
+ ignore_index: int.
268
+ If target == ignore_index, the loss is set to 0.0.
269
+ label_smoothing: float
270
+ logit_scale: float
271
+ A scaling factor applied to the logits. Default: 1.0
272
+ num_chunks: int
273
+ The number of chunks to split the input tensor into for processing.
274
+ This can help optimize memory usage and computation speed.
275
+ Default: 8
276
+ reduction:
277
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
278
+ 'mean': the weighted mean of the output is taken,
279
+ 'sum': the output will be summed.
280
+ Default: 'mean'.
281
+ Returns:
282
+ losses: [batch,], float
283
+ """
284
+ return FusedLinearListNetFunction.apply(
285
+ x,
286
+ target,
287
+ weight,
288
+ bias,
289
+ ignore_index,
290
+ logit_scale,
291
+ num_chunks,
292
+ reduction
293
+ )
294
+
295
+
296
+ class FusedLinearListNetLoss(nn.Module):
297
+
298
+ def __init__(
299
+ self,
300
+ ignore_index: int = -100,
301
+ label_smoothing: float = 0.0,
302
+ logit_scale: float = 1.0,
303
+ num_chunks: int = 8,
304
+ reduction: str = "mean"
305
+ ):
306
+ """
307
+ Args:
308
+ ignore_index: int.
309
+ If target == ignore_index, the loss is set to 0.0.
310
+ label_smoothing: float
311
+ logit_scale: float
312
+ A scaling factor applied to the logits. Default: 1.0
313
+ num_chunks: int
314
+ The number of chunks to split the input tensor into for processing.
315
+ This can help optimize memory usage and computation speed.
316
+ Default: 8
317
+ reduction:
318
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
319
+ 'mean': the weighted mean of the output is taken,
320
+ 'sum': the output will be summed.
321
+ Default: 'mean'.
322
+ """
323
+ super().__init__()
324
+
325
+ assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported"
326
+
327
+ self.ignore_index = ignore_index
328
+ self.label_smoothing = label_smoothing
329
+ self.logit_scale = logit_scale
330
+ self.num_chunks = num_chunks
331
+ self.reduction = reduction
332
+
333
+ @torch.compiler.disable
334
+ def forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ target: torch.LongTensor,
338
+ weight: torch.Tensor,
339
+ bias: Optional[torch.Tensor] = None
340
+ ):
341
+ """
342
+ Args:
343
+ x (torch.Tensor): [batch_size, seq_len, hidden_size]
344
+ target (torch.LongTensor): [batch_size, seq_len]
345
+ where each value is in [0, V).
346
+ weight (torch.Tensor): [vocab_size, hidden_size]
347
+ where `vocab_size` is the number of classes.
348
+ bias (Optional[torch.Tensor]): [vocab_size]
349
+ where `vocab_size` is the number of classes.
350
+ Returns:
351
+ loss
352
+ """
353
+ loss = fused_linear_listnet_loss(
354
+ x.view(-1, x.shape[-1]),
355
+ target.view(-1, target.shape[-1]),
356
+ weight=weight,
357
+ bias=bias,
358
+ ignore_index=self.ignore_index,
359
+ label_smoothing=self.label_smoothing,
360
+ logit_scale=self.logit_scale,
361
+ num_chunks=self.num_chunks,
362
+ reduction=self.reduction
363
+ )
364
+ return loss
365
+
366
+
367
+ class LinearLossParallel(ParallelStyle):
368
+ def __init__(
369
+ self,
370
+ *,
371
+ sequence_dim: int = 1,
372
+ use_local_output: bool = False,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.sequence_sharding = (Shard(sequence_dim),)
377
+ self.use_local_output = use_local_output
378
+
379
+ @staticmethod
380
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
381
+ x, target, weight, bias = inputs
382
+
383
+ if not isinstance(x, DTensor):
384
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
385
+ x = DTensor.from_local(x, device_mesh, sequence_sharding)
386
+ if x.placements != sequence_sharding:
387
+ x = x.redistribute(placements=sequence_sharding, async_op=True)
388
+ if not isinstance(target, DTensor):
389
+ target = DTensor.from_local(target, device_mesh, [Replicate()])
390
+ if target.placements != sequence_sharding:
391
+ target = target.redistribute(placements=sequence_sharding, async_op=True)
392
+
393
+ if not isinstance(weight, DTensor):
394
+ weight = DTensor.from_local(weight, device_mesh, [Replicate()])
395
+ if weight.placements != [Replicate()]:
396
+ # we replicate the weight/bias in FLCE
397
+ weight = weight.redistribute(placements=[Replicate()], async_op=True)
398
+
399
+ if bias is not None and not isinstance(bias, DTensor):
400
+ bias = DTensor.from_local(bias, device_mesh, [Replicate()])
401
+ if bias is not None and bias.placements != [Replicate()]:
402
+ bias = bias.redistribute(placements=[Replicate()], async_op=True)
403
+
404
+ return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias
405
+
406
+ @staticmethod
407
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
408
+ return outputs.to_local() if use_local_output else outputs
409
+
410
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
411
+ return distribute_module(
412
+ module,
413
+ device_mesh,
414
+ partition_fn=None,
415
+ input_fn=partial(self._prepare_input_fn, self.sequence_sharding),
416
+ output_fn=partial(self._prepare_output_fn, self.use_local_output)
417
+ )
418
+
419
+ # Naive ListNet loss function implementation
420
+ def list_net_loss(y_pred, y_true):
421
+ """
422
+ ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
423
+ :param y_pred: predictions from the model, shape [*, slate_length]
424
+ :param y_true: ground truth labels, shape [*, slate_length]
425
+ :return: loss value, a torch.Tensor
426
+ """
427
+ return torch.mean(-torch.sum(F.softmax(y_true, dim=-1).nan_to_num(nan=0) * F.log_softmax(y_pred, dim=-1), dim=-1))
fla/modules/fused_norm_gate.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from fla.utils import get_multiprocessor_count, input_guard
16
+
17
+
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [1, 2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['N', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM', 'HAS_BIAS'],
25
+ )
26
+ @triton.jit
27
+ def layer_norm_gated_fwd_kernel(
28
+ X, # pointer to the input
29
+ G, # pointer to the gate
30
+ Y, # pointer to the output
31
+ W, # pointer to the weights
32
+ B, # pointer to the biases
33
+ RESIDUAL, # pointer to the residual
34
+ RESIDUAL_OUT, # pointer to the residual
35
+ Mean, # pointer to the mean
36
+ Rstd, # pointer to the 1/std
37
+ N, # number of columns in X
38
+ eps, # epsilon to avoid division by zero
39
+ ACTIVATION: tl.constexpr,
40
+ IS_RMS_NORM: tl.constexpr,
41
+ BLOCK_N: tl.constexpr,
42
+ HAS_RESIDUAL: tl.constexpr,
43
+ STORE_RESIDUAL_OUT: tl.constexpr,
44
+ HAS_WEIGHT: tl.constexpr,
45
+ HAS_BIAS: tl.constexpr
46
+ ):
47
+ # Map the program id to the row of X and Y it should compute.
48
+ row = tl.program_id(0)
49
+ X += row * N
50
+ Y += row * N
51
+ G += row * N
52
+ if HAS_RESIDUAL:
53
+ RESIDUAL += row * N
54
+ if STORE_RESIDUAL_OUT:
55
+ RESIDUAL_OUT += row * N
56
+ # Compute mean and variance
57
+ cols = tl.arange(0, BLOCK_N)
58
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
59
+ if HAS_RESIDUAL:
60
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
61
+ x += residual
62
+ if STORE_RESIDUAL_OUT:
63
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
64
+ if not IS_RMS_NORM:
65
+ mean = tl.sum(x, axis=0) / N
66
+ tl.store(Mean + row, mean)
67
+ xbar = tl.where(cols < N, x - mean, 0.0)
68
+ var = tl.sum(xbar * xbar, axis=0) / N
69
+ else:
70
+ xbar = tl.where(cols < N, x, 0.0)
71
+ var = tl.sum(xbar * xbar, axis=0) / N
72
+ rstd = 1 / tl.sqrt(var + eps)
73
+ tl.store(Rstd + row, rstd)
74
+ # Normalize and apply linear transformation
75
+ mask = cols < N
76
+ if HAS_WEIGHT:
77
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
78
+ if HAS_BIAS:
79
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
80
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
81
+ y = x_hat * w if HAS_WEIGHT else x_hat
82
+ if HAS_BIAS:
83
+ y = y + b
84
+
85
+ # Swish output gate
86
+ g = tl.load(G + cols, mask=cols < N, other=0.0).to(tl.float32)
87
+ if ACTIVATION == 'swish':
88
+ y = y * g * tl.sigmoid(g)
89
+ elif ACTIVATION == 'silu':
90
+ y = y * g * tl.sigmoid(g)
91
+ elif ACTIVATION == 'sigmoid':
92
+ y = y * tl.sigmoid(g)
93
+
94
+ # Write output
95
+ tl.store(Y + cols, y, mask=mask)
96
+
97
+
98
+ def layer_norm_gated_fwd(
99
+ x: torch.Tensor,
100
+ g: torch.Tensor,
101
+ weight: torch.Tensor,
102
+ bias: torch.Tensor,
103
+ activation: str = 'swish',
104
+ eps: float = 1e-5,
105
+ residual: torch.Tensor = None,
106
+ out_dtype: torch.dtype = None,
107
+ residual_dtype: torch.dtype = None,
108
+ is_rms_norm: bool = False
109
+ ):
110
+ if residual is not None:
111
+ residual_dtype = residual.dtype
112
+ M, N = x.shape
113
+ if residual is not None:
114
+ assert residual.shape == (M, N)
115
+ if weight is not None:
116
+ assert weight.shape == (N,)
117
+ if bias is not None:
118
+ assert bias.shape == (N,)
119
+ # allocate output
120
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
121
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
122
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
123
+ else:
124
+ residual_out = None
125
+ mean = torch.empty((M,), dtype=torch.float, device=x.device) if not is_rms_norm else None
126
+ rstd = torch.empty((M,), dtype=torch.float, device=x.device)
127
+ # Less than 64KB per feature: enqueue fused kernel
128
+ MAX_FUSED_SIZE = 65536 // x.element_size()
129
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
130
+ if N > BLOCK_N:
131
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
132
+ # heuristics for number of warps
133
+
134
+ layer_norm_gated_fwd_kernel[(M,)](
135
+ x,
136
+ g,
137
+ y,
138
+ weight,
139
+ bias,
140
+ residual,
141
+ residual_out,
142
+ mean,
143
+ rstd,
144
+ N,
145
+ eps,
146
+ ACTIVATION=activation,
147
+ IS_RMS_NORM=is_rms_norm,
148
+ BLOCK_N=BLOCK_N,
149
+ HAS_RESIDUAL=residual is not None,
150
+ STORE_RESIDUAL_OUT=residual_out is not None,
151
+ HAS_WEIGHT=weight is not None,
152
+ HAS_BIAS=bias is not None,
153
+ )
154
+ # residual_out is None if residual is None and residual_dtype == input_dtype
155
+ return y, mean, rstd, residual_out if residual_out is not None else x
156
+
157
+
158
+ @triton.heuristics({
159
+ 'RECOMPUTE_OUTPUT': lambda args: args["Y"] is not None
160
+ })
161
+ @triton.autotune(
162
+ configs=[
163
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
164
+ for num_warps in [1, 2, 4, 8, 16, 32]
165
+ for num_stages in [2, 3, 4]
166
+ ],
167
+ key=['N', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM', 'HAS_BIAS'],
168
+ )
169
+ @triton.jit
170
+ def layer_norm_gated_bwd_kernel(
171
+ X, # pointer to the input
172
+ G, # pointer to the gate
173
+ W, # pointer to the weights
174
+ B, # pointer to the biases
175
+ Y, # pointer to the output to be recomputed
176
+ DY, # pointer to the output gradient
177
+ DX, # pointer to the input gradient
178
+ DG, # pointer to the gate gradient
179
+ DW, # pointer to the partial sum of weights gradient
180
+ DB, # pointer to the partial sum of biases gradient
181
+ DRESIDUAL,
182
+ DRESIDUAL_IN,
183
+ Mean, # pointer to the mean
184
+ Rstd, # pointer to the 1/std
185
+ M, # number of rows in X
186
+ N, # number of columns in X
187
+ eps, # epsilon to avoid division by zero
188
+ rows_per_program,
189
+ ACTIVATION: tl.constexpr,
190
+ IS_RMS_NORM: tl.constexpr,
191
+ BLOCK_N: tl.constexpr,
192
+ HAS_DRESIDUAL: tl.constexpr,
193
+ STORE_DRESIDUAL: tl.constexpr,
194
+ HAS_WEIGHT: tl.constexpr,
195
+ HAS_BIAS: tl.constexpr,
196
+ RECOMPUTE_OUTPUT: tl.constexpr,
197
+ ):
198
+ # Map the program id to the elements of X, DX, and DY it should compute.
199
+ row_block_id = tl.program_id(0)
200
+ row_start = row_block_id * rows_per_program
201
+ cols = tl.arange(0, BLOCK_N)
202
+ mask = cols < N
203
+ X += row_start * N
204
+ G += row_start * N
205
+ if HAS_DRESIDUAL:
206
+ DRESIDUAL += row_start * N
207
+ if STORE_DRESIDUAL:
208
+ DRESIDUAL_IN += row_start * N
209
+ DY += row_start * N
210
+ DX += row_start * N
211
+ DG += row_start * N
212
+ if RECOMPUTE_OUTPUT:
213
+ Y += row_start * N
214
+ if HAS_WEIGHT:
215
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
216
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
217
+ if HAS_BIAS:
218
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
219
+ if HAS_BIAS:
220
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
221
+
222
+ row_end = min((row_block_id + 1) * rows_per_program, M)
223
+ for row in range(row_start, row_end):
224
+ # Load data to SRAM
225
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
226
+ g = tl.load(G + cols, mask=mask, other=0).to(tl.float32)
227
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
228
+
229
+ if not IS_RMS_NORM:
230
+ mean = tl.load(Mean + row)
231
+ rstd = tl.load(Rstd + row)
232
+ # Compute dx
233
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
234
+ xhat = tl.where(mask, xhat, 0.0)
235
+
236
+ y = xhat * w if HAS_WEIGHT else xhat
237
+ if HAS_BIAS:
238
+ y = y + b
239
+ if RECOMPUTE_OUTPUT:
240
+ tl.store(Y + cols, y, mask=mask)
241
+
242
+ sigmoid_g = tl.sigmoid(g)
243
+ if ACTIVATION == 'swish':
244
+ dg = dy * y * (sigmoid_g + g * sigmoid_g * (1 - sigmoid_g))
245
+ dy = dy * g * sigmoid_g
246
+ elif ACTIVATION == 'silu':
247
+ dg = dy * y * (sigmoid_g + g * sigmoid_g * (1 - sigmoid_g))
248
+ dy = dy * g * sigmoid_g
249
+ elif ACTIVATION == 'sigmoid':
250
+ dg = dy * y * sigmoid_g * (1 - sigmoid_g)
251
+ dy = dy * sigmoid_g
252
+ wdy = dy
253
+ if HAS_WEIGHT:
254
+ wdy = dy * w
255
+ dw += dy * xhat
256
+ if HAS_BIAS:
257
+ db += dy
258
+ if not IS_RMS_NORM:
259
+ c1 = tl.sum(xhat * wdy, axis=0) / N
260
+ c2 = tl.sum(wdy, axis=0) / N
261
+ dx = (wdy - (xhat * c1 + c2)) * rstd
262
+ else:
263
+ c1 = tl.sum(xhat * wdy, axis=0) / N
264
+ dx = (wdy - xhat * c1) * rstd
265
+ if HAS_DRESIDUAL:
266
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
267
+ dx += dres
268
+ # Write dx
269
+ if STORE_DRESIDUAL:
270
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
271
+ tl.store(DX + cols, dx, mask=mask)
272
+ tl.store(DG + cols, dg, mask=mask)
273
+
274
+ X += N
275
+ G += N
276
+ if HAS_DRESIDUAL:
277
+ DRESIDUAL += N
278
+ if STORE_DRESIDUAL:
279
+ DRESIDUAL_IN += N
280
+ if RECOMPUTE_OUTPUT:
281
+ Y += N
282
+ DY += N
283
+ DX += N
284
+ DG += N
285
+ if HAS_WEIGHT:
286
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
287
+ if HAS_BIAS:
288
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
289
+
290
+
291
+ def layer_norm_gated_bwd(
292
+ dy: torch.Tensor,
293
+ x: torch.Tensor,
294
+ g: torch.Tensor,
295
+ weight: torch.Tensor,
296
+ bias: torch.Tensor,
297
+ activation: str = 'swish',
298
+ eps: float = 1e-5,
299
+ mean: torch.Tensor = None,
300
+ rstd: torch.Tensor = None,
301
+ dresidual: torch.Tensor = None,
302
+ has_residual: bool = False,
303
+ is_rms_norm: bool = False,
304
+ x_dtype: torch.dtype = None,
305
+ recompute_output: bool = False,
306
+ ):
307
+ M, N = x.shape
308
+ assert dy.shape == (M, N)
309
+ if dresidual is not None:
310
+ assert dresidual.shape == (M, N)
311
+ if weight is not None:
312
+ assert weight.shape == (N,)
313
+ if bias is not None:
314
+ assert bias.shape == (N,)
315
+ # allocate output
316
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
317
+ dg = torch.empty_like(g) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
318
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
319
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
320
+
321
+ # Less than 64KB per feature: enqueue fused kernel
322
+ MAX_FUSED_SIZE = 65536 // x.element_size()
323
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
324
+ if N > BLOCK_N:
325
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
326
+ sm_count = get_multiprocessor_count(x.device.index)
327
+ dw = torch.empty((sm_count, N), dtype=torch.float, device=weight.device) if weight is not None else None
328
+ db = torch.empty((sm_count, N), dtype=torch.float, device=bias.device) if bias is not None else None
329
+ rows_per_program = math.ceil(M / sm_count)
330
+ grid = (sm_count,)
331
+ layer_norm_gated_bwd_kernel[grid](
332
+ x,
333
+ g,
334
+ weight,
335
+ bias,
336
+ y,
337
+ dy,
338
+ dx,
339
+ dg,
340
+ dw,
341
+ db,
342
+ dresidual,
343
+ dresidual_in,
344
+ mean,
345
+ rstd,
346
+ M,
347
+ N,
348
+ eps,
349
+ rows_per_program,
350
+ ACTIVATION=activation,
351
+ IS_RMS_NORM=is_rms_norm,
352
+ BLOCK_N=BLOCK_N,
353
+ HAS_DRESIDUAL=dresidual is not None,
354
+ STORE_DRESIDUAL=dresidual_in is not None,
355
+ HAS_WEIGHT=weight is not None,
356
+ HAS_BIAS=bias is not None,
357
+ )
358
+ dw = dw.sum(0).to(weight.dtype) if weight is not None else None
359
+ db = db.sum(0).to(bias.dtype) if bias is not None else None
360
+ # Don't need to compute dresidual_in separately in this case
361
+ if has_residual and dx.dtype == x.dtype:
362
+ dresidual_in = dx
363
+ return (dx, dg, dw, db, dresidual_in) if not recompute_output else (dx, dg, dw, db, dresidual_in, y)
364
+
365
+
366
+ class LayerNormGatedFunction(torch.autograd.Function):
367
+
368
+ @staticmethod
369
+ @input_guard
370
+ def forward(
371
+ ctx,
372
+ x: torch.Tensor,
373
+ g: torch.Tensor,
374
+ weight: torch.Tensor,
375
+ bias: torch.Tensor,
376
+ activation: str,
377
+ residual: Optional[torch.Tensor] = None,
378
+ eps: float = 1e-6,
379
+ prenorm: bool = False,
380
+ residual_in_fp32: bool = False,
381
+ is_rms_norm: bool = False,
382
+ ):
383
+ x_shape_og = x.shape
384
+ g_shape_og = g.shape
385
+ # reshape input data into 2D tensor
386
+ x = x.reshape(-1, x.shape[-1])
387
+ g = g.reshape(-1, g.shape[-1])
388
+ if residual is not None:
389
+ assert residual.shape == x_shape_og
390
+ residual = residual.reshape(-1, residual.shape[-1])
391
+ residual_dtype = (
392
+ residual.dtype
393
+ if residual is not None
394
+ else (torch.float if residual_in_fp32 else None)
395
+ )
396
+ y, mean, rstd, residual_out = layer_norm_gated_fwd(
397
+ x=x,
398
+ g=g,
399
+ weight=weight,
400
+ bias=bias,
401
+ activation=activation,
402
+ eps=eps,
403
+ residual=residual,
404
+ residual_dtype=residual_dtype,
405
+ is_rms_norm=is_rms_norm
406
+ )
407
+ ctx.save_for_backward(residual_out, g, weight, bias, mean, rstd)
408
+ ctx.x_shape_og = x_shape_og
409
+ ctx.g_shape_og = g_shape_og
410
+ ctx.activation = activation
411
+ ctx.eps = eps
412
+ ctx.is_rms_norm = is_rms_norm
413
+ ctx.has_residual = residual is not None
414
+ ctx.prenorm = prenorm
415
+ ctx.x_dtype = x.dtype
416
+ y = y.reshape(x_shape_og)
417
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
418
+
419
+ @staticmethod
420
+ @input_guard
421
+ def backward(ctx, dy, *args):
422
+ x, g, weight, bias, mean, rstd = ctx.saved_tensors
423
+ dy = dy.reshape(-1, dy.shape[-1])
424
+ assert dy.shape == x.shape
425
+ if ctx.prenorm:
426
+ dresidual = args[0]
427
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
428
+ assert dresidual.shape == x.shape
429
+ else:
430
+ dresidual = None
431
+ dx, dg, dw, db, dresidual_in = layer_norm_gated_bwd(
432
+ dy=dy,
433
+ x=x,
434
+ g=g,
435
+ weight=weight,
436
+ bias=bias,
437
+ activation=ctx.activation,
438
+ eps=ctx.eps,
439
+ mean=mean,
440
+ rstd=rstd,
441
+ dresidual=dresidual,
442
+ has_residual=ctx.has_residual,
443
+ is_rms_norm=ctx.is_rms_norm,
444
+ x_dtype=ctx.x_dtype,
445
+ )
446
+ return (
447
+ dx.reshape(ctx.x_shape_og),
448
+ dg.reshape(ctx.g_shape_og),
449
+ dw,
450
+ db,
451
+ None,
452
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
453
+ None,
454
+ None,
455
+ None,
456
+ None,
457
+ )
458
+
459
+
460
+ class LayerNormGatedLinearFunction(torch.autograd.Function):
461
+
462
+ @staticmethod
463
+ @input_guard
464
+ def forward(
465
+ ctx,
466
+ x: torch.Tensor,
467
+ g: torch.Tensor,
468
+ norm_weight: torch.Tensor,
469
+ norm_bias: torch.Tensor,
470
+ linear_weight: torch.Tensor,
471
+ linear_bias: torch.Tensor,
472
+ residual: Optional[torch.Tensor] = None,
473
+ eps: float = 1e-6,
474
+ prenorm: bool = False,
475
+ residual_in_fp32: bool = False,
476
+ is_rms_norm: bool = False,
477
+ ):
478
+ x_shape_og = x.shape
479
+ g_shape_og = g.shape
480
+ # reshape input data into 2D tensor
481
+ x = x.reshape(-1, x.shape[-1])
482
+ g = g.reshape(-1, g.shape[-1])
483
+ if residual is not None:
484
+ assert residual.shape == x_shape_og
485
+ residual = residual.reshape(-1, residual.shape[-1])
486
+ residual_dtype = (
487
+ residual.dtype
488
+ if residual is not None
489
+ else (torch.float if residual_in_fp32 else None)
490
+ )
491
+ y, mean, rstd, residual_out = layer_norm_gated_fwd(
492
+ x=x,
493
+ g=g,
494
+ weight=norm_weight,
495
+ bias=norm_bias,
496
+ eps=eps,
497
+ residual=residual,
498
+ residual_dtype=residual_dtype,
499
+ is_rms_norm=is_rms_norm
500
+ )
501
+ y = y.reshape(x_shape_og)
502
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
503
+ linear_weight = linear_weight.to(dtype)
504
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
505
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
506
+ # We don't store y, will be recomputed in the backward pass to save memory
507
+ ctx.save_for_backward(residual_out, g, norm_weight, norm_bias, linear_weight, mean, rstd)
508
+ ctx.x_shape_og = x_shape_og
509
+ ctx.g_shape_og = g_shape_og
510
+ ctx.eps = eps
511
+ ctx.is_rms_norm = is_rms_norm
512
+ ctx.has_residual = residual is not None
513
+ ctx.prenorm = prenorm
514
+ ctx.x_dtype = x.dtype
515
+ ctx.linear_bias_is_none = linear_bias is None
516
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
517
+
518
+ @staticmethod
519
+ @input_guard
520
+ def backward(ctx, dout, *args):
521
+ x, g, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
522
+ dout = dout.reshape(-1, dout.shape[-1])
523
+ dy = F.linear(dout, linear_weight.t())
524
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
525
+ assert dy.shape == x.shape
526
+ if ctx.prenorm:
527
+ dresidual = args[0]
528
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
529
+ assert dresidual.shape == x.shape
530
+ else:
531
+ dresidual = None
532
+ dx, dg, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_gated_bwd(
533
+ dy=dy,
534
+ x=x,
535
+ g=g,
536
+ norm_weight=norm_weight,
537
+ norm_bias=norm_bias,
538
+ eps=ctx.eps,
539
+ mean=mean,
540
+ rstd=rstd,
541
+ dresidual=dresidual,
542
+ has_residual=ctx.has_residual,
543
+ is_rms_norm=ctx.is_rms_norm,
544
+ x_dtype=ctx.x_dtype,
545
+ recompute_output=True,
546
+ )
547
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
548
+ return (
549
+ dx.reshape(ctx.x_shape_og),
550
+ dg.reshape(ctx.g_shape_og),
551
+ dnorm_weight,
552
+ dnorm_bias,
553
+ dlinear_weight,
554
+ dlinear_bias,
555
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
556
+ None,
557
+ None,
558
+ None,
559
+ None,
560
+ )
561
+
562
+
563
+ def layer_norm_gated(
564
+ x: torch.Tensor,
565
+ g: torch.Tensor,
566
+ weight: torch.Tensor,
567
+ bias: torch.Tensor,
568
+ activation: str = 'swish',
569
+ residual: Optional[torch.Tensor] = None,
570
+ prenorm: bool = False,
571
+ residual_in_fp32: bool = False,
572
+ eps: float = 1e-6
573
+ ):
574
+ return LayerNormGatedFunction.apply(
575
+ x,
576
+ g,
577
+ weight,
578
+ bias,
579
+ activation,
580
+ residual,
581
+ eps,
582
+ prenorm,
583
+ residual_in_fp32,
584
+ False
585
+ )
586
+
587
+
588
+ def rms_norm_gated(
589
+ x: torch.Tensor,
590
+ g: torch.Tensor,
591
+ weight: torch.Tensor,
592
+ bias: torch.Tensor,
593
+ activation: str = 'swish',
594
+ residual: Optional[torch.Tensor] = None,
595
+ prenorm: bool = False,
596
+ residual_in_fp32: bool = False,
597
+ eps: float = 1e-6
598
+ ):
599
+ return LayerNormGatedFunction.apply(
600
+ x,
601
+ g,
602
+ weight,
603
+ bias,
604
+ activation,
605
+ residual,
606
+ eps,
607
+ prenorm,
608
+ residual_in_fp32,
609
+ True
610
+ )
611
+
612
+
613
+ def layer_norm_swish_gate_linear(
614
+ x: torch.Tensor,
615
+ g: torch.Tensor,
616
+ norm_weight: torch.Tensor,
617
+ norm_bias: torch.Tensor,
618
+ linear_weight: torch.Tensor,
619
+ linear_bias: torch.Tensor,
620
+ residual: Optional[torch.Tensor] = None,
621
+ prenorm: bool = False,
622
+ residual_in_fp32: bool = False,
623
+ eps: float = 1e-6
624
+ ):
625
+ return LayerNormGatedLinearFunction.apply(
626
+ x,
627
+ g,
628
+ norm_weight,
629
+ norm_bias,
630
+ linear_weight,
631
+ linear_bias,
632
+ residual,
633
+ eps,
634
+ prenorm,
635
+ residual_in_fp32,
636
+ False
637
+ )
638
+
639
+
640
+ def rms_norm_swish_gate_linear(
641
+ x,
642
+ g: torch.Tensor,
643
+ norm_weight: torch.Tensor,
644
+ norm_bias: torch.Tensor,
645
+ linear_weight: torch.Tensor,
646
+ linear_bias: torch.Tensor,
647
+ residual: Optional[torch.Tensor] = None,
648
+ prenorm: bool = False,
649
+ residual_in_fp32: bool = False,
650
+ eps: float = 1e-6
651
+ ):
652
+ return LayerNormGatedLinearFunction.apply(
653
+ x,
654
+ g,
655
+ norm_weight,
656
+ norm_bias,
657
+ linear_weight,
658
+ linear_bias,
659
+ residual,
660
+ eps,
661
+ prenorm,
662
+ residual_in_fp32,
663
+ True
664
+ )
665
+
666
+
667
+ class FusedLayerNormGated(nn.Module):
668
+
669
+ def __init__(
670
+ self,
671
+ hidden_size: int,
672
+ elementwise_affine: bool = True,
673
+ bias: bool = False,
674
+ activation: str = 'swish',
675
+ eps: float = 1e-5,
676
+ device: Optional[torch.device] = None,
677
+ dtype: Optional[torch.dtype] = None,
678
+ ) -> FusedLayerNormGated:
679
+ factory_kwargs = {"device": device, "dtype": dtype}
680
+ super().__init__()
681
+
682
+ self.hidden_size = hidden_size
683
+ self.elementwise_affine = elementwise_affine
684
+ self.eps = eps
685
+ self.activation = activation
686
+
687
+ if self.activation not in ['swish', 'silu', 'sigmoid']:
688
+ raise ValueError(f"Unsupported activation: {self.activation}")
689
+
690
+ self.register_parameter("weight", None)
691
+ self.register_parameter("bias", None)
692
+ if elementwise_affine:
693
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
694
+ if bias:
695
+ self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
696
+
697
+ self.reset_parameters()
698
+
699
+ def reset_parameters(self):
700
+ if self.elementwise_affine:
701
+ nn.init.ones_(self.weight)
702
+ if self.bias is not None:
703
+ nn.init.zeros_(self.bias)
704
+
705
+ def __repr__(self) -> str:
706
+ s = f"{self.__class__.__name__}({self.hidden_size}"
707
+ if not self.elementwise_affine:
708
+ s += f", elementwise_affine={self.elementwise_affine}"
709
+ s += f", eps={self.eps}"
710
+ s += f", activation={self.activation}"
711
+ s += ")"
712
+ return s
713
+
714
+ def forward(
715
+ self,
716
+ x: torch.Tensor,
717
+ g: torch.Tensor,
718
+ residual: Optional[torch.Tensor] = None,
719
+ prenorm: bool = False,
720
+ residual_in_fp32: bool = False
721
+ ) -> torch.Tensor:
722
+ return layer_norm_gated(
723
+ x,
724
+ g,
725
+ self.weight,
726
+ self.bias,
727
+ self.activation,
728
+ residual=residual,
729
+ eps=self.eps,
730
+ prenorm=prenorm,
731
+ residual_in_fp32=residual_in_fp32
732
+ )
733
+
734
+
735
+ class FusedRMSNormGated(nn.Module):
736
+
737
+ def __init__(
738
+ self,
739
+ hidden_size: int,
740
+ elementwise_affine: bool = True,
741
+ eps: float = 1e-5,
742
+ activation: str = 'swish',
743
+ device: Optional[torch.device] = None,
744
+ dtype: Optional[torch.dtype] = None,
745
+ ) -> FusedRMSNormGated:
746
+ factory_kwargs = {"device": device, "dtype": dtype}
747
+ super().__init__()
748
+
749
+ self.hidden_size = hidden_size
750
+ self.elementwise_affine = elementwise_affine
751
+ self.eps = eps
752
+ self.activation = activation
753
+
754
+ if self.activation not in ['swish', 'silu', 'sigmoid']:
755
+ raise ValueError(f"Unsupported activation: {self.activation}")
756
+
757
+ if elementwise_affine:
758
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
759
+ else:
760
+ self.register_parameter("weight", None)
761
+ self.register_parameter("bias", None)
762
+
763
+ self.reset_parameters()
764
+
765
+ def reset_parameters(self):
766
+ if self.elementwise_affine:
767
+ nn.init.ones_(self.weight)
768
+
769
+ def __repr__(self) -> str:
770
+ s = f"{self.__class__.__name__}({self.hidden_size}"
771
+ if not self.elementwise_affine:
772
+ s += f", elementwise_affine={self.elementwise_affine}"
773
+ s += f", eps={self.eps}"
774
+ s += f", activation={self.activation}"
775
+ s += ")"
776
+ return s
777
+
778
+ def forward(
779
+ self,
780
+ x: torch.Tensor,
781
+ g: torch.Tensor,
782
+ residual: Optional[torch.Tensor] = None,
783
+ prenorm: bool = False,
784
+ residual_in_fp32: bool = False
785
+ ) -> torch.Tensor:
786
+ return rms_norm_gated(
787
+ x,
788
+ g,
789
+ self.weight,
790
+ self.bias,
791
+ self.activation,
792
+ residual=residual,
793
+ eps=self.eps,
794
+ prenorm=prenorm,
795
+ residual_in_fp32=residual_in_fp32
796
+ )
797
+
798
+
799
+ class FusedLayerNormSwishGate(FusedLayerNormGated):
800
+
801
+ def __init__(
802
+ self,
803
+ hidden_size: int,
804
+ elementwise_affine: bool = True,
805
+ bias: bool = False,
806
+ eps: float = 1e-5,
807
+ device: Optional[torch.device] = None,
808
+ dtype: Optional[torch.dtype] = None,
809
+ ) -> FusedLayerNormSwishGate:
810
+ super().__init__(
811
+ hidden_size=hidden_size,
812
+ elementwise_affine=elementwise_affine,
813
+ bias=bias,
814
+ eps=eps,
815
+ device=device,
816
+ dtype=dtype
817
+ )
818
+
819
+
820
+ class FusedRMSNormSwishGate(FusedRMSNormGated):
821
+
822
+ def __init__(
823
+ self,
824
+ hidden_size: int,
825
+ elementwise_affine: bool = True,
826
+ eps: float = 1e-5,
827
+ device: Optional[torch.device] = None,
828
+ dtype: Optional[torch.dtype] = None,
829
+ ) -> FusedRMSNormSwishGate:
830
+ super().__init__(
831
+ hidden_size=hidden_size,
832
+ elementwise_affine=elementwise_affine,
833
+ eps=eps,
834
+ device=device,
835
+ dtype=dtype
836
+ )
837
+
838
+
839
+ class FusedLayerNormGatedLinear(nn.Module):
840
+
841
+ def __init__(
842
+ self,
843
+ hidden_size: int,
844
+ elementwise_affine: bool = True,
845
+ eps: float = 1e-5,
846
+ device: Optional[torch.device] = None,
847
+ dtype: Optional[torch.dtype] = None,
848
+ ) -> FusedLayerNormGatedLinear:
849
+ factory_kwargs = {"device": device, "dtype": dtype}
850
+ super().__init__()
851
+
852
+ self.hidden_size = hidden_size
853
+ self.elementwise_affine = elementwise_affine
854
+ self.eps = eps
855
+
856
+ if elementwise_affine:
857
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
858
+ else:
859
+ self.register_parameter("weight", None)
860
+ self.register_parameter("bias", None)
861
+
862
+ self.reset_parameters()
863
+
864
+ def reset_parameters(self):
865
+ if self.elementwise_affine:
866
+ nn.init.ones_(self.weight)
867
+
868
+ def __repr__(self) -> str:
869
+ s = f"{self.__class__.__name__}({self.hidden_size}"
870
+ if not self.elementwise_affine:
871
+ s += f", elementwise_affine={self.elementwise_affine}"
872
+ s += f", eps={self.eps}"
873
+ s += ")"
874
+ return s
875
+
876
+ def forward(
877
+ self,
878
+ x: torch.Tensor,
879
+ g: torch.Tensor,
880
+ weight: Optional[torch.Tensor] = None,
881
+ bias: Optional[torch.Tensor] = None,
882
+ residual: Optional[torch.Tensor] = None,
883
+ prenorm: bool = False,
884
+ residual_in_fp32: bool = False
885
+ ) -> torch.Tensor:
886
+ return layer_norm_swish_gate_linear(
887
+ x,
888
+ g,
889
+ self.weight,
890
+ self.bias,
891
+ weight,
892
+ bias,
893
+ residual=residual,
894
+ eps=self.eps,
895
+ prenorm=prenorm,
896
+ residual_in_fp32=residual_in_fp32
897
+ )
898
+
899
+
900
+ class FusedLayerNormSwishGateLinear(FusedLayerNormGatedLinear):
901
+
902
+ def __init__(
903
+ self,
904
+ hidden_size: int,
905
+ elementwise_affine: bool = True,
906
+ eps: float = 1e-5,
907
+ device: Optional[torch.device] = None,
908
+ dtype: Optional[torch.dtype] = None,
909
+ ) -> FusedLayerNormSwishGateLinear:
910
+ super().__init__(
911
+ hidden_size=hidden_size,
912
+ elementwise_affine=elementwise_affine,
913
+ eps=eps,
914
+ device=device,
915
+ dtype=dtype
916
+ )
917
+
918
+
919
+ class FusedRMSNormGatedLinear(nn.Module):
920
+
921
+ def __init__(
922
+ self,
923
+ hidden_size,
924
+ elementwise_affine: bool = True,
925
+ eps: float = 1e-5,
926
+ device: Optional[torch.device] = None,
927
+ dtype: Optional[torch.dtype] = None,
928
+ ) -> FusedRMSNormGatedLinear:
929
+ factory_kwargs = {"device": device, "dtype": dtype}
930
+ super().__init__()
931
+
932
+ self.hidden_size = hidden_size
933
+ self.elementwise_affine = elementwise_affine
934
+ self.eps = eps
935
+
936
+ self.register_parameter("weight", None)
937
+ self.register_parameter("bias", None)
938
+ if elementwise_affine:
939
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
940
+
941
+ self.reset_parameters()
942
+
943
+ def reset_parameters(self):
944
+ if self.elementwise_affine:
945
+ nn.init.ones_(self.weight)
946
+
947
+ def __repr__(self) -> str:
948
+ s = f"{self.__class__.__name__}({self.hidden_size}"
949
+ if not self.elementwise_affine:
950
+ s += f", elementwise_affine={self.elementwise_affine}"
951
+ s += f", eps={self.eps}"
952
+ s += ")"
953
+ return s
954
+
955
+ def forward(
956
+ self,
957
+ x: torch.Tensor,
958
+ g: torch.Tensor,
959
+ weight: Optional[torch.Tensor] = None,
960
+ bias: Optional[torch.Tensor] = None,
961
+ residual: Optional[torch.Tensor] = None,
962
+ prenorm: bool = False,
963
+ residual_in_fp32: bool = False
964
+ ) -> torch.Tensor:
965
+ return rms_norm_swish_gate_linear(
966
+ x,
967
+ g,
968
+ self.weight,
969
+ self.bias,
970
+ weight,
971
+ bias,
972
+ residual=residual,
973
+ eps=self.eps,
974
+ prenorm=prenorm,
975
+ residual_in_fp32=residual_in_fp32
976
+ )
977
+
978
+
979
+ class FusedRMSNormSwishGateLinear(FusedRMSNormGatedLinear):
980
+
981
+ def __init__(
982
+ self,
983
+ hidden_size: int,
984
+ elementwise_affine: bool = True,
985
+ eps: float = 1e-5,
986
+ device: Optional[torch.device] = None,
987
+ dtype: Optional[torch.dtype] = None,
988
+ ) -> FusedRMSNormSwishGateLinear:
989
+ super().__init__(
990
+ hidden_size=hidden_size,
991
+ elementwise_affine=elementwise_affine,
992
+ eps=eps,
993
+ device=device,
994
+ dtype=dtype
995
+ )
fla/modules/grpo.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
4
+ """
5
+ # Get the per-token log probabilities for the completions for the model and the reference model
6
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
7
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
8
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
9
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
10
+
11
+ input_ids = input_ids[:, -logits_to_keep:]
12
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
13
+ # See https://github.com/huggingface/trl/issues/2770
14
+ logits = logits[:, -logits_to_keep:]
15
+ return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
16
+
17
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
18
+ if return_outputs:
19
+ raise ValueError("The GRPOTrainer does not support returning outputs")
20
+ # Compute the per-token log probabilities for the model
21
+
22
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
23
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
24
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
25
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
26
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
27
+
28
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
29
+
30
+ # Compute the KL divergence between the model and the reference model
31
+ ref_per_token_logps = inputs["ref_per_token_logps"]
32
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
33
+
34
+ # x - x.detach() allows for preserving gradients from x
35
+ advantages = inputs["advantages"]
36
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
37
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
38
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
39
+
40
+ # Log the metrics
41
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
42
+ self._metrics["completion_length"].append(completion_length)
43
+
44
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
45
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
46
+
47
+ return loss
48
+ """
49
+
50
+
51
+ import torch
52
+ import triton
53
+ import triton.language as tl
54
+
55
+ from fla.ops.utils.op import exp, log
56
+ from fla.utils import input_guard
57
+
58
+
59
+ @triton.autotune(
60
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
61
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
62
+ for NUM_WARPS in [8, 16, 32]
63
+ for NUM_STAGES in [1, 2, 4]
64
+ ], key=['B', 'N']
65
+ )
66
+ @triton.jit
67
+ def grpo_fwd_kernel(
68
+ logits_ptr,
69
+ ref_logp_ptr,
70
+ input_ids_ptr,
71
+ advantages_ptr,
72
+ completion_mask_ptr,
73
+ loss_ptr,
74
+ lse_ptr,
75
+ beta,
76
+ save_kl: tl.constexpr,
77
+ B,
78
+ M,
79
+ N,
80
+ L,
81
+ start_idx,
82
+ BLOCK_SIZE: tl.constexpr
83
+ ):
84
+ row_idx = tl.program_id(0)
85
+
86
+ off_b = row_idx // L
87
+ N = tl.cast(N, tl.int64)
88
+
89
+ loss_ptr += row_idx
90
+
91
+ completion_mask_ptr += row_idx
92
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
93
+ if not_skip == 1:
94
+ ref_logp_ptr += row_idx
95
+ lse_ptr += row_idx
96
+ advantages_ptr += off_b
97
+ logits_ptr += N * (row_idx + off_b)
98
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
99
+ base_cols = tl.arange(0, BLOCK_SIZE)
100
+
101
+ m_i = -float("inf")
102
+ l_i = 0.0
103
+ for start_n in tl.range(0, N, BLOCK_SIZE):
104
+ cols = start_n + base_cols
105
+ mask = cols < N
106
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
107
+ m_ij = tl.max(logits)
108
+ new_m_i = tl.maximum(m_i, m_ij)
109
+ l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i))
110
+ m_i = new_m_i
111
+ lse = log(l_i) + m_i
112
+
113
+ idx = tl.load(input_ids_ptr)
114
+ x = tl.load(logits_ptr+idx).to(tl.float32)
115
+ advantage = tl.load(advantages_ptr).to(tl.float32)
116
+ ref_logp = tl.load(ref_logp_ptr)
117
+ logp = x - lse
118
+ diff = ref_logp - logp
119
+ kl = exp(diff) - diff - 1
120
+ loss = kl * beta - advantage
121
+
122
+ tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty))
123
+ tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty))
124
+ if save_kl:
125
+ tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty))
126
+ else:
127
+ # store 0
128
+ tl.store(loss_ptr, 0.0)
129
+ if save_kl:
130
+ tl.store(loss_ptr+M, 0.0)
131
+
132
+
133
+ @triton.autotune(
134
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
135
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
136
+ for NUM_WARPS in [8, 16, 32]
137
+ for NUM_STAGES in [1, 2, 4]
138
+ ], key=['B', 'N']
139
+ )
140
+ @triton.jit
141
+ def grpo_bwd_kernel(
142
+ dloss_ptr,
143
+ dlogits_ptr,
144
+ logits_ptr,
145
+ ref_logp_ptr,
146
+ input_ids_ptr,
147
+ advantages_ptr,
148
+ completion_mask_ptr,
149
+ lse_ptr,
150
+ beta,
151
+ B,
152
+ N,
153
+ L,
154
+ start_idx,
155
+ BLOCK_SIZE: tl.constexpr
156
+ ):
157
+
158
+ row_idx = tl.program_id(0) # B*L
159
+ off_b = row_idx // L
160
+
161
+ N = tl.cast(N, tl.int64)
162
+
163
+ dlogits_ptr += N * (row_idx + off_b)
164
+ base_cols = tl.arange(0, BLOCK_SIZE)
165
+ completion_mask_ptr += row_idx
166
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
167
+
168
+ if not_skip == 1:
169
+ lse_ptr += row_idx
170
+ dloss_ptr += row_idx
171
+ advantages_ptr += off_b
172
+ ref_logp_ptr += row_idx
173
+ logits_ptr += N * (row_idx + off_b)
174
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
175
+ dloss = tl.load(dloss_ptr).to(tl.float32)
176
+ lse = tl.load(lse_ptr).to(tl.float32)
177
+ idx = tl.load(input_ids_ptr)
178
+ x = tl.load(logits_ptr+idx).to(tl.float32)
179
+ advantage = tl.load(advantages_ptr).to(tl.float32)
180
+ ref_logp = tl.load(ref_logp_ptr)
181
+ logp = x - lse
182
+
183
+ dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1)
184
+ - advantage) * dloss
185
+
186
+ for start_n in tl.range(0, N, BLOCK_SIZE):
187
+ cols = start_n + base_cols
188
+ mask = cols < N
189
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
190
+ probs = exp(logits - lse)
191
+ dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp
192
+
193
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
194
+ else:
195
+ dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
196
+ for start_n in tl.range(0, N, BLOCK_SIZE):
197
+ cols = start_n + base_cols
198
+ mask = cols < N
199
+
200
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
201
+
202
+
203
+ class GrpoLoss(torch.autograd.Function):
204
+
205
+ @input_guard
206
+ @staticmethod
207
+ def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl):
208
+ ctx.input_shape = logits.shape
209
+ B, L_ADD_1, N = ctx.input_shape
210
+ L = L_ADD_1 - 1
211
+ M = B * L
212
+ input_ids_start_index = input_ids.size(1) - L
213
+
214
+ if not save_kl:
215
+ loss = torch.empty(B, L, device=logits.device, dtype=torch.float32)
216
+ else:
217
+ loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32)
218
+
219
+ lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)
220
+
221
+ if completion_mask is None:
222
+ completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32)
223
+ else:
224
+ loss[:B].masked_fill_(completion_mask.logical_not(), 0.0)
225
+
226
+ grpo_fwd_kernel[(M,)](
227
+ logits_ptr=logits,
228
+ ref_logp_ptr=ref_logp,
229
+ input_ids_ptr=input_ids,
230
+ advantages_ptr=advantages,
231
+ completion_mask_ptr=completion_mask,
232
+ loss_ptr=loss,
233
+ lse_ptr=lse,
234
+ beta=beta,
235
+ save_kl=save_kl,
236
+ B=B, M=M, N=N, L=L,
237
+ start_idx=input_ids_start_index,
238
+ )
239
+ ctx.beta = beta
240
+ ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)
241
+ ctx.ref_logp = ref_logp
242
+ return loss
243
+
244
+ @input_guard
245
+ @staticmethod
246
+ def backward(ctx, dloss):
247
+ # The grad of logits comes from two parts, the reward part and the kl part
248
+ lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors
249
+ B, L_ADD_1, N = ctx.input_shape
250
+ L = L_ADD_1 - 1
251
+ M = B * L
252
+
253
+ input_ids_start_index = input_ids.size(1) - L
254
+
255
+ dlogits = torch.empty_like(logits) # B, L_ADD_1, N
256
+
257
+ grpo_bwd_kernel[(M,)](
258
+ dloss_ptr=dloss,
259
+ dlogits_ptr=dlogits,
260
+ logits_ptr=logits,
261
+ ref_logp_ptr=ctx.ref_logp,
262
+ input_ids_ptr=input_ids,
263
+ advantages_ptr=advantages,
264
+ completion_mask_ptr=completion_mask,
265
+ lse_ptr=lse,
266
+ beta=ctx.beta,
267
+ B=B, N=N, L=L,
268
+ start_idx=input_ids_start_index,
269
+ )
270
+ # The last token in the completion is not used in the loss computation
271
+ # and therefore its gradient should be set to 0
272
+ dlogits[:, -1, :].fill_(0.0)
273
+ return dlogits.view(*ctx.input_shape), None, None, None, None, None, None
274
+
275
+
276
+ def fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False) -> torch.Tensor:
277
+ '''
278
+ compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)
279
+
280
+ Args:
281
+ logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]
282
+ ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]
283
+ input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids
284
+ advantages: Tensor, [B], the advantages of each prompt
285
+ beta: float, the weight of kl loss
286
+ completion_mask: Tensor, loss mask
287
+ save_kl: bool, if true will save kl
288
+
289
+ Retutn:
290
+ loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part
291
+
292
+ NOTE: logits(ref_logits) is computed by these steps
293
+ logits_to_keep = completion_ids.size(1)
294
+
295
+ def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):
296
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
297
+ logits = model(
298
+ input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
299
+ ).logits
300
+ return logits
301
+
302
+ logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)
303
+ '''
304
+ out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)
305
+ if not save_kl:
306
+ return out
307
+ else:
308
+ return out.chunk(2, axis=0)
309
+
310
+
311
+ def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):
312
+ def get_log_probs(logits, input_ids):
313
+ per_token_logps = []
314
+ for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
315
+ log_probs = logits_row.log_softmax(dim=-1)
316
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
317
+ per_token_logps.append(token_log_prob)
318
+ return torch.stack(per_token_logps)
319
+
320
+ logits = logits[:, :-1]
321
+ per_token_logps = get_log_probs(logits, input_ids)
322
+ ref_per_token_logps = ref_logp
323
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
324
+
325
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
326
+ per_token_loss = -(per_token_loss - beta * per_token_kl)
327
+ if completion_mask is not None:
328
+ per_token_loss *= completion_mask
329
+ if save_kl:
330
+ per_token_kl *= completion_mask
331
+ return per_token_loss if not save_kl else (per_token_loss, per_token_kl)
332
+
333
+
334
+ @torch.compile(fullgraph=True)
335
+ def grpo_loss_with_old_logps(
336
+ logps: torch.Tensor,
337
+ ref_logps: torch.Tensor,
338
+ old_logps: torch.Tensor,
339
+ pad_mask: torch.Tensor,
340
+ logits_to_keep: int,
341
+ rewards: torch.Tensor,
342
+ beta: float = 0.2,
343
+ epsilon: float = 0.2
344
+ ):
345
+ """
346
+ Compute the GRPO (Group Relative Policy Optimization) loss.
347
+
348
+ Args:
349
+ logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy.
350
+ ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy.
351
+ old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy.
352
+ completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool).
353
+ pad_token_id: Pad token ID.
354
+ logits_to_keep (int): Number of logits to keep for masking.
355
+ rewards (torch.Tensor): [Batch] Rewards for each generation.
356
+ beta (float) = 0.2: A hyperparameter for weighting the KL divergence term.
357
+ epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights.
358
+
359
+ Returns:
360
+ torch.Tensor: The computed GRPO loss.
361
+ """
362
+ B = logps.shape[0]
363
+ assert B > 1, "Batch * Num generations should be greater than 1"
364
+
365
+ rewards_shaped = rewards.view(-1, B) # B,num_generations
366
+ advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \
367
+ (rewards_shaped.std(dim=1, keepdim=True) + 1e-8)
368
+ advantages = advantages.view(-1) # B*num_generations
369
+ # Calculate the per - token KL divergence
370
+ per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1
371
+
372
+ # Calculate the ratio of probabilities (importance weights)
373
+ # Importance weights are calculated as exp(log_pi_theta - log_pi_theta_old)
374
+ importance_weights = torch.exp(logps - old_logps)
375
+
376
+ # Clip the importance weights to the range [1 - epsilon, 1 + epsilon]
377
+ importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon)
378
+
379
+ # Create a completion mask. It checks which positions are valid based on logits_to_keep
380
+ completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0
381
+
382
+ # Combine the completion mask and padding mask
383
+ completion_mask = completion_mask & pad_mask # Ensure matching shape
384
+
385
+ # Add an extra dimension to advantages to match the shape for element - wise multiplication
386
+ advantages = advantages.unsqueeze(1)
387
+
388
+ # Calculate the per - token loss. It takes the minimum of the unclipped and clipped importance weights
389
+ # and subtracts the KL divergence term weighted by beta, then multiplies by the completion mask
390
+ token_loss = -(torch.min(advantages * importance_weights, advantages *
391
+ importance_weights_clipped) - beta * per_token_kl) * completion_mask
392
+
393
+ # Calculate the final loss by summing the token losses and normalizing by the number of valid tokens
394
+ loss = -token_loss.sum() / completion_mask.sum()
395
+
396
+ return loss
fla/modules/l2norm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import input_guard
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=num_warps)
16
+ for num_warps in [1, 2, 4, 8, 16, 32]
17
+ ],
18
+ key=['N']
19
+ )
20
+ @triton.jit
21
+ def l2norm_fwd_kernel(
22
+ X,
23
+ Y,
24
+ N,
25
+ eps,
26
+ BLOCK_N: tl.constexpr,
27
+ ):
28
+ i_m = tl.program_id(0)
29
+ X += i_m * N
30
+ Y += i_m * N
31
+ # Compute mean and variance
32
+ cols = tl.arange(0, BLOCK_N)
33
+ mask = cols < N
34
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
35
+ xbar = tl.where(mask, x, 0.0)
36
+ var = tl.sum(xbar * xbar, axis=0)
37
+ rstd = 1 / tl.sqrt(var + eps)
38
+ # tl.store(Rstd + i_m, rstd)
39
+ # Normalize and apply linear transformation
40
+ y = x * rstd
41
+ # Write output
42
+ tl.store(Y + cols, y, mask=mask)
43
+
44
+
45
+ @triton.autotune(
46
+ configs=[
47
+ triton.Config({}, num_warps=num_warps)
48
+ for num_warps in [1, 2, 4, 8, 16, 32]
49
+ ],
50
+ key=['N']
51
+ )
52
+ @triton.jit
53
+ def l2norm_bwd_kernel(
54
+ X,
55
+ DY,
56
+ DX,
57
+ N,
58
+ eps,
59
+ BLOCK_N: tl.constexpr,
60
+ ):
61
+ i_m = tl.program_id(0)
62
+ X += i_m * N
63
+ DX += i_m * N
64
+ DY += i_m * N
65
+
66
+ # Y += i_m * stride_y_row
67
+ cols = tl.arange(0, BLOCK_N)
68
+ mask = cols < N
69
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
70
+ x = tl.where(mask, x, 0.0)
71
+ var = tl.sum(x * x)
72
+ rstd = 1 / tl.sqrt(var + eps)
73
+ # tl.store(Rstd + i_m, rstd)
74
+ # Normalize and apply linear transformation
75
+ # y = x * rstd
76
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
77
+ dy = tl.where(mask, dy, 0.0)
78
+ dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
79
+ tl.store(DX + cols, dx, mask=mask)
80
+
81
+
82
+ def l2norm_fwd(
83
+ x: torch.Tensor,
84
+ eps: float = 1e-6,
85
+ output_dtype: Optional[torch.dtype] = None
86
+ ):
87
+ x_shape_og = x.shape
88
+ x = x.reshape(-1, x.shape[-1])
89
+ # allocate output
90
+ if output_dtype is None:
91
+ y = torch.empty_like(x)
92
+ else:
93
+ y = torch.empty_like(x, dtype=output_dtype)
94
+ assert y.stride(-1) == 1
95
+ N = x.shape[-1]
96
+ M = x.shape[0]
97
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
98
+ # Less than 64KB per feature: enqueue fused kernel
99
+ MAX_FUSED_SIZE = 65536 // x.element_size()
100
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
101
+ if N > BLOCK_N:
102
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
103
+ # heuristics for number of warps
104
+ l2norm_fwd_kernel[(M,)](
105
+ x,
106
+ y,
107
+ N,
108
+ eps,
109
+ BLOCK_N,
110
+ )
111
+ return y.reshape(x_shape_og)
112
+
113
+
114
+ def l2norm_bwd(
115
+ x: torch.Tensor,
116
+ dy: torch.Tensor,
117
+ eps: float = 1e-5
118
+ ):
119
+ x_shape_og = x.shape
120
+ x = x.reshape(-1, dy.shape[-1])
121
+ dy = dy.reshape(-1, dy.shape[-1])
122
+ if dy.stride(-1) != 1:
123
+ dy = dy.contiguous()
124
+ assert dy.shape == x.shape
125
+ # allocate output
126
+ dx = torch.empty_like(x)
127
+ M = x.shape[0]
128
+ N = x.shape[-1]
129
+ # rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
130
+ # Less than 64KB per feature: enqueue fused kernel
131
+ MAX_FUSED_SIZE = 65536 // x.element_size()
132
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
133
+ if N > BLOCK_N:
134
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
135
+ # heuristics for number of warps
136
+ l2norm_bwd_kernel[(M,)](
137
+ x,
138
+ dy,
139
+ dx,
140
+ N,
141
+ eps,
142
+ BLOCK_N,
143
+ )
144
+ return dx.reshape(x_shape_og)
145
+
146
+
147
+ class L2NormFunction(torch.autograd.Function):
148
+
149
+ @staticmethod
150
+ @input_guard
151
+ def forward(
152
+ ctx,
153
+ x,
154
+ eps=1e-6,
155
+ output_dtype=None
156
+ ):
157
+ y = l2norm_fwd(x, eps, output_dtype)
158
+ ctx.eps = eps
159
+ ctx.x_dtype = x.dtype
160
+ ctx.save_for_backward(x)
161
+ return y
162
+
163
+ @staticmethod
164
+ @input_guard
165
+ def backward(ctx, dy):
166
+ x, = ctx.saved_tensors
167
+ dx = l2norm_bwd(x, dy, ctx.eps)
168
+ return dx, None, None
169
+
170
+
171
+ def l2_norm(
172
+ x: torch.Tensor,
173
+ eps: float = 1e-6,
174
+ output_dtype: Optional[torch.dtype] = None
175
+ ) -> torch.Tensor:
176
+ return L2NormFunction.apply(x, eps, output_dtype)