zaydzuhri commited on
Commit
7be68db
·
verified ·
1 Parent(s): 57a6d54

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__pycache__/attn.cpython-312.pyc +0 -0
  2. fla/layers/__pycache__/based.cpython-312.pyc +0 -0
  3. fla/layers/__pycache__/delta_net.cpython-312.pyc +0 -0
  4. fla/layers/__pycache__/forgetting_attn.cpython-312.pyc +0 -0
  5. fla/layers/__pycache__/gated_deltanet.cpython-312.pyc +0 -0
  6. fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc +0 -0
  7. fla/layers/__pycache__/gla.cpython-312.pyc +0 -0
  8. fla/layers/__pycache__/hgrn2.cpython-312.pyc +0 -0
  9. fla/layers/__pycache__/lightnet.cpython-312.pyc +0 -0
  10. fla/layers/__pycache__/nsa.cpython-312.pyc +0 -0
  11. fla/layers/__pycache__/rebased.cpython-312.pyc +0 -0
  12. fla/layers/__pycache__/rwkv6.cpython-312.pyc +0 -0
  13. fla/models/forgetting_transformer/__init__.py +16 -0
  14. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  15. fla/models/mamba2/modeling_mamba2.py +1093 -0
  16. fla/models/nsa/modeling_nsa.py +398 -0
  17. fla/models/rwkv7/__init__.py +13 -0
  18. fla/models/transformer_top/__init__.py +13 -0
  19. fla/modules/__init__.py +30 -0
  20. fla/modules/convolution.py +434 -0
  21. fla/modules/fused_linear_listnet_loss.py +427 -0
  22. fla/modules/grpo.py +396 -0
  23. fla/modules/layernorm.py +1196 -0
  24. fla/modules/mlp.py +127 -0
  25. fla/ops/__pycache__/__init__.cpython-312.pyc +0 -0
  26. fla/ops/abc/__init__.py +7 -0
  27. fla/ops/abc/__pycache__/chunk.cpython-312.pyc +0 -0
  28. fla/ops/abc/chunk.py +1116 -0
  29. fla/ops/abc/naive.py +96 -0
  30. fla/ops/attn/__init__.py +7 -0
  31. fla/ops/attn/__pycache__/parallel.cpython-312.pyc +0 -0
  32. fla/ops/attn/parallel.py +629 -0
  33. fla/ops/based/__init__.py +9 -0
  34. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  36. fla/ops/based/__pycache__/parallel.cpython-312.pyc +0 -0
  37. fla/ops/based/naive.py +72 -0
  38. fla/ops/common/__init__.py +1 -0
  39. fla/ops/common/chunk_delta_h.py +399 -0
  40. fla/ops/common/chunk_h.py +422 -0
  41. fla/ops/common/chunk_h_parallel.py +650 -0
  42. fla/ops/common/chunk_h_split.py +677 -0
  43. fla/ops/common/chunk_o.py +668 -0
  44. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  45. fla/ops/common/fused_recurrent.py +575 -0
  46. fla/ops/common/utils.py +69 -0
  47. fla/ops/delta_rule/README.md +90 -0
  48. fla/ops/delta_rule/__init__.py +11 -0
  49. fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  50. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
fla/layers/__pycache__/attn.cpython-312.pyc ADDED
Binary file (9.5 kB). View file
 
fla/layers/__pycache__/based.cpython-312.pyc ADDED
Binary file (6.46 kB). View file
 
fla/layers/__pycache__/delta_net.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
fla/layers/__pycache__/gated_deltanet.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/layers/__pycache__/gla.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/layers/__pycache__/hgrn2.cpython-312.pyc ADDED
Binary file (8.6 kB). View file
 
fla/layers/__pycache__/lightnet.cpython-312.pyc ADDED
Binary file (8.85 kB). View file
 
fla/layers/__pycache__/nsa.cpython-312.pyc ADDED
Binary file (6.55 kB). View file
 
fla/layers/__pycache__/rebased.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
fla/layers/__pycache__/rwkv6.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
fla/models/forgetting_transformer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
6
+ from fla.models.forgetting_transformer.modeling_forgetting_transformer import (
7
+ ForgettingTransformerForCausalLM,
8
+ ForgettingTransformerModel
9
+ )
10
+
11
+ AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig)
12
+ AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel)
13
+ AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM)
14
+
15
+
16
+ __all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel']
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/mamba2/modeling_mamba2.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch MAMBA2 model."""
15
+
16
+ import math
17
+ import warnings
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from transformers.activations import ACT2FN
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import ModelOutput, logging
28
+ from transformers.utils.deprecation import deprecate_kwarg
29
+
30
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
31
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
32
+ from fla.modules.layernorm_gated import RMSNormGated
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ with warnings.catch_warnings():
37
+ warnings.simplefilter('ignore')
38
+ try:
39
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
40
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
41
+ except ImportError:
42
+ (
43
+ selective_state_update,
44
+ mamba_chunk_scan_combined,
45
+ mamba_split_conv1d_scan_combined,
46
+ ) = (None, None, None)
47
+ try:
48
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
49
+ except ImportError:
50
+ causal_conv1d_update, causal_conv1d_fn = None, None
51
+ is_fast_path_available = all((
52
+ selective_state_update,
53
+ causal_conv1d_fn,
54
+ causal_conv1d_update
55
+ ))
56
+
57
+
58
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
59
+ """
60
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
61
+
62
+ Assumes that we only have tensors of either size 4 or 3
63
+ """
64
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
65
+
66
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
67
+
68
+
69
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
70
+ """
71
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
72
+ simultaneously splitting it into chunk sequences.
73
+
74
+ Assumes that we only have tensors of either size 4 or 3
75
+ """
76
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
77
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
78
+
79
+ if len(input_tensor.shape) == 3:
80
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
81
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
82
+ else:
83
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] ->
84
+ # [bsz, -1, chunk_size, num_heads, head_dim or state_size]
85
+ return input_tensor.reshape(
86
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
87
+ )
88
+
89
+
90
+ def segment_sum(input_tensor):
91
+ """
92
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
93
+ """
94
+ chunk_size = input_tensor.size(-1)
95
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
96
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
97
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
98
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
99
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
100
+ input_tensor = input_tensor.masked_fill(~mask, 0)
101
+ # 3. compute actual cumsum
102
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
103
+
104
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
105
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
106
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
107
+ return tensor_segsum
108
+
109
+
110
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
111
+ """
112
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
113
+ """
114
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
115
+ dtype = hidden_states.dtype
116
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
117
+
118
+ return hidden_states
119
+
120
+
121
+ class Mamba2Cache:
122
+ """
123
+ Arguments:
124
+ config: Mamba2Config
125
+ batch_size: int
126
+ dtype: torch.dtype
127
+ device: torch.device
128
+
129
+ Attributes:
130
+ dtype: (`torch.dtype`):
131
+ The default `dtype` used to initializing the cache.
132
+ conv_kernel_size: (`int`):
133
+ Model's convolution kernel size taken from config.
134
+ n_groups: (`int`):
135
+ Model's number of groups taken from the config - similar to tensor parallel in Transformer.
136
+ state_size: (`int`):
137
+ Model's SSM state size taken from config.
138
+ num_heads: (`int`):
139
+ The number of heads used in the linear attention / SSM.
140
+ head_dim: (`int`):
141
+ The respective dimension of the heads used in the linear attention / SSM.
142
+ intermediate_size: (`int`):
143
+ Model's intermediate_size based on (expand * hidden_dim) from config.
144
+ conv_states: (`torch.Tensor`):
145
+ A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]`
146
+ that holds convolutional states.
147
+ ssm_states: (`torch.Tensor`):
148
+ A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ config: Mamba2Config,
154
+ batch_size: int,
155
+ dtype: torch.dtype = torch.float16,
156
+ device: Optional[str] = None,
157
+ ):
158
+ self.dtype = dtype
159
+ self.conv_kernel_size = config.conv_kernel
160
+ self.n_groups = config.n_groups
161
+ self.state_size = config.state_size
162
+ self.num_heads = config.num_heads
163
+ self.head_dim = config.head_dim
164
+ self.intermediate_size = int(config.expand * config.hidden_size)
165
+
166
+ self.conv_states = torch.zeros(
167
+ config.num_hidden_layers,
168
+ batch_size,
169
+ self.intermediate_size + 2 * self.n_groups * self.state_size,
170
+ self.conv_kernel_size,
171
+ device=device,
172
+ dtype=dtype,
173
+ )
174
+ self.ssm_states = torch.zeros(
175
+ config.num_hidden_layers,
176
+ batch_size,
177
+ self.num_heads,
178
+ self.head_dim,
179
+ self.state_size,
180
+ device=device,
181
+ dtype=dtype,
182
+ )
183
+
184
+ def update_conv_state(
185
+ self,
186
+ layer_idx: int,
187
+ new_conv_state: torch.Tensor,
188
+ cache_init: bool = False
189
+ ) -> torch.Tensor:
190
+ if cache_init:
191
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
192
+ else:
193
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
194
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
195
+ return self.conv_states[layer_idx]
196
+
197
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
198
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
199
+ return self.ssm_states[layer_idx]
200
+
201
+ def reset(self):
202
+ self.conv_states.zero_()
203
+ self.ssm_states.zero_()
204
+
205
+
206
+ class Mamba2Mixer(nn.Module):
207
+ """
208
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
209
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
210
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
211
+ and is why Mamba is called **selective** state spaces)
212
+ """
213
+
214
+ def __init__(self, config: Mamba2Config, layer_idx: int):
215
+ super().__init__()
216
+ self.num_heads = config.num_heads
217
+ self.hidden_size = config.hidden_size
218
+ self.ssm_state_size = config.state_size
219
+ self.conv_kernel_size = config.conv_kernel
220
+ self.intermediate_size = int(config.expand * self.hidden_size)
221
+ self.time_step_rank = int(config.time_step_rank)
222
+ self.layer_idx = layer_idx
223
+ self.use_conv_bias = config.use_conv_bias
224
+ self.activation = config.hidden_act
225
+ self.act = ACT2FN[config.hidden_act]
226
+
227
+ self.layer_norm_epsilon = config.layer_norm_epsilon
228
+ self.rms_norm = config.rms_norm
229
+
230
+ self.n_groups = config.n_groups
231
+ self.head_dim = config.head_dim
232
+ self.chunk_size = config.chunk_size
233
+
234
+ self.time_step_limit = config.time_step_limit
235
+ self.time_step_min = config.time_step_min
236
+ self.time_step_max = config.time_step_max
237
+
238
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
239
+ self.conv1d = nn.Conv1d(
240
+ in_channels=self.conv_dim,
241
+ out_channels=self.conv_dim,
242
+ bias=config.use_conv_bias,
243
+ kernel_size=config.conv_kernel,
244
+ groups=self.conv_dim,
245
+ padding=config.conv_kernel - 1,
246
+ )
247
+
248
+ # projection of the input hidden states
249
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
250
+ self.in_proj = nn.Linear(
251
+ self.hidden_size,
252
+ projection_size,
253
+ bias=config.use_bias,
254
+ )
255
+ # selective projection used to make dt, B and C input dependant
256
+
257
+ # time step projection (discretization)
258
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
259
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
260
+
261
+ # S4D real initialization. These are not discretized!
262
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
263
+ A = torch.arange(1, self.num_heads + 1)
264
+ self.A_log = nn.Parameter(torch.log(A))
265
+ self.A_log._no_weight_decay = True
266
+ self.norm = RMSNormGated(
267
+ self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=False
268
+ )
269
+ self.D = nn.Parameter(torch.ones(self.num_heads))
270
+ self.D._no_weight_decay = True
271
+
272
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
273
+ self.use_bias = config.use_bias
274
+
275
+ if not is_fast_path_available:
276
+ logger.warning_once(
277
+ "The fast path is not available because one of "
278
+ "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. "
279
+ "Falling back to the naive implementation. "
280
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
281
+ "https://github.com/Dao-AILab/causal-conv1d"
282
+ )
283
+
284
+ def cuda_kernels_forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ cache_params: Optional[Mamba2Cache] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ attention_mask: Optional[torch.Tensor] = None,
290
+ ):
291
+ # 1. Gated MLP's linear projection
292
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
293
+ projected_states = self.in_proj(hidden_states)
294
+
295
+ # Set up dimensions for reshapes later
296
+ batch_size, seq_len, _ = hidden_states.shape
297
+ groups_time_state_size = self.n_groups * self.ssm_state_size
298
+ d_mlp = (
299
+ projected_states.shape[-1]
300
+ - 2 * self.intermediate_size
301
+ - 2 * self.n_groups * self.ssm_state_size
302
+ - self.num_heads
303
+ ) // 2
304
+
305
+ # Single step calculations via cache
306
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
307
+ _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
308
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
309
+ )
310
+
311
+ # 2. Convolution sequence transformation
312
+ hidden_states_B_C = causal_conv1d_update(
313
+ hidden_states_B_C,
314
+ cache_params.conv_states[self.layer_idx],
315
+ self.conv1d.weight.squeeze(1),
316
+ self.conv1d.bias,
317
+ self.activation,
318
+ )
319
+
320
+ hidden_states, B, C = torch.split(
321
+ hidden_states_B_C,
322
+ [
323
+ self.intermediate_size,
324
+ groups_time_state_size,
325
+ groups_time_state_size,
326
+ ],
327
+ dim=-1,
328
+ )
329
+
330
+ # 3. SSM transformation
331
+ A = -torch.exp(self.A_log.float()) # (nheads,)
332
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
333
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
334
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
335
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
336
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
337
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
338
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
339
+
340
+ hidden_states = selective_state_update(
341
+ cache_params.ssm_states[self.layer_idx],
342
+ hidden_states_reshaped,
343
+ dt,
344
+ A,
345
+ B,
346
+ C,
347
+ D,
348
+ z=None,
349
+ dt_bias=dt_bias,
350
+ dt_softplus=True,
351
+ )
352
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
353
+ hidden_states = self.norm(hidden_states, gate)
354
+
355
+ # 4. Final linear projection
356
+ out = self.out_proj(hidden_states)[:, None, ...]
357
+
358
+ # Fused calculations or step by step if no initialized cache is found
359
+ else:
360
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
361
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
362
+
363
+ # 2-4. Fused kernel for conv1d, SSM, and the final projection
364
+ if self.training and cache_params is None:
365
+ out = mamba_split_conv1d_scan_combined(
366
+ projected_states,
367
+ self.conv1d.weight.squeeze(1),
368
+ self.conv1d.bias,
369
+ self.dt_bias,
370
+ A,
371
+ D=self.D,
372
+ chunk_size=self.chunk_size,
373
+ seq_idx=None, # was seq_idx
374
+ activation=self.activation,
375
+ rmsnorm_weight=self.norm.weight,
376
+ rmsnorm_eps=self.norm.eps,
377
+ outproj_weight=self.out_proj.weight,
378
+ outproj_bias=self.out_proj.bias,
379
+ headdim=self.head_dim,
380
+ ngroups=self.n_groups,
381
+ norm_before_gate=False,
382
+ return_final_states=False,
383
+ **dt_limit_kwargs,
384
+ )
385
+
386
+ else:
387
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
388
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
389
+ )
390
+
391
+ # 2. Convolution sequence transformation
392
+ # Init cache
393
+ if cache_params is not None:
394
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
395
+ conv_states = nn.functional.pad(
396
+ hidden_states_B_C_transposed,
397
+ (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
398
+ )
399
+ cache_params.update_conv_state(
400
+ layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
401
+ )
402
+
403
+ if self.activation not in ["silu", "swish"]:
404
+ hidden_states_B_C = self.act(
405
+ self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
406
+ )
407
+ else:
408
+ hidden_states_B_C = causal_conv1d_fn(
409
+ x=hidden_states_B_C.transpose(1, 2),
410
+ weight=self.conv1d.weight.squeeze(1),
411
+ bias=self.conv1d.bias,
412
+ activation=self.activation,
413
+ ).transpose(1, 2)
414
+
415
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
416
+ hidden_states, B, C = torch.split(
417
+ hidden_states_B_C,
418
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
419
+ dim=-1,
420
+ )
421
+
422
+ # 3. SSM transformation
423
+ scan_output, ssm_state = mamba_chunk_scan_combined(
424
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
425
+ dt,
426
+ A,
427
+ B.view(batch_size, seq_len, self.n_groups, -1),
428
+ C.view(batch_size, seq_len, self.n_groups, -1),
429
+ chunk_size=self.chunk_size,
430
+ D=self.D,
431
+ z=None,
432
+ seq_idx=None,
433
+ return_final_states=True,
434
+ dt_bias=self.dt_bias,
435
+ dt_softplus=True,
436
+ **dt_limit_kwargs,
437
+ )
438
+
439
+ # Init cache
440
+ if ssm_state is not None and cache_params is not None:
441
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
442
+
443
+ scan_output = scan_output.view(batch_size, seq_len, -1)
444
+ # Multiply "gate" branch and apply extra normalization layer
445
+ scan_output = self.norm(scan_output, gate)
446
+
447
+ # 4. Final linear projection
448
+ out = self.out_proj(scan_output)
449
+ return out
450
+
451
+ # fmt: off
452
+ def torch_forward(
453
+ self,
454
+ input_states,
455
+ cache_params: Optional[Mamba2Cache] = None,
456
+ cache_position: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None
458
+ ):
459
+ batch_size, seq_len, _ = input_states.shape
460
+ dtype = input_states.dtype
461
+
462
+ # 1. Gated MLP's linear projection
463
+ input_states = apply_mask_to_padding_states(input_states, attention_mask)
464
+ projected_states = self.in_proj(input_states)
465
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size -
466
+ 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2
467
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
468
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
469
+ )
470
+
471
+ # 2. Convolution sequence transformation
472
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
473
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
474
+
475
+ # We need to guarantee that anything regarding the cache is on the same device
476
+ conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
477
+
478
+ hidden_states_B_C = torch.sum(
479
+ conv_states * self.conv1d.weight.squeeze(1), dim=-1
480
+ )
481
+ if self.use_conv_bias:
482
+ hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
483
+ hidden_states_B_C = self.act(hidden_states_B_C)
484
+ else:
485
+ # Init cache
486
+ if cache_params is not None:
487
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
488
+ conv_states = nn.functional.pad(
489
+ hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
490
+ )
491
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
492
+
493
+ hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
494
+
495
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
496
+ hidden_states, B, C = torch.split(
497
+ hidden_states_B_C,
498
+ [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
499
+ dim=-1
500
+ )
501
+
502
+ # 3. SSM transformation
503
+ A = -torch.exp(self.A_log.float()) # [num_heads]
504
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
505
+ # We need to guarantee that anything regarding the cache is on the same device
506
+ cache_device = cache_params.ssm_states.device
507
+
508
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
509
+ # for batched generation
510
+ dt = dt[:, 0, :][:, None, ...]
511
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
512
+ # [num_heads] -> [num_heads, head_dim]
513
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
514
+
515
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
516
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
517
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
518
+ # [bsz, num_heads, head_dim, state_size]
519
+ dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
520
+
521
+ # Discretize B
522
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
523
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
524
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
525
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
526
+ B = B.reshape(batch_size, -1, B.shape[-1])
527
+ # [bsz, num_heads, head_dim, state_size]
528
+ dB = dt[..., None] * B[..., None, :]
529
+
530
+ # Discretize x into dB
531
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
532
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
533
+ dBx = (dB * hidden_states[..., None]).to(device=cache_device)
534
+
535
+ # State calculation
536
+ cache_params.update_ssm_state(
537
+ layer_idx=self.layer_idx,
538
+ new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
539
+ )
540
+
541
+ # Subsequent output
542
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
543
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
544
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
545
+ C = C.reshape(batch_size, -1, C.shape[-1])
546
+ # [bsz, num_heads, head_dim]
547
+
548
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
549
+ # Reshape ssm_states to merge the first two dimensions
550
+ # Shape: [b*h, d, n]
551
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size)
552
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
553
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
554
+ y = y.view(batch_size, self.num_heads, self.head_dim)
555
+
556
+ # D skip connection
557
+ # [num_heads] -> [num_heads, head_dim]
558
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
559
+ y = (y + hidden_states * D).to(y.dtype)
560
+
561
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
562
+ y = y.reshape(batch_size, -1)[:, None, ...]
563
+ else:
564
+ # begin ssd naive implementation without einsums
565
+ dt = nn.functional.softplus(dt + self.dt_bias)
566
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
567
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
568
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
569
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
570
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
571
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
572
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
573
+
574
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
575
+
576
+ # Discretize x and A
577
+ hidden_states = hidden_states * dt[..., None]
578
+ A = A.to(hidden_states.dtype) * dt
579
+
580
+ # Rearrange into blocks/chunks
581
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
582
+
583
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
584
+ A = A.permute(0, 3, 1, 2)
585
+ A_cumsum = torch.cumsum(A, dim=-1)
586
+
587
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
588
+ # This is the analog of a causal mask
589
+ L = torch.exp(segment_sum(A))
590
+
591
+ # Contraction of C and B to get G (attention-weights like)
592
+ # shape: (b, c, l, s, h, n)
593
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
594
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
595
+
596
+ # Compute M, equivalent to applying attention mask to weights
597
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
598
+ M = M_intermediate.sum(dim=-1)
599
+
600
+ # Compute Y_diag (apply to values)
601
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
602
+
603
+ # 2. Compute the state for each intra-chunk
604
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
605
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
606
+ B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
607
+ states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
608
+
609
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
610
+ # (middle term of factorization of off-diag blocks; A terms)
611
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
612
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
613
+ else:
614
+ previous_states = torch.zeros_like(states[:, :1])
615
+ states = torch.cat([previous_states, states], dim=1)
616
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
617
+ decay_chunk = decay_chunk.transpose(1, 3)
618
+ new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
619
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
620
+
621
+ # 4. Compute state -> output conversion per chunk
622
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
623
+ state_decay_out = torch.exp(A_cumsum)
624
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
625
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
626
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
627
+
628
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
629
+ y = Y_diag + Y_off
630
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
631
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
632
+
633
+ y = y + D_residual
634
+ # Cutting off padded chunks
635
+ if pad_size > 0:
636
+ y = y[:, :seq_len, :, :]
637
+ y = y.reshape(batch_size, seq_len, -1)
638
+
639
+ # Init cache
640
+ if ssm_state is not None and cache_params is not None:
641
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
642
+
643
+ scan_output = self.norm(y, gate)
644
+
645
+ # end ssd naive
646
+
647
+ # 4. Final linear projection
648
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
649
+ return contextualized_states
650
+ # fmt: on
651
+
652
+ def forward(
653
+ self,
654
+ hidden_states,
655
+ cache_params: Optional[Mamba2Cache] = None,
656
+ cache_position: Optional[torch.LongTensor] = None,
657
+ attention_mask: Optional[torch.Tensor] = None,
658
+ ):
659
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
660
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
661
+ dtype = hidden_states.dtype
662
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
663
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
664
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
665
+
666
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
667
+
668
+
669
+ class Mamba2Block(nn.Module):
670
+ def __init__(self, config, layer_idx):
671
+ super().__init__()
672
+ self.config = config
673
+ self.layer_idx = layer_idx
674
+ self.residual_in_fp32 = config.residual_in_fp32
675
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
676
+ self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states,
681
+ cache_params: Optional[Mamba2Cache] = None,
682
+ cache_position: Optional[torch.LongTensor] = None,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ ):
685
+ residual = hidden_states
686
+ hidden_states = self.norm(hidden_states)
687
+ if self.residual_in_fp32:
688
+ residual = residual.to(torch.float32)
689
+
690
+ hidden_states = self.mixer(
691
+ hidden_states,
692
+ cache_params=cache_params,
693
+ cache_position=cache_position,
694
+ attention_mask=attention_mask,
695
+ )
696
+ hidden_states = residual + hidden_states
697
+ if self.residual_in_fp32:
698
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
699
+ return hidden_states
700
+
701
+
702
+ class Mamba2PreTrainedModel(PreTrainedModel, GenerationMixin):
703
+ """
704
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
705
+ models.
706
+ """
707
+
708
+ config_class = Mamba2Config
709
+ base_model_prefix = "backbone"
710
+ _no_split_modules = ["Mamba2Block"]
711
+ supports_gradient_checkpointing = True
712
+ _is_stateful = True
713
+
714
+ def _init_weights(
715
+ self,
716
+ module: nn.Module,
717
+ num_residuals_per_layer: int = 1,
718
+ ):
719
+ """Initialize the weights."""
720
+ if isinstance(module, Mamba2Mixer):
721
+
722
+ # --- A_log ---
723
+ A = torch.arange(1, module.num_heads + 1)
724
+ with torch.no_grad():
725
+ if not isinstance(module.A_log, torch.distributed.tensor.DTensor):
726
+ module.A_log.copy_(torch.log(A))
727
+ else:
728
+ logger.warning_once("`A_log` is a DTensor, skipping initialization")
729
+ module.A_log._no_weight_decay = True
730
+
731
+ # --- D ---
732
+ nn.init.ones_(module.D)
733
+ module.D._no_weight_decay = True
734
+
735
+ # --- dt_bias ---
736
+ dt = torch.exp(
737
+ torch.rand(self.config.num_heads)
738
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
739
+ + math.log(self.config.time_step_min)
740
+ ).clamp(min=self.config.time_step_floor)
741
+
742
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
743
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
744
+ with torch.no_grad():
745
+ if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor):
746
+ module.dt_bias.copy_(inv_dt)
747
+ else:
748
+ logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
749
+ module.dt_bias._no_reinit = True
750
+
751
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
752
+ # Slightly different from the TF version which uses truncated_normal for initialization
753
+ # cf https://github.com/pytorch/pytorch/pull/5617
754
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
755
+ if module.bias is not None:
756
+ nn.init.zeros_(module.bias)
757
+ # guard against deprecated behavior
758
+ if hasattr(module.bias, "_no_reinit"):
759
+ raise ValueError("This is not supposed to happen")
760
+ elif isinstance(module, nn.Embedding):
761
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
762
+ elif hasattr(module, 'reset_parameters'):
763
+ module.reset_parameters()
764
+
765
+ if self.config.rescale_prenorm_residual:
766
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
767
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
768
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
769
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
770
+ #
771
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
772
+ p = None
773
+ if hasattr(module, 'o_proj'):
774
+ # p = module.o_proj.weight
775
+ # guard against deprecated behavior
776
+ raise ValueError("This is not supposed to happen")
777
+ elif hasattr(module, 'out_proj'):
778
+ p = module.out_proj.weight
779
+ elif hasattr(module, 'down_proj'):
780
+ p = module.down_proj.weight
781
+ if p is not None:
782
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
783
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
784
+ # We need to reinit p since this code could be called multiple times
785
+ # Having just p *= scale would repeatedly scale it down
786
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
787
+ with torch.no_grad():
788
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
789
+
790
+
791
+ @dataclass
792
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
793
+ class Mamba2Output(ModelOutput):
794
+ """
795
+ Class for the MAMBA2 model outputs.
796
+
797
+ Args:
798
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
799
+ Sequence of hidden-states at the output of the last layer of the model.
800
+ cache_params (`Mamba2Cache`):
801
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
802
+ avoid providing the old `input_ids`.
803
+
804
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
805
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
806
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
807
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
808
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
809
+
810
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
811
+ """
812
+
813
+ last_hidden_state: Optional[torch.FloatTensor] = None
814
+ cache_params: Optional[Mamba2Cache] = None
815
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+
818
+ @dataclass
819
+ # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
820
+ class Mamba2CausalLMOutput(ModelOutput):
821
+ """
822
+ Base class for causal language model (or autoregressive) outputs.
823
+
824
+ Args:
825
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
826
+ Language modeling loss (for next-token prediction).
827
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
828
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
829
+ cache_params (`Mamba2Cache`):
830
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
831
+ avoid providing the old `input_ids`.
832
+
833
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
834
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
835
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
836
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
837
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
838
+
839
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
840
+ """
841
+
842
+ loss: Optional[torch.FloatTensor] = None
843
+ logits: Optional[torch.FloatTensor] = None
844
+ cache_params: Optional[Mamba2Cache] = None
845
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
846
+
847
+
848
+ class Mamba2Model(Mamba2PreTrainedModel):
849
+ def __init__(self, config):
850
+ super().__init__(config)
851
+
852
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
853
+ self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
854
+
855
+ self.gradient_checkpointing = False
856
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
857
+ # Initialize weights and apply final processing
858
+ self._register_load_state_dict_pre_hook(self.load_hook)
859
+ self.post_init()
860
+
861
+ def load_hook(self, state_dict, prefix, *args):
862
+ for k in state_dict:
863
+ if "embedding." in k:
864
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
865
+ break
866
+
867
+ def get_input_embeddings(self):
868
+ return self.embeddings
869
+
870
+ def set_input_embeddings(self, new_embeddings):
871
+ self.embeddings = new_embeddings
872
+
873
+ def forward(
874
+ self,
875
+ input_ids: Optional[torch.LongTensor] = None,
876
+ inputs_embeds: Optional[torch.LongTensor] = None,
877
+ cache_params: Optional[Mamba2Cache] = None,
878
+ use_cache: Optional[bool] = None,
879
+ output_hidden_states: Optional[bool] = None,
880
+ return_dict: Optional[bool] = None,
881
+ cache_position: Optional[torch.LongTensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ **kwargs,
884
+ ) -> Union[Tuple, Mamba2Output]:
885
+ output_hidden_states = (
886
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
887
+ )
888
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
889
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
890
+
891
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
892
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
893
+
894
+ if inputs_embeds is None:
895
+ inputs_embeds = self.embeddings(input_ids)
896
+
897
+ if self.gradient_checkpointing and self.training and use_cache:
898
+ use_cache = False
899
+
900
+ if use_cache:
901
+ if cache_params is None:
902
+ cache_params = Mamba2Cache(
903
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
904
+ )
905
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
906
+ elif cache_position is None:
907
+ # cases when we do manual forward instead of using `model.generate` which will initiate
908
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
909
+ # hack to conjecture the current cache position
910
+ raise ValueError(
911
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
912
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
913
+ "be initialized for you automatically"
914
+ )
915
+ else:
916
+ cache_params = None
917
+
918
+ hidden_states = inputs_embeds
919
+ all_hidden_states = () if output_hidden_states else None
920
+ for mixer_block in self.layers:
921
+ if self.gradient_checkpointing and self.training:
922
+ hidden_states = self._gradient_checkpointing_func(
923
+ mixer_block.__call__,
924
+ hidden_states,
925
+ cache_params,
926
+ cache_position,
927
+ attention_mask,
928
+ )
929
+ else:
930
+ hidden_states = mixer_block(
931
+ hidden_states,
932
+ cache_params=cache_params,
933
+ cache_position=cache_position,
934
+ attention_mask=attention_mask,
935
+ )
936
+
937
+ if output_hidden_states:
938
+ all_hidden_states = all_hidden_states + (hidden_states,)
939
+
940
+ hidden_states = self.norm_f(hidden_states)
941
+
942
+ if output_hidden_states:
943
+ all_hidden_states = all_hidden_states + (hidden_states,)
944
+
945
+ if not return_dict:
946
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
947
+
948
+ return Mamba2Output(
949
+ last_hidden_state=hidden_states,
950
+ cache_params=cache_params if use_cache else None,
951
+ hidden_states=all_hidden_states,
952
+ )
953
+
954
+
955
+ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
956
+ _tied_weights_keys = []
957
+
958
+ def __init__(self, config):
959
+ super().__init__(config)
960
+ self.backbone = Mamba2Model(config)
961
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
962
+ self.criterion = None
963
+
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ def get_output_embeddings(self):
968
+ return self.lm_head
969
+
970
+ def set_output_embeddings(self, new_embeddings):
971
+ self.lm_head = new_embeddings
972
+
973
+ def get_input_embeddings(self):
974
+ return self.backbone.get_input_embeddings()
975
+
976
+ def set_input_embeddings(self, new_embeddings):
977
+ return self.backbone.set_input_embeddings(new_embeddings)
978
+
979
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
980
+ def prepare_inputs_for_generation(
981
+ self,
982
+ input_ids,
983
+ inputs_embeds=None,
984
+ use_cache=None,
985
+ cache_params: Optional[Mamba2Cache] = None,
986
+ cache_position: Optional[torch.LongTensor] = None,
987
+ attention_mask: Optional[torch.Tensor] = None,
988
+ logits_to_keep: Optional[int] = None,
989
+ **kwargs,
990
+ ):
991
+ if use_cache:
992
+ # `cache_position` should have been initialized in `generate`
993
+ if cache_position is None:
994
+ raise ValueError(
995
+ "`cache_position` should not be None as it should have been initialized in "
996
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
997
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
998
+ )
999
+ if cache_position[0] > 0:
1000
+ input_ids = input_ids[:, -1][..., None]
1001
+
1002
+ if attention_mask is not None:
1003
+ attention_mask = None
1004
+ else:
1005
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
1006
+ # considering padding will be applied when input length is shorter, and truncation
1007
+ # will be applied when it is longer, so it will be equivalent to always have it match
1008
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
1009
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
1010
+
1011
+ if inputs_embeds is not None and cache_params is None:
1012
+ model_inputs = {"inputs_embeds": inputs_embeds}
1013
+ else:
1014
+ model_inputs = {"input_ids": input_ids}
1015
+
1016
+ if logits_to_keep is not None:
1017
+ model_inputs['logits_to_keep'] = logits_to_keep
1018
+
1019
+ model_inputs.update({
1020
+ 'attention_mask': attention_mask,
1021
+ 'cache_params': cache_params,
1022
+ 'use_cache': use_cache,
1023
+ 'cache_position': cache_position,
1024
+ 'logits_to_keep': logits_to_keep
1025
+ })
1026
+ return model_inputs
1027
+
1028
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1029
+ def forward(
1030
+ self,
1031
+ input_ids: Optional[torch.LongTensor] = None,
1032
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1033
+ cache_params: Optional[Mamba2Cache] = None,
1034
+ labels: Optional[torch.LongTensor] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ return_dict: Optional[bool] = None,
1037
+ use_cache: Optional[bool] = None,
1038
+ cache_position: Optional[torch.Tensor] = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ logits_to_keep: Optional[int] = 0,
1041
+ **kwargs, # for now we need this for generation
1042
+ ) -> Union[Tuple, Mamba2CausalLMOutput]:
1043
+ r"""
1044
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1045
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1046
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1047
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1048
+ """
1049
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1050
+
1051
+ outputs = self.backbone(
1052
+ input_ids,
1053
+ cache_params=cache_params,
1054
+ inputs_embeds=inputs_embeds,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ use_cache=use_cache,
1058
+ cache_position=cache_position,
1059
+ attention_mask=attention_mask,
1060
+ )
1061
+ hidden_states = outputs[0]
1062
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
1063
+
1064
+ loss, logits = None, None
1065
+ if not fuse_linear_and_cross_entropy or labels is None:
1066
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
1067
+ if labels is not None:
1068
+ if getattr(self, 'criterion', None) is None:
1069
+ if fuse_linear_and_cross_entropy:
1070
+ criterion = FusedLinearCrossEntropyLoss()
1071
+ elif self.config.fuse_cross_entropy:
1072
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
1073
+ else:
1074
+ criterion = nn.CrossEntropyLoss()
1075
+ else:
1076
+ criterion = self.criterion
1077
+ labels = labels.to(hidden_states.device)
1078
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
1079
+ if fuse_linear_and_cross_entropy:
1080
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
1081
+ else:
1082
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
1083
+
1084
+ if not return_dict:
1085
+ output = (logits,) + outputs[1:]
1086
+ return (loss,) + output if loss is not None else output
1087
+
1088
+ return Mamba2CausalLMOutput(
1089
+ loss=loss,
1090
+ logits=logits,
1091
+ cache_params=outputs.cache_params,
1092
+ hidden_states=outputs.hidden_states,
1093
+ )
fla/models/nsa/modeling_nsa.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.nsa import NativeSparseAttention
19
+ from fla.models.nsa.configuration_nsa import NSAConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as NSAMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class NSABlock(nn.Module):
32
+ def __init__(self, config: NSAConfig, layer_idx: int):
33
+ super().__init__()
34
+
35
+ self.config = config
36
+ self.layer_idx = layer_idx
37
+
38
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
39
+ self.attn = NativeSparseAttention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.num_heads,
42
+ num_kv_heads=config.num_kv_heads,
43
+ qkv_bias=config.qkv_bias,
44
+ block_size=config.block_size,
45
+ block_counts=config.block_counts,
46
+ window_size=config.window_size,
47
+ rope_theta=config.rope_theta,
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = NSAMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ intermediate_size=config.intermediate_size,
56
+ hidden_act=config.hidden_act,
57
+ fuse_swiglu=config.fuse_swiglu
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
65
+ use_cache: Optional[bool] = False,
66
+ output_attentions: Optional[bool] = False,
67
+ **kwargs: Unpack[Dict]
68
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
69
+ residual = hidden_states
70
+ hidden_states = self.attn_norm(hidden_states)
71
+ hidden_states, attentions, past_key_values = self.attn(
72
+ hidden_states=hidden_states,
73
+ attention_mask=attention_mask,
74
+ past_key_values=past_key_values,
75
+ use_cache=use_cache,
76
+ output_attentions=output_attentions,
77
+ **kwargs
78
+ )
79
+ if self.config.fuse_norm:
80
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
81
+ else:
82
+ hidden_states = residual + hidden_states
83
+ residual = hidden_states
84
+ hidden_states = self.mlp_norm(hidden_states)
85
+ hidden_states = self.mlp(hidden_states, **kwargs)
86
+ hidden_states = residual + hidden_states
87
+
88
+ outputs = (hidden_states, attentions, past_key_values)
89
+
90
+ return outputs
91
+
92
+
93
+ class NSAPreTrainedModel(PreTrainedModel):
94
+
95
+ config_class = NSAConfig
96
+ base_model_prefix = 'model'
97
+ supports_gradient_checkpointing = True
98
+ _no_split_modules = ['NSABlock']
99
+ _supports_cache_class = True
100
+
101
+ def __init__(self, *inputs, **kwargs):
102
+ super().__init__(*inputs, **kwargs)
103
+
104
+ def _init_weights(
105
+ self,
106
+ module: nn.Module,
107
+ prenorm_residual_strategy: Optional[str] = 'rescale',
108
+ num_residuals_per_layer: int = 2,
109
+ ):
110
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
111
+ # Slightly different from the TF version which uses truncated_normal for initialization
112
+ # cf https://github.com/pytorch/pytorch/pull/5617
113
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
114
+ if module.bias is not None:
115
+ nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
+ elif hasattr(module, 'reset_parameters'):
119
+ module.reset_parameters()
120
+
121
+ if prenorm_residual_strategy is not None:
122
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
123
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
124
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
125
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
126
+ #
127
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
128
+ p = None
129
+ if hasattr(module, 'o_proj'):
130
+ p = module.o_proj.weight
131
+ elif hasattr(module, 'down_proj'):
132
+ p = module.down_proj.weight
133
+ if p is not None:
134
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
135
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
136
+ # We need to reinit p since this code could be called multiple times
137
+ # Having just p *= scale would repeatedly scale it down
138
+ if prenorm_residual_strategy == 'rescale':
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
142
+ elif prenorm_residual_strategy == 'zero':
143
+ nn.init.zeros_(p)
144
+ else:
145
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
146
+
147
+
148
+ class NSAModel(NSAPreTrainedModel):
149
+
150
+ def __init__(self, config: NSAConfig):
151
+ super().__init__(config)
152
+ self.padding_idx = config.pad_token_id
153
+ self.vocab_size = config.vocab_size
154
+
155
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
156
+ self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
157
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
158
+
159
+ self.gradient_checkpointing = False
160
+
161
+ self.post_init()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.embeddings
165
+
166
+ def set_input_embeddings(self, value):
167
+ self.embeddings = value
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: Optional[torch.LongTensor] = None,
172
+ attention_mask: Optional[torch.Tensor] = None, # noqa
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ **kwargs: Unpack[Dict]
180
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
181
+ if output_attentions:
182
+ warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
183
+ output_attentions = False
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
185
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ # retrieve input_ids and inputs_embeds
190
+ if input_ids is not None and inputs_embeds is not None:
191
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
192
+ if input_ids is None and inputs_embeds is None:
193
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
194
+
195
+ if inputs_embeds is None:
196
+ inputs_embeds = self.embeddings(input_ids)
197
+ hidden_states = inputs_embeds
198
+
199
+ if use_cache and not isinstance(past_key_values, Cache):
200
+ past_key_values = Cache.from_legacy_cache(past_key_values)
201
+
202
+ if self.gradient_checkpointing and self.training and use_cache:
203
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
204
+ use_cache = False
205
+
206
+ all_hidden_states = () if output_hidden_states else None
207
+ all_attns = () if output_attentions else None
208
+ for layer in self.layers:
209
+ if output_hidden_states:
210
+ all_hidden_states += (hidden_states,)
211
+
212
+ if self.gradient_checkpointing and self.training:
213
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
214
+ layer.__call__,
215
+ hidden_states,
216
+ attention_mask,
217
+ past_key_values,
218
+ use_cache,
219
+ output_attentions,
220
+ **kwargs
221
+ )
222
+ else:
223
+ hidden_states, attentions, past_key_values = layer(
224
+ hidden_states,
225
+ attention_mask=attention_mask,
226
+ past_key_values=past_key_values,
227
+ use_cache=use_cache,
228
+ output_attentions=output_attentions,
229
+ **kwargs
230
+ )
231
+
232
+ if output_attentions:
233
+ all_attns += (attentions,)
234
+
235
+ hidden_states = self.norm(hidden_states)
236
+
237
+ # add hidden states from the last decoder layer
238
+ if output_hidden_states:
239
+ all_hidden_states += (hidden_states,)
240
+
241
+ if not return_dict:
242
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=past_key_values,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_attns
248
+ )
249
+
250
+
251
+ class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
252
+
253
+ _tied_weights_keys = ["lm_head.weight"]
254
+
255
+ def __init__(self, config):
256
+ super().__init__(config)
257
+ self.model = NSAModel(config)
258
+ self.vocab_size = config.vocab_size
259
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
260
+ self.criterion = None
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_input_embeddings(self):
266
+ return self.model.embeddings
267
+
268
+ def set_input_embeddings(self, value):
269
+ self.model.embeddings = value
270
+
271
+ def get_output_embeddings(self):
272
+ return self.lm_head
273
+
274
+ def set_output_embeddings(self, new_embeddings):
275
+ self.lm_head = new_embeddings
276
+
277
+ def set_decoder(self, decoder):
278
+ self.model = decoder
279
+
280
+ def get_decoder(self):
281
+ return self.model
282
+
283
+ def generate(self, *args, **kwargs):
284
+ try:
285
+ return super().generate(*args, **kwargs)
286
+ except AttributeError as exception:
287
+ if 'past_key_values' in str(exception):
288
+ raise AttributeError(
289
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
290
+ f"which is not supported for {self.__class__.__name__}. "
291
+ f"Try another generation strategy instead. "
292
+ f"For the available generation strategies, check this doc: "
293
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
294
+ )
295
+ else:
296
+ raise exception
297
+
298
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
299
+ def prepare_inputs_for_generation(
300
+ self,
301
+ input_ids: torch.LongTensor = None,
302
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ inputs_embeds: Optional[torch.Tensor] = None,
305
+ use_cache: bool = True,
306
+ logits_to_keep: Optional[int] = None,
307
+ **kwargs
308
+ ):
309
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
310
+ if past_key_values is not None and len(past_key_values) > 0:
311
+ input_ids = input_ids[:, -1:]
312
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
313
+ if inputs_embeds is not None and len(past_key_values) == 0:
314
+ model_inputs = {'inputs_embeds': inputs_embeds}
315
+ else:
316
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
317
+ # recompiles graphs as the stride of the inputs is a guard.
318
+ # Ref: https://github.com/huggingface/transformers/pull/29114
319
+ # TODO: use `next_tokens` directly instead.
320
+ model_inputs = {'input_ids': input_ids.contiguous()}
321
+
322
+ if logits_to_keep is not None:
323
+ model_inputs['logits_to_keep'] = logits_to_keep
324
+
325
+ model_inputs.update({
326
+ 'past_key_values': past_key_values,
327
+ 'use_cache': use_cache,
328
+ 'attention_mask': attention_mask,
329
+ })
330
+ return model_inputs
331
+
332
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ inputs_embeds: Optional[torch.Tensor] = None,
338
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ logits_to_keep: Optional[int] = 0,
345
+ **kwargs: Unpack[Dict]
346
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
347
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
348
+ output_hidden_states = (
349
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
350
+ )
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+
353
+ outputs = self.model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ inputs_embeds=inputs_embeds,
357
+ past_key_values=past_key_values,
358
+ use_cache=use_cache,
359
+ output_attentions=output_attentions,
360
+ output_hidden_states=output_hidden_states,
361
+ return_dict=return_dict,
362
+ **kwargs
363
+ )
364
+
365
+ hidden_states = outputs[0]
366
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
367
+
368
+ loss, logits = None, None
369
+ if not fuse_linear_and_cross_entropy or labels is None:
370
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
371
+ if labels is not None:
372
+ if getattr(self, 'criterion', None) is None:
373
+ if fuse_linear_and_cross_entropy:
374
+ criterion = FusedLinearCrossEntropyLoss()
375
+ elif self.config.fuse_cross_entropy:
376
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
377
+ else:
378
+ criterion = nn.CrossEntropyLoss()
379
+ else:
380
+ criterion = self.criterion
381
+ labels = labels.to(hidden_states.device)
382
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
383
+ if fuse_linear_and_cross_entropy:
384
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
385
+ else:
386
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
387
+
388
+ if not return_dict:
389
+ output = (logits,) + outputs[1:]
390
+ return (loss,) + output if loss is not None else output
391
+
392
+ return CausalLMOutputWithPast(
393
+ loss=loss,
394
+ logits=logits,
395
+ past_key_values=outputs.past_key_values,
396
+ hidden_states=outputs.hidden_states,
397
+ attentions=outputs.attentions,
398
+ )
fla/models/rwkv7/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
6
+ from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model
7
+
8
+ AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True)
9
+ AutoModel.register(RWKV7Config, RWKV7Model, True)
10
+ AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model']
fla/models/transformer_top/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer_top.configuration_transformer import TOPTransformerConfig
6
+ from fla.models.transformer_top.modeling_transformer import TOPTransformerForCausalLM, TOPTransformerModel
7
+
8
+ AutoConfig.register(TOPTransformerConfig.model_type, TOPTransformerConfig)
9
+ AutoModel.register(TOPTransformerConfig, TOPTransformerModel)
10
+ AutoModelForCausalLM.register(TOPTransformerConfig, TOPTransformerForCausalLM)
11
+
12
+
13
+ __all__ = ['TOPTransformerConfig', 'TOPTransformerForCausalLM', 'TOPTransformerModel']
fla/modules/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution
4
+ from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear
5
+ from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
6
+ from fla.modules.fused_kl_div import FusedKLDivLoss
7
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
8
+ from fla.modules.fused_linear_listnet_loss import FusedLinearListNetLoss
9
+ from fla.modules.fused_norm_gate import (
10
+ FusedLayerNormGated,
11
+ FusedLayerNormSwishGate,
12
+ FusedLayerNormSwishGateLinear,
13
+ FusedRMSNormGated,
14
+ FusedRMSNormSwishGate,
15
+ FusedRMSNormSwishGateLinear
16
+ )
17
+ from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear
18
+ from fla.modules.mlp import GatedMLP
19
+ from fla.modules.rotary import RotaryEmbedding
20
+
21
+ __all__ = [
22
+ 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
23
+ 'BitLinear', 'FusedBitLinear',
24
+ 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss',
25
+ 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
26
+ 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear',
27
+ 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
28
+ 'GatedMLP',
29
+ 'RotaryEmbedding'
30
+ ]
fla/modules/convolution.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+ from fla.modules.activations import ACT2FN
17
+ from fla.ops.common.utils import prepare_position_ids, prepare_sequence_ids
18
+ from fla.utils import checkpoint, input_guard
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn = None
24
+ causal_conv1d_update = None
25
+
26
+
27
+ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
28
+ seqlen = u.shape[-1]
29
+ fft_size = 2 * seqlen
30
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
31
+ if k_rev is not None:
32
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
33
+ k_f = k_f + k_rev_f.conj()
34
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
35
+
36
+ if len(u.shape) > 3:
37
+ k_f = k_f.unsqueeze(1)
38
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
39
+
40
+ out = y + u
41
+ if gelu:
42
+ out = F.gelu(out)
43
+ if dropout_mask is not None:
44
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
45
+ else:
46
+ return out.to(dtype=u.dtype)
47
+
48
+
49
+ @checkpoint
50
+ def proj_then_conv1d(
51
+ x: torch.Tensor,
52
+ proj_weight: torch.Tensor,
53
+ conv1d_weight: torch.Tensor,
54
+ conv1d_bias: Optional[torch.Tensor] = None,
55
+ cache: Optional[torch.Tensor] = None
56
+ ) -> torch.Tensor:
57
+ # We do matmul and transpose BLH -> HBL at the same time
58
+ x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2])
59
+
60
+ if causal_conv1d_fn is None:
61
+ raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
62
+ if cache is None:
63
+ x = causal_conv1d_fn(
64
+ x=x,
65
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
66
+ bias=conv1d_bias,
67
+ activation="silu",
68
+ ).transpose(1, 2)
69
+ else:
70
+ assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
71
+ x = x.squeeze(-1)
72
+ x = causal_conv1d_update(
73
+ x=x,
74
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
75
+ bias=conv1d_bias,
76
+ cache=cache,
77
+ activation="silu",
78
+ )
79
+ return x
80
+
81
+
82
+ @triton.jit
83
+ def causal_conv1d_varlen_states_fwd_kernel(
84
+ x,
85
+ cache,
86
+ offsets,
87
+ D,
88
+ W,
89
+ BD: tl.constexpr,
90
+ BW: tl.constexpr
91
+ ):
92
+ i_d, i_w, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ eos = tl.load(offsets + i_n + 1)
94
+ bos = tl.maximum(tl.load(offsets + i_n), eos - W)
95
+ o_t = eos - (i_w + 1) * BW + tl.arange(0, BW)
96
+ o_d = i_d * BD + tl.arange(0, BD)
97
+ o_w = W - (i_w + 1) * BW + tl.arange(0, BW)
98
+
99
+ b_x = tl.load(x + o_t * D + o_d[:, None], mask=(o_t >= bos) & (o_d[:, None] < D), other=0)
100
+ tl.store(cache + i_n * D*W + o_d[:, None] * W + o_w, b_x, mask=(o_d[:, None] < D) & (o_w >= 0))
101
+
102
+
103
+ @input_guard
104
+ def causal_conv1d_varlen_states_fwd(
105
+ x: torch.Tensor,
106
+ cache: torch.Tensor,
107
+ cu_seqlens: torch.Tensor,
108
+ state_len: int
109
+ ) -> torch.Tensor:
110
+ N, D, W = len(cu_seqlens) - 1, x.shape[-1], state_len
111
+ cache = torch.empty(N, D, W, dtype=x.dtype, device=x.device) if cache is None else cache
112
+ BD = min(triton.next_power_of_2(D), 256)
113
+ BW = min(triton.next_power_of_2(state_len), 16)
114
+ grid = (triton.cdiv(D, BD), triton.cdiv(W, BW), N)
115
+ with torch.cuda.device(x.device.index):
116
+ causal_conv1d_varlen_states_fwd_kernel[grid](
117
+ x=x,
118
+ cache=cache,
119
+ offsets=cu_seqlens,
120
+ D=D,
121
+ W=W,
122
+ BW=BW,
123
+ BD=BD
124
+ )
125
+ return cache
126
+
127
+
128
+ class ShortConvolution(nn.Conv1d):
129
+ """
130
+ Simple wrapper around `nn.Conv1d` that accepts dimension last.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size: int,
136
+ kernel_size: int,
137
+ bias: bool = False,
138
+ activation: Optional[str] = 'silu',
139
+ use_fast_conv1d: Optional[bool] = True,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ super().__init__(
144
+ in_channels=hidden_size,
145
+ out_channels=hidden_size,
146
+ kernel_size=kernel_size,
147
+ groups=hidden_size,
148
+ bias=bias,
149
+ padding=kernel_size - 1,
150
+ device=device,
151
+ dtype=dtype,
152
+ )
153
+
154
+ self.hidden_size = hidden_size
155
+ self.activation = None
156
+ if activation is not None:
157
+ assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
158
+ self.activation = activation
159
+
160
+ if causal_conv1d_fn is None:
161
+ if use_fast_conv1d:
162
+ raise RuntimeError(
163
+ "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel "
164
+ "or set `use_fast_conv1d` to False"
165
+ )
166
+ else:
167
+ warnings.warn(
168
+ "The naive Pytorch verison is very slow in practice, "
169
+ "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel",
170
+ category=ImportWarning
171
+ )
172
+ self.use_fast_conv1d = use_fast_conv1d
173
+
174
+ def extra_repr(self):
175
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
176
+ ', stride={stride}')
177
+ if self.padding != (0,) * len(self.padding):
178
+ s += ', padding={padding}'
179
+ if self.dilation != (1,) * len(self.dilation):
180
+ s += ', dilation={dilation}'
181
+ if self.output_padding != (0,) * len(self.output_padding):
182
+ s += ', output_padding={output_padding}'
183
+ if self.groups != 1:
184
+ s += ', groups={groups}'
185
+ if self.bias is None:
186
+ s += ', bias=False'
187
+ if self.padding_mode != 'zeros':
188
+ s += ', padding_mode={padding_mode}'
189
+ if self.activation is not None:
190
+ s += ', activation={activation}'
191
+ if not self.use_fast_conv1d:
192
+ s += ', use_fast_conv1d={use_fast_conv1d}'
193
+ return s.format(**self.__dict__)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ mask: Optional[torch.Tensor] = None,
199
+ cache: Optional[torch.Tensor] = None,
200
+ output_final_state: bool = False,
201
+ cu_seqlens: Optional[torch.LongTensor] = None,
202
+ **kwargs,
203
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """
205
+ Args:
206
+ x (`torch.Tensor`):
207
+ Tensor of shape `[B, T, D]`.
208
+ If `seq_idx` is provided, `B` must be 1.
209
+ mask (`Optional[torch.Tensor]`):
210
+ Attention mask dealing with padded positions.
211
+ cache (`Optional[torch.Tensor]`):
212
+ Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size.
213
+ If provided, the cache is updated **inplace**.
214
+ output_final_state (Optional[bool]):
215
+ Whether to output the final state of shape `[N, D, W]`. Default: `False`.
216
+ cu_seqlens (Optional[torch.LongTensor]):
217
+ Cumulative sequence lengths for each batch. Used for varlen. Default: `None`.
218
+ Shape: [B+1]
219
+
220
+ Returns:
221
+ Tensor of shape `[B, T, D]`.
222
+ """
223
+
224
+ B, T, D, W = *x.shape, self.kernel_size[0]
225
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
226
+ if mask is not None:
227
+ if cu_seqlens is not None:
228
+ raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time")
229
+ x = x.mul_(mask.unsqueeze(-1))
230
+ if output_final_state and cache is None:
231
+ cache = x.new_zeros(N, D, W)
232
+ # during the decoding phase, we assume the batch is composed of sequences of length 1
233
+ if cache is not None and B * T == N:
234
+ return self.step(x, cache, cu_seqlens)
235
+
236
+ if cache is not None:
237
+ if cu_seqlens is not None:
238
+ cache = causal_conv1d_varlen_states_fwd(x, cache, cu_seqlens, W)
239
+ else:
240
+ cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w'))
241
+
242
+ x = rearrange(x, 'b t d -> b d t')
243
+ if self.use_fast_conv1d:
244
+ # Sequence index for each token. Used for varlen.
245
+ # Suppose a batch consists of two sequences with lengths 3 and 4,
246
+ # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
247
+ # NOTE: No need to provide this arg if `cu_seqlens` is passed.
248
+ # This arg is just for BC, and will be removed in the future.
249
+ # [B, T]
250
+ seq_idx = kwargs.get('seq_idx', None)
251
+ if cu_seqlens is not None and seq_idx is None:
252
+ seq_idx = prepare_sequence_ids(prepare_position_ids(cu_seqlens)).to(torch.int32).unsqueeze(0)
253
+ x = causal_conv1d_fn(
254
+ x=x,
255
+ weight=rearrange(self.weight, "d 1 w -> d w"),
256
+ bias=self.bias,
257
+ activation=self.activation,
258
+ seq_idx=seq_idx,
259
+ )
260
+ else:
261
+ if cu_seqlens is not None:
262
+ raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version")
263
+ x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
264
+ if self.activation is not None:
265
+ x = ACT2FN[self.activation](x)
266
+ return rearrange(x, "b d t -> b t d"), cache
267
+
268
+ def step(
269
+ self,
270
+ x: torch.Tensor,
271
+ cache: torch.Tensor,
272
+ cu_seqlens: Optional[torch.LongTensor] = None
273
+ ):
274
+ shape = x.shape
275
+ x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
276
+ if self.use_fast_conv1d:
277
+ x = causal_conv1d_update(
278
+ x=x,
279
+ conv_state=cache,
280
+ weight=rearrange(self.weight, "d 1 w -> d w"),
281
+ bias=self.bias,
282
+ activation=self.activation,
283
+ )
284
+ else:
285
+ dtype = x.dtype
286
+ # we follow the fast mode that updates the cache in-place
287
+ cache.copy_(cache.roll(shifts=-1, dims=-1))
288
+ cache[:, :, -1] = x
289
+ x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
290
+ if self.bias is not None:
291
+ x = x + self.bias
292
+ if self.activation is not None:
293
+ x = ACT2FN[self.activation](x).to(dtype=dtype)
294
+ return x.view(shape), cache
295
+
296
+ @property
297
+ def state_size(self) -> int:
298
+ return self.hidden_size * self.kernel_size
299
+
300
+
301
+ class LongConvolution(nn.Module):
302
+ """
303
+ LongConvolution applies a convolution operation on the input tensor using a fixed
304
+ filter of length max_len.
305
+ The filter is learned during training and is applied using FFT convolution.
306
+ Args:
307
+ hidden_size (int): The number of expected features in the input and output.
308
+ max_len (int): The maximum sequence length.
309
+ Returns:
310
+ y: [batch_size, seq_len, hidden_size] tensor
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ hidden_size: int,
316
+ max_len: int,
317
+ **kwargs,
318
+ ):
319
+ """
320
+ Initializes the LongConvolution module.
321
+ Args:
322
+ hidden_size (int): The number of expected features in the input and output.
323
+ max_len (int): The maximum sequence length.
324
+ """
325
+ super().__init__()
326
+ self.hidden_size = hidden_size
327
+ self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True)
328
+
329
+ def forward(self, x: torch.Tensor, *args, **kwargs):
330
+ """
331
+ Applies the LongConvolution operation on the input tensor.
332
+ Args:
333
+ x: [batch_size, seq_len, hidden_size] tensor
334
+ Returns:
335
+ y: [batch_size, seq_len, hidden_size] tensor
336
+ """
337
+ x = x.transpose(1, 2)
338
+ y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
339
+ y = y.transpose(1, 2)
340
+ return y.to(dtype=x.dtype)
341
+
342
+
343
+ class PositionalEmbedding(nn.Module):
344
+ def __init__(self, emb_dim: int, seq_len: int, **kwargs):
345
+ """Complex exponential positional embeddings for implicit long convolution filters."""
346
+ super().__init__()
347
+
348
+ self.seq_len = seq_len
349
+ # The time embedding fed to the filteres is normalized so that t_f = 1
350
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
351
+
352
+ if emb_dim > 1:
353
+ bands = (emb_dim - 1) // 2
354
+ # To compute the right embeddings we use the "proper" linspace
355
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
356
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
357
+
358
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
359
+ z = torch.exp(-1j * f * w)
360
+ z = torch.cat([t, z.real, z.imag], dim=-1)
361
+ self.z = nn.Parameter(z, requires_grad=False)
362
+
363
+ def forward(self, L):
364
+ return self.z[:, :L]
365
+
366
+
367
+ class ImplicitLongConvolution(nn.Module):
368
+ """
369
+ Long convolution with implicit filter parameterized by an MLP.
370
+
371
+ Args:
372
+ hidden_size (int):
373
+ The number of expected features in the input and output.
374
+ max_len (int):
375
+ The maximum sequence length.
376
+ d_emb (Optional[int]):
377
+ The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
378
+ Defaults to 3.
379
+ d_hidden (Optional[int]):
380
+ The number of features in the hidden layer of the MLP. Defaults to 16.
381
+
382
+ Attributes:
383
+ pos_emb (`PositionalEmbedding`): The positional embedding layer.
384
+ mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
385
+
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ hidden_size: int,
391
+ max_len: int,
392
+ d_emb: int = 3,
393
+ d_hidden: int = 16,
394
+ **kwargs,
395
+ ):
396
+ """
397
+ Long convolution with implicit filter parameterized by an MLP.
398
+
399
+
400
+ """
401
+ super().__init__()
402
+ self.hidden_size = hidden_size
403
+ self.d_emb = d_emb
404
+
405
+ assert (
406
+ d_emb % 2 != 0 and d_emb >= 3
407
+ ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
408
+ self.pos_emb = PositionalEmbedding(d_emb, max_len)
409
+
410
+ # final linear layer
411
+ self.mlp = nn.Sequential(
412
+ nn.Linear(d_emb, d_hidden),
413
+ torch.nn.ReLU(),
414
+ nn.Linear(d_hidden, hidden_size),
415
+ )
416
+
417
+ def filter(self, seq_len: int, *args, **kwargs):
418
+ k = self.mlp(self.pos_emb(seq_len))
419
+
420
+ return k.transpose(1, 2)
421
+
422
+ def forward(self, x: torch.Tensor, *args, **kwargs):
423
+ """
424
+ Args:
425
+ x: [batch_size, seq_len, hidden_size] tensor
426
+ Returns:
427
+ y: [batch_size, seq_len, hidden_size] tensor
428
+ """
429
+ x = x.transpose(1, 2)
430
+ k = self.filter(x.shape[-1])
431
+ y = fft_conv(x, k, dropout_mask=None, gelu=False)
432
+
433
+ y = y.transpose(1, 2)
434
+ return y.to(dtype=x.dtype)
fla/modules/fused_linear_listnet_loss.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Code adapted from
4
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
16
+ from torch.distributed.tensor.parallel import ParallelStyle
17
+
18
+ from fla.ops.utils import logsumexp_fwd
19
+ from fla.ops.utils.op import exp
20
+ from fla.utils import input_guard
21
+
22
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
23
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
24
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
25
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
26
+ MAX_FUSED_SIZE = 65536 // 2
27
+
28
+ @triton.jit
29
+ def listnet_kernel(
30
+ logits,
31
+ targets, # Now full target distributions
32
+ lse_logits,
33
+ lse_targets,
34
+ loss,
35
+ total,
36
+ ignore_index,
37
+ logit_scale: tl.constexpr,
38
+ reduction: tl.constexpr,
39
+ V: tl.constexpr,
40
+ BV: tl.constexpr
41
+ ):
42
+ i_n = tl.program_id(0).to(tl.int64)
43
+ NV = tl.cdiv(V, BV)
44
+
45
+ # Pointers to current token's data
46
+ logits_ptr = logits + i_n * V
47
+ targets_ptr = targets + i_n * V
48
+ loss_ptr = loss + i_n
49
+
50
+ # Compute prediction softmax
51
+ b_lse_logits = tl.load(lse_logits + i_n)
52
+ b_lse_targets = tl.load(lse_targets + i_n)
53
+ b_loss = 0.0
54
+
55
+ # Compute gradient: softmax(pred) - softmax(target)
56
+ for iv in range(0, NV):
57
+ o_v = iv * BV + tl.arange(0, BV)
58
+ mask = o_v < V
59
+
60
+ # Load target and compute softmax
61
+ t_val = tl.load(targets_ptr + o_v, mask=mask, other=0.0)
62
+ p_target = tl.exp(t_val - b_lse_targets)
63
+
64
+ # Load logits and compute softmax
65
+ l_val = tl.load(logits_ptr + o_v, mask=mask, other=0.0) * logit_scale
66
+ l_val_minus_lse = l_val - b_lse_logits
67
+ p_pred = tl.exp(l_val_minus_lse)
68
+
69
+ # Gradient calculation
70
+ grad_val = p_pred - p_target
71
+ if reduction == "mean":
72
+ grad_val = grad_val / total
73
+ grad_val = tl.where(b_lse_targets == float('inf'), 0.0, grad_val)
74
+ tl.store(logits_ptr + o_v, grad_val, mask=mask)
75
+
76
+ # Cross-entropy loss
77
+ # instead of: b_loss -= tl.sum(p_target * tl.log(p_pred), axis=0)
78
+ b_loss -= tl.sum(p_target * l_val_minus_lse, axis=0)
79
+
80
+ tl.store(loss_ptr, b_loss)
81
+
82
+ @triton.jit
83
+ def elementwise_mul_kernel(
84
+ x,
85
+ g,
86
+ N: tl.constexpr,
87
+ B: tl.constexpr
88
+ ):
89
+ """
90
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
91
+ The multiplication is performed in-place on the tensor pointed by x.
92
+
93
+ Parameters:
94
+ x:
95
+ Pointer to the input tensor.
96
+ g:
97
+ Pointer to the gradient output value.
98
+ N (int):
99
+ The number of columns in the input tensor.
100
+ B (int):
101
+ The block size for Triton operations.
102
+ """
103
+
104
+ # Get the program ID and convert it to int64 to avoid overflow
105
+ i_x = tl.program_id(0).to(tl.int64)
106
+ o_x = i_x * B + tl.arange(0, B)
107
+
108
+ # Load the gradient output value
109
+ b_g = tl.load(g)
110
+ b_x = tl.load(x + o_x, mask=o_x < N)
111
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
112
+
113
+
114
+ def fused_linear_listnet_forward(
115
+ x: torch.Tensor,
116
+ targets: torch.Tensor, # Float tensor [N, V]
117
+ weight: torch.Tensor,
118
+ bias: torch.Tensor = None,
119
+ ignore_index: int = -100,
120
+ logit_scale: float = 1.0,
121
+ num_chunks: int = 8,
122
+ reduction: str = "mean"
123
+ ):
124
+ N, H, V = *x.shape, weight.shape[0]
125
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
126
+ NC = min(num_chunks, triton.cdiv(V, H))
127
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
128
+ NC = triton.cdiv(N, C)
129
+
130
+ # Initialize outputs
131
+ dx = torch.zeros_like(x)
132
+ dw = torch.zeros_like(weight, dtype=torch.float) if weight is not None else None
133
+ db = torch.zeros_like(bias, dtype=torch.float) if bias is not None else None
134
+ loss = torch.zeros(N, device=x.device, dtype=torch.float)
135
+ total = N # All tokens considered
136
+
137
+ for ic in range(NC):
138
+ start, end = ic * C, min((ic + 1) * C, N)
139
+ c_x = x[start:end]
140
+ c_logits = F.linear(c_x, weight, bias)
141
+ c_targets = targets[start:end]
142
+ c_lse_logits = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float)
143
+ c_lse_targets = logsumexp_fwd(c_targets, dtype=torch.float).nan_to_num(nan=float("inf"))
144
+ c_loss = loss[start:end]
145
+
146
+ # Call ListNet kernel
147
+ listnet_kernel[(c_logits.shape[0],)](
148
+ logits=c_logits,
149
+ targets=c_targets, # Full target distributions
150
+ lse_logits=c_lse_logits,
151
+ lse_targets=c_lse_targets,
152
+ loss=c_loss,
153
+ total=total,
154
+ ignore_index=ignore_index,
155
+ logit_scale=logit_scale,
156
+ reduction=reduction,
157
+ V=V,
158
+ BV=BV,
159
+ num_warps=32
160
+ )
161
+
162
+ # Backward through linear layer
163
+ dx[start:end] = torch.mm(c_logits, weight)
164
+ if weight is not None:
165
+ dw += c_logits.t() @ c_x
166
+ if bias is not None:
167
+ db += c_logits.sum(0)
168
+
169
+ loss = loss.sum()
170
+ if reduction == "mean":
171
+ loss = loss / total
172
+
173
+ return loss, dx, dw, db
174
+
175
+
176
+ def fused_linear_listnet_backward(
177
+ do: torch.Tensor,
178
+ dx: torch.Tensor,
179
+ dw: torch.Tensor,
180
+ db: torch.Tensor
181
+ ):
182
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
183
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
184
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
185
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
186
+ N, H = dx.shape
187
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
188
+
189
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
190
+ x=dx,
191
+ g=do,
192
+ N=N*H,
193
+ B=B,
194
+ num_warps=32,
195
+ )
196
+
197
+ # handle dw
198
+ if dw is not None:
199
+ V, H = dw.shape
200
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
201
+ x=dw,
202
+ g=do,
203
+ N=V*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ if db is not None:
209
+ V = db.shape[0]
210
+ elementwise_mul_kernel[(triton.cdiv(V, B),)](
211
+ x=db,
212
+ g=do,
213
+ N=V,
214
+ B=B,
215
+ num_warps=32,
216
+ )
217
+ return dx, dw, db
218
+
219
+
220
+ class FusedLinearListNetFunction(torch.autograd.Function):
221
+ @staticmethod
222
+ def forward(
223
+ ctx,
224
+ x: torch.Tensor,
225
+ targets: torch.Tensor, # Float targets
226
+ weight: torch.Tensor,
227
+ bias: torch.Tensor = None,
228
+ ignore_index: int = -100,
229
+ logit_scale: float = 1.0,
230
+ num_chunks: int = 8,
231
+ reduction: str = "mean"
232
+ ):
233
+ loss, dx, dw, db = fused_linear_listnet_forward(
234
+ x, targets, weight, bias, ignore_index,
235
+ logit_scale, num_chunks, reduction
236
+ )
237
+ ctx.save_for_backward(dx, dw, db)
238
+ return loss
239
+
240
+ @staticmethod
241
+ def backward(ctx, do):
242
+ dx, dw, db = ctx.saved_tensors
243
+ dx, dw, db = fused_linear_listnet_backward(do, dx, dw, db)
244
+ return dx, None, dw, db, None, None, None, None
245
+
246
+
247
+ def fused_linear_listnet_loss(
248
+ x: torch.Tensor,
249
+ target: torch.LongTensor,
250
+ weight: torch.Tensor,
251
+ bias: torch.Tensor = None,
252
+ ignore_index: int = -100,
253
+ label_smoothing: float = 0.0,
254
+ logit_scale: float = 1.0,
255
+ num_chunks: int = 8,
256
+ reduction: str = "mean"
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """
259
+ Args:
260
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
261
+ target (torch.LongTensor): [batch_size * seq_len]
262
+ where each value is in [0, vocab_size).
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ bias (Optional[torch.Tensor]): [vocab_size]
266
+ where `vocab_size` is the number of classes.
267
+ ignore_index: int.
268
+ If target == ignore_index, the loss is set to 0.0.
269
+ label_smoothing: float
270
+ logit_scale: float
271
+ A scaling factor applied to the logits. Default: 1.0
272
+ num_chunks: int
273
+ The number of chunks to split the input tensor into for processing.
274
+ This can help optimize memory usage and computation speed.
275
+ Default: 8
276
+ reduction:
277
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
278
+ 'mean': the weighted mean of the output is taken,
279
+ 'sum': the output will be summed.
280
+ Default: 'mean'.
281
+ Returns:
282
+ losses: [batch,], float
283
+ """
284
+ return FusedLinearListNetFunction.apply(
285
+ x,
286
+ target,
287
+ weight,
288
+ bias,
289
+ ignore_index,
290
+ logit_scale,
291
+ num_chunks,
292
+ reduction
293
+ )
294
+
295
+
296
+ class FusedLinearListNetLoss(nn.Module):
297
+
298
+ def __init__(
299
+ self,
300
+ ignore_index: int = -100,
301
+ label_smoothing: float = 0.0,
302
+ logit_scale: float = 1.0,
303
+ num_chunks: int = 8,
304
+ reduction: str = "mean"
305
+ ):
306
+ """
307
+ Args:
308
+ ignore_index: int.
309
+ If target == ignore_index, the loss is set to 0.0.
310
+ label_smoothing: float
311
+ logit_scale: float
312
+ A scaling factor applied to the logits. Default: 1.0
313
+ num_chunks: int
314
+ The number of chunks to split the input tensor into for processing.
315
+ This can help optimize memory usage and computation speed.
316
+ Default: 8
317
+ reduction:
318
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
319
+ 'mean': the weighted mean of the output is taken,
320
+ 'sum': the output will be summed.
321
+ Default: 'mean'.
322
+ """
323
+ super().__init__()
324
+
325
+ assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported"
326
+
327
+ self.ignore_index = ignore_index
328
+ self.label_smoothing = label_smoothing
329
+ self.logit_scale = logit_scale
330
+ self.num_chunks = num_chunks
331
+ self.reduction = reduction
332
+
333
+ @torch.compiler.disable
334
+ def forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ target: torch.LongTensor,
338
+ weight: torch.Tensor,
339
+ bias: Optional[torch.Tensor] = None
340
+ ):
341
+ """
342
+ Args:
343
+ x (torch.Tensor): [batch_size, seq_len, hidden_size]
344
+ target (torch.LongTensor): [batch_size, seq_len]
345
+ where each value is in [0, V).
346
+ weight (torch.Tensor): [vocab_size, hidden_size]
347
+ where `vocab_size` is the number of classes.
348
+ bias (Optional[torch.Tensor]): [vocab_size]
349
+ where `vocab_size` is the number of classes.
350
+ Returns:
351
+ loss
352
+ """
353
+ loss = fused_linear_listnet_loss(
354
+ x.view(-1, x.shape[-1]),
355
+ target.view(-1, target.shape[-1]),
356
+ weight=weight,
357
+ bias=bias,
358
+ ignore_index=self.ignore_index,
359
+ label_smoothing=self.label_smoothing,
360
+ logit_scale=self.logit_scale,
361
+ num_chunks=self.num_chunks,
362
+ reduction=self.reduction
363
+ )
364
+ return loss
365
+
366
+
367
+ class LinearLossParallel(ParallelStyle):
368
+ def __init__(
369
+ self,
370
+ *,
371
+ sequence_dim: int = 1,
372
+ use_local_output: bool = False,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.sequence_sharding = (Shard(sequence_dim),)
377
+ self.use_local_output = use_local_output
378
+
379
+ @staticmethod
380
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
381
+ x, target, weight, bias = inputs
382
+
383
+ if not isinstance(x, DTensor):
384
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
385
+ x = DTensor.from_local(x, device_mesh, sequence_sharding)
386
+ if x.placements != sequence_sharding:
387
+ x = x.redistribute(placements=sequence_sharding, async_op=True)
388
+ if not isinstance(target, DTensor):
389
+ target = DTensor.from_local(target, device_mesh, [Replicate()])
390
+ if target.placements != sequence_sharding:
391
+ target = target.redistribute(placements=sequence_sharding, async_op=True)
392
+
393
+ if not isinstance(weight, DTensor):
394
+ weight = DTensor.from_local(weight, device_mesh, [Replicate()])
395
+ if weight.placements != [Replicate()]:
396
+ # we replicate the weight/bias in FLCE
397
+ weight = weight.redistribute(placements=[Replicate()], async_op=True)
398
+
399
+ if bias is not None and not isinstance(bias, DTensor):
400
+ bias = DTensor.from_local(bias, device_mesh, [Replicate()])
401
+ if bias is not None and bias.placements != [Replicate()]:
402
+ bias = bias.redistribute(placements=[Replicate()], async_op=True)
403
+
404
+ return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias
405
+
406
+ @staticmethod
407
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
408
+ return outputs.to_local() if use_local_output else outputs
409
+
410
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
411
+ return distribute_module(
412
+ module,
413
+ device_mesh,
414
+ partition_fn=None,
415
+ input_fn=partial(self._prepare_input_fn, self.sequence_sharding),
416
+ output_fn=partial(self._prepare_output_fn, self.use_local_output)
417
+ )
418
+
419
+ # Naive ListNet loss function implementation
420
+ def list_net_loss(y_pred, y_true):
421
+ """
422
+ ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
423
+ :param y_pred: predictions from the model, shape [*, slate_length]
424
+ :param y_true: ground truth labels, shape [*, slate_length]
425
+ :return: loss value, a torch.Tensor
426
+ """
427
+ return torch.mean(-torch.sum(F.softmax(y_true, dim=-1).nan_to_num(nan=0) * F.log_softmax(y_pred, dim=-1), dim=-1))
fla/modules/grpo.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
4
+ """
5
+ # Get the per-token log probabilities for the completions for the model and the reference model
6
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
7
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
8
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
9
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
10
+
11
+ input_ids = input_ids[:, -logits_to_keep:]
12
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
13
+ # See https://github.com/huggingface/trl/issues/2770
14
+ logits = logits[:, -logits_to_keep:]
15
+ return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
16
+
17
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
18
+ if return_outputs:
19
+ raise ValueError("The GRPOTrainer does not support returning outputs")
20
+ # Compute the per-token log probabilities for the model
21
+
22
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
23
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
24
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
25
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
26
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
27
+
28
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
29
+
30
+ # Compute the KL divergence between the model and the reference model
31
+ ref_per_token_logps = inputs["ref_per_token_logps"]
32
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
33
+
34
+ # x - x.detach() allows for preserving gradients from x
35
+ advantages = inputs["advantages"]
36
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
37
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
38
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
39
+
40
+ # Log the metrics
41
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
42
+ self._metrics["completion_length"].append(completion_length)
43
+
44
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
45
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
46
+
47
+ return loss
48
+ """
49
+
50
+
51
+ import torch
52
+ import triton
53
+ import triton.language as tl
54
+
55
+ from fla.ops.utils.op import exp, log
56
+ from fla.utils import input_guard
57
+
58
+
59
+ @triton.autotune(
60
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
61
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
62
+ for NUM_WARPS in [8, 16, 32]
63
+ for NUM_STAGES in [1, 2, 4]
64
+ ], key=['B', 'N']
65
+ )
66
+ @triton.jit
67
+ def grpo_fwd_kernel(
68
+ logits_ptr,
69
+ ref_logp_ptr,
70
+ input_ids_ptr,
71
+ advantages_ptr,
72
+ completion_mask_ptr,
73
+ loss_ptr,
74
+ lse_ptr,
75
+ beta,
76
+ save_kl: tl.constexpr,
77
+ B,
78
+ M,
79
+ N,
80
+ L,
81
+ start_idx,
82
+ BLOCK_SIZE: tl.constexpr
83
+ ):
84
+ row_idx = tl.program_id(0)
85
+
86
+ off_b = row_idx // L
87
+ N = tl.cast(N, tl.int64)
88
+
89
+ loss_ptr += row_idx
90
+
91
+ completion_mask_ptr += row_idx
92
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
93
+ if not_skip == 1:
94
+ ref_logp_ptr += row_idx
95
+ lse_ptr += row_idx
96
+ advantages_ptr += off_b
97
+ logits_ptr += N * (row_idx + off_b)
98
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
99
+ base_cols = tl.arange(0, BLOCK_SIZE)
100
+
101
+ m_i = -float("inf")
102
+ l_i = 0.0
103
+ for start_n in tl.range(0, N, BLOCK_SIZE):
104
+ cols = start_n + base_cols
105
+ mask = cols < N
106
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
107
+ m_ij = tl.max(logits)
108
+ new_m_i = tl.maximum(m_i, m_ij)
109
+ l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i))
110
+ m_i = new_m_i
111
+ lse = log(l_i) + m_i
112
+
113
+ idx = tl.load(input_ids_ptr)
114
+ x = tl.load(logits_ptr+idx).to(tl.float32)
115
+ advantage = tl.load(advantages_ptr).to(tl.float32)
116
+ ref_logp = tl.load(ref_logp_ptr)
117
+ logp = x - lse
118
+ diff = ref_logp - logp
119
+ kl = exp(diff) - diff - 1
120
+ loss = kl * beta - advantage
121
+
122
+ tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty))
123
+ tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty))
124
+ if save_kl:
125
+ tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty))
126
+ else:
127
+ # store 0
128
+ tl.store(loss_ptr, 0.0)
129
+ if save_kl:
130
+ tl.store(loss_ptr+M, 0.0)
131
+
132
+
133
+ @triton.autotune(
134
+ [triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
135
+ for BLOCK_SIZE in [1024, 2048, 4096, 8192]
136
+ for NUM_WARPS in [8, 16, 32]
137
+ for NUM_STAGES in [1, 2, 4]
138
+ ], key=['B', 'N']
139
+ )
140
+ @triton.jit
141
+ def grpo_bwd_kernel(
142
+ dloss_ptr,
143
+ dlogits_ptr,
144
+ logits_ptr,
145
+ ref_logp_ptr,
146
+ input_ids_ptr,
147
+ advantages_ptr,
148
+ completion_mask_ptr,
149
+ lse_ptr,
150
+ beta,
151
+ B,
152
+ N,
153
+ L,
154
+ start_idx,
155
+ BLOCK_SIZE: tl.constexpr
156
+ ):
157
+
158
+ row_idx = tl.program_id(0) # B*L
159
+ off_b = row_idx // L
160
+
161
+ N = tl.cast(N, tl.int64)
162
+
163
+ dlogits_ptr += N * (row_idx + off_b)
164
+ base_cols = tl.arange(0, BLOCK_SIZE)
165
+ completion_mask_ptr += row_idx
166
+ not_skip = tl.load(completion_mask_ptr).to(tl.int1)
167
+
168
+ if not_skip == 1:
169
+ lse_ptr += row_idx
170
+ dloss_ptr += row_idx
171
+ advantages_ptr += off_b
172
+ ref_logp_ptr += row_idx
173
+ logits_ptr += N * (row_idx + off_b)
174
+ input_ids_ptr += row_idx + (off_b+1) * start_idx
175
+ dloss = tl.load(dloss_ptr).to(tl.float32)
176
+ lse = tl.load(lse_ptr).to(tl.float32)
177
+ idx = tl.load(input_ids_ptr)
178
+ x = tl.load(logits_ptr+idx).to(tl.float32)
179
+ advantage = tl.load(advantages_ptr).to(tl.float32)
180
+ ref_logp = tl.load(ref_logp_ptr)
181
+ logp = x - lse
182
+
183
+ dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1)
184
+ - advantage) * dloss
185
+
186
+ for start_n in tl.range(0, N, BLOCK_SIZE):
187
+ cols = start_n + base_cols
188
+ mask = cols < N
189
+ logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
190
+ probs = exp(logits - lse)
191
+ dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp
192
+
193
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
194
+ else:
195
+ dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
196
+ for start_n in tl.range(0, N, BLOCK_SIZE):
197
+ cols = start_n + base_cols
198
+ mask = cols < N
199
+
200
+ tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
201
+
202
+
203
+ class GrpoLoss(torch.autograd.Function):
204
+
205
+ @input_guard
206
+ @staticmethod
207
+ def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl):
208
+ ctx.input_shape = logits.shape
209
+ B, L_ADD_1, N = ctx.input_shape
210
+ L = L_ADD_1 - 1
211
+ M = B * L
212
+ input_ids_start_index = input_ids.size(1) - L
213
+
214
+ if not save_kl:
215
+ loss = torch.empty(B, L, device=logits.device, dtype=torch.float32)
216
+ else:
217
+ loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32)
218
+
219
+ lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)
220
+
221
+ if completion_mask is None:
222
+ completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32)
223
+ else:
224
+ loss[:B].masked_fill_(completion_mask.logical_not(), 0.0)
225
+
226
+ grpo_fwd_kernel[(M,)](
227
+ logits_ptr=logits,
228
+ ref_logp_ptr=ref_logp,
229
+ input_ids_ptr=input_ids,
230
+ advantages_ptr=advantages,
231
+ completion_mask_ptr=completion_mask,
232
+ loss_ptr=loss,
233
+ lse_ptr=lse,
234
+ beta=beta,
235
+ save_kl=save_kl,
236
+ B=B, M=M, N=N, L=L,
237
+ start_idx=input_ids_start_index,
238
+ )
239
+ ctx.beta = beta
240
+ ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)
241
+ ctx.ref_logp = ref_logp
242
+ return loss
243
+
244
+ @input_guard
245
+ @staticmethod
246
+ def backward(ctx, dloss):
247
+ # The grad of logits comes from two parts, the reward part and the kl part
248
+ lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors
249
+ B, L_ADD_1, N = ctx.input_shape
250
+ L = L_ADD_1 - 1
251
+ M = B * L
252
+
253
+ input_ids_start_index = input_ids.size(1) - L
254
+
255
+ dlogits = torch.empty_like(logits) # B, L_ADD_1, N
256
+
257
+ grpo_bwd_kernel[(M,)](
258
+ dloss_ptr=dloss,
259
+ dlogits_ptr=dlogits,
260
+ logits_ptr=logits,
261
+ ref_logp_ptr=ctx.ref_logp,
262
+ input_ids_ptr=input_ids,
263
+ advantages_ptr=advantages,
264
+ completion_mask_ptr=completion_mask,
265
+ lse_ptr=lse,
266
+ beta=ctx.beta,
267
+ B=B, N=N, L=L,
268
+ start_idx=input_ids_start_index,
269
+ )
270
+ # The last token in the completion is not used in the loss computation
271
+ # and therefore its gradient should be set to 0
272
+ dlogits[:, -1, :].fill_(0.0)
273
+ return dlogits.view(*ctx.input_shape), None, None, None, None, None, None
274
+
275
+
276
+ def fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False) -> torch.Tensor:
277
+ '''
278
+ compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)
279
+
280
+ Args:
281
+ logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]
282
+ ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]
283
+ input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids
284
+ advantages: Tensor, [B], the advantages of each prompt
285
+ beta: float, the weight of kl loss
286
+ completion_mask: Tensor, loss mask
287
+ save_kl: bool, if true will save kl
288
+
289
+ Retutn:
290
+ loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part
291
+
292
+ NOTE: logits(ref_logits) is computed by these steps
293
+ logits_to_keep = completion_ids.size(1)
294
+
295
+ def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):
296
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
297
+ logits = model(
298
+ input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
299
+ ).logits
300
+ return logits
301
+
302
+ logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)
303
+ '''
304
+ out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)
305
+ if not save_kl:
306
+ return out
307
+ else:
308
+ return out.chunk(2, axis=0)
309
+
310
+
311
+ def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):
312
+ def get_log_probs(logits, input_ids):
313
+ per_token_logps = []
314
+ for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
315
+ log_probs = logits_row.log_softmax(dim=-1)
316
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
317
+ per_token_logps.append(token_log_prob)
318
+ return torch.stack(per_token_logps)
319
+
320
+ logits = logits[:, :-1]
321
+ per_token_logps = get_log_probs(logits, input_ids)
322
+ ref_per_token_logps = ref_logp
323
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
324
+
325
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
326
+ per_token_loss = -(per_token_loss - beta * per_token_kl)
327
+ if completion_mask is not None:
328
+ per_token_loss *= completion_mask
329
+ if save_kl:
330
+ per_token_kl *= completion_mask
331
+ return per_token_loss if not save_kl else (per_token_loss, per_token_kl)
332
+
333
+
334
+ @torch.compile(fullgraph=True)
335
+ def grpo_loss_with_old_logps(
336
+ logps: torch.Tensor,
337
+ ref_logps: torch.Tensor,
338
+ old_logps: torch.Tensor,
339
+ pad_mask: torch.Tensor,
340
+ logits_to_keep: int,
341
+ rewards: torch.Tensor,
342
+ beta: float = 0.2,
343
+ epsilon: float = 0.2
344
+ ):
345
+ """
346
+ Compute the GRPO (Group Relative Policy Optimization) loss.
347
+
348
+ Args:
349
+ logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy.
350
+ ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy.
351
+ old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy.
352
+ completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool).
353
+ pad_token_id: Pad token ID.
354
+ logits_to_keep (int): Number of logits to keep for masking.
355
+ rewards (torch.Tensor): [Batch] Rewards for each generation.
356
+ beta (float) = 0.2: A hyperparameter for weighting the KL divergence term.
357
+ epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights.
358
+
359
+ Returns:
360
+ torch.Tensor: The computed GRPO loss.
361
+ """
362
+ B = logps.shape[0]
363
+ assert B > 1, "Batch * Num generations should be greater than 1"
364
+
365
+ rewards_shaped = rewards.view(-1, B) # B,num_generations
366
+ advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \
367
+ (rewards_shaped.std(dim=1, keepdim=True) + 1e-8)
368
+ advantages = advantages.view(-1) # B*num_generations
369
+ # Calculate the per - token KL divergence
370
+ per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1
371
+
372
+ # Calculate the ratio of probabilities (importance weights)
373
+ # Importance weights are calculated as exp(log_pi_theta - log_pi_theta_old)
374
+ importance_weights = torch.exp(logps - old_logps)
375
+
376
+ # Clip the importance weights to the range [1 - epsilon, 1 + epsilon]
377
+ importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon)
378
+
379
+ # Create a completion mask. It checks which positions are valid based on logits_to_keep
380
+ completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0
381
+
382
+ # Combine the completion mask and padding mask
383
+ completion_mask = completion_mask & pad_mask # Ensure matching shape
384
+
385
+ # Add an extra dimension to advantages to match the shape for element - wise multiplication
386
+ advantages = advantages.unsqueeze(1)
387
+
388
+ # Calculate the per - token loss. It takes the minimum of the unclipped and clipped importance weights
389
+ # and subtracts the KL divergence term weighted by beta, then multiplies by the completion mask
390
+ token_loss = -(torch.min(advantages * importance_weights, advantages *
391
+ importance_weights_clipped) - beta * per_token_kl) * completion_mask
392
+
393
+ # Calculate the final loss by summing the token losses and normalizing by the number of valid tokens
394
+ loss = -token_loss.sum() / completion_mask.sum()
395
+
396
+ return loss
fla/modules/layernorm.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+ # https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
5
+ # Implement residual + layer_norm / rms_norm.
6
+
7
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
8
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
9
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
10
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
11
+
12
+ from __future__ import annotations
13
+
14
+ from functools import partial
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from einops import rearrange
22
+ from torch.distributed import DeviceMesh
23
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
24
+ from torch.distributed.tensor.parallel import ParallelStyle
25
+
26
+ from fla.utils import get_multiprocessor_count, input_guard
27
+
28
+
29
+ def layer_norm_ref(
30
+ x: torch.Tensor,
31
+ weight: torch.Tensor,
32
+ bias: torch.Tensor,
33
+ residual: torch.Tensor = None,
34
+ eps: float = 1e-5,
35
+ prenorm: bool = False,
36
+ upcast: bool = False
37
+ ):
38
+ dtype = x.dtype
39
+ if upcast:
40
+ weight = weight.float()
41
+ bias = bias.float() if bias is not None else None
42
+ if upcast:
43
+ x = x.float()
44
+ residual = residual.float() if residual is not None else residual
45
+ if residual is not None:
46
+ x = (x + residual).to(x.dtype)
47
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
48
+ dtype
49
+ )
50
+ return out if not prenorm else (out, x)
51
+
52
+
53
+ def rms_norm_ref(
54
+ x: torch.Tensor,
55
+ weight: torch.Tensor,
56
+ bias: torch.Tensor,
57
+ residual: torch.Tensor = None,
58
+ eps: float = 1e-5,
59
+ prenorm: bool = False,
60
+ upcast: bool = False
61
+ ):
62
+ dtype = x.dtype
63
+ if upcast:
64
+ weight = weight.float()
65
+ bias = bias.float() if bias is not None else None
66
+ if upcast:
67
+ x = x.float()
68
+ residual = residual.float() if residual is not None else residual
69
+ if residual is not None:
70
+ x = (x + residual).to(x.dtype)
71
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
72
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
73
+ out = out.to(dtype)
74
+ return out if not prenorm else (out, x)
75
+
76
+
77
+ def group_norm_ref(
78
+ x: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor,
81
+ num_groups: int,
82
+ residual: torch.Tensor = None,
83
+ eps: float = 1e-5,
84
+ is_rms_norm: bool = False,
85
+ prenorm: bool = False,
86
+ upcast: bool = False
87
+ ):
88
+ dtype = x.dtype
89
+ if upcast:
90
+ weight = weight.float()
91
+ bias = bias.float() if bias is not None else None
92
+ if upcast:
93
+ x = x.float()
94
+ residual = residual.float() if residual is not None else residual
95
+ if residual is not None:
96
+ x = (x + residual).to(x.dtype)
97
+ residual = x
98
+ x, weight = [
99
+ rearrange(data, "... (g d) -> ... g d", g=num_groups) for data in (x, weight)
100
+ ]
101
+ if bias is not None:
102
+ bias = rearrange(bias, '... (g d) -> ... g d', g=num_groups)
103
+ if not is_rms_norm:
104
+ mean = x.mean(dim=-1, keepdim=True)
105
+ x = x - mean
106
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
107
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
108
+ out = rearrange(out, "... g d -> ... (g d)")
109
+ out = out.to(dtype)
110
+ return out if not prenorm else (out, residual)
111
+
112
+
113
+ class GroupNormRef(nn.Module):
114
+
115
+ def __init__(
116
+ self,
117
+ num_groups: int,
118
+ hidden_size: int,
119
+ elementwise_affine: bool = True,
120
+ bias: bool = False,
121
+ eps: float = 1e-5,
122
+ is_rms_norm: bool = False
123
+ ) -> GroupNormRef:
124
+ super().__init__()
125
+
126
+ if hidden_size % num_groups != 0:
127
+ raise ValueError('num_channels must be divisible by num_groups')
128
+
129
+ self.num_groups = num_groups
130
+ self.hidden_size = hidden_size
131
+ self.elementwise_affine = elementwise_affine
132
+ self.eps = eps
133
+ self.is_rms_norm = is_rms_norm
134
+
135
+ self.register_parameter("weight", None)
136
+ self.register_parameter("bias", None)
137
+ if elementwise_affine:
138
+ self.weight = nn.Parameter(torch.empty(hidden_size))
139
+ if bias:
140
+ self.bias = nn.Parameter(torch.empty(hidden_size))
141
+
142
+ self.reset_parameters()
143
+
144
+ def reset_parameters(self):
145
+ if self.elementwise_affine:
146
+ nn.init.ones_(self.weight)
147
+ if self.bias is not None:
148
+ nn.init.zeros_(self.bias)
149
+
150
+ def __repr__(self) -> str:
151
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
152
+ if not self.elementwise_affine:
153
+ s += f", elementwise_affine={self.elementwise_affine}"
154
+ if self.is_rms_norm:
155
+ s += f", is_rms_norm={self.is_rms_norm}"
156
+ s += f", eps={self.eps}"
157
+ s += ")"
158
+ return s
159
+
160
+ def forward(self, x, residual=None, prenorm=False):
161
+ return group_norm_ref(
162
+ x,
163
+ self.weight,
164
+ self.bias,
165
+ num_groups=self.num_groups,
166
+ residual=residual,
167
+ eps=self.eps,
168
+ is_rms_norm=self.is_rms_norm,
169
+ prenorm=prenorm,
170
+ upcast=True
171
+ )
172
+
173
+
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
177
+ for num_warps in [1, 2, 4, 8, 16, 32]
178
+ for num_stages in [2, 3, 4]
179
+ ],
180
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
181
+ )
182
+ @triton.jit
183
+ def layer_norm_fwd_kernel(
184
+ X, # pointer to the input
185
+ Y, # pointer to the output
186
+ W, # pointer to the weights
187
+ B, # pointer to the biases
188
+ RESIDUAL, # pointer to the residual
189
+ RESIDUAL_OUT, # pointer to the residual
190
+ Mean, # pointer to the mean
191
+ Rstd, # pointer to the 1/std
192
+ N, # number of columns in X
193
+ G, # number of groups
194
+ eps, # epsilon to avoid division by zero
195
+ IS_RMS_NORM: tl.constexpr,
196
+ BLOCK_N: tl.constexpr,
197
+ HAS_RESIDUAL: tl.constexpr,
198
+ STORE_RESIDUAL_OUT: tl.constexpr,
199
+ HAS_WEIGHT: tl.constexpr,
200
+ HAS_BIAS: tl.constexpr
201
+ ):
202
+ # Map the program id to the row of X and Y it should compute.
203
+ row = tl.program_id(0)
204
+ group = row % G
205
+ X += row * N
206
+ Y += row * N
207
+ if HAS_RESIDUAL:
208
+ RESIDUAL += row * N
209
+ if STORE_RESIDUAL_OUT:
210
+ RESIDUAL_OUT += row * N
211
+ # Compute mean and variance
212
+ cols = tl.arange(0, BLOCK_N)
213
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
214
+ if HAS_RESIDUAL:
215
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
216
+ x += residual
217
+ if STORE_RESIDUAL_OUT:
218
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
219
+ if not IS_RMS_NORM:
220
+ mean = tl.sum(x, axis=0) / N
221
+ tl.store(Mean + row, mean)
222
+ xbar = tl.where(cols < N, x - mean, 0.0)
223
+ var = tl.sum(xbar * xbar, axis=0) / N
224
+ else:
225
+ xbar = tl.where(cols < N, x, 0.0)
226
+ var = tl.sum(xbar * xbar, axis=0) / N
227
+ rstd = 1 / tl.sqrt(var + eps)
228
+ tl.store(Rstd + row, rstd)
229
+ # Normalize and apply linear transformation
230
+ mask = cols < N
231
+ if HAS_WEIGHT:
232
+ w = tl.load(W + group * N + cols, mask=mask).to(tl.float32)
233
+ if HAS_BIAS:
234
+ b = tl.load(B + group * N + cols, mask=mask).to(tl.float32)
235
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
236
+
237
+ y = tl.fma(x_hat, w, b) if HAS_WEIGHT and HAS_BIAS else \
238
+ x_hat * w if HAS_WEIGHT else \
239
+ x_hat + b if HAS_BIAS else x_hat
240
+ # Write output
241
+ y = tl.cast(y, dtype=Y.dtype.element_ty, fp_downcast_rounding="rtne")
242
+ tl.store(Y + cols, y, mask=mask)
243
+
244
+
245
+ def layer_norm_fwd(
246
+ x: torch.Tensor,
247
+ weight: torch.Tensor,
248
+ bias: torch.Tensor,
249
+ eps: float,
250
+ residual: torch.Tensor = None,
251
+ out_dtype: torch.dtype = None,
252
+ residual_dtype: torch.dtype = None,
253
+ is_rms_norm: bool = False,
254
+ num_groups: int = 1
255
+ ):
256
+ if residual is not None:
257
+ residual_dtype = residual.dtype
258
+ M, N, G = *x.shape, num_groups
259
+ if residual is not None:
260
+ assert residual.shape == (M, N)
261
+ if weight is not None:
262
+ assert weight.shape == (G * N,)
263
+ if bias is not None:
264
+ assert bias.shape == (G * N,)
265
+ # allocate output
266
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
267
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
268
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
269
+ else:
270
+ residual_out = None
271
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
272
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
273
+ # Less than 64KB per feature: enqueue fused kernel
274
+ MAX_FUSED_SIZE = 65536 // x.element_size()
275
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
276
+ if N > BLOCK_N:
277
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
278
+ # heuristics for number of warps
279
+ layer_norm_fwd_kernel[(M,)](
280
+ x,
281
+ y,
282
+ weight,
283
+ bias,
284
+ residual,
285
+ residual_out,
286
+ mean,
287
+ rstd,
288
+ N,
289
+ G,
290
+ eps,
291
+ is_rms_norm,
292
+ BLOCK_N,
293
+ residual is not None,
294
+ residual_out is not None,
295
+ weight is not None,
296
+ bias is not None,
297
+ )
298
+ # residual_out is None if residual is None and residual_dtype == input_dtype
299
+ return y, mean, rstd, residual_out if residual_out is not None else x
300
+
301
+
302
+ @triton.heuristics({
303
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None
304
+ })
305
+ @triton.autotune(
306
+ configs=[
307
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
308
+ for num_warps in [1, 2, 4, 8, 16, 32]
309
+ for num_stages in [2, 3, 4]
310
+ ],
311
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
312
+ )
313
+ @triton.jit
314
+ def layer_norm_bwd_kernel(
315
+ X, # pointer to the input
316
+ W, # pointer to the weights
317
+ B, # pointer to the biases
318
+ Y, # pointer to the output to be recomputed
319
+ DY, # pointer to the output gradient
320
+ DX, # pointer to the input gradient
321
+ DW, # pointer to the partial sum of weights gradient
322
+ DB, # pointer to the partial sum of biases gradient
323
+ DRESIDUAL,
324
+ DRESIDUAL_IN,
325
+ Mean, # pointer to the mean
326
+ Rstd, # pointer to the 1/std
327
+ M, # number of rows in X
328
+ N, # number of columns in X
329
+ G, # number of groups
330
+ rows_per_program,
331
+ programs_per_group,
332
+ IS_RMS_NORM: tl.constexpr,
333
+ BLOCK_N: tl.constexpr,
334
+ HAS_DRESIDUAL: tl.constexpr,
335
+ STORE_DRESIDUAL: tl.constexpr,
336
+ HAS_WEIGHT: tl.constexpr,
337
+ HAS_BIAS: tl.constexpr,
338
+ RECOMPUTE_OUTPUT: tl.constexpr,
339
+ ):
340
+ row_block_id = tl.program_id(0)
341
+ group_id, program_id_in_group = row_block_id // programs_per_group, row_block_id % programs_per_group
342
+
343
+ row_start = group_id + program_id_in_group * G * rows_per_program
344
+ row_end = min(row_start + G * rows_per_program, M)
345
+
346
+ cols = tl.arange(0, BLOCK_N)
347
+ mask = cols < N
348
+
349
+ if HAS_WEIGHT:
350
+ w = tl.load(W + group_id * N + cols, mask=mask).to(tl.float32)
351
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
352
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
353
+ b = tl.load(B + group_id * N + cols, mask=mask, other=0.0).to(tl.float32)
354
+ if HAS_BIAS:
355
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
356
+
357
+ for row in range(row_start, row_end, G):
358
+ # Load data to SRAM
359
+ x = tl.load(X + row * N + cols, mask=mask, other=0).to(tl.float32)
360
+ dy = tl.load(DY + row * N + cols, mask=mask, other=0).to(tl.float32)
361
+ if not IS_RMS_NORM:
362
+ mean = tl.load(Mean + row)
363
+ rstd = tl.load(Rstd + row)
364
+ # Compute dx
365
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
366
+ xhat = tl.where(mask, xhat, 0.0)
367
+ if RECOMPUTE_OUTPUT:
368
+ y = xhat * w if HAS_WEIGHT else xhat
369
+ if HAS_BIAS:
370
+ y = y + b
371
+ tl.store(Y + row * N + cols, y, mask=mask)
372
+ wdy = dy
373
+ if HAS_WEIGHT:
374
+ wdy = dy * w
375
+ dw += dy * xhat
376
+ if HAS_BIAS:
377
+ db += dy
378
+ if not IS_RMS_NORM:
379
+ c1 = tl.sum(xhat * wdy, axis=0) / N
380
+ c2 = tl.sum(wdy, axis=0) / N
381
+ dx = (wdy - (xhat * c1 + c2)) * rstd
382
+ else:
383
+ c1 = tl.sum(xhat * wdy, axis=0) / N
384
+ dx = (wdy - xhat * c1) * rstd
385
+ if HAS_DRESIDUAL:
386
+ dres = tl.load(DRESIDUAL + row * N + cols, mask=mask, other=0).to(tl.float32)
387
+ dx += dres
388
+ # Write dx
389
+ dx = tl.cast(dx, dtype=DX.dtype.element_ty, fp_downcast_rounding="rtne")
390
+ if STORE_DRESIDUAL:
391
+ tl.store(DRESIDUAL_IN + row * N + cols, dx, mask=mask)
392
+ tl.store(DX + row * N + cols, dx, mask=mask)
393
+
394
+ if HAS_WEIGHT:
395
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
396
+ if HAS_BIAS:
397
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
398
+
399
+
400
+ def layer_norm_bwd(
401
+ dy: torch.Tensor,
402
+ x: torch.Tensor,
403
+ weight: torch.Tensor,
404
+ bias: torch.Tensor,
405
+ eps: float,
406
+ mean: torch.Tensor,
407
+ rstd: torch.Tensor,
408
+ dresidual: torch.Tensor = None,
409
+ has_residual: bool = False,
410
+ is_rms_norm: bool = False,
411
+ x_dtype: torch.dtype = None,
412
+ recompute_output: bool = False,
413
+ num_groups: int = 1
414
+ ):
415
+ M, N, G = *x.shape, num_groups
416
+ assert dy.shape == (M, N)
417
+ if dresidual is not None:
418
+ assert dresidual.shape == (M, N)
419
+ if weight is not None:
420
+ assert weight.shape == (G * N,)
421
+ if bias is not None:
422
+ assert bias.shape == (G * N,)
423
+ # allocate output
424
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
425
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
426
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
427
+
428
+ # Less than 64KB per feature: enqueue fused kernel
429
+ MAX_FUSED_SIZE = 65536 // x.element_size()
430
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
431
+ if N > BLOCK_N:
432
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
433
+ # each program handles one group only
434
+ S = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G
435
+ dw = torch.empty((S, N), dtype=torch.float32, device=weight.device) if weight is not None else None
436
+ db = torch.empty((S, N), dtype=torch.float32, device=bias.device) if bias is not None else None
437
+ rows_per_program = triton.cdiv(M, S)
438
+ programs_per_group = S // G
439
+ grid = (S,)
440
+ layer_norm_bwd_kernel[grid](
441
+ x,
442
+ weight,
443
+ bias,
444
+ y,
445
+ dy,
446
+ dx,
447
+ dw,
448
+ db,
449
+ dresidual,
450
+ dresidual_in,
451
+ mean,
452
+ rstd,
453
+ M,
454
+ N,
455
+ G,
456
+ rows_per_program,
457
+ programs_per_group,
458
+ is_rms_norm,
459
+ BLOCK_N,
460
+ dresidual is not None,
461
+ dresidual_in is not None,
462
+ weight is not None,
463
+ bias is not None,
464
+ )
465
+ dw = dw.view(G, -1, N).sum(1).to(weight).view_as(weight) if weight is not None else None
466
+ db = db.view(G, -1, N).sum(1).to(bias).view_as(bias) if bias is not None else None
467
+ # Don't need to compute dresidual_in separately in this case
468
+ if has_residual and dx.dtype == x.dtype:
469
+ dresidual_in = dx
470
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
471
+
472
+
473
+ class LayerNormFunction(torch.autograd.Function):
474
+
475
+ @staticmethod
476
+ @input_guard
477
+ def forward(
478
+ ctx,
479
+ x,
480
+ weight,
481
+ bias,
482
+ residual=None,
483
+ eps=1e-5,
484
+ prenorm=False,
485
+ residual_in_fp32=False,
486
+ is_rms_norm=False,
487
+ num_groups=1
488
+ ):
489
+ x_shape_og = x.shape
490
+
491
+ if x.shape[-1] % num_groups != 0:
492
+ raise ValueError('num_channels must be divisible by num_groups')
493
+ # reshape input data into 2D tensor
494
+ x = x.reshape(-1, (x.shape[-1] // num_groups))
495
+ if residual is not None:
496
+ assert residual.shape == x_shape_og
497
+ residual = residual.reshape_as(x)
498
+ residual_dtype = (
499
+ residual.dtype
500
+ if residual is not None
501
+ else (torch.float32 if residual_in_fp32 else None)
502
+ )
503
+ y, mean, rstd, residual_out = layer_norm_fwd(
504
+ x,
505
+ weight,
506
+ bias,
507
+ eps,
508
+ residual,
509
+ residual_dtype=residual_dtype,
510
+ is_rms_norm=is_rms_norm,
511
+ num_groups=num_groups
512
+ )
513
+ ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
514
+ ctx.x_shape_og = x_shape_og
515
+ ctx.eps = eps
516
+ ctx.is_rms_norm = is_rms_norm
517
+ ctx.num_groups = num_groups
518
+ ctx.has_residual = residual is not None
519
+ ctx.prenorm = prenorm
520
+ ctx.x_dtype = x.dtype
521
+ y = y.reshape(x_shape_og)
522
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
523
+
524
+ @staticmethod
525
+ @input_guard
526
+ def backward(ctx, dy, *args):
527
+ x, weight, bias, mean, rstd = ctx.saved_tensors
528
+ dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups))
529
+ assert dy.shape == x.shape
530
+ if ctx.prenorm:
531
+ dresidual = args[0]
532
+ dresidual = dresidual.reshape(-1, x.shape[-1])
533
+ assert dresidual.shape == x.shape
534
+ else:
535
+ dresidual = None
536
+ dx, dw, db, dresidual_in = layer_norm_bwd(
537
+ dy,
538
+ x,
539
+ weight,
540
+ bias,
541
+ ctx.eps,
542
+ mean,
543
+ rstd,
544
+ dresidual,
545
+ ctx.has_residual,
546
+ ctx.is_rms_norm,
547
+ x_dtype=ctx.x_dtype,
548
+ num_groups=ctx.num_groups
549
+ )
550
+ return (
551
+ dx.reshape(ctx.x_shape_og),
552
+ dw,
553
+ db,
554
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
555
+ None,
556
+ None,
557
+ None,
558
+ None,
559
+ None
560
+ )
561
+
562
+
563
+ def layer_norm(
564
+ x: torch.Tensor,
565
+ weight: torch.Tensor,
566
+ bias: torch.Tensor,
567
+ residual: torch.Tensor = None,
568
+ eps: float = 1e-5,
569
+ prenorm: bool = False,
570
+ residual_in_fp32: bool = False,
571
+ is_rms_norm: bool = False
572
+ ):
573
+ return LayerNormFunction.apply(
574
+ x,
575
+ weight,
576
+ bias,
577
+ residual,
578
+ eps,
579
+ prenorm,
580
+ residual_in_fp32,
581
+ is_rms_norm
582
+ )
583
+
584
+
585
+ def group_norm(
586
+ x: torch.Tensor,
587
+ weight: torch.Tensor,
588
+ bias: torch.Tensor,
589
+ residual: torch.Tensor = None,
590
+ eps: float = 1e-5,
591
+ prenorm: bool = False,
592
+ residual_in_fp32: bool = False,
593
+ is_rms_norm: bool = False,
594
+ num_groups: int = 1
595
+ ):
596
+ return LayerNormFunction.apply(
597
+ x,
598
+ weight,
599
+ bias,
600
+ residual,
601
+ eps,
602
+ prenorm,
603
+ residual_in_fp32,
604
+ is_rms_norm,
605
+ num_groups
606
+ )
607
+
608
+
609
+ def rms_norm(
610
+ x: torch.Tensor,
611
+ weight: torch.Tensor,
612
+ bias: torch.Tensor,
613
+ residual: torch.Tensor = None,
614
+ eps: float = 1e-5,
615
+ prenorm: bool = False,
616
+ residual_in_fp32: bool = False
617
+ ):
618
+ return LayerNormFunction.apply(
619
+ x,
620
+ weight,
621
+ bias,
622
+ residual,
623
+ eps,
624
+ prenorm,
625
+ residual_in_fp32,
626
+ True
627
+ )
628
+
629
+
630
+ def layer_norm_linear(
631
+ x: torch.Tensor,
632
+ norm_weight: torch.Tensor,
633
+ norm_bias: torch.Tensor,
634
+ linear_weight: torch.Tensor,
635
+ linear_bias: torch.Tensor,
636
+ residual: torch.Tensor = None,
637
+ eps: float = 1e-5,
638
+ prenorm: bool = False,
639
+ residual_in_fp32: bool = False,
640
+ is_rms_norm: bool = False,
641
+ num_groups: int = 1
642
+ ):
643
+ return LayerNormLinearFunction.apply(
644
+ x,
645
+ norm_weight,
646
+ norm_bias,
647
+ linear_weight,
648
+ linear_bias,
649
+ residual,
650
+ eps,
651
+ prenorm,
652
+ residual_in_fp32,
653
+ is_rms_norm,
654
+ num_groups
655
+ )
656
+
657
+
658
+ def rms_norm_linear(
659
+ x: torch.Tensor,
660
+ norm_weight: torch.Tensor,
661
+ norm_bias: torch.Tensor,
662
+ linear_weight: torch.Tensor,
663
+ linear_bias: torch.Tensor,
664
+ residual: torch.Tensor = None,
665
+ eps: float = 1e-5,
666
+ prenorm: bool = False,
667
+ residual_in_fp32: bool = False
668
+ ):
669
+ return layer_norm_linear(
670
+ x=x,
671
+ norm_weight=norm_weight,
672
+ norm_bias=norm_bias,
673
+ linear_weight=linear_weight,
674
+ linear_bias=linear_bias,
675
+ residual=residual,
676
+ eps=eps,
677
+ prenorm=prenorm,
678
+ residual_in_fp32=residual_in_fp32,
679
+ is_rms_norm=True
680
+ )
681
+
682
+
683
+ def group_norm_linear(
684
+ x: torch.Tensor,
685
+ norm_weight: torch.Tensor,
686
+ norm_bias: torch.Tensor,
687
+ linear_weight: torch.Tensor,
688
+ linear_bias: torch.Tensor,
689
+ residual: torch.Tensor = None,
690
+ eps: float = 1e-5,
691
+ prenorm: bool = False,
692
+ residual_in_fp32: bool = False,
693
+ is_rms_norm: bool = False,
694
+ num_groups: int = 1
695
+ ):
696
+ return layer_norm_linear(
697
+ x=x,
698
+ norm_weight=norm_weight,
699
+ norm_bias=norm_bias,
700
+ linear_weight=linear_weight,
701
+ linear_bias=linear_bias,
702
+ residual=residual,
703
+ eps=eps,
704
+ prenorm=prenorm,
705
+ residual_in_fp32=residual_in_fp32,
706
+ is_rms_norm=is_rms_norm,
707
+ num_groups=num_groups
708
+ )
709
+
710
+
711
+ class LayerNorm(nn.Module):
712
+
713
+ def __init__(
714
+ self,
715
+ hidden_size: int,
716
+ elementwise_affine: bool = True,
717
+ bias: bool = False,
718
+ eps: float = 1e-5
719
+ ) -> LayerNorm:
720
+ super().__init__()
721
+
722
+ self.hidden_size = hidden_size
723
+ self.elementwise_affine = elementwise_affine
724
+ self.eps = eps
725
+
726
+ self.register_parameter("weight", None)
727
+ self.register_parameter("bias", None)
728
+ if elementwise_affine:
729
+ self.weight = nn.Parameter(torch.empty(hidden_size))
730
+ if bias:
731
+ self.bias = nn.Parameter(torch.empty(hidden_size))
732
+
733
+ self.reset_parameters()
734
+
735
+ def reset_parameters(self):
736
+ if self.elementwise_affine:
737
+ nn.init.ones_(self.weight)
738
+ if self.bias is not None:
739
+ nn.init.zeros_(self.bias)
740
+
741
+ def __repr__(self) -> str:
742
+ s = f"{self.__class__.__name__}({self.hidden_size}"
743
+ if not self.elementwise_affine:
744
+ s += f", elementwise_affine={self.elementwise_affine}"
745
+ s += f", eps={self.eps}"
746
+ s += ")"
747
+ return s
748
+
749
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
750
+ return layer_norm(
751
+ x,
752
+ self.weight,
753
+ self.bias,
754
+ residual=residual,
755
+ eps=self.eps,
756
+ prenorm=prenorm,
757
+ residual_in_fp32=residual_in_fp32
758
+ )
759
+
760
+
761
+ class GroupNorm(nn.Module):
762
+
763
+ def __init__(
764
+ self,
765
+ num_groups: int,
766
+ hidden_size: int,
767
+ elementwise_affine: bool = True,
768
+ bias: bool = False,
769
+ eps: float = 1e-5,
770
+ is_rms_norm: bool = False
771
+ ) -> GroupNorm:
772
+ super().__init__()
773
+
774
+ if hidden_size % num_groups != 0:
775
+ raise ValueError('num_channels must be divisible by num_groups')
776
+
777
+ self.num_groups = num_groups
778
+ self.hidden_size = hidden_size
779
+ self.elementwise_affine = elementwise_affine
780
+ self.eps = eps
781
+ self.is_rms_norm = is_rms_norm
782
+
783
+ self.register_parameter("weight", None)
784
+ self.register_parameter("bias", None)
785
+ if elementwise_affine:
786
+ self.weight = nn.Parameter(torch.empty(hidden_size))
787
+ if bias:
788
+ self.bias = nn.Parameter(torch.empty(hidden_size))
789
+
790
+ self.reset_parameters()
791
+
792
+ def reset_parameters(self):
793
+ if self.elementwise_affine:
794
+ nn.init.ones_(self.weight)
795
+ if self.bias is not None:
796
+ nn.init.zeros_(self.bias)
797
+
798
+ def __repr__(self) -> str:
799
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
800
+ if not self.elementwise_affine:
801
+ s += f", elementwise_affine={self.elementwise_affine}"
802
+ if self.is_rms_norm:
803
+ s += f", is_rms_norm={self.is_rms_norm}"
804
+ s += f", eps={self.eps}"
805
+ s += ")"
806
+ return s
807
+
808
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
809
+ return group_norm(
810
+ x,
811
+ self.weight,
812
+ self.bias,
813
+ residual=residual,
814
+ eps=self.eps,
815
+ prenorm=prenorm,
816
+ residual_in_fp32=residual_in_fp32,
817
+ is_rms_norm=self.is_rms_norm,
818
+ num_groups=self.num_groups
819
+ )
820
+
821
+
822
+ class RMSNorm(nn.Module):
823
+
824
+ def __init__(
825
+ self,
826
+ hidden_size: int,
827
+ elementwise_affine: bool = True,
828
+ bias: bool = False,
829
+ eps: float = 1e-5
830
+ ) -> RMSNorm:
831
+ super().__init__()
832
+
833
+ self.hidden_size = hidden_size
834
+ self.elementwise_affine = elementwise_affine
835
+ self.eps = eps
836
+
837
+ self.register_parameter("weight", None)
838
+ self.register_parameter("bias", None)
839
+ if elementwise_affine:
840
+ self.weight = nn.Parameter(torch.empty(hidden_size))
841
+ if bias:
842
+ self.bias = nn.Parameter(torch.empty(hidden_size))
843
+
844
+ self.reset_parameters()
845
+
846
+ def reset_parameters(self):
847
+ if self.elementwise_affine:
848
+ nn.init.ones_(self.weight)
849
+ if self.bias is not None:
850
+ nn.init.zeros_(self.bias)
851
+
852
+ def __repr__(self) -> str:
853
+ s = f"{self.__class__.__name__}({self.hidden_size}"
854
+ if not self.elementwise_affine:
855
+ s += f", elementwise_affine={self.elementwise_affine}"
856
+ s += f", eps={self.eps}"
857
+ s += ")"
858
+ return s
859
+
860
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
861
+ return rms_norm(
862
+ x,
863
+ self.weight,
864
+ self.bias,
865
+ residual=residual,
866
+ eps=self.eps,
867
+ prenorm=prenorm,
868
+ residual_in_fp32=residual_in_fp32,
869
+ )
870
+
871
+
872
+ class LayerNormLinearFunction(torch.autograd.Function):
873
+
874
+ @staticmethod
875
+ @input_guard
876
+ def forward(
877
+ ctx,
878
+ x,
879
+ norm_weight,
880
+ norm_bias,
881
+ linear_weight,
882
+ linear_bias,
883
+ residual=None,
884
+ eps=1e-5,
885
+ prenorm=False,
886
+ residual_in_fp32=False,
887
+ is_rms_norm=False,
888
+ num_groups=1
889
+ ):
890
+ x_shape_og = x.shape
891
+
892
+ if x.shape[-1] % num_groups != 0:
893
+ raise ValueError('num_channels must be divisible by num_groups')
894
+ # reshape input data into 2D tensor
895
+ x = x.reshape(-1, (x.shape[-1] // num_groups))
896
+ if residual is not None:
897
+ assert residual.shape == x_shape_og
898
+ residual = residual.reshape_as(x)
899
+ residual_dtype = (
900
+ residual.dtype
901
+ if residual is not None
902
+ else (torch.float32 if residual_in_fp32 else None)
903
+ )
904
+ y, mean, rstd, residual_out = layer_norm_fwd(
905
+ x,
906
+ norm_weight,
907
+ norm_bias,
908
+ eps,
909
+ residual,
910
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
911
+ residual_dtype=residual_dtype,
912
+ is_rms_norm=is_rms_norm,
913
+ num_groups=num_groups
914
+ )
915
+ y = y.reshape(x_shape_og)
916
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
917
+ linear_weight = linear_weight.to(dtype)
918
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
919
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
920
+ # We don't store y, will be recomputed in the backward pass to save memory
921
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
922
+ ctx.x_shape_og = x_shape_og
923
+ ctx.eps = eps
924
+ ctx.is_rms_norm = is_rms_norm
925
+ ctx.num_groups = num_groups
926
+ ctx.has_residual = residual is not None
927
+ ctx.prenorm = prenorm
928
+ ctx.x_dtype = x.dtype
929
+ ctx.linear_bias_is_none = linear_bias is None
930
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
931
+
932
+ @staticmethod
933
+ @input_guard
934
+ def backward(ctx, dout, *args):
935
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
936
+ dout = dout.reshape(-1, dout.shape[-1])
937
+ dy = F.linear(dout, linear_weight.t())
938
+ dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups))
939
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
940
+ assert dy.shape == x.shape
941
+ if ctx.prenorm:
942
+ dresidual = args[0]
943
+ dresidual = dresidual.reshape(-1, x.shape[-1])
944
+ assert dresidual.shape == x.shape
945
+ else:
946
+ dresidual = None
947
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd(
948
+ dy,
949
+ x,
950
+ norm_weight,
951
+ norm_bias,
952
+ ctx.eps,
953
+ mean,
954
+ rstd,
955
+ dresidual,
956
+ ctx.has_residual,
957
+ ctx.is_rms_norm,
958
+ x_dtype=ctx.x_dtype,
959
+ recompute_output=True,
960
+ num_groups=ctx.num_groups
961
+ )
962
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y.view(-1, linear_weight.shape[-1]))
963
+ return (
964
+ dx.reshape(ctx.x_shape_og),
965
+ dnorm_weight,
966
+ dnorm_bias,
967
+ dlinear_weight,
968
+ dlinear_bias,
969
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
970
+ None,
971
+ None,
972
+ None,
973
+ None,
974
+ None
975
+ )
976
+
977
+
978
+ class LayerNormLinear(nn.Module):
979
+
980
+ def __init__(
981
+ self,
982
+ hidden_size,
983
+ elementwise_affine: bool = True,
984
+ bias: bool = False,
985
+ eps: float = 1e-5
986
+ ) -> LayerNormLinear:
987
+ super().__init__()
988
+
989
+ self.hidden_size = hidden_size
990
+ self.elementwise_affine = elementwise_affine
991
+ self.eps = eps
992
+
993
+ self.register_parameter("weight", None)
994
+ self.register_parameter("bias", None)
995
+ if elementwise_affine:
996
+ self.weight = nn.Parameter(torch.empty(hidden_size))
997
+ if bias:
998
+ self.bias = nn.Parameter(torch.empty(hidden_size))
999
+
1000
+ self.reset_parameters()
1001
+
1002
+ def reset_parameters(self):
1003
+ if self.elementwise_affine:
1004
+ nn.init.ones_(self.weight)
1005
+ if self.bias is not None:
1006
+ nn.init.zeros_(self.bias)
1007
+
1008
+ def __repr__(self) -> str:
1009
+ s = f"{self.__class__.__name__}({self.hidden_size}"
1010
+ if not self.elementwise_affine:
1011
+ s += f", elementwise_affine={self.elementwise_affine}"
1012
+ s += f", eps={self.eps}"
1013
+ s += ")"
1014
+ return s
1015
+
1016
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1017
+ return layer_norm_linear(
1018
+ x=x,
1019
+ norm_weight=self.weight,
1020
+ norm_bias=self.bias,
1021
+ linear_weight=weight,
1022
+ linear_bias=bias,
1023
+ residual=residual,
1024
+ eps=self.eps,
1025
+ prenorm=prenorm,
1026
+ residual_in_fp32=residual_in_fp32,
1027
+ is_rms_norm=False
1028
+ )
1029
+
1030
+
1031
+ class GroupNormLinear(nn.Module):
1032
+
1033
+ def __init__(
1034
+ self,
1035
+ num_groups: int,
1036
+ hidden_size: int,
1037
+ elementwise_affine: bool = True,
1038
+ bias: bool = False,
1039
+ eps: float = 1e-5,
1040
+ is_rms_norm: bool = False
1041
+ ) -> GroupNormLinear:
1042
+ super().__init__()
1043
+
1044
+ if hidden_size % num_groups != 0:
1045
+ raise ValueError('num_channels must be divisible by num_groups')
1046
+
1047
+ self.num_groups = num_groups
1048
+ self.hidden_size = hidden_size
1049
+ self.elementwise_affine = elementwise_affine
1050
+ self.eps = eps
1051
+ self.is_rms_norm = is_rms_norm
1052
+
1053
+ self.register_parameter("weight", None)
1054
+ self.register_parameter("bias", None)
1055
+ if elementwise_affine:
1056
+ self.weight = nn.Parameter(torch.empty(hidden_size))
1057
+ if bias:
1058
+ self.bias = nn.Parameter(torch.empty(hidden_size))
1059
+
1060
+ self.reset_parameters()
1061
+
1062
+ def reset_parameters(self):
1063
+ if self.elementwise_affine:
1064
+ nn.init.ones_(self.weight)
1065
+ if self.bias is not None:
1066
+ nn.init.zeros_(self.bias)
1067
+
1068
+ def __repr__(self) -> str:
1069
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
1070
+ if not self.elementwise_affine:
1071
+ s += f", elementwise_affine={self.elementwise_affine}"
1072
+ if self.is_rms_norm:
1073
+ s += f", is_rms_norm={self.is_rms_norm}"
1074
+ s += f", eps={self.eps}"
1075
+ s += ")"
1076
+ return s
1077
+
1078
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1079
+ return layer_norm_linear(
1080
+ x=x,
1081
+ norm_weight=self.weight,
1082
+ norm_bias=self.bias,
1083
+ linear_weight=weight,
1084
+ linear_bias=bias,
1085
+ residual=residual,
1086
+ eps=self.eps,
1087
+ prenorm=prenorm,
1088
+ residual_in_fp32=residual_in_fp32,
1089
+ is_rms_norm=self.is_rms_norm,
1090
+ num_groups=self.num_groups
1091
+ )
1092
+
1093
+
1094
+ class RMSNormLinear(nn.Module):
1095
+
1096
+ def __init__(
1097
+ self,
1098
+ hidden_size,
1099
+ elementwise_affine: bool = True,
1100
+ bias: bool = False,
1101
+ eps: float = 1e-5
1102
+ ) -> RMSNormLinear:
1103
+ super().__init__()
1104
+
1105
+ self.hidden_size = hidden_size
1106
+ self.elementwise_affine = elementwise_affine
1107
+ self.eps = eps
1108
+
1109
+ self.register_parameter("weight", None)
1110
+ self.register_parameter("bias", None)
1111
+ if elementwise_affine:
1112
+ self.weight = nn.Parameter(torch.empty(hidden_size))
1113
+ if bias:
1114
+ self.bias = nn.Parameter(torch.empty(hidden_size))
1115
+
1116
+ self.reset_parameters()
1117
+
1118
+ def reset_parameters(self):
1119
+ if self.elementwise_affine:
1120
+ nn.init.ones_(self.weight)
1121
+ if self.bias is not None:
1122
+ nn.init.zeros_(self.bias)
1123
+
1124
+ def __repr__(self) -> str:
1125
+ s = f"{self.__class__.__name__}({self.hidden_size}"
1126
+ if not self.elementwise_affine:
1127
+ s += f", elementwise_affine={self.elementwise_affine}"
1128
+ s += f", eps={self.eps}"
1129
+ s += ")"
1130
+ return s
1131
+
1132
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1133
+ return layer_norm_linear(
1134
+ x=x,
1135
+ norm_weight=self.weight,
1136
+ norm_bias=self.bias,
1137
+ linear_weight=weight,
1138
+ linear_bias=bias,
1139
+ residual=residual,
1140
+ eps=self.eps,
1141
+ prenorm=prenorm,
1142
+ residual_in_fp32=residual_in_fp32,
1143
+ is_rms_norm=True
1144
+ )
1145
+
1146
+
1147
+ class NormParallel(ParallelStyle):
1148
+
1149
+ def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False):
1150
+ super().__init__()
1151
+ self.sequence_sharding = (Shard(sequence_dim),)
1152
+ self.use_local_output = use_local_output
1153
+
1154
+ def _replicate_module_fn(
1155
+ self, name: str, module: nn.Module, device_mesh: DeviceMesh
1156
+ ):
1157
+ for p_name, param in module.named_parameters():
1158
+ # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
1159
+ # us to simply just use from_local
1160
+ replicated_param = torch.nn.Parameter(
1161
+ DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
1162
+ )
1163
+ module.register_parameter(p_name, replicated_param)
1164
+
1165
+ @staticmethod
1166
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
1167
+ input_tensor = inputs[0]
1168
+ if isinstance(input_tensor, DTensor):
1169
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
1170
+ if input_tensor.placements != sequence_sharding:
1171
+ input_tensor = input_tensor.redistribute(
1172
+ placements=sequence_sharding, async_op=True
1173
+ )
1174
+ return input_tensor
1175
+ elif isinstance(input_tensor, torch.Tensor):
1176
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
1177
+ return DTensor.from_local(
1178
+ input_tensor, device_mesh, sequence_sharding, run_check=False
1179
+ )
1180
+ else:
1181
+ raise ValueError(
1182
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
1183
+ )
1184
+
1185
+ @staticmethod
1186
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
1187
+ return outputs.to_local() if use_local_output else outputs
1188
+
1189
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
1190
+ return distribute_module(
1191
+ module,
1192
+ device_mesh,
1193
+ self._replicate_module_fn,
1194
+ partial(self._prepare_input_fn, self.sequence_sharding),
1195
+ partial(self._prepare_output_fn, self.use_local_output),
1196
+ )
fla/modules/mlp.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from functools import partial
7
+ from typing import TYPE_CHECKING, Any, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.distributed import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_module
13
+ from torch.distributed.tensor.parallel import ParallelStyle
14
+
15
+ from fla.modules.activations import swiglu, swiglu_linear
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+
21
+ class GatedMLP(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ hidden_ratio: Optional[int] = None,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = 'swish',
29
+ fuse_swiglu: bool = True
30
+ ) -> GatedMLP:
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ # the final number of params is `hidden_ratio * hidden_size^2`
35
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
36
+ if hidden_ratio is None:
37
+ hidden_ratio = 4
38
+ if intermediate_size is None:
39
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
40
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
41
+ self.hidden_ratio = hidden_ratio
42
+ self.intermediate_size = intermediate_size
43
+ self.hidden_act = hidden_act
44
+ self.fuse_swiglu = fuse_swiglu
45
+
46
+ if hidden_act != 'swish':
47
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
48
+
49
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
50
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
51
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
52
+ if self.fuse_swiglu:
53
+ self.swiglu_linear = SwiGLULinear()
54
+
55
+ def forward(
56
+ self,
57
+ x: torch.Tensor,
58
+ **kwargs: Unpack[Any]
59
+ ) -> torch.Tensor:
60
+ gate, y = self.gate_proj(x), self.up_proj(x)
61
+ if self.fuse_swiglu:
62
+ return self.swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
63
+ else:
64
+ return self.down_proj(swiglu(gate, y))
65
+
66
+
67
+ class SwiGLULinear(nn.Module):
68
+
69
+ def forward(self, x, y, weight, bias):
70
+ return swiglu_linear(x, y, weight, bias)
71
+
72
+
73
+ class SwiGLULinearParallel(ParallelStyle):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ input_layouts: Optional[Placement] = None,
78
+ output_layouts: Optional[Placement] = None,
79
+ use_local_output: bool = True,
80
+ ):
81
+ super().__init__()
82
+ self.input_layouts = (input_layouts or Shard(-1),)
83
+ self.output_layouts = (output_layouts or Replicate(),)
84
+ self.desired_input_layouts = (Shard(-1),)
85
+ self.use_local_output = use_local_output
86
+
87
+ @staticmethod
88
+ def _prepare_input_fn(
89
+ input_layouts, desired_input_layouts, mod, inputs, device_mesh
90
+ ):
91
+ x, y, weight, bias = inputs
92
+ if not isinstance(x, DTensor):
93
+ x = DTensor.from_local(x, device_mesh, input_layouts, run_check=False)
94
+ if x.placements != desired_input_layouts:
95
+ x = x.redistribute(placements=desired_input_layouts, async_op=True)
96
+
97
+ if not isinstance(y, DTensor):
98
+ y = DTensor.from_local(y, device_mesh, input_layouts, run_check=False)
99
+ if y.placements != desired_input_layouts:
100
+ y = y.redistribute(placements=desired_input_layouts, async_op=True)
101
+
102
+ if not isinstance(weight, DTensor):
103
+ weight = DTensor.from_local(weight, device_mesh, (Shard(1),))
104
+
105
+ if bias is not None and not isinstance(bias, DTensor):
106
+ bias = DTensor.from_local(bias, device_mesh, (Replicate(),))
107
+
108
+ return x, y, weight, bias
109
+
110
+ @staticmethod
111
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
112
+ # Rowwise sharding produces partial output, depending on output layouts:
113
+ # 1. to replicate -> allreduce
114
+ # 2. to shard -> reduce_scatter
115
+ if outputs.placements != output_layouts:
116
+ outputs = outputs.redistribute(placements=output_layouts, async_op=True)
117
+ # back to local tensor if use_local_output is True
118
+ return outputs.to_local() if use_local_output else outputs
119
+
120
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
121
+ return distribute_module(
122
+ module,
123
+ device_mesh,
124
+ partition_fn=None,
125
+ input_fn=partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
126
+ output_fn=partial(self._prepare_output_fn, self.output_layouts, self.use_local_output)
127
+ )
fla/ops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
fla/ops/abc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+
5
+ __all__ = [
6
+ 'chunk_abc'
7
+ ]
fla/ops/abc/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (72 kB). View file
 
fla/ops/abc/chunk.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def chunk_abc_fwd_kernel_h(
17
+ k,
18
+ v,
19
+ z,
20
+ h,
21
+ h0,
22
+ ht,
23
+ T,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ NT: tl.constexpr,
30
+ NORMK: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr
33
+ ):
34
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+
36
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
37
+ if USE_INITIAL_STATE:
38
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
39
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
40
+ if NORMK:
41
+ p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,))
42
+ else:
43
+ p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,))
44
+ b_zp = tl.load(p_z0).to(tl.float32)
45
+ for i_t in range(NT):
46
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
47
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
48
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
49
+
50
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
51
+ # [BK, BT]
52
+ b_k = tl.load(p_k, boundary_check=(0, 1))
53
+ # [BT, BV]
54
+ b_v = tl.load(p_v, boundary_check=(0, 1))
55
+ if NORMK:
56
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
57
+ # [BK,]
58
+ b_zc = tl.load(p_zc, boundary_check=(0,))
59
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
60
+ # [BK, BV]
61
+ b_h = b_h * b_r[:, None]
62
+ b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype)
63
+ else:
64
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
65
+ # [BV,]
66
+ b_zc = tl.load(p_zc, boundary_check=(0,))
67
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
68
+ # [BK, BV]
69
+ b_h = b_h * b_r[None, :]
70
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
71
+ # [BK, BV]
72
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
73
+
74
+ if STORE_FINAL_STATE:
75
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+
79
+ @triton.jit(do_not_specialize=['T'])
80
+ def chunk_abc_fwd_kernel_intra_K(
81
+ v,
82
+ z,
83
+ o,
84
+ A,
85
+ T,
86
+ V: tl.constexpr,
87
+ BT: tl.constexpr,
88
+ BC: tl.constexpr,
89
+ BV: tl.constexpr,
90
+ NC: tl.constexpr
91
+ ):
92
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_t, i_i = i_c // NC, i_c % NC
94
+
95
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
97
+ # [BV,]
98
+ b_zn = tl.load(p_zn, boundary_check=(0,))
99
+ # [BC, BV]
100
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
101
+ for i_j in range(0, i_i):
102
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
103
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
104
+ # [BC, BV]
105
+ b_v = tl.load(p_v, boundary_check=(0, 1))
106
+ # [BC, BC]
107
+ b_A = tl.load(p_A, boundary_check=(0, 1))
108
+ b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)
109
+ b_z = tl.load(p_z, boundary_check=(0, 1))
110
+ b_o *= exp(b_zn[None, :] - b_z)
111
+
112
+ o_i = tl.arange(0, BC)
113
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
114
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
115
+ for j in range(0, BC):
116
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
117
+ # [BC,]
118
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
119
+ # [BV,]
120
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
121
+ # [BC, BV]
122
+ # avoid 0 * inf = inf
123
+ m_i = o_i[:, None] >= j
124
+ b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0)
125
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+
128
+
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def chunk_abc_fwd_kernel_K(
131
+ q,
132
+ k,
133
+ z,
134
+ h,
135
+ o,
136
+ A,
137
+ scale,
138
+ T,
139
+ K: tl.constexpr,
140
+ V: tl.constexpr,
141
+ BT: tl.constexpr,
142
+ BK: tl.constexpr,
143
+ BV: tl.constexpr,
144
+ NT: tl.constexpr
145
+ ):
146
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
147
+ i_p = tl.maximum(i_t * BT - 1, 0)
148
+
149
+ o_i = tl.arange(0, BT)
150
+ m_s = o_i[:, None] >= o_i[None, :]
151
+
152
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
153
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
154
+ for i_k in range(tl.cdiv(K, BK)):
155
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
157
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
158
+
159
+ # [BT, BK]
160
+ b_q = tl.load(p_q, boundary_check=(0, 1))
161
+ b_q = (b_q * scale).to(b_q.dtype)
162
+ # [BK, BT]
163
+ b_k = tl.load(p_k, boundary_check=(0, 1))
164
+ # [BK, BV]
165
+ b_h = tl.load(p_h, boundary_check=(0, 1))
166
+ # [BT, BV]
167
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
168
+ # [BT, BT]
169
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
170
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
171
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ # [BT, BV]
173
+ b_z = tl.load(p_z, boundary_check=(0, 1))
174
+ # [BT, BV]
175
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
176
+ b_zp = tl.load(p_zp, boundary_check=(0,))
177
+ b_o = b_o * exp(b_zp[None, :] - b_z)
178
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
179
+
180
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+ # [BT, BT]
182
+ b_A = tl.where(m_s, b_A, 0.)
183
+ if i_v == 0:
184
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
185
+
186
+
187
+ @triton.jit(do_not_specialize=['T'])
188
+ def chunk_abc_fwd_kernel_intra_V(
189
+ q,
190
+ k,
191
+ z,
192
+ A,
193
+ scale,
194
+ T,
195
+ K: tl.constexpr,
196
+ BT: tl.constexpr,
197
+ BC: tl.constexpr,
198
+ BK: tl.constexpr,
199
+ NC: tl.constexpr
200
+ ):
201
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
203
+ n_bh = tl.num_programs(2)
204
+
205
+ if i_i > i_j:
206
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
208
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
209
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
210
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
211
+ # [BK,]
212
+ b_zn = tl.load(p_zn, boundary_check=(0,))
213
+ # [BC, BK]
214
+ b_q = tl.load(p_q, boundary_check=(0, 1))
215
+ b_z = tl.load(p_z, boundary_check=(0, 1))
216
+ b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype)
217
+ # [BK, BC]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype)
220
+ # [BC, BC]
221
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
222
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
223
+ elif i_i == i_j:
224
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
225
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
226
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
227
+ # [BC, BK]
228
+ b_q = tl.load(p_q, boundary_check=(0, 1))
229
+ b_z = tl.load(p_z, boundary_check=(0, 1))
230
+
231
+ o_i = tl.arange(0, BC)
232
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
233
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
234
+ for j in range(0, BC):
235
+ # [BK,]
236
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
237
+ # [BC,]
238
+ b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1)
239
+ b_A = tl.where(o_i >= j, b_A, 0.)
240
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
241
+
242
+ p_k = tl.advance(p_k, (K,))
243
+
244
+
245
+ @triton.jit(do_not_specialize=['T'])
246
+ def chunk_abc_fwd_kernel_V(
247
+ q,
248
+ v,
249
+ z,
250
+ h,
251
+ o,
252
+ A,
253
+ scale,
254
+ T,
255
+ K: tl.constexpr,
256
+ V: tl.constexpr,
257
+ BT: tl.constexpr,
258
+ BK: tl.constexpr,
259
+ BV: tl.constexpr,
260
+ NT: tl.constexpr
261
+ ):
262
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
263
+ i_p = tl.maximum(i_t * BT - 1, 0)
264
+
265
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
266
+ for i_k in range(tl.cdiv(K, BK)):
267
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
268
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
270
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
271
+
272
+ # [BT, BK]
273
+ b_q = tl.load(p_q, boundary_check=(0, 1))
274
+ b_q = (b_q * scale).to(b_q.dtype)
275
+ # [BT, BK]
276
+ b_z = tl.load(p_z, boundary_check=(0, 1))
277
+ # [BT, BK]
278
+ b_zp = tl.load(p_zp, boundary_check=(0,))
279
+ b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype)
280
+ # [BK, BV]
281
+ b_h = tl.load(p_h, boundary_check=(0, 1))
282
+ # works but dkw, owing to divine benevolence
283
+ # [BT, BV]
284
+ if i_k >= 0:
285
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
286
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
287
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
288
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
289
+ # [BT, BV]
290
+ b_v = tl.load(p_v, boundary_check=(0, 1))
291
+ # [BT, BT]
292
+ b_A = tl.load(p_A, boundary_check=(0, 1))
293
+ b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False)
294
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
295
+
296
+
297
+ @triton.jit(do_not_specialize=['T'])
298
+ def chunk_abc_bwd_kernel_dh(
299
+ q,
300
+ z,
301
+ do,
302
+ dh,
303
+ scale,
304
+ T,
305
+ K: tl.constexpr,
306
+ V: tl.constexpr,
307
+ BT: tl.constexpr,
308
+ BK: tl.constexpr,
309
+ BV: tl.constexpr,
310
+ NT: tl.constexpr,
311
+ NORMK: tl.constexpr
312
+ ):
313
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
314
+
315
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
316
+ b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32)
317
+ for i_t in range(NT - 1, -1, -1):
318
+ i_p = tl.maximum(i_t * BT - 1, 0)
319
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
320
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
321
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+
323
+ # [BK, BT]
324
+ b_q = tl.load(p_q, boundary_check=(0, 1))
325
+ b_q = (b_q * scale).to(b_q.dtype)
326
+ # [BT, BV]
327
+ b_do = tl.load(p_do, boundary_check=(0, 1))
328
+
329
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
330
+ if NORMK:
331
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
333
+ # [BK,]
334
+ b_zc = tl.load(p_zc, boundary_check=(0,))
335
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
336
+ # [BK, BT]
337
+ b_z = tl.load(p_z, boundary_check=(0, 1))
338
+ b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype)
339
+ # [BK, BV]
340
+ b_dh = b_dh * b_r[:, None]
341
+ else:
342
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
343
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
344
+ # [BV,]
345
+ b_zc = tl.load(p_zc, boundary_check=(0,))
346
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
347
+ # [BT, BV]
348
+ b_z = tl.load(p_z, boundary_check=(0,))
349
+ b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype)
350
+ # [BK, BV]
351
+ b_dh = b_dh * b_r[None, :]
352
+ # [BK, BV]
353
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
354
+
355
+
356
+ @triton.jit(do_not_specialize=['T'])
357
+ def chunk_abc_bwd_kernel_V(
358
+ k,
359
+ v,
360
+ z,
361
+ h,
362
+ A,
363
+ do,
364
+ dh,
365
+ dq,
366
+ dk,
367
+ dv,
368
+ dA,
369
+ scale,
370
+ T,
371
+ K: tl.constexpr,
372
+ V: tl.constexpr,
373
+ BT: tl.constexpr,
374
+ BK: tl.constexpr,
375
+ BV: tl.constexpr,
376
+ NT: tl.constexpr
377
+ ):
378
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
379
+ i_p = tl.maximum(i_t * BT - 1, 0)
380
+ n_bh = tl.num_programs(2)
381
+
382
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
383
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
384
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
385
+
386
+ # [BK,]
387
+ b_zc = tl.load(p_zc, boundary_check=(0,))
388
+ # [BT, BK]
389
+ b_k = tl.load(p_k, boundary_check=(0, 1))
390
+ b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype)
391
+ # [BT, BT]
392
+ b_A = tl.load(p_A, boundary_check=(0, 1))
393
+
394
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
395
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
396
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
397
+ for i_v in range(tl.cdiv(V, BV)):
398
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
399
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
400
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
401
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
402
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
403
+
404
+ # [BT, BV]
405
+ b_v = tl.load(p_v, boundary_check=(0, 1))
406
+ # [BV, BK]
407
+ b_h = tl.load(p_h, boundary_check=(0, 1))
408
+ # [BT, BV]
409
+ b_do = tl.load(p_do, boundary_check=(0, 1))
410
+ # [BK, BV]
411
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
412
+
413
+ # [BT, BV]
414
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
415
+ if i_k == 0:
416
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False)
417
+ b_do = (b_do * scale).to(b_do.dtype)
418
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
419
+ # [BT, BT]
420
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
421
+ # [BT, BK]
422
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
423
+ # [BT, BK]
424
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
425
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
426
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
427
+ # [BK,]
428
+ b_zp = tl.load(p_zp, boundary_check=(0,))
429
+ # [BT, BK]
430
+ b_z = tl.load(p_z, boundary_check=(0, 1))
431
+ b_z = exp(b_zp[None, :] - b_z)
432
+ # [BT, BK]
433
+ b_dq = b_dq * b_z
434
+ b_dk = b_dk * b_k
435
+
436
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
437
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
438
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
439
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
440
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
441
+
442
+ o_i = tl.arange(0, BT)
443
+ m_s = o_i[:, None] >= o_i[None, :]
444
+ # [BT, BT]
445
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
446
+ if i_k == 0:
447
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
448
+
449
+
450
+ @triton.jit(do_not_specialize=['T'])
451
+ def chunk_abc_bwd_kernel_intra_V(
452
+ q,
453
+ k,
454
+ z,
455
+ dA,
456
+ dq,
457
+ dk,
458
+ T,
459
+ K: tl.constexpr,
460
+ BT: tl.constexpr,
461
+ BC: tl.constexpr,
462
+ BK: tl.constexpr,
463
+ NC: tl.constexpr
464
+ ):
465
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
466
+ i_t, i_i = i_c // NC, i_c % NC
467
+
468
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
469
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
470
+ # [BK,]
471
+ b_zn = tl.load(p_zn, boundary_check=(0,))
472
+ # [BC, BK]
473
+ b_z = tl.load(p_z, boundary_check=(0, 1))
474
+ b_zq = exp(b_zn[None, :] - b_z)
475
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
476
+ for i_j in range(0, i_i):
477
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
478
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
479
+ # [BC, BK]
480
+ b_k = tl.load(p_k, boundary_check=(0, 1))
481
+ b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype)
482
+ # [BC, BC]
483
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
484
+ # [BC, BK]
485
+ b_dq += tl.dot(b_dA, b_kz, allow_tf32=False)
486
+ b_dq *= b_zq
487
+
488
+ o_i = tl.arange(0, BC)
489
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
490
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
491
+ for j in range(0, BC):
492
+ p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
493
+ # [BC,]
494
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
495
+ # [BK,]
496
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
497
+ # [BC, BK]
498
+ m_i = o_i[:, None] >= j
499
+ # [BC, BK]
500
+ b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.)
501
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
502
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
503
+
504
+ tl.debug_barrier()
505
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
506
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
507
+ # [BK,]
508
+ b_zn = tl.load(p_zn, boundary_check=(0,))
509
+ # [BC, BK]
510
+ b_k = tl.load(p_k, boundary_check=(0, 1))
511
+ b_kz = exp(b_k - b_zn[None, :])
512
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
513
+ for i_j in range(i_i + 1, NC):
514
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
515
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
516
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
517
+ # [BC, BK]
518
+ b_q = tl.load(p_q, boundary_check=(0, 1))
519
+ b_z = tl.load(p_z, boundary_check=(0, 1))
520
+ b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype)
521
+ # [BC, BC]
522
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
523
+ # [BC, BK]
524
+ b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False)
525
+ b_dk *= b_kz
526
+
527
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
528
+ for j in range(0, BC):
529
+ p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
530
+ p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
531
+ # [BC,]
532
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
533
+ # [BK,]
534
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
535
+ b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32)
536
+ # [BC, BK]
537
+ m_i = o_i[:, None] <= j
538
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.)
539
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
540
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
541
+
542
+
543
+ @triton.jit(do_not_specialize=['T'])
544
+ def chunk_abc_bwd_kernel_intra_K(
545
+ v,
546
+ z,
547
+ do,
548
+ dA,
549
+ scale,
550
+ T,
551
+ V: tl.constexpr,
552
+ BT: tl.constexpr,
553
+ BC: tl.constexpr,
554
+ BV: tl.constexpr,
555
+ NC: tl.constexpr
556
+ ):
557
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
558
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
559
+ n_bh = tl.num_programs(2)
560
+
561
+ if i_i > i_j:
562
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
563
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
564
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
565
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
566
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
567
+ # [BV,]
568
+ b_zn = tl.load(p_zn, boundary_check=(0,))
569
+ # [BC, BV]
570
+ b_z = tl.load(p_z, boundary_check=(0, 1))
571
+ b_do = tl.load(p_do, boundary_check=(0, 1))
572
+ b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype)
573
+ # [BV, BC]
574
+ b_v = tl.load(p_v, boundary_check=(0, 1))
575
+ b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype)
576
+ # [BC, BC]
577
+ b_dA = tl.dot(b_do, b_v, allow_tf32=False)
578
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
579
+ elif i_i == i_j:
580
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
581
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
582
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
583
+ # [BC, BV]
584
+ b_z = tl.load(p_z, boundary_check=(0, 1))
585
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
586
+
587
+ o_i = tl.arange(0, BC)
588
+ o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
589
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
590
+ for j in range(0, BC):
591
+ # [BV,]
592
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
593
+ # [BC,]
594
+ b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1)
595
+ b_dA = tl.where(o_i >= j, b_dA, 0)
596
+ tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A)
597
+
598
+ p_v = tl.advance(p_v, (V,))
599
+
600
+
601
+ @triton.jit(do_not_specialize=['T'])
602
+ def chunk_abc_bwd_kernel_K(
603
+ q,
604
+ k,
605
+ v,
606
+ z,
607
+ h,
608
+ A,
609
+ do,
610
+ dh,
611
+ dq,
612
+ dk,
613
+ dv,
614
+ dA,
615
+ scale,
616
+ T,
617
+ K: tl.constexpr,
618
+ V: tl.constexpr,
619
+ BT: tl.constexpr,
620
+ BK: tl.constexpr,
621
+ BV: tl.constexpr,
622
+ NT: tl.constexpr
623
+ ):
624
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
625
+ i_p = tl.maximum(i_t * BT - 1, 0)
626
+ n_bh = tl.num_programs(2)
627
+
628
+ o_i = tl.arange(0, BT)
629
+ m_s = o_i[:, None] >= o_i[None, :]
630
+
631
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
632
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
633
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
634
+
635
+ # [BT, BK]
636
+ b_q = tl.load(p_q, boundary_check=(0, 1))
637
+ b_k = tl.load(p_k, boundary_check=(0, 1))
638
+ # [BT, BT]
639
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
640
+ b_A = tl.where(m_s, b_A, 0.)
641
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
642
+
643
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
644
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
645
+ for i_v in range(tl.cdiv(V, BV)):
646
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
647
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
648
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
649
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
650
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
651
+
652
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
653
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
654
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
655
+
656
+ # [BV,]
657
+ b_zp = tl.load(p_zp, boundary_check=(0,))
658
+ b_zc = tl.load(p_zc, boundary_check=(0,))
659
+ # [BT, BV]
660
+ b_v = tl.load(p_v, boundary_check=(0, 1))
661
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
662
+ b_z = tl.load(p_z, boundary_check=(0, 1))
663
+ b_z = exp(b_zp[None, :] - b_z)
664
+ # [BV, BK]
665
+ b_h = tl.load(p_h, boundary_check=(0, 1))
666
+ # [BT, BV]
667
+ b_do = tl.load(p_do, boundary_check=(0, 1))
668
+ b_do = (b_do * b_z * scale).to(b_do.dtype)
669
+ # [BK, BV]
670
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
671
+
672
+ # [BT, BK]
673
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
674
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
675
+ # [BT, BV]
676
+ b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False)
677
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
678
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
679
+ # [BT, BT]
680
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
681
+ # [BT, BK]
682
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
683
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
684
+
685
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
686
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
687
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
688
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
689
+
690
+
691
+ @triton.jit(do_not_specialize=['T'])
692
+ def chunk_abc_bwd_kernel_intra_KV(
693
+ v,
694
+ z,
695
+ A,
696
+ do,
697
+ dv,
698
+ T,
699
+ V: tl.constexpr,
700
+ BT: tl.constexpr,
701
+ BC: tl.constexpr,
702
+ BV: tl.constexpr,
703
+ NC: tl.constexpr
704
+ ):
705
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
706
+ i_t, i_i = i_c // NC, i_c % NC
707
+
708
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
709
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
710
+ # [BV,]
711
+ b_zn = tl.load(p_zn, boundary_check=(0,))
712
+ # [BC, BV]
713
+ b_v = tl.load(p_v, boundary_check=(0, 1))
714
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
715
+ for i_j in range(i_i + 1, NC):
716
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
717
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
718
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
719
+ # [BC, BV]
720
+ b_z = tl.load(p_z, boundary_check=(0, 1))
721
+ b_do = tl.load(p_do, boundary_check=(0, 1))
722
+ b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype)
723
+ # [BC, BC]
724
+ b_A = tl.load(p_A, boundary_check=(0, 1))
725
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
726
+ b_dv *= exp(b_v - b_zn[None, :])
727
+
728
+ o_i = tl.arange(0, BC)
729
+ for j in range(0, BC):
730
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
731
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
732
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
733
+ # [BC,]
734
+ b_A = tl.load(p_A, boundary_check=(0,))
735
+ # [BV,]
736
+ b_z = tl.load(p_z, boundary_check=(0,))
737
+ b_do = tl.load(p_do, boundary_check=(0,))
738
+ # [BC, BV]
739
+ m_i = o_i[:, None] <= j
740
+ b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.)
741
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
742
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
743
+
744
+
745
+ @triton.jit(do_not_specialize=['T'])
746
+ def chunk_abc_bwd_kernel_rcum_inter(
747
+ s,
748
+ z,
749
+ ss,
750
+ doo,
751
+ T,
752
+ S: tl.constexpr,
753
+ BT: tl.constexpr,
754
+ BS: tl.constexpr,
755
+ NT: tl.constexpr
756
+ ):
757
+ i_m, i_bh = tl.program_id(0), tl.program_id(1)
758
+
759
+ b_sp = tl.zeros([BS,], dtype=tl.float32)
760
+ b_zp = tl.full([BS,], float('inf'), dtype=tl.float32)
761
+ for i_t in range(NT - 1, -1, -1):
762
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
763
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
764
+ p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,))
765
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
766
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
767
+ # [BS,]
768
+ b_zc = tl.load(p_zc, boundary_check=(0,))
769
+ # [BT, BS]
770
+ b_s = tl.load(p_s, boundary_check=(0, 1))
771
+ b_z = tl.load(p_z, boundary_check=(0, 1))
772
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
773
+
774
+ b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :]
775
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
776
+ # [BS,]
777
+ b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0)
778
+ b_zp = b_zc
779
+
780
+
781
+ @triton.jit(do_not_specialize=['T'])
782
+ def chunk_abc_bwd_kernel_rcum_intra(
783
+ s,
784
+ z,
785
+ ss,
786
+ doo,
787
+ T,
788
+ S: tl.constexpr,
789
+ BT: tl.constexpr,
790
+ BC: tl.constexpr,
791
+ BS: tl.constexpr,
792
+ NC: tl.constexpr
793
+ ):
794
+ i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
795
+ i_t, i_i = i_c // NC, i_c % NC
796
+
797
+ o_i = tl.arange(0, BC)
798
+ m_o = tl.full([BC, BC], 1., dtype=tl.float32)
799
+
800
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
801
+ p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,))
802
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
803
+ # [BC, BS]
804
+ b_s = tl.load(p_s, boundary_check=(0, 1))
805
+ # [BS,]
806
+ b_zn = tl.load(p_zn, boundary_check=(0,))
807
+
808
+ b_doo = tl.zeros([BC, BS], dtype=tl.float32)
809
+ for i_j in range(i_i + 1, NC):
810
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
811
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
812
+ # [BC, BS]
813
+ b_z = tl.load(p_z, boundary_check=(0, 1))
814
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
815
+ # [BC, BS]
816
+ b_doo += b_ss * exp(b_zn[None, :] - b_z)
817
+ b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False)
818
+
819
+ for j in range(0, BC):
820
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
821
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
822
+ # [BS,]
823
+ b_z = tl.load(p_z, boundary_check=(0,))
824
+ b_ss = tl.load(p_ss, boundary_check=(0,))
825
+ # [BC, BS]
826
+ m_i = o_i[:, None] <= j
827
+ b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.)
828
+ b_doo += tl.load(p_doo, boundary_check=(0, 1))
829
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
830
+
831
+
832
+ class ChunkABCFunction(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ @input_guard
836
+ def forward(ctx, q, k, v, s, initial_state, output_final_state):
837
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
838
+ BT, BC = 64, 16
839
+ BK = min(64, triton.next_power_of_2(K))
840
+ BV = min(64, triton.next_power_of_2(V))
841
+ BM = min(64, triton.next_power_of_2(M))
842
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
843
+ NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)
844
+ num_warps = 4 if BK == 64 else 2
845
+ num_stages = 1
846
+
847
+ def fwd_pre(s, B, H, T, S):
848
+ # keep cummulative normalizer in fp32
849
+ z = torch.empty_like(s, dtype=torch.float)
850
+ grid = (B * H,)
851
+ logcumsumexp_fwd_kernel[grid](
852
+ s, z,
853
+ T=T, S=S
854
+ )
855
+ return z
856
+
857
+ def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):
858
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
859
+ h = q.new_empty(B, H, NT * K, V)
860
+ grid = (NV, NK, B * H)
861
+ chunk_abc_fwd_kernel_h[grid](
862
+ k, v, z, h, h0, ht,
863
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
864
+ NORMK=normk,
865
+ USE_INITIAL_STATE=h0 is not None,
866
+ STORE_FINAL_STATE=ht is not None,
867
+ num_warps=num_warps,
868
+ num_stages=num_stages
869
+ )
870
+ return h
871
+
872
+ final_state = None
873
+ if output_final_state:
874
+ final_state = (q.new_empty(B, H, K, M, dtype=torch.float),
875
+ q.new_empty(B, H, M, V, dtype=torch.float))
876
+
877
+ z = fwd_pre(s, B, H, T, M)
878
+ scale = K ** -0.5
879
+ hk = fwd_inner(
880
+ q=q, k=k, v=s, z=z,
881
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
882
+ normk=False,
883
+ h0=initial_state[0] if initial_state is not None else None,
884
+ ht=final_state[0] if final_state is not None else None
885
+ )
886
+ ok1 = torch.empty_like(s)
887
+ Ak = q.new_empty(B, H, T, BT)
888
+ grid = (NM, NT, B * H)
889
+ chunk_abc_fwd_kernel_K[grid](
890
+ q, k, z, hk, ok1, Ak,
891
+ scale=scale,
892
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
893
+ num_warps=num_warps,
894
+ num_stages=num_stages
895
+ )
896
+ ok0 = torch.empty_like(s)
897
+ grid = (NM, NT * NC, B * H)
898
+ chunk_abc_fwd_kernel_intra_K[grid](
899
+ s, z, ok0, Ak,
900
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
901
+ num_warps=2,
902
+ num_stages=num_stages
903
+ )
904
+ ok = ok0.add_(ok1)
905
+
906
+ scale = 1.
907
+ # p is kept in fp32 for safe softmax backward
908
+ p = softmax_fwd(ok, dtype=torch.float)
909
+ qv = p.to(q.dtype)
910
+
911
+ scale = 1.
912
+ hv = fwd_inner(
913
+ q=qv, k=s, v=v, z=z,
914
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
915
+ normk=True,
916
+ h0=initial_state[1] if initial_state is not None else None,
917
+ ht=final_state[1] if final_state is not None else None
918
+ )
919
+ Av = q.new_zeros(NM, B, H, T, BT)
920
+ grid = (NM, NT * NC * NC, B * H)
921
+ chunk_abc_fwd_kernel_intra_V[grid](
922
+ qv, s, z, Av,
923
+ scale=scale,
924
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
925
+ num_warps=2,
926
+ num_stages=num_stages
927
+ )
928
+ Av = Av.sum(0)
929
+ ov = torch.empty_like(v)
930
+ grid = (NV, NT, B * H)
931
+ chunk_abc_fwd_kernel_V[grid](
932
+ qv, v, z, hv, ov, Av,
933
+ scale=scale,
934
+ T=T,
935
+ K=M,
936
+ V=V,
937
+ BT=BT,
938
+ BK=BM,
939
+ BV=BV,
940
+ NT=NT,
941
+ num_warps=num_warps,
942
+ num_stages=num_stages
943
+ )
944
+ ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av)
945
+ ctx.BT = BT
946
+ return ov, final_state
947
+
948
+ @staticmethod
949
+ @input_guard
950
+ def backward(ctx, dov, dht=None):
951
+ q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors
952
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
953
+ BT, BC = ctx.BT, 16
954
+ BK = min(64, triton.next_power_of_2(K))
955
+ BV = min(64, triton.next_power_of_2(V))
956
+ BM = min(64, triton.next_power_of_2(M))
957
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
958
+ NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM)
959
+ num_warps = 4 if BK == 64 else 2
960
+ num_stages = 1
961
+
962
+ def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False):
963
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
964
+ dh = q.new_empty(B, H, NT * K, V)
965
+ grid = (NK, NV, B * H)
966
+ chunk_abc_bwd_kernel_dh[grid](
967
+ q, z, do, dh,
968
+ scale=scale,
969
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
970
+ NORMK=normk,
971
+ num_warps=num_warps,
972
+ num_stages=num_stages
973
+ )
974
+ return dh
975
+
976
+ def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS):
977
+ doo = torch.empty_like(s)
978
+ grid = (NS, B * H)
979
+ chunk_abc_bwd_kernel_rcum_inter[grid](
980
+ s, z, ss, doo,
981
+ T=T, S=S, BT=BT, BS=BS, NT=NT,
982
+ num_warps=num_warps,
983
+ num_stages=num_stages
984
+ )
985
+ grid = (NS, NT * NC, B * H)
986
+ chunk_abc_bwd_kernel_rcum_intra[grid](
987
+ s, z, ss, doo,
988
+ T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC,
989
+ num_warps=num_warps,
990
+ num_stages=num_stages
991
+ )
992
+ return doo
993
+
994
+ scale = 1.
995
+ qv = p.to(q.dtype)
996
+ dhv = bwd_inner(
997
+ qv, z, dov,
998
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
999
+ scale=scale,
1000
+ normk=True
1001
+ )
1002
+ dp1 = torch.empty_like(p)
1003
+ dsv1 = torch.empty_like(s, dtype=torch.float)
1004
+ dv = v.new_empty(NM, *v.shape)
1005
+ dAv = q.new_zeros(B, H, T, BT)
1006
+ grid = (NM, NT, B * H)
1007
+ chunk_abc_bwd_kernel_V[grid](
1008
+ s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv,
1009
+ scale=scale,
1010
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1011
+ num_warps=num_warps,
1012
+ num_stages=num_stages
1013
+ )
1014
+ dv = dv.sum(0)
1015
+ dp0 = torch.empty_like(p)
1016
+ dsv0 = s.new_zeros(s.shape, dtype=torch.float)
1017
+ grid = (NM, NT * NC, B * H)
1018
+ chunk_abc_bwd_kernel_intra_V[grid](
1019
+ qv, s, z, dAv, dp0, dsv0,
1020
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1021
+ num_warps=2,
1022
+ num_stages=num_stages
1023
+ )
1024
+ dp = dp1.add_(dp0)
1025
+ dsv = dsv1.add_(dsv0)
1026
+
1027
+ # softmax gradient, equivalent to:
1028
+ # dok = p * (dp - (p * dp).sum(-1, True))
1029
+ dok = softmax_bwd(p, dp, dtype=ok.dtype)
1030
+
1031
+ scale = K ** -0.5
1032
+ dhk = bwd_inner(
1033
+ q, z, dok,
1034
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1035
+ scale=scale,
1036
+ normk=False
1037
+ )
1038
+ dAk = q.new_zeros(NM, B, H, T, BT)
1039
+ grid = (NM, NT * NC * NC, B * H)
1040
+ chunk_abc_bwd_kernel_intra_K[grid](
1041
+ s, z, dok, dAk,
1042
+ scale=scale,
1043
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1044
+ num_warps=2,
1045
+ num_stages=num_stages
1046
+ )
1047
+ dAk = dAk.sum(0)
1048
+
1049
+ Ak = q.new_zeros(NK, B, H, T, BT)
1050
+ dq = torch.empty_like(q)
1051
+ dk = torch.empty_like(k)
1052
+ dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float)
1053
+ grid = (NK, NT, B * H)
1054
+ chunk_abc_bwd_kernel_K[grid](
1055
+ q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk,
1056
+ scale=scale,
1057
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1058
+ num_warps=num_warps,
1059
+ num_stages=num_stages
1060
+ )
1061
+ Ak = Ak.sum(0)
1062
+ dsk1 = dsk1.sum(0)
1063
+ dsk0 = torch.empty_like(s, dtype=torch.float)
1064
+ grid = (NM, NT * NC, B * H)
1065
+ chunk_abc_bwd_kernel_intra_KV[grid](
1066
+ s, z, Ak, dok, dsk0,
1067
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1068
+ num_warps=2,
1069
+ num_stages=num_stages
1070
+ )
1071
+ ds = dsv.add_(dsk1.add_(dsk0))
1072
+ ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM)
1073
+ ds = ds.to(s.dtype)
1074
+ return dq, dk, dv, ds, None, None
1075
+
1076
+
1077
+ @torch.compiler.disable
1078
+ def chunk_abc(
1079
+ q: torch.Tensor,
1080
+ k: torch.Tensor,
1081
+ v: torch.Tensor,
1082
+ s: torch.Tensor,
1083
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1084
+ output_final_state: bool = False,
1085
+ head_first: bool = True
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ r"""
1088
+ Args:
1089
+ q (torch.Tensor):
1090
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1091
+ k (torch.Tensor):
1092
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1093
+ v (torch.Tensor):
1094
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
1095
+ s (torch.Tensor):
1096
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
1097
+ initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
1098
+ Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`.
1099
+ output_final_state (Optional[bool]):
1100
+ Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`.
1101
+ head_first (Optional[bool]):
1102
+ Whether the inputs are in the head-first format.
1103
+ Default: `True`.
1104
+
1105
+ Returns:
1106
+ o (torch.Tensor):
1107
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1108
+ final_state (torch.Tensor):
1109
+ Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`.
1110
+ """
1111
+ if not head_first:
1112
+ q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s))
1113
+ o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)
1114
+ if not head_first:
1115
+ o = o.transpose(1, 2)
1116
+ return o, final_state
fla/ops/abc/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_abc(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
69
+
70
+
71
+ def naive_cumsum_abc(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ s: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ """
78
+ A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
79
+ This is just for demonstration purposes, with no numerical stabilities guaranteed.
80
+ """
81
+
82
+ dtype = q.dtype
83
+ q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
84
+
85
+ scale = q.shape[-1] ** -0.5
86
+ # [batch_size, n_heads, seq_len, n_slots]
87
+ s = (s - s.max(2, True)[0]).exp()
88
+ z = s.cumsum(2)
89
+ # [batch_size, n_heads, seq_len, n_slots, d_head]
90
+ K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
91
+ V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
92
+ # [batch_size, n_heads, seq_len, n_slots]
93
+ p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
94
+ # [batch_size, n_heads, seq_len, d_head]
95
+ o = torch.einsum('...m,...md->...d', p, V)
96
+ return o.to(dtype), None
fla/ops/attn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_attn
4
+
5
+ __all__ = [
6
+ 'parallel_attn'
7
+ ]
fla/ops/attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (286 Bytes). View file
 
fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
fla/ops/based/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem
13
+
14
+ BKV_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in BKV_LIST
26
+ for BV in BKV_LIST
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ split_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ USE_G: tl.constexpr,
53
+ USE_GK: tl.constexpr,
54
+ USE_GV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ NS = tl.cdiv(T, BS)
67
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
68
+ else:
69
+ bos, eos = i_n * T, i_n * T + T
70
+ NT = tl.cdiv(T, BT)
71
+ NS = tl.cdiv(T, BS)
72
+ boh = i_n * NS
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
79
+
80
+ for i_t in range(NT):
81
+ i_s = i_t // (BS // BT)
82
+ if HEAD_FIRST:
83
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
84
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+
86
+ o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
87
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ else:
89
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
93
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+
95
+ if i_t % (BS // BT) == 0:
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ last_idx = min((i_t + 1) * BT, T) - 1
102
+
103
+ # scalar decay
104
+ if USE_G:
105
+ if HEAD_FIRST:
106
+ b_g_last = tl.load(g + i_nh * T + last_idx)
107
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
108
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
109
+ else:
110
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
111
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
112
+ b_h *= exp(b_g_last)
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+ b_h *= exp(b_gk_last)[:, None]
128
+
129
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
130
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
131
+
132
+ # vector decay, h = h @ Diag(gv)
133
+ if USE_GV:
134
+ if HEAD_FIRST:
135
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
136
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
137
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
138
+ else:
139
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
140
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
141
+
142
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
143
+ b_h *= exp(b_gv_last)[None, :]
144
+
145
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
146
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
147
+
148
+ b_h += tl.dot(b_k, b_v)
149
+
150
+ if STORE_FINAL_STATE:
151
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
157
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
158
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
159
+ })
160
+ @triton.autotune(
161
+ configs=[
162
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
163
+ for BK in BKV_LIST
164
+ for BV in BKV_LIST
165
+ for num_warps in [1, 2, 4, 8]
166
+ for num_stages in [2, 3, 4]
167
+ ],
168
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_bwd_kernel_dh(
172
+ q,
173
+ g,
174
+ gk,
175
+ gv,
176
+ do,
177
+ dh,
178
+ dht,
179
+ dh0,
180
+ offsets,
181
+ split_offsets,
182
+ scale,
183
+ T,
184
+ HQ: tl.constexpr,
185
+ H: tl.constexpr,
186
+ K: tl.constexpr,
187
+ V: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr,
192
+ NG: tl.constexpr,
193
+ USE_G: tl.constexpr,
194
+ USE_GK: tl.constexpr,
195
+ USE_GV: tl.constexpr,
196
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
197
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_bg = i_nh // NG
203
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
204
+ i_h = i_hq // NG
205
+ if USE_OFFSETS:
206
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
207
+ T = eos - bos
208
+ NT = tl.cdiv(T, BT)
209
+ NS = tl.cdiv(T, BS)
210
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
211
+ else:
212
+ bos, eos = i_n * T, i_n * T + T
213
+ NT = tl.cdiv(T, BT)
214
+ NS = tl.cdiv(T, BS)
215
+ boh = i_n * NS
216
+
217
+ # [BK, BV]
218
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
219
+ if USE_FINAL_STATE_GRADIENT:
220
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
221
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
222
+
223
+ for i_t in range(NT - 1, -1, -1):
224
+ i_s = i_t // (BS // BT)
225
+ if HEAD_FIRST:
226
+ o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
227
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
228
+ else:
229
+ o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
230
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
231
+
232
+ if i_t % (BS // BT) == 0:
233
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
234
+ last_idx = min(i_t * BT + BT, T) - 1
235
+ # [BK, BT]
236
+ if HEAD_FIRST:
237
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
238
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ else:
240
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ b_q = tl.load(p_q, boundary_check=(0, 1))
243
+ b_q = (b_q * scale).to(b_q.dtype)
244
+ # [BT, BV]
245
+ b_do = tl.load(p_do, boundary_check=(0, 1))
246
+
247
+ if USE_G:
248
+ if HEAD_FIRST:
249
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
250
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
251
+ b_g_last = tl.load(g + i_bg * T + last_idx)
252
+ else:
253
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
254
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
255
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
256
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
257
+
258
+ b_dh *= exp(b_g_last)
259
+
260
+ if USE_GK:
261
+ if HEAD_FIRST:
262
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
263
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
264
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
265
+ else:
266
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
267
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+
269
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
270
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
271
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
272
+ b_dh *= exp(b_gk_last)[:, None]
273
+
274
+ if USE_GV:
275
+ if HEAD_FIRST:
276
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
277
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
278
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
279
+ else:
280
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
281
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
282
+
283
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
284
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
285
+
286
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
287
+ b_dh *= exp(b_gv_last)[None, :]
288
+
289
+ b_dh += tl.dot(b_q, b_do)
290
+
291
+ if STORE_INITIAL_STATE_GRADIENT:
292
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
293
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
294
+
295
+
296
+ def chunk_fwd_h(
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ g: torch.Tensor,
300
+ gk: torch.Tensor,
301
+ gv: torch.Tensor,
302
+ h0: torch.Tensor,
303
+ output_final_state: bool,
304
+ offsets: Optional[torch.Tensor] = None,
305
+ head_first: bool = True,
306
+ chunk_size: int = 64,
307
+ split_size: Optional[int] = None,
308
+ states_in_fp32: bool = False
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ if head_first:
311
+ B, H, T, K, V = *k.shape, v.shape[-1]
312
+ else:
313
+ B, T, H, K, V = *k.shape, v.shape[-1]
314
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
315
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
316
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
317
+ # N: the actual number of sequences in the batch with either equal or variable lengths
318
+ if offsets is None:
319
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
320
+ else:
321
+ split_offsets = prepare_chunk_offsets(offsets, BS)
322
+ N, NS = len(offsets) - 1, split_offsets[-1]
323
+
324
+ if head_first:
325
+ h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
326
+ else:
327
+ h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
328
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
329
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
330
+ chunk_fwd_kernel_h[grid](
331
+ k=k,
332
+ v=v,
333
+ h=h,
334
+ g=g,
335
+ gk=gk,
336
+ gv=gv,
337
+ h0=h0,
338
+ ht=ht,
339
+ offsets=offsets,
340
+ split_offsets=split_offsets,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ V=V,
345
+ BT=BT,
346
+ BS=BS,
347
+ USE_G=g is not None,
348
+ USE_GK=gk is not None,
349
+ USE_GV=gv is not None,
350
+ HEAD_FIRST=head_first
351
+ )
352
+ return h, ht
353
+
354
+
355
+ def chunk_bwd_dh(
356
+ q: torch.Tensor,
357
+ k: torch.Tensor,
358
+ v: torch.Tensor,
359
+ g: torch.Tensor,
360
+ gk: torch.Tensor,
361
+ gv: torch.Tensor,
362
+ do: torch.Tensor,
363
+ h0: torch.Tensor,
364
+ dht: torch.Tensor,
365
+ scale: float,
366
+ offsets: Optional[torch.Tensor] = None,
367
+ head_first: bool = True,
368
+ chunk_size: int = 64,
369
+ split_size: Optional[int] = None,
370
+ states_in_fp32: bool = False
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ if head_first:
373
+ B, H, T, K, V = *k.shape, v.shape[-1]
374
+ HQ = q.shape[1]
375
+ else:
376
+ B, T, H, K, V = *k.shape, v.shape[-1]
377
+ HQ = q.shape[2]
378
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
379
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
380
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
381
+ # N: the actual number of sequences in the batch with either equal or variable lengths
382
+ # NG: number of groups in GQA
383
+ if offsets is None:
384
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
385
+ else:
386
+ split_offsets = prepare_chunk_offsets(offsets, BS)
387
+ N, NS = len(offsets) - 1, split_offsets[-1]
388
+ NG = HQ // H
389
+
390
+ if head_first:
391
+ dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
392
+ else:
393
+ dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
394
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
395
+
396
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
397
+ chunk_bwd_kernel_dh[grid](
398
+ q=q,
399
+ g=g,
400
+ gk=gk,
401
+ gv=gv,
402
+ do=do,
403
+ dh=dh,
404
+ dht=dht,
405
+ dh0=dh0,
406
+ offsets=offsets,
407
+ split_offsets=split_offsets,
408
+ scale=scale,
409
+ T=T,
410
+ HQ=HQ,
411
+ H=H,
412
+ K=K,
413
+ V=V,
414
+ BT=BT,
415
+ BS=BS,
416
+ NG=NG,
417
+ USE_G=g is not None,
418
+ USE_GK=gk is not None,
419
+ USE_GV=gv is not None,
420
+ HEAD_FIRST=head_first
421
+ )
422
+ return dh, dh0
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/chunk_h_split.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in [32, 64]
22
+ for BV in [32, 64]
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3]
25
+ ],
26
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_fwd_kernel_h_split(
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ hs,
36
+ hr,
37
+ h0,
38
+ ht,
39
+ offsets,
40
+ split_indices,
41
+ T,
42
+ S: tl.constexpr,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_GK: tl.constexpr,
51
+ USE_GV: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # handle one split at a time
58
+ # i_h: head index
59
+ # i_n: sequence index
60
+ # i_s: local split index inside a sequence
61
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_ss, i_h = i_sh // H, i_sh % H
63
+ if USE_OFFSETS:
64
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ NS = tl.cdiv(T, S)
68
+ else:
69
+ NS = tl.cdiv(T, S)
70
+ i_n, i_s = i_ss // NS, i_ss % NS
71
+ bos, eos = i_n * T, i_n * T + T
72
+ i_nh = i_n * H + i_h
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # for the first split, we directly store the state as the final result
77
+ if i_s == 0:
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
81
+ p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
83
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
86
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ else:
88
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ # [BK, BT]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BT, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ last_idx = min(i_t * BT + BT, T) - 1
95
+
96
+ # scalar decay
97
+ if USE_G:
98
+ if HEAD_FIRST:
99
+ b_g_last = tl.load(g + i_nh * T + last_idx)
100
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
101
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
102
+ else:
103
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
104
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
105
+ b_h *= exp(b_g_last)
106
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
107
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
108
+
109
+ # vector decay, h = Diag(gk) @ h
110
+ if USE_GK:
111
+ if HEAD_FIRST:
112
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
113
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
114
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
115
+ else:
116
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
117
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
118
+
119
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
120
+ b_h *= exp(b_gk_last)[:, None]
121
+
122
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
123
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
124
+
125
+ # vector decay, h = h @ Diag(gv)
126
+ if USE_GV:
127
+ if HEAD_FIRST:
128
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
130
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
131
+ else:
132
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
134
+
135
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
136
+ b_h *= exp(b_gv_last)[None, :]
137
+
138
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
139
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
140
+
141
+ b_h += tl.dot(b_k, b_v)
142
+
143
+ # if there are more than one splits, we store the result to (unreduced) hs
144
+ # otherwise, we store the result to ht as the final state
145
+ if NS > 1:
146
+ p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1))
148
+ elif STORE_FINAL_STATE:
149
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
155
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
156
+ })
157
+ @triton.autotune(
158
+ configs=[
159
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
160
+ for BK in [32, 64]
161
+ for BV in [32, 64]
162
+ for num_warps in [2, 4, 8]
163
+ for num_stages in [2, 3, 4]
164
+ ],
165
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
166
+ )
167
+ @triton.jit(do_not_specialize=['T'])
168
+ def chunk_fwd_kernel_h_reduction(
169
+ g,
170
+ gk,
171
+ gv,
172
+ hs,
173
+ hr,
174
+ ht,
175
+ offsets,
176
+ split_offsets,
177
+ T,
178
+ S: tl.constexpr,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_G: tl.constexpr,
186
+ USE_GK: tl.constexpr,
187
+ USE_GV: tl.constexpr,
188
+ STORE_FINAL_STATE: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr
191
+ ):
192
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
193
+ i_n, i_h = i_nh // H, i_nh % H
194
+ if USE_OFFSETS:
195
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
196
+ T = eos - bos
197
+ NS = tl.cdiv(T, S)
198
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
199
+ else:
200
+ bos, eos = i_n * T, i_n * T + T
201
+ NS = tl.cdiv(T, S)
202
+ boh = i_n * NS
203
+
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+ # skip the first split
206
+ for i_s in range(1, NS):
207
+ p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
209
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
210
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
213
+ last_idx = min(i_t * BT + BT, T) - 1
214
+ # scalar decay
215
+ if USE_G:
216
+ if HEAD_FIRST:
217
+ b_g_last = tl.load(g + i_nh * T + last_idx)
218
+ else:
219
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
220
+ b_h *= exp(b_g_last)
221
+
222
+ # vector decay, h = Diag(gk) @ h
223
+ if USE_GK:
224
+ if HEAD_FIRST:
225
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
226
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
227
+ else:
228
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
229
+
230
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
231
+ b_h *= exp(b_gk_last)[:, None]
232
+
233
+ # vector decay, h = h @ Diag(gv)
234
+ if USE_GV:
235
+ if HEAD_FIRST:
236
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
237
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
238
+ else:
239
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
240
+
241
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
242
+ b_h *= exp(b_gv_last)[None, :]
243
+
244
+ if NS > 1:
245
+ if STORE_FINAL_STATE:
246
+ p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
247
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
248
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
249
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ @triton.heuristics({
253
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
254
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
255
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
256
+ })
257
+ @triton.autotune(
258
+ configs=[
259
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
260
+ for BK in [32, 64]
261
+ for BV in [32, 64]
262
+ for num_warps in [2, 4, 8]
263
+ for num_stages in [2, 3]
264
+ ],
265
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
266
+ )
267
+ @triton.jit(do_not_specialize=['T'])
268
+ def chunk_bwd_kernel_dh_split(
269
+ q,
270
+ g,
271
+ gk,
272
+ gv,
273
+ do,
274
+ dht,
275
+ dhs,
276
+ dhr,
277
+ dh0,
278
+ offsets,
279
+ split_indices,
280
+ scale,
281
+ T,
282
+ S: tl.constexpr,
283
+ HQ: tl.constexpr,
284
+ H: tl.constexpr,
285
+ K: tl.constexpr,
286
+ V: tl.constexpr,
287
+ BT: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr,
290
+ NG: tl.constexpr,
291
+ USE_G: tl.constexpr,
292
+ USE_GK: tl.constexpr,
293
+ USE_GV: tl.constexpr,
294
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
295
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
296
+ USE_OFFSETS: tl.constexpr,
297
+ HEAD_FIRST: tl.constexpr
298
+ ):
299
+ # handle one split at a time
300
+ # i_h: head index
301
+ # i_n: sequence index
302
+ # i_s: local split index inside a sequence
303
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
304
+ i_ss, i_hq = i_sh // HQ, i_sh % HQ
305
+ if USE_OFFSETS:
306
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
307
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
308
+ T = eos - bos
309
+ NS = tl.cdiv(T, S)
310
+ else:
311
+ NS = tl.cdiv(T, S)
312
+ i_n, i_s = i_ss // NS, i_ss % NS
313
+ bos, eos = i_n * T, i_n * T + T
314
+ i_nh = i_n * HQ + i_hq
315
+ i_ng, i_h = i_nh // NG, i_hq // NG
316
+
317
+ # [BK, BV]
318
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
319
+ if i_s == NS - 1:
320
+ if USE_FINAL_STATE_GRADIENT:
321
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
323
+ p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
325
+
326
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
327
+ if HEAD_FIRST:
328
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
329
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
330
+ else:
331
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
333
+
334
+ b_q = tl.load(p_q, boundary_check=(0, 1))
335
+ b_q = (b_q * scale).to(b_q.dtype)
336
+ # [BT, BV]
337
+ b_do = tl.load(p_do, boundary_check=(0, 1))
338
+
339
+ last_idx = min(i_t * BT + BT, T) - 1
340
+ if USE_G:
341
+ if HEAD_FIRST:
342
+ p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT)
343
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
344
+ b_g_last = tl.load(g + i_ng * T + last_idx)
345
+ else:
346
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
347
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
348
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
349
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
350
+ b_dh *= exp(b_g_last)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
356
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
357
+ else:
358
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
359
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
360
+
361
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
362
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
363
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
364
+ b_dh *= exp(b_gk_last)[:, None]
365
+
366
+ if USE_GV:
367
+ if HEAD_FIRST:
368
+ p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
370
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
371
+ else:
372
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
374
+
375
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
376
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
377
+
378
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
379
+ b_dh *= exp(b_gv_last)[None, :]
380
+
381
+ b_dh += tl.dot(b_q, b_do)
382
+
383
+ if NS > 1:
384
+ p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
385
+ tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1))
386
+ elif STORE_INITIAL_STATE_GRADIENT:
387
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
388
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
389
+
390
+
391
+ @triton.heuristics({
392
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
393
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
394
+ })
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
398
+ for BK in [32, 64]
399
+ for BV in [32, 64]
400
+ for num_warps in [2, 4, 8]
401
+ for num_stages in [2, 3, 4]
402
+ ],
403
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
404
+ )
405
+ @triton.jit(do_not_specialize=['T'])
406
+ def chunk_bwd_kernel_dh_reduction(
407
+ g,
408
+ gk,
409
+ gv,
410
+ dhs,
411
+ dhr,
412
+ dh0,
413
+ offsets,
414
+ split_offsets,
415
+ T,
416
+ S: tl.constexpr,
417
+ H: tl.constexpr,
418
+ HQ: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NG: tl.constexpr,
425
+ USE_G: tl.constexpr,
426
+ USE_GK: tl.constexpr,
427
+ USE_GV: tl.constexpr,
428
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
434
+ i_ng, i_h = i_nh // NG, i_hq // NG
435
+ if USE_OFFSETS:
436
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
437
+ T = eos - bos
438
+ NS = tl.cdiv(T, S)
439
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
440
+ else:
441
+ bos, eos = i_n * T, i_n * T + T
442
+ NS = tl.cdiv(T, S)
443
+ boh = i_n * NS
444
+
445
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
446
+ for i_s in range(NS - 2, -1, -1):
447
+ p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
448
+ p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
449
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
450
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
451
+
452
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
453
+ last_idx = min(i_t * BT + BT, T) - 1
454
+ # scalar decay
455
+ if USE_G:
456
+ if HEAD_FIRST:
457
+ b_g_last = tl.load(g + i_ng * T + last_idx)
458
+ else:
459
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
460
+ b_dh *= exp(b_g_last)
461
+
462
+ if USE_GK:
463
+ if HEAD_FIRST:
464
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
465
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
466
+ else:
467
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
468
+
469
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
470
+ b_dh *= exp(b_gk_last)[:, None]
471
+
472
+ if USE_GV:
473
+ if HEAD_FIRST:
474
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
475
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
476
+ else:
477
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
478
+
479
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
480
+ b_dh *= exp(b_gv_last)[None, :]
481
+
482
+ if NS > 1:
483
+ if STORE_INITIAL_STATE_GRADIENT:
484
+ p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
485
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
486
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
487
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
488
+
489
+
490
+ def chunk_fwd_h(
491
+ k: torch.Tensor,
492
+ v: torch.Tensor,
493
+ g: torch.Tensor,
494
+ gk: torch.Tensor,
495
+ gv: torch.Tensor,
496
+ h0: torch.Tensor,
497
+ output_final_state: bool,
498
+ offsets: Optional[torch.LongTensor] = None,
499
+ split_offsets: Optional[torch.LongTensor] = None,
500
+ split_indices: Optional[torch.LongTensor] = None,
501
+ head_first: bool = True,
502
+ chunk_size: int = 64,
503
+ split_size: int = 256,
504
+ states_in_fp32: bool = True
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ if head_first:
507
+ B, H, T, K, V = *k.shape, v.shape[-1]
508
+ else:
509
+ B, T, H, K, V = *k.shape, v.shape[-1]
510
+ # B: batch size
511
+ # N: the actual number of sequences in the batch
512
+ # H: number of heads
513
+ # T: sequence length, can be variable across sequences
514
+ # S: split size, a multiple of chunk size
515
+ # BT: chunk size
516
+ S, BT = split_size, chunk_size
517
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
518
+ if offsets is None:
519
+ N = B
520
+ NS = N * triton.cdiv(T, S)
521
+ else:
522
+ N = len(offsets) - 1
523
+ NS = split_offsets[-1]
524
+
525
+ # unreduced kv states per split
526
+ hs = k.new_empty(NS, H, K, V, dtype=torch.float)
527
+ # reduced states per split
528
+ hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
529
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
530
+ # parallelized over splits
531
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H)
532
+ chunk_fwd_kernel_h_split[grid](
533
+ k=k,
534
+ v=v,
535
+ g=g,
536
+ gk=gk,
537
+ gv=gv,
538
+ hs=hs,
539
+ hr=hr,
540
+ h0=h0,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ split_indices=split_indices,
544
+ T=T,
545
+ S=S,
546
+ H=H,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ USE_G=g is not None,
551
+ USE_GK=gk is not None,
552
+ USE_GV=gv is not None,
553
+ HEAD_FIRST=head_first
554
+ )
555
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
556
+ chunk_fwd_kernel_h_reduction[grid](
557
+ g=g,
558
+ gk=gk,
559
+ gv=gv,
560
+ hs=hs,
561
+ hr=hr,
562
+ ht=ht,
563
+ offsets=offsets,
564
+ split_offsets=split_offsets,
565
+ T=T,
566
+ S=S,
567
+ H=H,
568
+ K=K,
569
+ V=V,
570
+ BT=BT,
571
+ USE_G=g is not None,
572
+ USE_GK=gk is not None,
573
+ USE_GV=gv is not None,
574
+ HEAD_FIRST=head_first
575
+ )
576
+ return hr, ht
577
+
578
+
579
+ def chunk_bwd_dh(
580
+ q: torch.Tensor,
581
+ k: torch.Tensor,
582
+ v: torch.Tensor,
583
+ g: torch.Tensor,
584
+ gk: torch.Tensor,
585
+ gv: torch.Tensor,
586
+ do: torch.Tensor,
587
+ h0: torch.Tensor,
588
+ dht: torch.Tensor,
589
+ scale: float,
590
+ offsets: Optional[torch.Tensor] = None,
591
+ split_offsets: Optional[torch.Tensor] = None,
592
+ split_indices: Optional[torch.Tensor] = None,
593
+ head_first: bool = True,
594
+ chunk_size: int = 64,
595
+ split_size: int = 256,
596
+ states_in_fp32: bool = True
597
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
598
+ if head_first:
599
+ B, H, T, K, V = *k.shape, v.shape[-1]
600
+ HQ = q.shape[1]
601
+ else:
602
+ B, T, H, K, V = *k.shape, v.shape[-1]
603
+ HQ = q.shape[2]
604
+ # B: batch size
605
+ # N: the actual number of sequences in the batch
606
+ # H: number of heads
607
+ # T: sequence length, can be variable across sequences
608
+ # S: split size, a multiple of chunk size
609
+ # BT: chunk size
610
+ S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size
611
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
612
+ if offsets is None:
613
+ N = B
614
+ NS = N * triton.cdiv(T, S)
615
+ else:
616
+ N = len(offsets) - 1
617
+ NS = split_offsets[-1]
618
+ # number of groups in GQA
619
+ NG = HQ // H
620
+
621
+ dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float)
622
+ dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
623
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
624
+
625
+ # parallelized over splits
626
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ)
627
+ chunk_bwd_kernel_dh_split[grid](
628
+ q=q,
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ do=do,
633
+ dht=dht,
634
+ dhs=dhs,
635
+ dhr=dhr,
636
+ dh0=dh0,
637
+ offsets=offsets,
638
+ split_indices=split_indices,
639
+ scale=scale,
640
+ T=T,
641
+ S=S,
642
+ HQ=HQ,
643
+ H=H,
644
+ K=K,
645
+ V=V,
646
+ BT=BT,
647
+ NG=NG,
648
+ USE_G=g is not None,
649
+ USE_GK=gk is not None,
650
+ USE_GV=gv is not None,
651
+ HEAD_FIRST=head_first,
652
+ )
653
+
654
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
655
+ chunk_bwd_kernel_dh_reduction[grid](
656
+ g=g,
657
+ gk=gk,
658
+ gv=gv,
659
+ dhs=dhs,
660
+ dhr=dhr,
661
+ dh0=dh0,
662
+ offsets=offsets,
663
+ split_offsets=split_offsets,
664
+ T=T,
665
+ S=S,
666
+ HQ=HQ,
667
+ H=H,
668
+ K=K,
669
+ V=V,
670
+ BT=BT,
671
+ NG=NG,
672
+ USE_G=g is not None,
673
+ USE_GK=gk is not None,
674
+ USE_GV=gv is not None,
675
+ HEAD_FIRST=head_first
676
+ )
677
+ return dhr, dh0
fla/ops/common/chunk_o.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, safe_exp
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in BKV_LIST
25
+ for BV in BKV_LIST
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT'],
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_fwd_kernel_o(
33
+ q,
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ o,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+
56
+ if USE_OFFSETS:
57
+ i_tg = i_t
58
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ else:
63
+ NT = tl.cdiv(T, BT)
64
+ i_tg = i_b * NT + i_t
65
+ bos, eos = i_b * T, i_b * T + T
66
+
67
+ s_qk = K if HEAD_FIRST else H*K
68
+ s_vo = V if HEAD_FIRST else H*V
69
+ s_g = 1 if HEAD_FIRST else H
70
+ # offset calculation
71
+ q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
72
+ k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
73
+ v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
74
+ o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
75
+ h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V)
76
+
77
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
78
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_k in range(tl.cdiv(K, BK)):
81
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ # [BK, BT]
87
+ b_k = tl.load(p_k, boundary_check=(0, 1))
88
+ # [BK, BV]
89
+ b_h = tl.load(p_h, boundary_check=(0, 1))
90
+
91
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+
96
+ if USE_G:
97
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
98
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
99
+ b_g = tl.load(p_g, boundary_check=(0,))
100
+ b_o = b_o * exp(b_g)[:, None]
101
+ b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
102
+
103
+ o_i = tl.arange(0, BT)
104
+ m_A = o_i[:, None] >= o_i[None, :]
105
+ b_A = tl.where(m_A, b_A, 0)
106
+
107
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
108
+ p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
109
+ b_v = tl.load(p_v, boundary_check=(0, 1))
110
+
111
+ # to fix mma -> mma layout conversion
112
+ # already solved by triton v3.2 or higher
113
+ b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
114
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
115
+
116
+
117
+ @triton.heuristics({
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
119
+ 'USE_G': lambda args: args['g'] is not None,
120
+ 'USE_DW': lambda args: args['dw'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
125
+ for num_warps in NUM_WARPS
126
+ for num_stages in [2, 3, 4]
127
+ ],
128
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'],
129
+ )
130
+ @triton.jit(do_not_specialize=['T'])
131
+ def chunk_bwd_kernel_dqkwg(
132
+ q,
133
+ k,
134
+ v,
135
+ h,
136
+ g,
137
+ do,
138
+ dh,
139
+ dq,
140
+ dk,
141
+ dg,
142
+ w,
143
+ dv,
144
+ dw,
145
+ offsets,
146
+ indices,
147
+ scale,
148
+ B: tl.constexpr,
149
+ T,
150
+ H: tl.constexpr,
151
+ K: tl.constexpr,
152
+ V: tl.constexpr,
153
+ BT: tl.constexpr,
154
+ BK: tl.constexpr,
155
+ BV: tl.constexpr,
156
+ USE_G: tl.constexpr,
157
+ USE_DW: tl.constexpr,
158
+ USE_OFFSETS: tl.constexpr,
159
+ HEAD_FIRST: tl.constexpr
160
+ ):
161
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
162
+ i_b, i_h = i_bh // H, i_bh % H
163
+ if USE_G:
164
+ dg += i_k * B * H * T
165
+ if USE_OFFSETS:
166
+ i_tg = i_t
167
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
168
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
169
+ T = eos - bos
170
+ NT = tl.cdiv(T, BT)
171
+ else:
172
+ NT = tl.cdiv(T, BT)
173
+ i_tg = i_b * NT + i_t
174
+ bos, eos = i_b * T, i_b * T + T
175
+
176
+ # offset calculation
177
+ v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
178
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
179
+ h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
180
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
181
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
182
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
183
+ dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
184
+ dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
185
+ s_qk = K if HEAD_FIRST else H*K
186
+ s_vo = V if HEAD_FIRST else H*V
187
+ s_g = 1 if HEAD_FIRST else H
188
+
189
+ # for delta rule only
190
+ if USE_DW:
191
+ dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
192
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
193
+ w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
194
+
195
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
196
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
197
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
198
+ b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None
199
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
200
+
201
+ for i_v in range(tl.cdiv(V, BV)):
202
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
203
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
205
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
206
+ # [BT, BV]
207
+ b_v = tl.load(p_v, boundary_check=(0, 1))
208
+ b_do = tl.load(p_do, boundary_check=(0, 1))
209
+ # [BV, BK]
210
+ b_h = tl.load(p_h, boundary_check=(0, 1))
211
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
212
+ if USE_G:
213
+ b_dg_last += (tl.sum(b_h * b_dh))
214
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
215
+ b_ds += tl.dot(b_do, tl.trans(b_v))
216
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
217
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
218
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
219
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
220
+ if USE_DW:
221
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
222
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
223
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
224
+
225
+ if USE_DW and not USE_G:
226
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ tl.debug_barrier()
230
+ o_i = tl.arange(0, BT)
231
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
232
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
233
+ b_q = tl.load(p_q, boundary_check=(0, 1))
234
+ b_k = tl.load(p_k, boundary_check=(0, 1))
235
+
236
+ p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
237
+ p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
238
+
239
+ if USE_G:
240
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
241
+ g += i_bh * T if HEAD_FIRST else bos * H + i_h
242
+ dg += i_bh * T if HEAD_FIRST else bos * H + i_h
243
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
244
+ b_g = tl.load(p_g, boundary_check=(0,))
245
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
246
+ b_dg_last *= exp(b_g_last)
247
+
248
+ if USE_DW:
249
+ p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
251
+ b_w = tl.load(p_w, boundary_check=(0, 1))
252
+ b_dw = b_dw * exp(b_g)[:, None]
253
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
254
+ b_dg -= tl.sum(b_w * b_dw, axis=1)
255
+
256
+ b_dq = b_dq * exp(b_g)[:, None] * scale
257
+ b_dg += tl.sum(b_dq * b_q, axis=1)
258
+
259
+ b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None]
260
+ b_dg -= tl.sum(b_k * b_dk, axis=1)
261
+ b_dg_last += tl.sum(b_dk * b_k)
262
+
263
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale
264
+ b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
265
+ b_dg += tl.sum(b_ds2, axis=1)
266
+ b_dg -= tl.sum(b_ds2, axis=0)
267
+
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ # [BT, BK]
270
+ b_dq += tl.dot(b_ds, b_k)
271
+ b_dk += tl.dot(tl.trans(b_ds), b_q)
272
+ p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
273
+ # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue
274
+ # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)
275
+ b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last)
276
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
277
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
278
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
279
+ else:
280
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
281
+ b_ds = b_ds.to(b_k.dtype)
282
+ b_dq += tl.dot(b_ds, b_k)
283
+ b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
284
+ b_dq *= scale
285
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+
288
+
289
+ @triton.heuristics({
290
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
291
+ 'USE_G': lambda args: args['g'] is not None,
292
+ })
293
+ @triton.autotune(
294
+ configs=[
295
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
296
+ for num_warps in [2, 4, 8]
297
+ for num_stages in [2, 3, 4]
298
+ ],
299
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
300
+ )
301
+ @triton.jit(do_not_specialize=['T'])
302
+ def chunk_bwd_kernel_dv(
303
+ q,
304
+ k,
305
+ g,
306
+ do,
307
+ dv,
308
+ dh,
309
+ offsets,
310
+ indices,
311
+ scale,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ V: tl.constexpr,
316
+ BT: tl.constexpr,
317
+ BK: tl.constexpr,
318
+ BV: tl.constexpr,
319
+ USE_G: tl.constexpr,
320
+ USE_OFFSETS: tl.constexpr,
321
+ HEAD_FIRST: tl.constexpr
322
+ ):
323
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
324
+ i_b, i_h = i_bh // H, i_bh % H
325
+ if USE_OFFSETS:
326
+ i_tg = i_t
327
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
328
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
329
+ T = eos - bos
330
+ NT = tl.cdiv(T, BT)
331
+ else:
332
+ NT = tl.cdiv(T, BT)
333
+ i_tg = i_b * NT + i_t
334
+ bos, eos = i_b * T, i_b * T + T
335
+
336
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
337
+
338
+ # offset calculation
339
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
340
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
341
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
342
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
343
+ s_qk = K if HEAD_FIRST else H*K
344
+ s_vo = V if HEAD_FIRST else H*V
345
+ s_g = 1 if HEAD_FIRST else H
346
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
347
+
348
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
349
+ for i_k in range(tl.cdiv(K, BK)):
350
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
352
+ b_q = tl.load(p_q, boundary_check=(0, 1))
353
+ b_k = tl.load(p_k, boundary_check=(0, 1))
354
+ b_A += tl.dot(b_k, b_q)
355
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
356
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
357
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
358
+
359
+ if USE_G:
360
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
361
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
362
+ b_g = tl.load(p_g, boundary_check=(0,))
363
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
364
+ b_dv *= safe_exp(-b_g + b_g_last)[:, None]
365
+
366
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
367
+ if USE_G:
368
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
369
+ else:
370
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
371
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
372
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ b_do = tl.load(p_do, boundary_check=(0, 1))
374
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
375
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
376
+
377
+
378
+ @triton.heuristics({
379
+ 'USE_G': lambda args: args['g'] is not None,
380
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
381
+ })
382
+ @triton.autotune(
383
+ configs=[
384
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
385
+ for num_warps in NUM_WARPS
386
+ for num_stages in [2, 3, 4]
387
+ ],
388
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
389
+ )
390
+ @triton.jit(do_not_specialize=['T'])
391
+ def chunk_bwd_kernel_dv_local(
392
+ q,
393
+ k,
394
+ g,
395
+ do,
396
+ dv,
397
+ offsets,
398
+ indices,
399
+ scale,
400
+ T,
401
+ H: tl.constexpr,
402
+ K: tl.constexpr,
403
+ V: tl.constexpr,
404
+ BT: tl.constexpr,
405
+ BK: tl.constexpr,
406
+ BV: tl.constexpr,
407
+ USE_G: tl.constexpr,
408
+ USE_OFFSETS: tl.constexpr,
409
+ HEAD_FIRST: tl.constexpr
410
+ ):
411
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
412
+ i_b, i_h = i_bh // H, i_bh % H
413
+ if USE_OFFSETS:
414
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
415
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
416
+ T = eos - bos
417
+ else:
418
+ bos, eos = i_b * T, i_b * T + T
419
+
420
+ # offset calculation
421
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
422
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
423
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
424
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
425
+ s_qk = K if HEAD_FIRST else H*K
426
+ s_vo = V if HEAD_FIRST else H*V
427
+ s_g = 1 if HEAD_FIRST else H
428
+
429
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
430
+ for i_k in range(tl.cdiv(K, BK)):
431
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
432
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
433
+ b_q = tl.load(p_q, boundary_check=(0, 1))
434
+ b_k = tl.load(p_k, boundary_check=(0, 1))
435
+ b_A += tl.dot(b_k, b_q)
436
+
437
+ if USE_G:
438
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
439
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
440
+ b_g = tl.load(p_g, boundary_check=(0,))
441
+
442
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
443
+ if USE_G:
444
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
445
+ else:
446
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
447
+
448
+ for i_v in range(tl.cdiv(V, BV)):
449
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
450
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
451
+ b_do = tl.load(p_do, boundary_check=(0, 1))
452
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
453
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
454
+
455
+
456
+ def chunk_fwd_o(
457
+ q: torch.Tensor,
458
+ k: torch.Tensor,
459
+ v: torch.Tensor,
460
+ h: torch.Tensor,
461
+ g: Optional[torch.Tensor] = None, # cumsum of log decay
462
+ scale: Optional[float] = None,
463
+ offsets: Optional[torch.LongTensor] = None,
464
+ indices: Optional[torch.LongTensor] = None,
465
+ head_first: bool = True,
466
+ chunk_size: int = 64
467
+ ) -> torch.Tensor:
468
+ if head_first:
469
+ B, H, T, K, V = *q.shape, v.shape[-1]
470
+ else:
471
+ B, T, H, K, V = *q.shape, v.shape[-1]
472
+ if scale is None:
473
+ scale = k.shape[-1] ** -0.5
474
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
475
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
476
+
477
+ o = torch.empty_like(v)
478
+
479
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
480
+ chunk_fwd_kernel_o[grid](
481
+ q,
482
+ k,
483
+ v,
484
+ h,
485
+ g,
486
+ o,
487
+ offsets,
488
+ indices,
489
+ scale,
490
+ T=T,
491
+ H=H,
492
+ K=K,
493
+ V=V,
494
+ BT=BT,
495
+ HEAD_FIRST=head_first
496
+ )
497
+ return o
498
+
499
+
500
+ def chunk_bwd_dv(
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ g: torch.Tensor,
504
+ do: torch.Tensor,
505
+ dh: torch.Tensor,
506
+ scale: float,
507
+ offsets: Optional[torch.LongTensor] = None,
508
+ indices: Optional[torch.LongTensor] = None,
509
+ head_first: bool = True,
510
+ chunk_size: int = 64
511
+ ) -> torch.Tensor:
512
+ if head_first:
513
+ B, H, T, K, V = *k.shape, do.shape[-1]
514
+ else:
515
+ B, T, H, K, V = *k.shape, do.shape[-1]
516
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
517
+ # H100 can have larger block size
518
+ if check_shared_mem('hopper', k.device.index):
519
+ CONST_TILING = 128
520
+ elif check_shared_mem:
521
+ CONST_TILING = 64
522
+ else:
523
+ CONST_TILING = 32
524
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
525
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
526
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
527
+ NV = triton.cdiv(V, BV)
528
+
529
+ dv = torch.empty_like(do)
530
+ grid = (NV, NT, B * H)
531
+ chunk_bwd_kernel_dv[grid](
532
+ q,
533
+ k,
534
+ g,
535
+ do,
536
+ dv,
537
+ dh,
538
+ offsets,
539
+ indices,
540
+ scale,
541
+ T=T,
542
+ H=H,
543
+ K=K,
544
+ V=V,
545
+ BT=BT,
546
+ BK=BK,
547
+ BV=BV,
548
+ HEAD_FIRST=head_first
549
+ )
550
+ return dv
551
+
552
+
553
+ def chunk_bwd_dv_local(
554
+ q: torch.Tensor,
555
+ k: torch.Tensor,
556
+ g: torch.Tensor,
557
+ do: torch.Tensor,
558
+ dh: torch.Tensor,
559
+ scale: float,
560
+ offsets: Optional[torch.LongTensor] = None,
561
+ indices: Optional[torch.LongTensor] = None,
562
+ head_first: bool = True,
563
+ chunk_size: int = 64
564
+ ) -> torch.Tensor:
565
+ if head_first:
566
+ B, H, T, K, V = *k.shape, do.shape[-1]
567
+ else:
568
+ B, T, H, K, V = *k.shape, do.shape[-1]
569
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
570
+ # H100 can have larger block size
571
+ if check_shared_mem('hopper', k.device.index):
572
+ CONST_TILING = 128
573
+ elif check_shared_mem:
574
+ CONST_TILING = 64
575
+ else:
576
+ CONST_TILING = 32
577
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
578
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
579
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
580
+
581
+ dv = torch.empty_like(do)
582
+ grid = (NT, B * H)
583
+ chunk_bwd_kernel_dv_local[grid](
584
+ q,
585
+ k,
586
+ g,
587
+ do,
588
+ dv,
589
+ offsets,
590
+ indices,
591
+ scale,
592
+ T=T,
593
+ H=H,
594
+ K=K,
595
+ V=V,
596
+ BT=BT,
597
+ BK=BK,
598
+ BV=BV,
599
+ HEAD_FIRST=head_first
600
+ )
601
+ return dv
602
+
603
+
604
+ def chunk_bwd_dqkwg(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ g: torch.Tensor,
609
+ do: torch.Tensor,
610
+ h: torch.Tensor,
611
+ dh: torch.Tensor,
612
+ dv: Optional[torch.Tensor] = None,
613
+ w: Optional[torch.Tensor] = None,
614
+ offsets: Optional[torch.LongTensor] = None,
615
+ indices: Optional[torch.LongTensor] = None,
616
+ chunk_size: int = 64,
617
+ scale: float = 1.0,
618
+ head_first: bool = True,
619
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
620
+
621
+ if head_first:
622
+ B, H, T, K, V = *k.shape, v.shape[-1]
623
+ else:
624
+ B, T, H, K, V = *k.shape, v.shape[-1]
625
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
626
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
627
+
628
+ CONST_TILING = 64 if check_shared_mem() else 32
629
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
630
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
631
+ NK = triton.cdiv(K, BK)
632
+ dq = torch.empty_like(q)
633
+ dk = torch.empty_like(k)
634
+ dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
635
+ dw = torch.empty_like(w) if w is not None else None
636
+
637
+ grid = (NK, NT, B * H)
638
+ chunk_bwd_kernel_dqkwg[grid](
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ h=h,
643
+ g=g,
644
+ do=do,
645
+ dh=dh,
646
+ dv=dv,
647
+ w=w,
648
+ dw=dw,
649
+ dq=dq,
650
+ dk=dk,
651
+ dg=dg,
652
+ offsets=offsets,
653
+ indices=indices,
654
+ scale=scale,
655
+ B=B,
656
+ T=T,
657
+ H=H,
658
+ K=K,
659
+ V=V,
660
+ BT=BT,
661
+ BK=BK,
662
+ BV=BV,
663
+ HEAD_FIRST=head_first
664
+ )
665
+
666
+ if dg is not None:
667
+ dg = dg.sum(0)
668
+ return dq, dk, dw, dg
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chunkwise-form Parallelism of DeltaNet
2
+
3
+ This section expands on the formulation presented in Appendix B of the DeltaNet paper.[^1]
4
+
5
+ To reduce notational clutter, we focus on the first chunk, denoting $\mathbf{S}^r=\mathbf{S}_{[1]}^r$. By partially expanding the recurrence, we have:
6
+ ```math
7
+ \begin{equation}
8
+ \begin{aligned}
9
+ \mathbf{S}^r &= \underbrace{\left(\prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \right)}_{:= \mathbf{P}^r} \cdot\mathbf{S}^{0} + \overbrace{\sum_{i=1}^{r} \underbrace{\left(\prod_{j=i+1}^r \mathbf{I} - \beta^j \boldsymbol{k}^j \boldsymbol{k}^{j\top} \right)}_{:= \mathbf{P}_{i+1}^r}\beta^i \boldsymbol{k}^i\boldsymbol{v}^{i\top}}^{:=\mathbf{H}^r} \\
10
+ &=\mathbf{P}^r \cdot \mathbf{S}^{0} + \mathbf{H}^r
11
+ \end{aligned}
12
+ \end{equation}
13
+ ```
14
+
15
+ where $\mathbf{P}_i^r$ involves cumulative products of generalized Householder matrices.
16
+ We abbreviate $\mathbf{P}_1^r$ as $\mathbf{P}^r$.
17
+ This can be optimized using the classical WY representation:
18
+ ```math
19
+ \begin{equation}
20
+ \mathbf{P}^{r} = \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} \in \mathbb{R}^{d_k \times d_k};\qquad
21
+ \boldsymbol{w}^r = \beta^r \left(\boldsymbol{k}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i \right)\boldsymbol{w}^i \right) \in \mathbb{R}^{d_k}
22
+ \end{equation}
23
+ ```
24
+
25
+ We prove this by induction:
26
+ ```math
27
+ \begin{align*}
28
+ \mathbf{P}^{r} &= \prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \\
29
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\mathbf{P}^{r-1} \\
30
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\left(\mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
31
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} + \beta^r\boldsymbol{k}^r \boldsymbol{k}^{r\top} \left(\sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
32
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \left(\boldsymbol{k}^{r} - \left(\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top} \boldsymbol{k}^i\right)\boldsymbol{w}^{i}\right) \right)^\top \\
33
+ &= \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top}
34
+ \end{align*}
35
+ ```
36
+
37
+ Similarly, $\mathbf{H}^r$ can be represented as:
38
+ ```math
39
+ \begin{equation}
40
+ \mathbf{H}^{r} = \sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} \in \mathbb{R}^{d_k \times d_v};\qquad \boldsymbol{u}^r = \beta^r \left(\boldsymbol{v}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i\right) \boldsymbol{u}^i \right)\in \mathbb{R}^{d_v}
41
+ \end{equation}
42
+ ```
43
+
44
+ This can also be proven by induction:
45
+ ```math
46
+ \begin{align*}
47
+ \mathbf{H}^{r} &= \sum_{i=1}^{r} \mathbf{P}_{i+1}^r \beta^i \boldsymbol{k}^i \boldsymbol{v}^{i\top}\\
48
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right) \mathbf{H}^{r-1} + \beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
49
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} +\beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
50
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \left(\beta^r \boldsymbol{v}^{r\top}-\beta^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top}\right) \\
51
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \beta^r\left(\boldsymbol{v}^{r}-\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top}\boldsymbol{k}^{i}\right)\boldsymbol{u}^{i} \right)^\top \\
52
+ &=\sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top}
53
+ \end{align*}
54
+ ```
55
+
56
+ In matrix form, $\mathbf{P}$ and $\mathbf{H}$ can be written as:
57
+ ```math
58
+ \begin{equation}
59
+ \mathbf{P}=\mathbf{I}-\mathbf{K}^\top\mathbf{W} \in \mathbb{R}^{d_k \times d_k}, \qquad\mathbf{H}=\mathbf{K}^\top\mathbf{U} \in \mathbb{R}^{d_k\times d_v}
60
+ \end{equation}
61
+ ```
62
+
63
+ Now we can derive the matrix form of $\mathbf{W}$ and $\mathbf{U}$:
64
+ ```math
65
+ \begin{align*}
66
+ \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} - \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\mathbf{W}\\
67
+ \left(\mathbf{I} + \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\right) \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K}
68
+ \end{align*}
69
+ ```
70
+ A similar process holds for $\mathbf{U}$. We can further write $\mathbf{W}$ and $\mathbf{U}$ in matrix form:
71
+ ```math
72
+ \begin{align*}
73
+ \mathbf{T} &= \left(\mathbf{I} + \mathrm{tril}\left(\mathrm{diag}(\beta)\mathbf{K} \mathbf{K}^\top,-1\right)\right)^{-1}\mathrm{diag}\left(\beta\right)\in \mathbb{R}^{C \times C}\\
74
+ \mathbf{W} &= \mathbf{T} \mathbf{K}\in \mathbb{R}^{C \times d_k}\\
75
+ \mathbf{U} &= \mathbf{T}\mathbf{V}\in \mathbb{R}^{C \times d_v}
76
+ \end{align*}
77
+ ```
78
+
79
+ Substituting these back into the original equations yields a hardware-efficient chunkwise algorithm for DeltaNet that leverages matrix multiplications, enabling tensor core based GPU optimization:
80
+ ```math
81
+ \begin{equation}
82
+ \begin{aligned}
83
+ \mathbf{S} &= \mathbf{P}\cdot\mathbf{S}^0 + \mathbf{H} \\
84
+ &= \mathbf{S}^0 + \mathbf{K}^\top (\mathbf{U} -\mathbf{W} \mathbf{S}^0) \in \mathbb{R}^{d_k \times d_v}\\
85
+ \mathbf{O} &= \mathbf{Q} \mathbf{S}^0 + (\mathbf{Q} \mathbf{K}^{\top} \odot \mathbf{M}) \left(\mathbf{U} - \mathbf{W} \mathbf{S}^0\right) \in \mathbb{R}^{C \times d_v}
86
+ \end{aligned}
87
+ \end{equation}
88
+ ```
89
+
90
+ [^1]: https://arxiv.org/abs/2406.06484
fla/ops/delta_rule/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_delta_rule
4
+ from .fused_chunk import fused_chunk_delta_rule
5
+ from .fused_recurrent import fused_recurrent_delta_rule
6
+
7
+ __all__ = [
8
+ 'fused_chunk_delta_rule',
9
+ 'fused_recurrent_delta_rule',
10
+ 'chunk_delta_rule'
11
+ ]
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (392 Bytes). View file
 
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (34 kB). View file