Maxtimer97 commited on
Commit
8a2bc5d
·
1 Parent(s): 2495dfe

Added modeling files

Browse files
configuration_chatglm.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+
7
+ def __init__(
8
+ self,
9
+ num_layers=28,
10
+ padded_vocab_size=65024,
11
+ hidden_size=4096,
12
+ ffn_hidden_size=13696,
13
+ kv_channels=128,
14
+ num_attention_heads=32,
15
+ seq_length=2048,
16
+ block_size=64,
17
+ kernel_size=64,
18
+ kernel_stride=64,
19
+ window_size=512,
20
+ topk=16,
21
+ init_blocks=1,
22
+ local_blocks=2,
23
+ hidden_dropout=0.0,
24
+ classifier_dropout=None,
25
+ attention_dropout=0.0,
26
+ layernorm_epsilon=1e-5,
27
+ rmsnorm=True,
28
+ apply_residual_connection_post_layernorm=False,
29
+ post_layer_norm=True,
30
+ add_bias_linear=False,
31
+ add_qkv_bias=False,
32
+ bias_dropout_fusion=True,
33
+ multi_query_attention=False,
34
+ multi_query_group_num=1,
35
+ rope_ratio=1,
36
+ apply_query_key_layer_scaling=True,
37
+ attention_softmax_in_fp32=True,
38
+ fp32_residual_connection=False,
39
+ **kwargs
40
+ ):
41
+ self.num_layers = num_layers
42
+ self.vocab_size = padded_vocab_size
43
+ self.padded_vocab_size = padded_vocab_size
44
+ self.hidden_size = hidden_size
45
+ self.ffn_hidden_size = ffn_hidden_size
46
+ self.kv_channels = kv_channels
47
+ self.num_attention_heads = num_attention_heads
48
+ self.seq_length = seq_length
49
+ self.block_size = block_size
50
+ self.kernel_size = kernel_size
51
+ self.kernel_stride = kernel_stride
52
+ self.window_size = window_size
53
+ self.topk = topk
54
+ self.init_blocks = init_blocks
55
+ self.local_blocks = local_blocks
56
+ self.hidden_dropout = hidden_dropout
57
+ self.classifier_dropout = classifier_dropout
58
+ self.attention_dropout = attention_dropout
59
+ self.layernorm_epsilon = layernorm_epsilon
60
+ self.rmsnorm = rmsnorm
61
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
62
+ self.post_layer_norm = post_layer_norm
63
+ self.add_bias_linear = add_bias_linear
64
+ self.add_qkv_bias = add_qkv_bias
65
+ self.bias_dropout_fusion = bias_dropout_fusion
66
+ self.multi_query_attention = multi_query_attention
67
+ self.multi_query_group_num = multi_query_group_num
68
+ self.rope_ratio = rope_ratio
69
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
70
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
71
+ self.fp32_residual_connection = fp32_residual_connection
72
+ super().__init__(**kwargs)
modeling_chatglm.py ADDED
@@ -0,0 +1,1382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import sys
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
+ from torch.nn.utils import skip_init
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+ from einops import rearrange
13
+
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPast,
16
+ CausalLMOutputWithPast,
17
+ SequenceClassifierOutputWithPast,
18
+ )
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import logging, is_torch_npu_available
21
+ from transformers.generation.logits_process import LogitsProcessor
22
+ from transformers.generation.utils import ModelOutput
23
+ from transformers.generation.utils import GenerationMixin
24
+
25
+
26
+ try:
27
+ from configuration_chatglm import ChatGLMConfig
28
+ from ops.pooling import mean_pooling
29
+ from ops.compressed_attention import compressed_attention
30
+ from ops.topk_sparse_attention import topk_sparse_attention
31
+ except ImportError:
32
+ from .configuration_chatglm import ChatGLMConfig
33
+ from .ops.pooling import mean_pooling
34
+ from .ops.compressed_attention import compressed_attention
35
+ from .ops.topk_sparse_attention import topk_sparse_attention
36
+
37
+ try:
38
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
39
+
40
+ if is_flash_attn_2_available():
41
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
42
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
43
+ except:
44
+ pass
45
+
46
+ # flags required to enable jit fusion kernels
47
+
48
+ if sys.platform != 'darwin' and not is_torch_npu_available():
49
+ torch._C._jit_set_profiling_mode(False)
50
+ torch._C._jit_set_profiling_executor(False)
51
+ torch._C._jit_override_can_fuse_on_cpu(True)
52
+ torch._C._jit_override_can_fuse_on_gpu(True)
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
57
+ _CONFIG_FOR_DOC = "ChatGLMConfig"
58
+
59
+
60
+ def default_init(cls, *args, **kwargs):
61
+ return cls(*args, **kwargs)
62
+
63
+
64
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
65
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
66
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
67
+ scores.zero_()
68
+ scores[..., 198] = 5e4
69
+ return scores
70
+
71
+
72
+ def split_tensor_along_last_dim(
73
+ tensor: torch.Tensor,
74
+ num_partitions: int,
75
+ contiguous_split_chunks: bool = False,
76
+ ) -> List[torch.Tensor]:
77
+ """Split a tensor along its last dimension.
78
+
79
+ Arguments:
80
+ tensor: input tensor.
81
+ num_partitions: number of partitions to split the tensor
82
+ contiguous_split_chunks: If True, make each chunk contiguous
83
+ in memory.
84
+
85
+ Returns:
86
+ A list of Tensors
87
+ """
88
+ # Get the size and dimension.
89
+ last_dim = tensor.dim() - 1
90
+ last_dim_size = tensor.size()[last_dim] // num_partitions
91
+ # Split.
92
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
93
+ # Note: torch.split does not create contiguous tensors by default.
94
+ if contiguous_split_chunks:
95
+ return tuple(chunk.contiguous() for chunk in tensor_list)
96
+
97
+ return tensor_list
98
+
99
+
100
+ class RotaryEmbedding(nn.Module):
101
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
102
+ super().__init__()
103
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
104
+ self.register_buffer("inv_freq", inv_freq)
105
+ self.dim = dim
106
+ self.original_impl = original_impl
107
+ self.rope_ratio = rope_ratio
108
+
109
+ def forward_impl(
110
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
111
+ ):
112
+ """Enhanced Transformer with Rotary Position Embedding.
113
+
114
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
115
+ transformers/rope/__init__.py. MIT License:
116
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
117
+ """
118
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
119
+ base = base * self.rope_ratio
120
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
121
+
122
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
123
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
124
+
125
+ # Calculate the product of position index and $\theta_i$
126
+ idx_theta = torch.outer(seq_idx, theta).float()
127
+
128
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
129
+
130
+ # this is to mimic the behaviour of complex32, else we will get different results
131
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
132
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
133
+ return cache
134
+
135
+ def forward(self, max_seq_len, offset=0):
136
+ return self.forward_impl(
137
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
138
+ )
139
+
140
+
141
+ @torch.jit.script
142
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
143
+ # x: [b, np, sq, hn]
144
+ b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
145
+ rot_dim = rope_cache.shape[-2] * 2
146
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
147
+ # truncate to support variable sizes
148
+ rope_cache = rope_cache[:, :sq]
149
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
150
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
151
+ x_out2 = torch.stack(
152
+ [
153
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
154
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
155
+ ],
156
+ -1,
157
+ )
158
+ x_out2 = x_out2.flatten(3)
159
+ return torch.cat((x_out2, x_pass), dim=-1)
160
+
161
+
162
+ class RMSNorm(torch.nn.Module):
163
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
164
+ super().__init__()
165
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
166
+ self.eps = eps
167
+
168
+ def forward(self, hidden_states: torch.Tensor):
169
+ input_dtype = hidden_states.dtype
170
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
171
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
172
+
173
+ return (self.weight * hidden_states).to(input_dtype)
174
+
175
+
176
+ class CoreAttention(torch.nn.Module):
177
+ def __init__(self, config: ChatGLMConfig, layer_number):
178
+ super(CoreAttention, self).__init__()
179
+ self.config = config
180
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
181
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
182
+ if self.apply_query_key_layer_scaling:
183
+ self.attention_softmax_in_fp32 = True
184
+ self.layer_number = max(1, layer_number)
185
+ self.is_causal = True
186
+
187
+ projection_size = config.kv_channels * config.num_attention_heads
188
+
189
+ # Per attention head and per partition values.
190
+ self.hidden_size_per_partition = projection_size
191
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
192
+ self.num_attention_heads_per_partition = config.num_attention_heads
193
+
194
+ coeff = None
195
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
196
+ if self.apply_query_key_layer_scaling:
197
+ coeff = self.layer_number
198
+ self.norm_factor *= coeff
199
+ self.coeff = coeff
200
+
201
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
202
+
203
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
204
+ # [b, np, sq, sk]
205
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
206
+
207
+ # [b, np, sq, hn] -> [b * np, sq, hn]
208
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
209
+ # [b, np, sk, hn] -> [b * np, sk, hn]
210
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
211
+
212
+ # preallocting input tensor: [b * np, sq, sk]
213
+ matmul_input_buffer = torch.empty(
214
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
215
+ device=query_layer.device
216
+ )
217
+
218
+ # Raw attention scores. [b * np, sq, sk]
219
+ matmul_result = torch.baddbmm(
220
+ matmul_input_buffer,
221
+ query_layer, # [b * np, sq, hn]
222
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
223
+ beta=0.0,
224
+ alpha=(1.0 / self.norm_factor),
225
+ )
226
+
227
+ # change view to [b, np, sq, sk]
228
+ attention_scores = matmul_result.view(*output_size)
229
+
230
+ # ===========================
231
+ # Attention probs and dropout
232
+ # ===========================
233
+
234
+ # attention scores and attention mask [b, np, sq, sk]
235
+ if self.attention_softmax_in_fp32:
236
+ attention_scores = attention_scores.float()
237
+ if self.coeff is not None:
238
+ attention_scores = attention_scores * self.coeff
239
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
240
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
241
+ device=attention_scores.device, dtype=torch.bool)
242
+ attention_mask.tril_()
243
+ attention_mask = ~attention_mask
244
+ if attention_mask is not None:
245
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
246
+ attention_probs = F.softmax(attention_scores, dim=-1)
247
+ attention_probs = attention_probs.type_as(value_layer)
248
+
249
+ # This is actually dropping out entire tokens to attend to, which might
250
+ # seem a bit unusual, but is taken from the original Transformer paper.
251
+ attention_probs = self.attention_dropout(attention_probs)
252
+
253
+ # query layer shape: [b * np, sq, hn]
254
+ # value layer shape: [b, np, sk, hn]
255
+ # attention shape: [b, np, sq, sk]
256
+ # context layer shape: [b, np, sq, hn]
257
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
258
+ # change view [b * np, sk, hn]
259
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
260
+ # change view [b * np, sq, sk]
261
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
262
+ # matmul: [b * np, sq, hn]
263
+ context_layer = torch.bmm(attention_probs, value_layer)
264
+ # change view [b, np, sq, hn]
265
+ context_layer = context_layer.view(*output_size)
266
+ # [b, np, sq, hn] --> [b, sq, np, hn]
267
+ context_layer = context_layer.transpose(1, 2).contiguous()
268
+ # [b, sq, np, hn] --> [b, sq, hp]
269
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
270
+ context_layer = context_layer.reshape(*new_context_layer_shape)
271
+
272
+ return context_layer
273
+
274
+
275
+ class SdpaAttention(CoreAttention):
276
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
277
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
278
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
279
+ is_causal=True,
280
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
281
+ else:
282
+ if attention_mask is not None:
283
+ attention_mask = ~attention_mask
284
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
285
+ attention_mask,
286
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
287
+ context_layer = context_layer.transpose(1, 2).contiguous()
288
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
289
+ context_layer = context_layer.reshape(*new_context_layer_shape)
290
+ return context_layer
291
+
292
+
293
+ def _get_unpad_data(attention_mask):
294
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
295
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
296
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
297
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
298
+ return (
299
+ indices,
300
+ cu_seqlens,
301
+ max_seqlen_in_batch,
302
+ )
303
+
304
+
305
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
306
+ class FlashAttention2(CoreAttention):
307
+ def __init__(self, *args, **kwargs):
308
+ super().__init__(*args, **kwargs)
309
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
310
+
311
+ def forward(self, query_states, key_states, value_states, attention_mask):
312
+ query_states = query_states.transpose(1, 2)
313
+ key_states = key_states.transpose(1, 2)
314
+ value_states = value_states.transpose(1, 2)
315
+ batch_size, query_length = query_states.shape[:2]
316
+ if not self._flash_attn_uses_top_left_mask:
317
+ causal = self.is_causal
318
+ else:
319
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
320
+ causal = self.is_causal and query_length != 1
321
+ dropout = self.config.attention_dropout if self.training else 0.0
322
+ # Contains at least one padding token in the sequence
323
+ if attention_mask is not None:
324
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
325
+ query_states, key_states, value_states, attention_mask, query_length
326
+ )
327
+
328
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
329
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
330
+
331
+ attn_output_unpad = flash_attn_varlen_func(
332
+ query_states,
333
+ key_states,
334
+ value_states,
335
+ cu_seqlens_q=cu_seqlens_q,
336
+ cu_seqlens_k=cu_seqlens_k,
337
+ max_seqlen_q=max_seqlen_in_batch_q,
338
+ max_seqlen_k=max_seqlen_in_batch_k,
339
+ dropout_p=dropout,
340
+ softmax_scale=None,
341
+ causal=causal,
342
+ )
343
+
344
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
345
+ else:
346
+ attn_output = flash_attn_func(
347
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
348
+ )
349
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
350
+ return attn_output
351
+
352
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
353
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
354
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
355
+
356
+ key_layer = index_first_axis(
357
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
358
+ )
359
+ value_layer = index_first_axis(
360
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
361
+ )
362
+ if query_length == kv_seq_len:
363
+ query_layer = index_first_axis(
364
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
365
+ indices_k
366
+ )
367
+ cu_seqlens_q = cu_seqlens_k
368
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
369
+ indices_q = indices_k
370
+ elif query_length == 1:
371
+ max_seqlen_in_batch_q = 1
372
+ cu_seqlens_q = torch.arange(
373
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
374
+ ) # There is a memcpy here, that is very bad.
375
+ indices_q = cu_seqlens_q[:-1]
376
+ query_layer = query_layer.squeeze(1)
377
+ else:
378
+ # The -q_len: slice assumes left padding.
379
+ attention_mask = attention_mask[:, -query_length:]
380
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
381
+
382
+ return (
383
+ query_layer,
384
+ key_layer,
385
+ value_layer,
386
+ indices_q,
387
+ (cu_seqlens_q, cu_seqlens_k),
388
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
389
+ )
390
+
391
+ class NativeSparseAttention(CoreAttention):
392
+ def __init__(self, *args, **kwargs):
393
+ super().__init__(*args, **kwargs)
394
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
395
+ self.block_size = self.config.block_size
396
+ self.kernel_size = self.config.kernel_size
397
+ self.kernel_stride = self.config.kernel_stride
398
+ self.init_blocks = self.config.init_blocks
399
+ self.local_blocks = self.config.local_blocks
400
+ self.topk = self.config.topk
401
+ self.window_size = self.config.window_size
402
+
403
+ def forward(self, query_states, key_states, value_states, g_cmp, g_spa, g_swa, attention_mask):
404
+ query_states = query_states.transpose(1, 2)
405
+ key_states = key_states.transpose(1, 2)
406
+ value_states = value_states.transpose(1, 2)
407
+ batch_size, query_length = query_states.shape[:2]
408
+ if not self._flash_attn_uses_top_left_mask:
409
+ causal = self.is_causal
410
+ else:
411
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
412
+ causal = self.is_causal and query_length != 1
413
+ dropout = self.config.attention_dropout if self.training else 0.0
414
+ # Contains at least one padding token in the sequence
415
+ if attention_mask is not None:
416
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
417
+ query_states, key_states, value_states, attention_mask, query_length
418
+ )
419
+
420
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
421
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
422
+
423
+ k_cmp, v_cmp = mean_pooling(key_states, self.block_size, cu_seq_lens), mean_pooling(value_states, self.block_size, cu_seq_lens)
424
+
425
+ # compute seqlens after compression
426
+ seqlens = cu_seq_lens[1:] - cu_seq_lens[:-1]
427
+ y_seqlens = torch.floor((seqlens - self.kernel_size) / self.kernel_stride).to(torch.int32) + 1
428
+ # corner case: if sequence_length < kernel_size, no compression for this sequence
429
+ y_seqlens[seqlens < self.kernel_size] = 0
430
+ cmp_seqlens = torch.cat(
431
+ [
432
+ torch.zeros(1, dtype=torch.int32, device="cuda"),
433
+ torch.cumsum(y_seqlens, dim=0),
434
+ ],
435
+ dim=0,
436
+ ).to(torch.int32)
437
+
438
+ # attention between query and compressed key value
439
+ compressed_seqlens = cmp_seqlens[1:] - cmp_seqlens[:-1]
440
+ compressed_attn_output, topk_idx = compressed_attention(
441
+ query_states,
442
+ k_cmp,
443
+ v_cmp,
444
+ self.kernel_size,
445
+ self.kernel_stride,
446
+ self.block_size,
447
+ self.topk,
448
+ cu_seq_lens,
449
+ cmp_seqlens,
450
+ seqlens.max().item(),
451
+ compressed_seqlens.max().item(),
452
+ None,
453
+ self.init_blocks,
454
+ self.local_blocks,
455
+ parallel_topk_compute=False,
456
+ )
457
+
458
+ attention_out = compressed_attn_output * g_cmp.unsqueeze(-1)
459
+
460
+ # topk sparse attention
461
+ seqlens = cu_seq_lens[1:] - cu_seq_lens[:-1]
462
+ sparse_attn_output = topk_sparse_attention(
463
+ query_states, key_states, value_states, topk_idx, self.block_size, cu_seq_lens, None
464
+ )
465
+
466
+ attention_out = torch.addcmul(attention_out, sparse_attn_output, g_spa.unsqueeze(-1))
467
+
468
+ sliding_attn_output = flash_attn_varlen_func(
469
+ query_states,
470
+ key_states,
471
+ value_states,
472
+ cu_seqlens_q=cu_seqlens_q,
473
+ cu_seqlens_k=cu_seqlens_k,
474
+ max_seqlen_q=max_seqlen_in_batch_q,
475
+ max_seqlen_k=max_seqlen_in_batch_k,
476
+ dropout_p=dropout,
477
+ softmax_scale=None,
478
+ causal=causal,
479
+ window_size=(self.window_size, -1),
480
+ )
481
+
482
+ attn_output_unpad = torch.addcmul(attention_out, sliding_attn_output, g_swa.unsqueeze(-1))
483
+
484
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
485
+ else:
486
+ cu_seq_lens = torch.arange(0, (query_states.shape[0] + 1) * query_states.shape[1], step=query_states.shape[1], device=query_states.device, dtype=torch.int32)
487
+ # cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
+ # max_seq_lens = cu_seq_lens.max().item()
489
+ # max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
490
+
491
+ k_cmp, v_cmp = mean_pooling(key_states, self.block_size, cu_seq_lens), mean_pooling(value_states, self.block_size, cu_seq_lens)
492
+ query_states, key_states, value_states = map(lambda x: rearrange(x, 'b t h d -> (b t) h d'), (query_states, key_states, value_states))
493
+ k_cmp, v_cmp = map(lambda x: rearrange(x, 'b t h d -> (b t) h d'), (k_cmp, v_cmp))
494
+ # compute seqlens after compression
495
+ seqlens = cu_seq_lens[1:] - cu_seq_lens[:-1]
496
+ y_seqlens = torch.floor((seqlens - self.kernel_size) / self.kernel_stride).to(torch.int32) + 1
497
+ # corner case: if sequence_length < kernel_size, no compression for this sequence
498
+ y_seqlens[seqlens < self.kernel_size] = 0
499
+ cmp_seqlens = torch.cat(
500
+ [
501
+ torch.zeros(1, dtype=torch.int32, device="cuda"),
502
+ torch.cumsum(y_seqlens, dim=0),
503
+ ],
504
+ dim=0,
505
+ ).to(torch.int32)
506
+
507
+ # attention between query and compressed key value
508
+ compressed_seqlens = cmp_seqlens[1:] - cmp_seqlens[:-1]
509
+ compressed_attn_output, topk_idx = compressed_attention(
510
+ query_states,
511
+ k_cmp,
512
+ v_cmp,
513
+ self.kernel_size,
514
+ self.kernel_stride,
515
+ self.block_size,
516
+ self.topk,
517
+ cu_seq_lens,
518
+ cmp_seqlens,
519
+ seqlens.max().item(),
520
+ compressed_seqlens.max().item(),
521
+ None,
522
+ self.init_blocks,
523
+ self.local_blocks,
524
+ parallel_topk_compute=False,
525
+ )
526
+
527
+ attention_out = compressed_attn_output * g_cmp.unsqueeze(-1)
528
+
529
+ # topk sparse attention
530
+ seqlens = cu_seq_lens[1:] - cu_seq_lens[:-1]
531
+ sparse_attn_output = topk_sparse_attention(
532
+ query_states, key_states, value_states, topk_idx, self.block_size, cu_seq_lens, None
533
+ )
534
+
535
+ attention_out = torch.addcmul(attention_out, sparse_attn_output, g_spa.unsqueeze(-1))
536
+
537
+ query_states, key_states, value_states = map(lambda x: rearrange(x, '(b t) h d -> b t h d', b=batch_size), (query_states, key_states, value_states))
538
+ sliding_attn_output = flash_attn_func(
539
+ query_states,
540
+ key_states,
541
+ value_states,
542
+ dropout_p=dropout,
543
+ softmax_scale=None,
544
+ causal=causal,
545
+ window_size=(self.window_size, -1),
546
+ )
547
+
548
+ attn_output = torch.addcmul(attention_out, sliding_attn_output, g_swa.unsqueeze(-1))
549
+
550
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
551
+ return attn_output
552
+
553
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
554
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
555
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
556
+
557
+ key_layer = index_first_axis(
558
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
559
+ )
560
+ value_layer = index_first_axis(
561
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
562
+ )
563
+ if query_length == kv_seq_len:
564
+ query_layer = index_first_axis(
565
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
566
+ indices_k
567
+ )
568
+ cu_seqlens_q = cu_seqlens_k
569
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
570
+ indices_q = indices_k
571
+ elif query_length == 1:
572
+ max_seqlen_in_batch_q = 1
573
+ cu_seqlens_q = torch.arange(
574
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
575
+ ) # There is a memcpy here, that is very bad.
576
+ indices_q = cu_seqlens_q[:-1]
577
+ query_layer = query_layer.squeeze(1)
578
+ else:
579
+ # The -q_len: slice assumes left padding.
580
+ attention_mask = attention_mask[:, -query_length:]
581
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
582
+
583
+ return (
584
+ query_layer,
585
+ key_layer,
586
+ value_layer,
587
+ indices_q,
588
+ (cu_seqlens_q, cu_seqlens_k),
589
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
590
+ )
591
+
592
+
593
+ CORE_ATTENTION_CLASSES = {
594
+ "eager": CoreAttention,
595
+ "sdpa": SdpaAttention,
596
+ "flash_attention_2": FlashAttention2,
597
+ "nsa": NativeSparseAttention
598
+ }
599
+
600
+
601
+ class SelfAttention(torch.nn.Module):
602
+ """Parallel self-attention layer abstract class.
603
+
604
+ Self-attention layer takes input with size [s, b, h]
605
+ and returns output of the same size.
606
+ """
607
+
608
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
609
+ super(SelfAttention, self).__init__()
610
+ self.layer_number = max(1, layer_number)
611
+ self.attn_implementation = config.attn_implementation
612
+
613
+ self.projection_size = config.kv_channels * config.num_attention_heads
614
+
615
+ # Per attention head and per partition values.
616
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
617
+ self.num_attention_heads_per_partition = config.num_attention_heads
618
+
619
+ self.multi_query_attention = config.multi_query_attention
620
+ self.qkv_hidden_size = 3 * self.projection_size
621
+ if self.multi_query_attention:
622
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
623
+ self.qkv_hidden_size = (
624
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
625
+ )
626
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
627
+ bias=config.add_bias_linear or config.add_qkv_bias,
628
+ device=device, **_config_to_kwargs(config)
629
+ )
630
+
631
+ # Gate for NSA between compressed, sparse and sliding window attentions
632
+ self.gate = nn.Linear(config.hidden_size, config.num_attention_heads*3, bias=True)
633
+ # Init such that the sigmoid gives 1/3
634
+ with torch.no_grad():
635
+ self.gate.weight.zero_()
636
+ self.gate.bias.fill_(-math.log(2)) # sigmoid ≈ 1/3
637
+
638
+ self.core_attention = CORE_ATTENTION_CLASSES[config.attn_implementation](config, self.layer_number)
639
+
640
+ # Output.
641
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
642
+ device=device, **_config_to_kwargs(config)
643
+ )
644
+
645
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
646
+ if self.multi_query_attention:
647
+ num_attention_heads = self.num_multi_query_groups_per_partition
648
+ else:
649
+ num_attention_heads = self.num_attention_heads_per_partition
650
+ return torch.empty(
651
+ inference_max_sequence_len,
652
+ batch_size,
653
+ num_attention_heads,
654
+ self.hidden_size_per_attention_head,
655
+ dtype=dtype,
656
+ device=device,
657
+ )
658
+
659
+ def forward(
660
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
661
+ ):
662
+ # hidden_states: [b, sq, h]
663
+
664
+ # =================================================
665
+ # Pre-allocate memory for key-values for inference.
666
+ # =================================================
667
+ # =====================
668
+ # Query, Key, and Value
669
+ # =====================
670
+
671
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
672
+ mixed_x_layer = self.query_key_value(hidden_states)
673
+
674
+ if self.multi_query_attention:
675
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
676
+ [
677
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
678
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
679
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
680
+ ],
681
+ dim=-1,
682
+ )
683
+ query_layer = query_layer.view(
684
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
685
+ )
686
+ key_layer = key_layer.view(
687
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
688
+ )
689
+ value_layer = value_layer.view(
690
+ value_layer.size()[:-1]
691
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
692
+ )
693
+ else:
694
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
695
+ (self.num_attention_heads_per_partition,
696
+ 3 * self.hidden_size_per_attention_head)
697
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
698
+
699
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
700
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
701
+
702
+ # [b, sq, np, hn] -> [b, np, sq, hn]
703
+ query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
704
+
705
+ # apply relative positional encoding (rotary embedding)
706
+ if rotary_pos_emb is not None:
707
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
708
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
709
+
710
+ # adjust key and value for inference
711
+ if kv_cache is not None:
712
+ cache_k, cache_v = kv_cache
713
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
714
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
715
+ if use_cache:
716
+ if kv_cache is None:
717
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
718
+ dim=1)
719
+ else:
720
+ kv_cache = (key_layer, value_layer)
721
+ else:
722
+ kv_cache = None
723
+
724
+
725
+
726
+ # ==================================
727
+ # core attention computation
728
+ # ==================================
729
+
730
+ if self.attn_implementation != "nsa":
731
+
732
+ if self.multi_query_attention:
733
+ key_layer = key_layer.unsqueeze(2)
734
+ key_layer = key_layer.expand(
735
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
736
+ )
737
+ key_layer = key_layer.contiguous().view(
738
+ key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
739
+ )
740
+ value_layer = value_layer.unsqueeze(2)
741
+ value_layer = value_layer.expand(
742
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
743
+ )
744
+ value_layer = value_layer.contiguous().view(
745
+ value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
746
+ )
747
+
748
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
749
+
750
+ else:
751
+ g = rearrange(self.gate(hidden_states), '... (h d) -> ... h d', d=3)
752
+ g_cmp, g_spa, g_swa = g.sigmoid().unbind(-1)
753
+
754
+ context_layer = self.core_attention(query_layer, key_layer, value_layer,
755
+ g_cmp, g_spa, g_swa, attention_mask)
756
+
757
+ # =================
758
+ # Output. [sq, b, h]
759
+ # =================
760
+
761
+ output = self.dense(context_layer)
762
+
763
+ return output, kv_cache
764
+
765
+
766
+ def _config_to_kwargs(args):
767
+ common_kwargs = {
768
+ "dtype": args.torch_dtype,
769
+ }
770
+ return common_kwargs
771
+
772
+
773
+ class MLP(torch.nn.Module):
774
+ """MLP.
775
+
776
+ MLP will take the input with h hidden state, project it to 4*h
777
+ hidden dimension, perform nonlinear transformation, and project the
778
+ state back into h hidden dimension.
779
+ """
780
+
781
+ def __init__(self, config: ChatGLMConfig, device=None):
782
+ super(MLP, self).__init__()
783
+
784
+ self.add_bias = config.add_bias_linear
785
+
786
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
787
+ self.dense_h_to_4h = nn.Linear(
788
+ config.hidden_size,
789
+ config.ffn_hidden_size * 2,
790
+ bias=self.add_bias,
791
+ device=device,
792
+ **_config_to_kwargs(config)
793
+ )
794
+
795
+ def swiglu(x):
796
+ x = torch.chunk(x, 2, dim=-1)
797
+ return F.silu(x[0]) * x[1]
798
+
799
+ self.activation_func = swiglu
800
+
801
+ # Project back to h.
802
+ self.dense_4h_to_h = nn.Linear(
803
+ config.ffn_hidden_size,
804
+ config.hidden_size,
805
+ bias=self.add_bias,
806
+ device=device,
807
+ **_config_to_kwargs(config)
808
+ )
809
+
810
+ def forward(self, hidden_states):
811
+ # [s, b, 4hp]
812
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
813
+ intermediate_parallel = self.activation_func(intermediate_parallel)
814
+ # [s, b, h]
815
+ output = self.dense_4h_to_h(intermediate_parallel)
816
+ return output
817
+
818
+
819
+ class GLMBlock(torch.nn.Module):
820
+ """A single transformer layer.
821
+
822
+ Transformer layer takes input with size [s, b, h] and returns an
823
+ output of the same size.
824
+ """
825
+
826
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
827
+ super(GLMBlock, self).__init__()
828
+ self.layer_number = layer_number
829
+
830
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
831
+
832
+ self.fp32_residual_connection = config.fp32_residual_connection
833
+
834
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
835
+ # Layernorm on the input data.
836
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
837
+ dtype=config.torch_dtype)
838
+
839
+ # Self attention.
840
+ self.self_attention = SelfAttention(config, layer_number, device=device)
841
+ self.hidden_dropout = config.hidden_dropout
842
+
843
+ # Layernorm on the attention output
844
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
845
+ dtype=config.torch_dtype)
846
+
847
+ # MLP
848
+ self.mlp = MLP(config, device=device)
849
+
850
+ def forward(
851
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
852
+ ):
853
+ # hidden_states: [s, b, h]
854
+
855
+ # Layer norm at the beginning of the transformer layer.
856
+ layernorm_output = self.input_layernorm(hidden_states)
857
+ # Self attention.
858
+ attention_output, kv_cache = self.self_attention(
859
+ layernorm_output,
860
+ attention_mask,
861
+ rotary_pos_emb,
862
+ kv_cache=kv_cache,
863
+ use_cache=use_cache
864
+ )
865
+
866
+ # Residual connection.
867
+ if self.apply_residual_connection_post_layernorm:
868
+ residual = layernorm_output
869
+ else:
870
+ residual = hidden_states
871
+
872
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
873
+ layernorm_input = residual + layernorm_input
874
+
875
+ # Layer norm post the self attention.
876
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
877
+
878
+ # MLP.
879
+ mlp_output = self.mlp(layernorm_output)
880
+
881
+ # Second residual connection.
882
+ if self.apply_residual_connection_post_layernorm:
883
+ residual = layernorm_output
884
+ else:
885
+ residual = layernorm_input
886
+
887
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
888
+ output = residual + output
889
+
890
+ return output, kv_cache
891
+
892
+
893
+ class GLMTransformer(torch.nn.Module):
894
+ """Transformer class."""
895
+
896
+ def __init__(self, config: ChatGLMConfig, device=None):
897
+ super(GLMTransformer, self).__init__()
898
+
899
+ self.fp32_residual_connection = config.fp32_residual_connection
900
+ self.post_layer_norm = config.post_layer_norm
901
+
902
+ # Number of layers.
903
+ self.num_layers = config.num_layers
904
+
905
+ # Transformer layers.
906
+ def build_layer(layer_number):
907
+ return GLMBlock(config, layer_number, device=device)
908
+
909
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
910
+
911
+ if self.post_layer_norm:
912
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
913
+ # Final layer norm before output.
914
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
915
+ dtype=config.torch_dtype)
916
+
917
+ self.gradient_checkpointing = False
918
+
919
+ def _get_layer(self, layer_number):
920
+ return self.layers[layer_number]
921
+
922
+ def forward(
923
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
924
+ use_cache: Optional[bool] = True,
925
+ output_hidden_states: Optional[bool] = False,
926
+ ):
927
+ if not kv_caches:
928
+ kv_caches = [None for _ in range(self.num_layers)]
929
+ presents = () if use_cache else None
930
+ if self.gradient_checkpointing and self.training:
931
+ if use_cache:
932
+ logger.warning_once(
933
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
934
+ )
935
+ use_cache = False
936
+
937
+ all_self_attentions = None
938
+ all_hidden_states = () if output_hidden_states else None
939
+ for index in range(self.num_layers):
940
+ if output_hidden_states:
941
+ all_hidden_states = all_hidden_states + (hidden_states,)
942
+
943
+ layer = self._get_layer(index)
944
+ if self.gradient_checkpointing and self.training:
945
+ layer_ret = torch.utils.checkpoint.checkpoint(
946
+ layer,
947
+ hidden_states,
948
+ attention_mask,
949
+ rotary_pos_emb,
950
+ kv_caches[index],
951
+ use_cache,
952
+ use_reentrant=False
953
+ )
954
+ else:
955
+ layer_ret = layer(
956
+ hidden_states,
957
+ attention_mask,
958
+ rotary_pos_emb,
959
+ kv_cache=kv_caches[index],
960
+ use_cache=use_cache
961
+ )
962
+ hidden_states, kv_cache = layer_ret
963
+ if use_cache:
964
+ # token by token decoding, use tuple format
965
+ if kv_caches[0] is not None:
966
+ presents = presents + (kv_cache,)
967
+ # prefilling in decoding, use tensor format to save cuda memory
968
+ else:
969
+ if len(presents) == 0:
970
+ presents = kv_cache
971
+ else:
972
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
973
+
974
+ if output_hidden_states:
975
+ all_hidden_states = all_hidden_states + (hidden_states,)
976
+
977
+ # Final layer norm.
978
+ if self.post_layer_norm:
979
+ hidden_states = self.final_layernorm(hidden_states)
980
+
981
+ return hidden_states, presents, all_hidden_states, all_self_attentions
982
+
983
+
984
+ class ChatGLMPreTrainedModel(PreTrainedModel):
985
+ """
986
+ An abstract class to handle weights initialization and
987
+ a simple interface for downloading and loading pretrained models.
988
+ """
989
+
990
+ is_parallelizable = False
991
+ supports_gradient_checkpointing = True
992
+ config_class = ChatGLMConfig
993
+ base_model_prefix = "transformer"
994
+ _no_split_modules = ["GLMBlock"]
995
+ _supports_flash_attn_2 = True
996
+ _supports_sdpa = True
997
+
998
+ def _init_weights(self, module: nn.Module):
999
+ """Initialize the weights."""
1000
+ return
1001
+
1002
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
1003
+ if self.config.attn_implementation == "flash_attention_2" or self.config.attn_implementation == "nsa":
1004
+ if padding_mask is not None and not padding_mask.all():
1005
+ return padding_mask
1006
+ return None
1007
+ batch_size, seq_length = input_ids.shape
1008
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
1009
+ full_attention_mask.tril_()
1010
+ past_length = 0
1011
+ if past_key_values:
1012
+ past_length = past_key_values[0][0].shape[2]
1013
+ if past_length:
1014
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
1015
+ device=input_ids.device), full_attention_mask), dim=-1)
1016
+ if padding_mask is not None:
1017
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
1018
+ if not past_length and padding_mask is not None:
1019
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
1020
+ full_attention_mask = (full_attention_mask < 0.5).bool()
1021
+ full_attention_mask.unsqueeze_(1)
1022
+ return full_attention_mask
1023
+
1024
+ def get_position_ids(self, input_ids, device):
1025
+ batch_size, seq_length = input_ids.shape
1026
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
1027
+ return position_ids
1028
+
1029
+ class Embedding(torch.nn.Module):
1030
+ """Language model embeddings."""
1031
+
1032
+ def __init__(self, config: ChatGLMConfig, device=None):
1033
+ super(Embedding, self).__init__()
1034
+
1035
+ self.hidden_size = config.hidden_size
1036
+ # Word embeddings (parallel).
1037
+ self.word_embeddings = nn.Embedding(
1038
+ config.padded_vocab_size,
1039
+ self.hidden_size,
1040
+ dtype=config.torch_dtype,
1041
+ device=device
1042
+ )
1043
+ self.fp32_residual_connection = config.fp32_residual_connection
1044
+
1045
+ def forward(self, input_ids):
1046
+ # Embeddings.
1047
+ words_embeddings = self.word_embeddings(input_ids)
1048
+ embeddings = words_embeddings
1049
+ # If the input flag for fp32 residual connection is set, convert for float.
1050
+ if self.fp32_residual_connection:
1051
+ embeddings = embeddings.float()
1052
+ return embeddings
1053
+
1054
+
1055
+ class ChatGLMModel(ChatGLMPreTrainedModel):
1056
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
1057
+ super().__init__(config)
1058
+ if empty_init:
1059
+ init_method = skip_init
1060
+ else:
1061
+ init_method = default_init
1062
+ init_kwargs = {}
1063
+ if device is not None:
1064
+ init_kwargs["device"] = device
1065
+ self.embedding = init_method(Embedding, config, **init_kwargs)
1066
+ self.num_layers = config.num_layers
1067
+ self.multi_query_group_num = config.multi_query_group_num
1068
+ self.kv_channels = config.kv_channels
1069
+
1070
+ # Rotary positional embeddings
1071
+ self.seq_length = config.seq_length
1072
+ rotary_dim = (
1073
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
1074
+ )
1075
+
1076
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
1077
+ original_impl=config.original_rope,
1078
+ device=device, dtype=config.torch_dtype)
1079
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
1080
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
1081
+ dtype=config.torch_dtype, **init_kwargs)
1082
+
1083
+ def get_input_embeddings(self):
1084
+ return self.embedding.word_embeddings
1085
+
1086
+ def set_input_embeddings(self, value):
1087
+ self.embedding.word_embeddings = value
1088
+
1089
+ def forward(
1090
+ self,
1091
+ input_ids,
1092
+ position_ids: Optional[torch.Tensor] = None,
1093
+ attention_mask: Optional[torch.BoolTensor] = None,
1094
+ full_attention_mask: Optional[torch.BoolTensor] = None,
1095
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1096
+ inputs_embeds: Optional[torch.Tensor] = None,
1097
+ use_cache: Optional[bool] = None,
1098
+ output_attentions: Optional[bool] = None,
1099
+ output_hidden_states: Optional[bool] = None,
1100
+ return_dict: Optional[bool] = None,
1101
+ ):
1102
+ output_hidden_states = (
1103
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1104
+ )
1105
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1106
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1107
+
1108
+ batch_size, seq_length = input_ids.shape
1109
+
1110
+ if inputs_embeds is None:
1111
+ inputs_embeds = self.embedding(input_ids)
1112
+
1113
+ if full_attention_mask is None:
1114
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
1115
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
1116
+
1117
+ # Rotary positional embeddings
1118
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
1119
+ if position_ids is not None:
1120
+ rotary_pos_emb = rotary_pos_emb[position_ids]
1121
+ else:
1122
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
1123
+
1124
+ # Run encoder.
1125
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
1126
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
1127
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
1128
+ )
1129
+ if presents is not None and type(presents) is torch.Tensor:
1130
+ presents = presents.split(1, dim=0)
1131
+ presents = list(presents)
1132
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
1133
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
1134
+ presents = tuple(presents)
1135
+
1136
+ if not return_dict:
1137
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1138
+
1139
+ return BaseModelOutputWithPast(
1140
+ last_hidden_state=hidden_states,
1141
+ past_key_values=presents,
1142
+ hidden_states=all_hidden_states,
1143
+ attentions=all_self_attentions,
1144
+ )
1145
+
1146
+
1147
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):
1148
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1149
+ super().__init__(config)
1150
+
1151
+ self.max_sequence_length = config.max_length
1152
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1153
+ self.config = config
1154
+
1155
+ def _update_model_kwargs_for_generation(
1156
+ self,
1157
+ outputs: ModelOutput,
1158
+ model_kwargs: Dict[str, Any],
1159
+ is_encoder_decoder: bool = False,
1160
+ num_new_tokens: int = 1,
1161
+ ) -> Dict[str, Any]:
1162
+ # update past_key_values
1163
+ for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
1164
+ if hasattr(outputs, possible_cache_name):
1165
+ if possible_cache_name in ("past_buckets_states", "mems"):
1166
+ cache_name = "past_key_values"
1167
+ else:
1168
+ cache_name = possible_cache_name
1169
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
1170
+ break
1171
+
1172
+ # update attention mask
1173
+ if "attention_mask" in model_kwargs:
1174
+ attention_mask = model_kwargs["attention_mask"]
1175
+ model_kwargs["attention_mask"] = torch.cat(
1176
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1177
+ )
1178
+
1179
+ # update position ids
1180
+ if "position_ids" in model_kwargs:
1181
+ position_ids = model_kwargs["position_ids"]
1182
+ new_position_id = position_ids[..., -1:].clone()
1183
+ new_position_id += 1
1184
+ model_kwargs["position_ids"] = torch.cat(
1185
+ [position_ids, new_position_id], dim=-1
1186
+ )
1187
+
1188
+ model_kwargs["is_first_forward"] = False
1189
+
1190
+ if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
1191
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
1192
+
1193
+ return model_kwargs
1194
+
1195
+ def prepare_inputs_for_generation(
1196
+ self,
1197
+ input_ids: torch.LongTensor,
1198
+ past_key_values: Optional[torch.Tensor] = None,
1199
+ attention_mask: Optional[torch.Tensor] = None,
1200
+ position_ids: Optional[torch.Tensor] = None,
1201
+ use_cache: Optional[bool] = None,
1202
+ is_first_forward: bool = True,
1203
+ **kwargs
1204
+ ) -> dict:
1205
+ # only last token for input_ids if past is not None
1206
+ if position_ids is None:
1207
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1208
+ if not is_first_forward:
1209
+ if past_key_values is not None:
1210
+ position_ids = position_ids[..., -1:]
1211
+ input_ids = input_ids[:, -1:]
1212
+ return {
1213
+ "input_ids": input_ids,
1214
+ "past_key_values": past_key_values,
1215
+ "position_ids": position_ids,
1216
+ "attention_mask": attention_mask,
1217
+ "return_last_logit": True,
1218
+ "use_cache": use_cache
1219
+ }
1220
+
1221
+ def forward(
1222
+ self,
1223
+ input_ids: Optional[torch.Tensor] = None,
1224
+ position_ids: Optional[torch.Tensor] = None,
1225
+ attention_mask: Optional[torch.Tensor] = None,
1226
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1227
+ inputs_embeds: Optional[torch.Tensor] = None,
1228
+ labels: Optional[torch.Tensor] = None,
1229
+ use_cache: Optional[bool] = None,
1230
+ output_attentions: Optional[bool] = None,
1231
+ output_hidden_states: Optional[bool] = None,
1232
+ return_dict: Optional[bool] = None,
1233
+ return_last_logit: Optional[bool] = False,
1234
+ ):
1235
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ transformer_outputs = self.transformer(
1239
+ input_ids=input_ids,
1240
+ position_ids=position_ids,
1241
+ attention_mask=attention_mask,
1242
+ past_key_values=past_key_values,
1243
+ inputs_embeds=inputs_embeds,
1244
+ use_cache=use_cache,
1245
+ output_hidden_states=output_hidden_states,
1246
+ return_dict=return_dict,
1247
+ )
1248
+
1249
+ hidden_states = transformer_outputs[0]
1250
+ if return_last_logit:
1251
+ hidden_states = hidden_states[:, -1:]
1252
+ lm_logits = self.transformer.output_layer(hidden_states)
1253
+
1254
+ loss = None
1255
+ if labels is not None:
1256
+ lm_logits = lm_logits.to(torch.float32)
1257
+
1258
+ # Shift so that tokens < n predict n
1259
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1260
+ shift_labels = labels[..., 1:].contiguous()
1261
+ # Flatten the tokens
1262
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1263
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1264
+
1265
+ lm_logits = lm_logits.to(hidden_states.dtype)
1266
+ loss = loss.to(hidden_states.dtype)
1267
+
1268
+ if not return_dict:
1269
+ output = (lm_logits,) + transformer_outputs[1:]
1270
+ return ((loss,) + output) if loss is not None else output
1271
+
1272
+ return CausalLMOutputWithPast(
1273
+ loss=loss,
1274
+ logits=lm_logits,
1275
+ past_key_values=transformer_outputs.past_key_values,
1276
+ hidden_states=transformer_outputs.hidden_states,
1277
+ attentions=transformer_outputs.attentions,
1278
+ )
1279
+
1280
+ @staticmethod
1281
+ def _reorder_cache(
1282
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1283
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1284
+ """
1285
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1286
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1287
+ beam_idx at every generation step.
1288
+
1289
+ Output shares the same memory storage as `past`.
1290
+ """
1291
+ return tuple(
1292
+ (
1293
+ layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
1294
+ layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
1295
+ )
1296
+ for layer_past in past
1297
+ )
1298
+
1299
+
1300
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1301
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1302
+ super().__init__(config)
1303
+
1304
+ self.num_labels = config.num_labels
1305
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1306
+
1307
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype)
1308
+ if config.classifier_dropout is not None:
1309
+ self.dropout = nn.Dropout(config.classifier_dropout)
1310
+ else:
1311
+ self.dropout = None
1312
+ self.config = config
1313
+
1314
+ def forward(
1315
+ self,
1316
+ input_ids: Optional[torch.LongTensor] = None,
1317
+ position_ids: Optional[torch.LongTensor] = None,
1318
+ attention_mask: Optional[torch.Tensor] = None,
1319
+ full_attention_mask: Optional[torch.Tensor] = None,
1320
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1321
+ inputs_embeds: Optional[torch.LongTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ use_cache: Optional[bool] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1329
+
1330
+ transformer_outputs = self.transformer(
1331
+ input_ids=input_ids,
1332
+ position_ids=position_ids,
1333
+ attention_mask=attention_mask,
1334
+ full_attention_mask=full_attention_mask,
1335
+ past_key_values=past_key_values,
1336
+ inputs_embeds=inputs_embeds,
1337
+ use_cache=use_cache,
1338
+ output_attentions=output_attentions,
1339
+ output_hidden_states=output_hidden_states,
1340
+ return_dict=return_dict,
1341
+ )
1342
+
1343
+ hidden_states = transformer_outputs[0]
1344
+ pooled_hidden_states = hidden_states[:, -1]
1345
+ if self.dropout is not None:
1346
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1347
+ logits = self.classifier_head(pooled_hidden_states)
1348
+
1349
+ loss = None
1350
+ if labels is not None:
1351
+ if self.config.problem_type is None:
1352
+ if self.num_labels == 1:
1353
+ self.config.problem_type = "regression"
1354
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1355
+ self.config.problem_type = "single_label_classification"
1356
+ else:
1357
+ self.config.problem_type = "multi_label_classification"
1358
+
1359
+ if self.config.problem_type == "regression":
1360
+ loss_fct = MSELoss()
1361
+ if self.num_labels == 1:
1362
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1363
+ else:
1364
+ loss = loss_fct(logits.float(), labels)
1365
+ elif self.config.problem_type == "single_label_classification":
1366
+ loss_fct = CrossEntropyLoss()
1367
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1368
+ elif self.config.problem_type == "multi_label_classification":
1369
+ loss_fct = BCEWithLogitsLoss()
1370
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1371
+
1372
+ if not return_dict:
1373
+ output = (logits,) + transformer_outputs[1:]
1374
+ return ((loss,) + output) if loss is not None else output
1375
+
1376
+ return SequenceClassifierOutputWithPast(
1377
+ loss=loss,
1378
+ logits=logits,
1379
+ past_key_values=transformer_outputs.past_key_values,
1380
+ hidden_states=transformer_outputs.hidden_states,
1381
+ attentions=transformer_outputs.attentions,
1382
+ )
ops/__pycache__/compressed_attention.cpython-310.pyc ADDED
Binary file (20.9 kB). View file
 
ops/__pycache__/pooling.cpython-310.pyc ADDED
Binary file (5.59 kB). View file
 
ops/__pycache__/topk_sparse_attention.cpython-310.pyc ADDED
Binary file (19 kB). View file
 
ops/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
ops/compressed_attention.py ADDED
@@ -0,0 +1,1320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import warnings
16
+ from typing import Any, Tuple, Union
17
+
18
+ import torch
19
+ import triton
20
+ import triton.language as tl
21
+
22
+ try:
23
+ from ops.utils import get_num_warps_stages, is_hopper_gpu
24
+ except ImportError:
25
+ from .ops.utils import get_num_warps_stages, is_hopper_gpu
26
+
27
+ IS_HOPPER_GPU = is_hopper_gpu()
28
+
29
+
30
+ @triton.jit
31
+ def forward_kernel(
32
+ q_ptr, # Q: n x h x d
33
+ k_ptr, # K: n x h x d
34
+ v_ptr, # V: n x h x d
35
+ o_ptr, # O: n x h x d
36
+ lse_ptr, # LSE: h x n
37
+ # size and stride at compresstion
38
+ kernel_size,
39
+ kernel_stride,
40
+ # seqlens
41
+ cu_seqlens_q,
42
+ cu_seqlens_k,
43
+ # shape
44
+ NUM_KV_HEADS,
45
+ NUM_SHARE_Q_HEADS,
46
+ HEAD_DIM,
47
+ # sm_scale
48
+ sm_scale,
49
+ # stride
50
+ stride_qn,
51
+ stride_qh,
52
+ stride_qd,
53
+ stride_kn,
54
+ stride_kh,
55
+ stride_kd,
56
+ stride_vn,
57
+ stride_vh,
58
+ stride_vd,
59
+ stride_on,
60
+ stride_oh,
61
+ stride_od,
62
+ stride_lh,
63
+ stride_ln,
64
+ # META parameters
65
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
66
+ BLOCK_SIZE_K: tl.constexpr, # k block size
67
+ BLOCK_SIZE_D: tl.constexpr,
68
+ ):
69
+ qk_scale = sm_scale * 1.44269504
70
+ # get batch id and head id
71
+ pid_b = tl.program_id(0)
72
+ pid_h = tl.program_id(1)
73
+ pid_q = tl.program_id(2)
74
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
75
+ # get q k start and len after rmpad
76
+ q_start = tl.load(cu_seqlens_q + pid_b)
77
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
78
+ k_start = tl.load(cu_seqlens_k + pid_b)
79
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
80
+ # skip first kernel_size query block, because they do no attend to any keys
81
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
82
+ if q_start_in_seq >= q_len:
83
+ return
84
+ # init qkv pointer
85
+ q_ptrs = tl.make_block_ptr(
86
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
87
+ shape=(q_len, HEAD_DIM),
88
+ strides=(stride_qn, stride_qd),
89
+ offsets=(q_start_in_seq, 0),
90
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
91
+ order=(1, 0),
92
+ )
93
+ k_ptrs = tl.make_block_ptr(
94
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
95
+ shape=(HEAD_DIM, k_len),
96
+ strides=(stride_kd, stride_kn),
97
+ offsets=(0, 0),
98
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
99
+ order=(0, 1),
100
+ )
101
+ v_ptrs = tl.make_block_ptr(
102
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
103
+ shape=(k_len, HEAD_DIM),
104
+ strides=(stride_vn, stride_vd),
105
+ offsets=(0, 0),
106
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
107
+ order=(1, 0),
108
+ )
109
+ # load q
110
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
111
+ # init statistics
112
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
113
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
114
+ m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
115
+ lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
116
+ acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
117
+ # attention
118
+ lo = 0
119
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
120
+ for i in range(lo, hi, BLOCK_SIZE_K):
121
+ i = tl.multiple_of(i, BLOCK_SIZE_K)
122
+ # load k
123
+ k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
124
+ # compute qk
125
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
126
+ qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
127
+ qk += tl.dot(q, k) * qk_scale
128
+ # compute m_ij and l_ij
129
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
130
+ p = tl.exp2(qk - m_ij[:, None])
131
+ l_ij = tl.sum(p, axis=1)
132
+ # scale acc_o
133
+ acc_o_scale = tl.exp2(m_i - m_ij)
134
+ acc_o = acc_o * acc_o_scale[:, None]
135
+ # load v and update acc_o
136
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
137
+ p = p.to(v.dtype)
138
+ acc_o += tl.dot(p, v)
139
+ # update statistics
140
+ m_i = m_ij
141
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
142
+ # update ptrs
143
+ k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
144
+ v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
145
+ # final scale
146
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
147
+ # save output
148
+ o_ptrs = tl.make_block_ptr(
149
+ base=o_ptr + q_start * stride_on + pid_h * stride_oh,
150
+ shape=(q_len, HEAD_DIM),
151
+ strides=(stride_on, stride_od),
152
+ offsets=(q_start_in_seq, 0),
153
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
154
+ order=(1, 0),
155
+ )
156
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
157
+ # save lse
158
+ l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
159
+ tl.store(l_ptrs, lse_i, mask=off_q < q_len)
160
+
161
+
162
+ @triton.jit
163
+ def backward_sum_o_do(
164
+ o_ptr, # O: n x h x d
165
+ do_ptr, # dO: n x h x d
166
+ delta_ptr, # D: h x n
167
+ o_len,
168
+ HEAD_DIM,
169
+ stride_on,
170
+ stride_oh,
171
+ stride_od,
172
+ stride_don,
173
+ stride_doh,
174
+ stride_dod,
175
+ stride_dh,
176
+ stride_dn,
177
+ BLOCK_SIZE_O: tl.constexpr,
178
+ BLOCK_SIZE_D: tl.constexpr,
179
+ ):
180
+ pid_n = tl.program_id(0)
181
+ pid_h = tl.program_id(1)
182
+ off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
183
+ off_d = tl.arange(0, BLOCK_SIZE_D)
184
+ o = tl.load(
185
+ o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od,
186
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
187
+ other=0,
188
+ ).to(tl.float32)
189
+ do = tl.load(
190
+ do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod,
191
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
192
+ other=0,
193
+ ).to(tl.float32)
194
+ delta = tl.sum(o * do, axis=1)
195
+ tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
196
+
197
+
198
+ @triton.jit
199
+ def backward_dkdv(
200
+ q_ptr, # Q: n x qh x d
201
+ k_ptr, # K: n x kh x d
202
+ v_ptr, # V: n x kh x d
203
+ lse_ptr, # LSE: qh x n
204
+ d_ptr, # Delta: qh x n
205
+ do_ptr,
206
+ dk_ptr, # DK: sh x n x kh x d
207
+ dv_ptr, # DV: sh x n x kh x d
208
+ kernel_size,
209
+ kernel_stride,
210
+ # seqlens
211
+ cu_seqlens_q,
212
+ cu_seqlens_k,
213
+ # shape
214
+ NUM_KV_HEADS,
215
+ NUM_SHARE_Q_HEADS,
216
+ HEAD_DIM,
217
+ # sm_scale
218
+ sm_scale,
219
+ # stride
220
+ stride_qn,
221
+ stride_qh,
222
+ stride_qd,
223
+ stride_kn,
224
+ stride_kh,
225
+ stride_kd,
226
+ stride_vn,
227
+ stride_vh,
228
+ stride_vd,
229
+ stride_lh,
230
+ stride_ln,
231
+ stride_dh,
232
+ stride_dn,
233
+ stride_don,
234
+ stride_doh,
235
+ stride_dod,
236
+ stride_dks,
237
+ stride_dkn,
238
+ stride_dkh,
239
+ stride_dkd,
240
+ stride_dvs,
241
+ stride_dvn,
242
+ stride_dvh,
243
+ stride_dvd,
244
+ # META parameters
245
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
246
+ BLOCK_SIZE_K: tl.constexpr, # k block size
247
+ BLOCK_SIZE_D: tl.constexpr,
248
+ ):
249
+ qk_scale = sm_scale * 1.44269504
250
+ # get batch id and head id
251
+ pid_b = tl.program_id(0)
252
+ pid_h = tl.program_id(1)
253
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
254
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
255
+ pid_k = tl.program_id(2)
256
+ # get q k start and len after rmpad
257
+ q_start = tl.load(cu_seqlens_q + pid_b)
258
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
259
+ k_start = tl.load(cu_seqlens_k + pid_b)
260
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
261
+ if BLOCK_SIZE_K * pid_k >= k_len:
262
+ return
263
+ # init pointers
264
+ k_ptrs = tl.make_block_ptr(
265
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
266
+ shape=(k_len, HEAD_DIM),
267
+ strides=(stride_kn, stride_kd),
268
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
269
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
270
+ order=(1, 0),
271
+ )
272
+ dk_ptrs = tl.make_block_ptr(
273
+ base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
274
+ shape=(k_len, HEAD_DIM),
275
+ strides=(stride_dkn, stride_dkd),
276
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
277
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
278
+ order=(1, 0),
279
+ )
280
+ v_ptrs = tl.make_block_ptr(
281
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
282
+ shape=(k_len, HEAD_DIM),
283
+ strides=(stride_vn, stride_vd),
284
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
285
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
286
+ order=(1, 0),
287
+ )
288
+ dv_ptrs = tl.make_block_ptr(
289
+ base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
290
+ shape=(k_len, HEAD_DIM),
291
+ strides=(stride_dvn, stride_dvd),
292
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
293
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
294
+ order=(1, 0),
295
+ )
296
+ # offsets
297
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
298
+ off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
299
+ # load k v and keep in SRAM
300
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
301
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
302
+ # init dk dv
303
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
304
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
305
+ q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
306
+ q_ptrs = tl.make_block_ptr(
307
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
308
+ shape=(HEAD_DIM, q_len),
309
+ strides=(stride_qd, stride_qn),
310
+ offsets=(0, q_lo),
311
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
312
+ order=(0, 1),
313
+ )
314
+ do_ptrs = tl.make_block_ptr(
315
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
316
+ shape=(HEAD_DIM, q_len),
317
+ strides=(stride_dod, stride_don),
318
+ offsets=(0, q_lo),
319
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
320
+ order=(0, 1),
321
+ )
322
+ d_ptrs = tl.make_block_ptr(
323
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
324
+ shape=(1, q_len),
325
+ strides=(0, stride_dn),
326
+ offsets=(0, q_lo),
327
+ block_shape=(1, BLOCK_SIZE_Q),
328
+ order=(1, 0),
329
+ )
330
+ lse_ptrs = tl.make_block_ptr(
331
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
332
+ shape=(1, q_len),
333
+ strides=(0, stride_ln),
334
+ offsets=(0, q_lo),
335
+ block_shape=(1, BLOCK_SIZE_Q),
336
+ order=(0, 1),
337
+ )
338
+ # loop for q blocks
339
+ for i in range(q_lo, q_len, BLOCK_SIZE_Q):
340
+ # load
341
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
342
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
343
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
344
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
345
+ # compute qk
346
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
347
+ qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
348
+ qk += tl.dot(k, q) * qk_scale
349
+ # compute p, ds
350
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
351
+ p = tl.exp2(qk - lse)
352
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
353
+ dp = tl.dot(v, do)
354
+ ds = sm_scale * p * (dp - d)
355
+ # cast dtype
356
+ p = p.to(do.dtype)
357
+ ds = ds.to(q.dtype)
358
+ # update dk and dv
359
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
360
+ dk += tl.dot(ds, tl.trans(q))
361
+ dv += tl.dot(p, tl.trans(do))
362
+ # increment pointers
363
+ q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
364
+ do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
365
+ lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
366
+ d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
367
+ # save dk dv
368
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
369
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
370
+
371
+
372
+ @triton.jit
373
+ def backward_dq(
374
+ q_ptr, # Q: n x qh x d
375
+ k_ptr, # K: n x kh x d
376
+ v_ptr, # V: n x kh x d
377
+ lse_ptr, # LSE: qh x n
378
+ d_ptr, # Delta: qh x n
379
+ do_ptr,
380
+ dq_ptr,
381
+ kernel_size,
382
+ kernel_stride,
383
+ # seqlens
384
+ cu_seqlens_q,
385
+ cu_seqlens_k,
386
+ # shape
387
+ NUM_KV_HEADS,
388
+ NUM_SHARE_Q_HEADS,
389
+ HEAD_DIM,
390
+ # sm_scale
391
+ sm_scale,
392
+ # stride
393
+ stride_qn,
394
+ stride_qh,
395
+ stride_qd,
396
+ stride_kn,
397
+ stride_kh,
398
+ stride_kd,
399
+ stride_vn,
400
+ stride_vh,
401
+ stride_vd,
402
+ stride_lh,
403
+ stride_ln,
404
+ stride_dh,
405
+ stride_dn,
406
+ stride_don,
407
+ stride_doh,
408
+ stride_dod,
409
+ stride_dqn,
410
+ stride_dqh,
411
+ stride_dqd,
412
+ # META parameters
413
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
414
+ BLOCK_SIZE_K: tl.constexpr, # k block size
415
+ BLOCK_SIZE_D: tl.constexpr,
416
+ ):
417
+ qk_scale = sm_scale * 1.44269504
418
+ # get batch id and head id
419
+ pid_b = tl.program_id(0)
420
+ pid_h = tl.program_id(1)
421
+ pid_q = tl.program_id(2)
422
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
423
+ # get q k start and len after rmpad
424
+ q_start = tl.load(cu_seqlens_q + pid_b)
425
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
426
+ k_start = tl.load(cu_seqlens_k + pid_b)
427
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
428
+ # skip first kernel_size query block, because they do no attend to any keys
429
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
430
+ if q_start_in_seq >= q_len:
431
+ return
432
+ # init pointers
433
+ q_ptrs = tl.make_block_ptr(
434
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
435
+ shape=(q_len, HEAD_DIM),
436
+ strides=(stride_qn, stride_qd),
437
+ offsets=(q_start_in_seq, 0),
438
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
439
+ order=(1, 0),
440
+ )
441
+ dq_ptrs = tl.make_block_ptr(
442
+ base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
443
+ shape=(q_len, HEAD_DIM),
444
+ strides=(stride_dqn, stride_dqd),
445
+ offsets=(q_start_in_seq, 0),
446
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
447
+ order=(1, 0),
448
+ )
449
+ k_ptrs = tl.make_block_ptr(
450
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
451
+ shape=(k_len, HEAD_DIM),
452
+ strides=(stride_kn, stride_kd),
453
+ offsets=(0, 0),
454
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
455
+ order=(1, 0),
456
+ )
457
+ v_ptrs = tl.make_block_ptr(
458
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
459
+ shape=(HEAD_DIM, k_len),
460
+ strides=(stride_vd, stride_vn),
461
+ offsets=(0, 0),
462
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
463
+ order=(0, 1),
464
+ )
465
+ do_ptrs = tl.make_block_ptr(
466
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
467
+ shape=(q_len, HEAD_DIM),
468
+ strides=(stride_don, stride_dod),
469
+ offsets=(q_start_in_seq, 0),
470
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
471
+ order=(1, 0),
472
+ )
473
+ d_ptrs = tl.make_block_ptr(
474
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
475
+ shape=(q_len, 1),
476
+ strides=(stride_dn, stride_dh),
477
+ offsets=(q_start_in_seq, 0),
478
+ block_shape=(BLOCK_SIZE_Q, 1),
479
+ order=(0, 1),
480
+ )
481
+ lse_ptrs = tl.make_block_ptr(
482
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
483
+ shape=(q_len, 1),
484
+ strides=(stride_ln, stride_lh),
485
+ offsets=(q_start_in_seq, 0),
486
+ block_shape=(BLOCK_SIZE_Q, 1),
487
+ order=(0, 1),
488
+ )
489
+ # offsets
490
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
491
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
492
+ # load q, do, lse, delta, and keep in SRAM
493
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
494
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
495
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
496
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
497
+ # init dq
498
+ dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
499
+ lo = 0
500
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
501
+ for i in range(lo, hi, BLOCK_SIZE_K):
502
+ # load
503
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
504
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
505
+ # compute qk
506
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
507
+ qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
508
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
509
+ # compute p, ds
510
+ p = tl.exp2(qk - lse)
511
+ dp = tl.dot(do, v)
512
+ ds = sm_scale * p * (dp - d)
513
+ # cast dtype
514
+ ds = ds.to(q.dtype)
515
+ # update dq
516
+ dq += tl.dot(ds, k)
517
+ # increment pointers
518
+ k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
519
+ v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
520
+ # save dq
521
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
522
+
523
+
524
+ def _compressed_attention_fwd(
525
+ q: torch.Tensor,
526
+ k: torch.Tensor,
527
+ v: torch.Tensor,
528
+ kernel_size: int,
529
+ kernel_stride: int,
530
+ cu_seqlens_q: torch.Tensor,
531
+ cu_seqlens_k: torch.Tensor,
532
+ max_seqlen_q: torch.Tensor,
533
+ max_seqlen_k: torch.Tensor,
534
+ sm_scale: float,
535
+ ):
536
+ # dtype check
537
+ assert k.dtype == q.dtype and v.dtype == q.dtype
538
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
539
+ # shape
540
+ q_len, num_q_heads, head_dim = q.shape
541
+ k_len, num_k_heads, head_dim = k.shape
542
+ v_len, num_v_heads, head_dim = v.shape
543
+ batch_size = cu_seqlens_q.shape[0] - 1
544
+ assert k_len == v_len and q_len > k_len
545
+ # gqa
546
+ assert num_k_heads == num_v_heads
547
+ assert num_q_heads % num_k_heads == 0
548
+ num_share_q_heads = num_q_heads // num_k_heads
549
+ # output tensor
550
+ o = torch.zeros_like(q)
551
+ lse = torch.full(
552
+ (num_q_heads, q_len),
553
+ fill_value=-torch.inf,
554
+ dtype=torch.float32,
555
+ device=q.device,
556
+ )
557
+ # launch kernel
558
+ grid = lambda META: (
559
+ batch_size,
560
+ num_q_heads,
561
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
562
+ )
563
+ BLOCK_SIZE_Q = 128
564
+ BLOCK_SIZE_K = 128
565
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
566
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
567
+ forward_kernel[grid](
568
+ q,
569
+ k,
570
+ v,
571
+ o,
572
+ lse,
573
+ kernel_size,
574
+ kernel_stride,
575
+ cu_seqlens_q,
576
+ cu_seqlens_k,
577
+ num_k_heads,
578
+ num_share_q_heads,
579
+ head_dim,
580
+ sm_scale,
581
+ q.stride(0),
582
+ q.stride(1),
583
+ q.stride(2),
584
+ k.stride(0),
585
+ k.stride(1),
586
+ k.stride(2),
587
+ v.stride(0),
588
+ v.stride(1),
589
+ v.stride(2),
590
+ o.stride(0),
591
+ o.stride(1),
592
+ o.stride(2),
593
+ lse.stride(0),
594
+ lse.stride(1),
595
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
596
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
597
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
598
+ num_warps=num_warps,
599
+ num_stages=num_stages,
600
+ )
601
+ return o, lse
602
+
603
+
604
+ def _compressed_attention_bwd(
605
+ o: torch.Tensor,
606
+ do: torch.Tensor,
607
+ lse: torch.Tensor,
608
+ q: torch.Tensor,
609
+ k: torch.Tensor,
610
+ v: torch.Tensor,
611
+ kernel_size: int,
612
+ kernel_stride: int,
613
+ cu_seqlens_q: torch.Tensor,
614
+ cu_seqlens_k: torch.Tensor,
615
+ max_seqlen_q: torch.Tensor,
616
+ max_seqlen_k: torch.Tensor,
617
+ sm_scale: float,
618
+ ):
619
+ q_len, num_q_heads, head_dim = q.shape
620
+ k_len, num_k_heads, head_dim = k.shape
621
+ v_len, num_v_heads, head_dim = v.shape
622
+ o_len, num_o_heads, head_dim = o.shape
623
+ num_share_q_heads = num_q_heads // num_k_heads
624
+ # compute D
625
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
626
+ grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
627
+ BLOCK_SIZE_O = 256
628
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
629
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
630
+ backward_sum_o_do[grid](
631
+ o,
632
+ do,
633
+ delta,
634
+ o_len,
635
+ head_dim,
636
+ o.stride(0),
637
+ o.stride(1),
638
+ o.stride(2),
639
+ do.stride(0),
640
+ do.stride(1),
641
+ do.stride(2),
642
+ delta.stride(0),
643
+ delta.stride(1),
644
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
645
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
646
+ num_warps=num_warps,
647
+ num_stages=num_stages,
648
+ )
649
+ # compute dk dv
650
+ dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
651
+ dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
652
+ batch_size = cu_seqlens_q.shape[0] - 1
653
+ grid = lambda META: (
654
+ batch_size,
655
+ num_q_heads,
656
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
657
+ )
658
+ BLOCK_SIZE_Q = 64
659
+ BLOCK_SIZE_K = 128
660
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
661
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
662
+ backward_dkdv[grid](
663
+ q,
664
+ k,
665
+ v,
666
+ lse,
667
+ delta,
668
+ do,
669
+ dk,
670
+ dv,
671
+ kernel_size,
672
+ kernel_stride,
673
+ cu_seqlens_q,
674
+ cu_seqlens_k,
675
+ num_k_heads,
676
+ num_share_q_heads,
677
+ head_dim,
678
+ sm_scale,
679
+ q.stride(0),
680
+ q.stride(1),
681
+ q.stride(2),
682
+ k.stride(0),
683
+ k.stride(1),
684
+ k.stride(2),
685
+ v.stride(0),
686
+ v.stride(1),
687
+ v.stride(2),
688
+ lse.stride(0),
689
+ lse.stride(1),
690
+ delta.stride(0),
691
+ delta.stride(1),
692
+ do.stride(0),
693
+ do.stride(1),
694
+ do.stride(2),
695
+ dk.stride(0),
696
+ dk.stride(1),
697
+ dk.stride(2),
698
+ dk.stride(3),
699
+ dv.stride(0),
700
+ dv.stride(1),
701
+ dv.stride(2),
702
+ dv.stride(3),
703
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
704
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
705
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
706
+ num_warps=num_warps,
707
+ num_stages=num_stages,
708
+ )
709
+ dk = dk.sum(0)
710
+ dv = dv.sum(0)
711
+ # compute dq
712
+ dq = torch.zeros_like(q)
713
+ grid = lambda META: (
714
+ batch_size,
715
+ num_q_heads,
716
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
717
+ )
718
+ BLOCK_SIZE_Q = 128
719
+ BLOCK_SIZE_K = 64
720
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
721
+ backward_dq[grid](
722
+ q,
723
+ k,
724
+ v,
725
+ lse,
726
+ delta,
727
+ do,
728
+ dq,
729
+ kernel_size,
730
+ kernel_stride,
731
+ cu_seqlens_q,
732
+ cu_seqlens_k,
733
+ num_k_heads,
734
+ num_share_q_heads,
735
+ head_dim,
736
+ sm_scale,
737
+ q.stride(0),
738
+ q.stride(1),
739
+ q.stride(2),
740
+ k.stride(0),
741
+ k.stride(1),
742
+ k.stride(2),
743
+ v.stride(0),
744
+ v.stride(1),
745
+ v.stride(2),
746
+ lse.stride(0),
747
+ lse.stride(1),
748
+ delta.stride(0),
749
+ delta.stride(1),
750
+ do.stride(0),
751
+ do.stride(1),
752
+ do.stride(2),
753
+ dq.stride(0),
754
+ dq.stride(1),
755
+ dq.stride(2),
756
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
757
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
758
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
759
+ num_warps=num_warps,
760
+ num_stages=num_stages,
761
+ )
762
+ return dq, dk, dv
763
+
764
+
765
+ class CompressedAttention(torch.autograd.Function):
766
+ @staticmethod
767
+ def forward(
768
+ ctx,
769
+ q: torch.Tensor,
770
+ k: torch.Tensor,
771
+ v: torch.Tensor,
772
+ kernel_size: int,
773
+ kernel_stride: int,
774
+ cu_seqlens_q: torch.Tensor,
775
+ cu_seqlens_k: torch.Tensor,
776
+ max_seqlen_q: torch.Tensor,
777
+ max_seqlen_k: torch.Tensor,
778
+ sm_scale=None,
779
+ ):
780
+ # dtype check
781
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
782
+ assert q.dtype == k.dtype and k.dtype == v.dtype
783
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
784
+ # softmax scale
785
+ if sm_scale is None:
786
+ sm_scale = 1 / math.sqrt(q.shape[-1])
787
+
788
+ o, lse = _compressed_attention_fwd(
789
+ q,
790
+ k,
791
+ v,
792
+ kernel_size,
793
+ kernel_stride,
794
+ cu_seqlens_q,
795
+ cu_seqlens_k,
796
+ max_seqlen_q,
797
+ max_seqlen_k,
798
+ sm_scale,
799
+ )
800
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
801
+ ctx.sm_scale = sm_scale
802
+ ctx.max_seqlen_q = max_seqlen_q
803
+ ctx.max_seqlen_k = max_seqlen_k
804
+ ctx.kernel_size = kernel_size
805
+ ctx.kernel_stride = kernel_stride
806
+ return o, lse
807
+
808
+ @staticmethod
809
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
810
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
811
+ max_seqlen_q = ctx.max_seqlen_q
812
+ max_seqlen_k = ctx.max_seqlen_k
813
+ sm_scale = ctx.sm_scale
814
+ kernel_size = ctx.kernel_size
815
+ kernel_stride = ctx.kernel_stride
816
+
817
+ dq, dk, dv = _compressed_attention_bwd(
818
+ o,
819
+ do,
820
+ lse,
821
+ q,
822
+ k,
823
+ v,
824
+ kernel_size,
825
+ kernel_stride,
826
+ cu_seqlens_q,
827
+ cu_seqlens_k,
828
+ max_seqlen_q,
829
+ max_seqlen_k,
830
+ sm_scale,
831
+ )
832
+ return dq, dk, dv, None, None, None, None, None, None, None
833
+
834
+
835
+ @triton.jit
836
+ def score_kernel(
837
+ q_ptr,
838
+ k_ptr,
839
+ lse_ptr,
840
+ s_ptr,
841
+ kernel_size,
842
+ kernel_stride,
843
+ # seqlens
844
+ cu_seqlens_q,
845
+ cu_seqlens_k,
846
+ # shape
847
+ NUM_KV_HEADS,
848
+ NUM_SHARE_Q_HEADS,
849
+ HEAD_DIM,
850
+ # sm_scale
851
+ sm_scale,
852
+ # stride
853
+ stride_qn,
854
+ stride_qh,
855
+ stride_qd,
856
+ stride_kn,
857
+ stride_kh,
858
+ stride_kd,
859
+ stride_lh,
860
+ stride_ln,
861
+ stride_sh,
862
+ stride_sq,
863
+ stride_sk,
864
+ # META parameters
865
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
866
+ BLOCK_SIZE_K: tl.constexpr, # k block size
867
+ BLOCK_SIZE_D: tl.constexpr,
868
+ ):
869
+ qk_scale = sm_scale * 1.44269504
870
+ # get batch id and head id
871
+ pid_bkh = tl.program_id(0)
872
+ pid_b = pid_bkh // NUM_KV_HEADS
873
+ pid_kh = pid_bkh % NUM_KV_HEADS
874
+ pid_q = tl.program_id(1)
875
+ pid_k = tl.program_id(2)
876
+ # get q k start and len after rmpad
877
+ q_start = tl.load(cu_seqlens_q + pid_b)
878
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
879
+ k_start = tl.load(cu_seqlens_k + pid_b)
880
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
881
+ if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
882
+ return
883
+ # init k pointer and load k
884
+ k_ptrs = tl.make_block_ptr(
885
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
886
+ shape=(HEAD_DIM, k_len),
887
+ strides=(stride_kd, stride_kn),
888
+ offsets=(0, pid_k * BLOCK_SIZE_K),
889
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
890
+ order=(0, 1),
891
+ )
892
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
893
+ # offsets
894
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
895
+ off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
896
+ causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
897
+ # init score
898
+ s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
899
+
900
+ q_ptrs = tl.make_block_ptr(
901
+ base=q_ptr + q_start * stride_qn + pid_kh * stride_qh,
902
+ shape=(q_len, HEAD_DIM),
903
+ strides=(stride_qn, stride_qd),
904
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
905
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
906
+ order=(1, 0),
907
+ )
908
+ lse_ptrs = tl.make_block_ptr(
909
+ base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh,
910
+ shape=(q_len, 1),
911
+ strides=(stride_ln, stride_lh),
912
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
913
+ block_shape=(BLOCK_SIZE_Q, 1),
914
+ order=(0, 1),
915
+ )
916
+ # load q and lse
917
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
918
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
919
+ # compute qk
920
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
921
+ qk += tl.dot(q, k) * qk_scale
922
+ # compute score
923
+ s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
924
+ # save output
925
+ s_ptrs = tl.make_block_ptr(
926
+ base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
927
+ shape=(q_len, k_len),
928
+ strides=(stride_sq, stride_sk),
929
+ offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
930
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
931
+ order=(1, 0),
932
+ )
933
+ tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
934
+
935
+
936
+ def _get_attention_score(
937
+ q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
938
+ k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
939
+ lse: torch.Tensor, # [num_q_heads, total_query_len]
940
+ kernel_size: int,
941
+ kernel_stride: int,
942
+ cu_seqlens_q: torch.Tensor,
943
+ cu_seqlens_k: torch.Tensor,
944
+ max_seqlen_q: int,
945
+ max_seqlen_k: int,
946
+ sm_scale: float,
947
+ ) -> torch.Tensor:
948
+ # dtype check
949
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
950
+ assert q.dtype == k.dtype
951
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
952
+ assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
953
+ # shape
954
+ q_len, num_q_heads, head_dim = q.shape
955
+ k_len, num_k_heads, head_dim = k.shape
956
+ batch_size = cu_seqlens_q.shape[0] - 1
957
+ assert q_len > k_len
958
+ if sm_scale is None:
959
+ sm_scale = 1 / math.sqrt(head_dim)
960
+ # gqa
961
+ assert num_q_heads % num_k_heads == 0
962
+ num_share_q_heads = num_q_heads // num_k_heads
963
+ # init score
964
+ score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device)
965
+
966
+ # launch kernel
967
+ grid = lambda META: (
968
+ batch_size * num_k_heads,
969
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
970
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
971
+ )
972
+ BLOCK_SIZE_Q = 128
973
+ BLOCK_SIZE_K = 128
974
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
975
+
976
+ score_kernel[grid](
977
+ q,
978
+ k,
979
+ lse,
980
+ score,
981
+ kernel_size,
982
+ kernel_stride,
983
+ cu_seqlens_q,
984
+ cu_seqlens_k,
985
+ num_k_heads,
986
+ num_share_q_heads,
987
+ head_dim,
988
+ sm_scale,
989
+ q.stride(0),
990
+ q.stride(1),
991
+ q.stride(2),
992
+ k.stride(0),
993
+ k.stride(1),
994
+ k.stride(2),
995
+ lse.stride(0),
996
+ lse.stride(1),
997
+ score.stride(0),
998
+ score.stride(1),
999
+ score.stride(2),
1000
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1001
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1002
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1003
+ num_warps=8,
1004
+ num_stages=3,
1005
+ )
1006
+ return score
1007
+
1008
+
1009
+ @triton.jit
1010
+ def _transform_score_kernel(
1011
+ s_ptr, # score, shape: [num_heads, q_len, k_len]
1012
+ bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
1013
+ offs,
1014
+ cu_seqlens_q,
1015
+ # shape
1016
+ num_heads,
1017
+ num_offs,
1018
+ max_k_len,
1019
+ max_blocks,
1020
+ pad_len,
1021
+ # kernel & block size
1022
+ block_size,
1023
+ block_stride, # block_size // kernel_stride
1024
+ init_blocks,
1025
+ local_blocks,
1026
+ # stride
1027
+ stride_sh,
1028
+ stride_sq,
1029
+ stride_sk,
1030
+ stride_bsh,
1031
+ stride_bsq,
1032
+ stride_bsk,
1033
+ TOTAL_QUERY_LEN: tl.constexpr,
1034
+ BLOCK_SIZE_Q: tl.constexpr,
1035
+ BLOCK_SIZE_K: tl.constexpr,
1036
+ BLOCK_SIZE_O: tl.constexpr,
1037
+ ):
1038
+ pid_bh = tl.program_id(0)
1039
+ pid_b = pid_bh // num_heads
1040
+ pid_h = pid_bh % num_heads
1041
+ pid_q = tl.program_id(1)
1042
+ pid_k = tl.program_id(2)
1043
+ q_start = tl.load(cu_seqlens_q + pid_b)
1044
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
1045
+ k_start = pid_k * BLOCK_SIZE_K
1046
+ if pid_q * BLOCK_SIZE_Q >= q_len:
1047
+ return
1048
+ # load weight
1049
+ off_o = tl.arange(0, BLOCK_SIZE_O)
1050
+ w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
1051
+ # load score
1052
+ off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
1053
+ off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
1054
+ off_k = off_k[None, :] + off_o[:, None]
1055
+ s_ptrs = (
1056
+ s_ptr
1057
+ + q_start * stride_sq
1058
+ + pid_h * stride_sh
1059
+ + off_q[:, None, None] * stride_sq
1060
+ + off_k[None, :, :] * stride_sk
1061
+ )
1062
+ # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
1063
+ s = tl.load(
1064
+ s_ptrs,
1065
+ mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
1066
+ other=0,
1067
+ )
1068
+ s = s * w[None, :, None]
1069
+ s = tl.sum(s, axis=1)
1070
+ # init mask and local mask
1071
+ off_bq = off_q // block_size
1072
+ off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
1073
+ s = tl.where(
1074
+ ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks))
1075
+ | (off_bk[None, :] < init_blocks - k_start),
1076
+ float("inf"),
1077
+ s,
1078
+ )
1079
+ # store block wise score
1080
+ bs_ptrs = (
1081
+ bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk
1082
+ )
1083
+ tl.store(
1084
+ bs_ptrs,
1085
+ s,
1086
+ mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
1087
+ )
1088
+
1089
+
1090
+ def transform_score(
1091
+ score: torch.Tensor,
1092
+ kernel_size: int,
1093
+ kernel_stride: int,
1094
+ block_size: int,
1095
+ cu_seqlens_q: torch.Tensor,
1096
+ cu_seqlens_k: torch.Tensor,
1097
+ max_seqlen_q: int,
1098
+ max_seqlen_k: int,
1099
+ init_blocks: int = 1,
1100
+ local_blocks: int = 2,
1101
+ ) -> torch.Tensor:
1102
+ num_k_heads, total_query_len, max_key_len = score.shape
1103
+ batch_size = cu_seqlens_q.shape[0] - 1
1104
+ pad_len = kernel_size // kernel_stride - 1
1105
+ max_blocks = math.ceil(max_seqlen_q / block_size)
1106
+ block_score = torch.zeros(
1107
+ num_k_heads,
1108
+ total_query_len,
1109
+ max_blocks,
1110
+ dtype=torch.float32,
1111
+ device=score.device,
1112
+ )
1113
+ offs = (
1114
+ torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
1115
+ + torch.arange(block_size // kernel_stride, device=score.device)[None, :]
1116
+ ).view(-1)
1117
+
1118
+ offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
1119
+
1120
+ num_offs = int(offs.shape[0])
1121
+
1122
+ BLOCK_SIZE_Q = 16
1123
+ BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
1124
+ BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
1125
+
1126
+ def grid(meta):
1127
+ grid = (
1128
+ num_k_heads * batch_size,
1129
+ triton.cdiv(total_query_len, BLOCK_SIZE_Q),
1130
+ triton.cdiv(max_blocks, BLOCK_SIZE_K),
1131
+ )
1132
+ return grid
1133
+
1134
+ _transform_score_kernel[grid](
1135
+ score,
1136
+ block_score,
1137
+ offs,
1138
+ cu_seqlens_q,
1139
+ num_k_heads,
1140
+ offs.shape[0],
1141
+ max_key_len,
1142
+ max_blocks,
1143
+ pad_len,
1144
+ block_size,
1145
+ block_size // kernel_stride,
1146
+ init_blocks,
1147
+ local_blocks,
1148
+ score.stride(0),
1149
+ score.stride(1),
1150
+ score.stride(2),
1151
+ block_score.stride(0),
1152
+ block_score.stride(1),
1153
+ block_score.stride(2),
1154
+ TOTAL_QUERY_LEN=total_query_len,
1155
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1156
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1157
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
1158
+ num_warps=4,
1159
+ num_stages=3,
1160
+ )
1161
+ return block_score
1162
+
1163
+
1164
+ def compressed_attention(
1165
+ q: torch.Tensor,
1166
+ k: torch.Tensor,
1167
+ v: torch.Tensor,
1168
+ kernel_size: int,
1169
+ kernel_stride: int,
1170
+ block_size: int,
1171
+ topk: int,
1172
+ cu_seqlens_q: torch.Tensor,
1173
+ cu_seqlens_k: torch.Tensor,
1174
+ max_seqlen_q: int,
1175
+ max_seqlen_k: int,
1176
+ sm_scale: float = None,
1177
+ init_blocks: int = 1,
1178
+ local_blocks: int = 2,
1179
+ parallel_topk_compute: Union[str, bool] = False,
1180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1181
+ """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
1182
+
1183
+ Args:
1184
+ q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
1185
+ k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1186
+ v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1187
+ kernel_size (int): kernel size in compress_key_value
1188
+ kernel_stride (int): stride of compress_key_value
1189
+ block_size (int): key value block size for topk sparse attention.
1190
+ topk (int): number of blocks for each query.
1191
+ cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
1192
+ cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
1193
+ max_seqlen_q (int): max q len of the batch.
1194
+ max_seqlen_k (int): max k len of the batch.
1195
+ sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
1196
+ init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
1197
+ local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
1198
+ parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
1199
+ We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
1200
+
1201
+ Returns:
1202
+ Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
1203
+ """
1204
+
1205
+ if max_seqlen_q is None:
1206
+ max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
1207
+ if max_seqlen_k is None:
1208
+ max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
1209
+
1210
+ attn_output, lse = CompressedAttention.apply(
1211
+ q,
1212
+ k,
1213
+ v,
1214
+ kernel_size,
1215
+ kernel_stride,
1216
+ cu_seqlens_q,
1217
+ cu_seqlens_k,
1218
+ max_seqlen_q,
1219
+ max_seqlen_k,
1220
+ sm_scale,
1221
+ )
1222
+
1223
+ # do not select topk index
1224
+ if topk <= 0:
1225
+ warnings.warn("topk <= 0, returned topk_idx will be None")
1226
+ return attn_output, None
1227
+
1228
+ assert topk >= init_blocks + local_blocks
1229
+ with torch.no_grad():
1230
+ num_k_heads, num_q_heads = k.shape[1], q.shape[1]
1231
+ num_shared_q_heads = num_q_heads // num_k_heads
1232
+ batch_size = cu_seqlens_q.shape[0] - 1
1233
+ q_idx = torch.cat(
1234
+ [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)],
1235
+ dim=0,
1236
+ )
1237
+ q_idx = q_idx // block_size
1238
+
1239
+ # whether to use parallel version
1240
+ if parallel_topk_compute == "auto":
1241
+ parallel_topk_compute = cu_seqlens_q[-1] <= 32768
1242
+ # parallel version
1243
+ if parallel_topk_compute:
1244
+ # recompute score
1245
+ score = _get_attention_score(
1246
+ q,
1247
+ k,
1248
+ lse,
1249
+ kernel_size,
1250
+ kernel_stride,
1251
+ cu_seqlens_q,
1252
+ cu_seqlens_k,
1253
+ max_seqlen_q,
1254
+ max_seqlen_k,
1255
+ sm_scale,
1256
+ )
1257
+ # transform score to block-wise score
1258
+ score = transform_score(
1259
+ score,
1260
+ kernel_size,
1261
+ kernel_stride,
1262
+ block_size,
1263
+ cu_seqlens_q,
1264
+ cu_seqlens_k,
1265
+ max_seqlen_q,
1266
+ max_seqlen_k,
1267
+ init_blocks,
1268
+ local_blocks,
1269
+ )
1270
+ # get topk
1271
+ topk = min(topk, score.shape[-1])
1272
+ topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
1273
+ topk_idx[topk_idx > q_idx[None, :, None]] = -1
1274
+ topk_idx = topk_idx.to(torch.int32)
1275
+ # non parallel version, avoid some current bugs when sequence length is too long
1276
+ # FIXME: need to fix later
1277
+ else:
1278
+ topk_idx_list = []
1279
+ head_tile = 1
1280
+ assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}"
1281
+ for h in range(num_k_heads // head_tile):
1282
+ # recompute score
1283
+ score = _get_attention_score(
1284
+ q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
1285
+ k[:, h * head_tile: (h + 1) * head_tile],
1286
+ lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
1287
+ kernel_size,
1288
+ kernel_stride,
1289
+ cu_seqlens_q,
1290
+ cu_seqlens_k,
1291
+ max_seqlen_q,
1292
+ max_seqlen_k,
1293
+ sm_scale,
1294
+ )
1295
+ # transform score to block-wise score
1296
+ score = transform_score(
1297
+ score,
1298
+ kernel_size,
1299
+ kernel_stride,
1300
+ block_size,
1301
+ cu_seqlens_q,
1302
+ cu_seqlens_k,
1303
+ max_seqlen_q,
1304
+ max_seqlen_k,
1305
+ init_blocks,
1306
+ local_blocks,
1307
+ )
1308
+ # get topk
1309
+ topk = min(topk, score.shape[-1])
1310
+ if score.dtype == torch.float32:
1311
+ score = score.to(torch.bfloat16)
1312
+ topk_idx = score.topk(topk, dim=-1, sorted=False).indices
1313
+ topk_idx = topk_idx.sort(-1).values
1314
+
1315
+ topk_idx[topk_idx > q_idx[None, :, None]] = -1
1316
+ topk_idx = topk_idx.to(torch.int32)
1317
+ topk_idx_list.append(topk_idx)
1318
+ topk_idx = torch.cat(topk_idx_list, dim=0)
1319
+
1320
+ return attn_output, topk_idx
ops/pooling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.index import prepare_chunk_indices
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BD': BD}, num_warps=num_warps)
20
+ for BD in [16, 32, 64, 128]
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def mean_pooling_fwd_kernel(
27
+ x,
28
+ o,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ D: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BD: tl.constexpr,
36
+ IS_VARLEN: tl.constexpr
37
+ ):
38
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_tg = i_t
42
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
43
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
44
+ T = eos - bos
45
+ NT = tl.cdiv(T, BT)
46
+ else:
47
+ NT = tl.cdiv(T, BT)
48
+ i_tg = i_b * NT + i_t
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
52
+ p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
53
+ # [BT, BD]
54
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
55
+ # [BD]
56
+ b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
57
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
58
+
59
+
60
+ @triton.heuristics({
61
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
62
+ })
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BD': BD}, num_warps=num_warps)
66
+ for BD in [16, 32, 64, 128]
67
+ for num_warps in [1, 2, 4, 8]
68
+ ],
69
+ key=['BT']
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def mean_pooling_bwd_kernel(
73
+ do,
74
+ dx,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ D: tl.constexpr,
80
+ BT: tl.constexpr,
81
+ BD: tl.constexpr,
82
+ IS_VARLEN: tl.constexpr
83
+ ):
84
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_tg = i_t
88
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
89
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
90
+ T = eos - bos
91
+ NT = tl.cdiv(T, BT)
92
+ else:
93
+ NT = tl.cdiv(T, BT)
94
+ i_tg = i_b * NT + i_t
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
98
+ p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
99
+ # [BD]
100
+ b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
101
+ # [BT, BD]
102
+ b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
103
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+
106
+ def mean_pooling_fwd(
107
+ x: torch.Tensor,
108
+ chunk_size: int,
109
+ cu_seqlens: Optional[torch.LongTensor] = None
110
+ ) -> torch.Tensor:
111
+ B, T, H, D = x.shape
112
+ BT = chunk_size
113
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
114
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
115
+
116
+ o = x.new_empty(B, NT, H, D)
117
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
118
+ mean_pooling_fwd_kernel[grid](
119
+ x,
120
+ o,
121
+ cu_seqlens,
122
+ chunk_indices,
123
+ T=T,
124
+ H=H,
125
+ D=D,
126
+ BT=BT,
127
+ )
128
+ return o
129
+
130
+
131
+ def mean_pooling_bwd(
132
+ do: torch.Tensor,
133
+ batch_size: int,
134
+ seq_len: int,
135
+ chunk_size: int,
136
+ cu_seqlens: Optional[torch.LongTensor] = None
137
+ ) -> torch.Tensor:
138
+ B, T, H, D = batch_size, seq_len, *do.shape[-2:]
139
+ BT = chunk_size
140
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
141
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
142
+
143
+ dx = do.new_empty(B, T, H, D)
144
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
145
+ mean_pooling_bwd_kernel[grid](
146
+ do,
147
+ dx,
148
+ cu_seqlens,
149
+ chunk_indices,
150
+ T=T,
151
+ H=H,
152
+ D=D,
153
+ BT=BT,
154
+ )
155
+ return dx
156
+
157
+
158
+ class MeanPoolingFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ @input_guard
162
+ @autocast_custom_fwd
163
+ def forward(
164
+ ctx,
165
+ x: torch.Tensor,
166
+ chunk_size: int,
167
+ cu_seqlens: Optional[torch.LongTensor] = None
168
+ ) -> torch.Tensor:
169
+ o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
170
+ ctx.batch_size = x.shape[0]
171
+ ctx.seq_len = x.shape[1]
172
+ ctx.chunk_size = chunk_size
173
+ ctx.cu_seqlens = cu_seqlens
174
+ return o
175
+
176
+ @staticmethod
177
+ @input_guard
178
+ @autocast_custom_bwd
179
+ def backward(
180
+ ctx, do
181
+ ) -> Tuple[torch.Tensor, None, None]:
182
+ batch_size = ctx.batch_size
183
+ seq_len = ctx.seq_len
184
+ chunk_size = ctx.chunk_size
185
+ cu_seqlens = ctx.cu_seqlens
186
+ dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
187
+ return dx, None, None
188
+
189
+
190
+ def mean_pooling(
191
+ x: torch.Tensor,
192
+ chunk_size: int,
193
+ cu_seqlens: Optional[torch.LongTensor] = None,
194
+ head_first: bool = False
195
+ ) -> torch.Tensor:
196
+ if head_first:
197
+ x = x.transpose(1, 2)
198
+ if cu_seqlens is not None:
199
+ if x.shape[0] != 1:
200
+ raise ValueError(
201
+ f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
202
+ f"Please flatten variable-length inputs before processing."
203
+ )
204
+ o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
205
+ if head_first:
206
+ o = o.transpose(1, 2)
207
+ return o
ops/topk_sparse_attention.py ADDED
@@ -0,0 +1,1213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Any, Optional
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ try:
22
+ from ops.utils import get_num_warps_stages, is_hopper_gpu
23
+ except ImportError:
24
+ from .ops.utils import get_num_warps_stages, is_hopper_gpu
25
+
26
+ IS_HOPPER_GPU = is_hopper_gpu()
27
+
28
+
29
+ @triton.jit
30
+ def forward_kernel_orig(
31
+ q_ptr, # Q: n x h x d
32
+ k_ptr, # K: n x kh x d
33
+ v_ptr, # V: n x kh x d
34
+ t_ptr, # topk_idx: kh x n x k
35
+ o_ptr, # O: n x h x d
36
+ lse_ptr, # LSE: h x n
37
+ # seqlens
38
+ cu_seqlens_q,
39
+ cu_seqlens_k,
40
+ # shape
41
+ NUM_KV_HEADS,
42
+ NUM_SHARE_Q_HEADS,
43
+ HEAD_DIM,
44
+ TOPK,
45
+ block_size,
46
+ # sm_scale
47
+ sm_scale,
48
+ # stride
49
+ stride_qn,
50
+ stride_qh,
51
+ stride_qd,
52
+ stride_kn,
53
+ stride_kh,
54
+ stride_kd,
55
+ stride_vn,
56
+ stride_vh,
57
+ stride_vd,
58
+ stride_th,
59
+ stride_tn,
60
+ stride_tk,
61
+ stride_on,
62
+ stride_oh,
63
+ stride_od,
64
+ stride_lh,
65
+ stride_ln,
66
+ # META parameters
67
+ # q loop num
68
+ num_q_loop: tl.constexpr,
69
+ num_k_loop: tl.constexpr,
70
+ MAX_SEQ_LEN: tl.constexpr,
71
+ BLOCK_SIZE_K: tl.constexpr, # k block size
72
+ BLOCK_SIZE_D: tl.constexpr,
73
+ BLOCK_SIZE_H: tl.constexpr,
74
+ BLOCK_SIZE_T: tl.constexpr,
75
+ ):
76
+ qk_scale = sm_scale * 1.44269504
77
+ # get batch id and head id
78
+ pid = tl.program_id(0)
79
+
80
+ Q = MAX_SEQ_LEN // num_q_loop
81
+ HK = NUM_KV_HEADS // num_k_loop
82
+
83
+ # 第几个 (b, kh_chunk, q_chunk)
84
+ pid_b = pid // (HK * Q)
85
+ pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head
86
+ pid_q = pid % Q
87
+
88
+ # get q k start and len after rmpad
89
+ q_start = tl.load(cu_seqlens_q + pid_b)
90
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
91
+ k_start = tl.load(cu_seqlens_k + pid_b)
92
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
93
+
94
+ if pid_q * num_q_loop >= q_len:
95
+ return
96
+ real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
97
+
98
+ for kh_offset in range(num_k_loop):
99
+ pid_kh = pid_kh_chunk * num_k_loop + kh_offset
100
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
101
+
102
+ for j in range(real_q_loop):
103
+ pid_q_j = pid_q * num_q_loop + j
104
+ # init topk idx pointer
105
+ off_t = tl.arange(0, BLOCK_SIZE_T)
106
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
107
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
108
+
109
+ """Removed causal attention, which should be:
110
+ real_topk = tl.sum(
111
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0),
112
+ axis=0,
113
+ )
114
+ """
115
+ # real_topk = tl.sum(
116
+ # tl.where((topk_idx >= 0), 1, 0),
117
+ # axis=0,
118
+ # )
119
+ real_topk = tl.sum(
120
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0),
121
+ axis=0,
122
+ )
123
+ # init qkv pointer
124
+ q_ptrs = tl.make_block_ptr(
125
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
126
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
127
+ strides=(stride_qh, stride_qd),
128
+ offsets=(0, 0),
129
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
130
+ order=(1, 0),
131
+ )
132
+ k_ptrs = tl.make_block_ptr(
133
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
134
+ shape=(HEAD_DIM, k_len),
135
+ strides=(stride_kd, stride_kn),
136
+ offsets=(0, 0),
137
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
138
+ order=(0, 1),
139
+ )
140
+ v_ptrs = tl.make_block_ptr(
141
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
142
+ shape=(k_len, HEAD_DIM),
143
+ strides=(stride_vn, stride_vd),
144
+ offsets=(0, 0),
145
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
146
+ order=(1, 0),
147
+ )
148
+ # load q
149
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
150
+ # init statistics
151
+ off_h = tl.arange(0, BLOCK_SIZE_H)
152
+ off_k = tl.arange(0, BLOCK_SIZE_K)
153
+ m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
154
+ lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
155
+ acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)
156
+ # sparse attention
157
+ for i in range(real_topk):
158
+ # get current block start index
159
+ c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
160
+ t_ptr_j = t_ptr_j + stride_tk
161
+ # load k
162
+ k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero")
163
+ # compute qk
164
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
165
+ qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
166
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
167
+ qk += tl.dot(q, k) * qk_scale
168
+ # compute m_ij and l_ij
169
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
170
+ p = tl.exp2(qk - m_ij[:, None])
171
+ l_ij = tl.sum(p, axis=1)
172
+ # scale acc_o
173
+ acc_o_scale = tl.exp2(m_i - m_ij)
174
+ acc_o = acc_o * acc_o_scale[:, None]
175
+ # load v and update acc_o
176
+ v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero")
177
+ p = p.to(v.dtype)
178
+ acc_o += tl.dot(p, v)
179
+ # update statistics
180
+ m_i = m_ij
181
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
182
+
183
+ # final scale
184
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
185
+ # save output
186
+ o_ptrs = tl.make_block_ptr(
187
+ base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh,
188
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
189
+ strides=(stride_oh, stride_od),
190
+ offsets=(0, 0),
191
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
192
+ order=(1, 0),
193
+ )
194
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
195
+ # save lse
196
+ lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
197
+ tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
198
+
199
+
200
+ @triton.jit
201
+ def backward_sum_o_do(
202
+ o_ptr, # O: n x h x d
203
+ do_ptr, # dO: n x h x d
204
+ delta_ptr, # D: h x n
205
+ o_len,
206
+ HEAD_DIM,
207
+ stride_on,
208
+ stride_oh,
209
+ stride_od,
210
+ stride_don,
211
+ stride_doh,
212
+ stride_dod,
213
+ stride_dh,
214
+ stride_dn,
215
+ BLOCK_SIZE_O: tl.constexpr,
216
+ BLOCK_SIZE_D: tl.constexpr,
217
+ ):
218
+ pid_n = tl.program_id(0)
219
+ pid_h = tl.program_id(1)
220
+ off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
221
+ off_d = tl.arange(0, BLOCK_SIZE_D)
222
+ o = tl.load(
223
+ o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od,
224
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
225
+ other=0,
226
+ ).to(tl.float32)
227
+ do = tl.load(
228
+ do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod,
229
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
230
+ other=0,
231
+ ).to(tl.float32)
232
+ delta = tl.sum(o * do, axis=1)
233
+ tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
234
+
235
+
236
+ @triton.jit
237
+ def count_kernel(
238
+ x_ptr, # [num_kv_heads, total_len, topk]
239
+ y_ptr, # [num_kv_heads, total_blocks]
240
+ cu_seqlens, # [batch_size + 1]
241
+ cu_seqblocks, # [batch_size + 1]
242
+ topk,
243
+ stride_xh,
244
+ stride_xn,
245
+ stride_xk,
246
+ stride_yh,
247
+ stride_yn,
248
+ BLOCK_SIZE_N: tl.constexpr,
249
+ BLOCK_SIZE_K: tl.constexpr,
250
+ BLOCK_SIZE_R: tl.constexpr,
251
+ ):
252
+ pid_h = tl.program_id(0)
253
+ pid_b = tl.program_id(1)
254
+ # get start and len after rmpad
255
+ seq_start = tl.load(cu_seqlens + pid_b)
256
+ seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start
257
+ blocks_start = tl.load(cu_seqblocks + pid_b)
258
+ num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start
259
+ # load x
260
+ off_k = tl.arange(0, BLOCK_SIZE_K)
261
+ off_n = tl.arange(0, BLOCK_SIZE_N)
262
+ x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn
263
+ x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk
264
+ # init y
265
+ y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32)
266
+ # loop
267
+ for i in range(0, seq_len, BLOCK_SIZE_N):
268
+ x = tl.load(
269
+ x_ptrs,
270
+ mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :],
271
+ other=-1,
272
+ )
273
+ x = tl.ravel(x)
274
+ y += tl.histogram(x, BLOCK_SIZE_R)
275
+ x_ptrs += BLOCK_SIZE_N * stride_xn
276
+ # store result
277
+ off_r = tl.arange(0, BLOCK_SIZE_R)
278
+ y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn
279
+ y_ptrs = y_ptr + off_r * stride_yn
280
+ tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks)
281
+
282
+
283
+ def count_query(
284
+ topk_idx: torch.Tensor,
285
+ cu_seqlens: torch.Tensor,
286
+ cu_seqblocks: torch.Tensor,
287
+ block_size: int,
288
+ ):
289
+ num_kv_heads, total_len, topk = topk_idx.shape
290
+ seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
291
+ seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1]
292
+ batch_size = seqlens.shape[0]
293
+ BLOCK_SIZE_K = triton.next_power_of_2(topk)
294
+ BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K)
295
+ BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2)
296
+ active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device)
297
+ grid = (num_kv_heads, batch_size)
298
+ count_kernel[grid](
299
+ topk_idx,
300
+ active_query_count,
301
+ cu_seqlens,
302
+ cu_seqblocks,
303
+ topk,
304
+ topk_idx.stride(0),
305
+ topk_idx.stride(1),
306
+ topk_idx.stride(2),
307
+ active_query_count.stride(0),
308
+ active_query_count.stride(1),
309
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
310
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
311
+ BLOCK_SIZE_R=BLOCK_SIZE_R,
312
+ num_warps=4,
313
+ num_stages=3,
314
+ )
315
+ return active_query_count
316
+
317
+
318
+ @triton.jit
319
+ def pad_topk_idx_kernel(
320
+ t_ptr,
321
+ p_ptr,
322
+ cu_seqlens,
323
+ topk,
324
+ stride_th,
325
+ stride_tn,
326
+ stride_tk,
327
+ stride_pb,
328
+ stride_ph,
329
+ stride_pn,
330
+ stride_pk,
331
+ BLOCK_SIZE_N: tl.constexpr,
332
+ BLOCK_SIZE_T: tl.constexpr,
333
+ ):
334
+ pid_b = tl.program_id(0)
335
+ pid_h = tl.program_id(1)
336
+ pid_n = tl.program_id(2)
337
+ # get q start and len after rmpad
338
+ q_start = tl.load(cu_seqlens + pid_b)
339
+ q_len = tl.load(cu_seqlens + pid_b + 1) - q_start
340
+ if BLOCK_SIZE_N * pid_n >= q_len:
341
+ return
342
+ # init prts
343
+ t_ptrs = tl.make_block_ptr(
344
+ base=t_ptr + pid_h * stride_th + q_start * stride_tn,
345
+ shape=(q_len, topk),
346
+ strides=(stride_tn, stride_tk),
347
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
348
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
349
+ order=(1, 0),
350
+ )
351
+ p_ptrs = tl.make_block_ptr(
352
+ base=p_ptr + pid_b * stride_pb + pid_h * stride_ph,
353
+ shape=(q_len, topk),
354
+ strides=(stride_pn, stride_pk),
355
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
356
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
357
+ order=(1, 0),
358
+ )
359
+ # load and save
360
+ idxs = tl.load(t_ptrs, boundary_check=(0, 1))
361
+ tl.store(p_ptrs, idxs, boundary_check=(0, 1))
362
+
363
+
364
+ @triton.jit
365
+ def save_topk_idx_kernel(
366
+ p_ptr,
367
+ t_ptr,
368
+ cu_seqblocks,
369
+ cu_topk_q_count,
370
+ n_len,
371
+ stride_pb,
372
+ stride_ph,
373
+ stride_pn,
374
+ stride_th,
375
+ stride_tn,
376
+ stride_ch,
377
+ stride_cn,
378
+ BLOCK_SIZE_N: tl.constexpr,
379
+ ):
380
+ pid_b = tl.program_id(0)
381
+ pid_h = tl.program_id(1)
382
+ pid_n = tl.program_id(2)
383
+ # get q start and len after rmpad
384
+ q_block_start = tl.load(cu_seqblocks + pid_b)
385
+ q_block_end = tl.load(cu_seqblocks + pid_b + 1)
386
+ c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn)
387
+ c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn)
388
+ c_len = c_end - c_start
389
+ if c_len <= 0:
390
+ return
391
+ if pid_n * BLOCK_SIZE_N >= c_len:
392
+ return
393
+ # init ptrs
394
+ p_ptrs = tl.make_block_ptr(
395
+ base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn,
396
+ shape=(c_len,),
397
+ strides=(stride_pn,),
398
+ offsets=(pid_n * BLOCK_SIZE_N,),
399
+ block_shape=(BLOCK_SIZE_N,),
400
+ order=(0,),
401
+ )
402
+ t_ptrs = tl.make_block_ptr(
403
+ base=t_ptr + pid_h * stride_th + c_start * stride_tn,
404
+ shape=(c_len,),
405
+ strides=(stride_tn,),
406
+ offsets=(pid_n * BLOCK_SIZE_N,),
407
+ block_shape=(BLOCK_SIZE_N,),
408
+ order=(0,),
409
+ )
410
+ # load and save
411
+ idxs = tl.load(p_ptrs, boundary_check=(0,))
412
+ tl.store(t_ptrs, idxs, boundary_check=(0,))
413
+
414
+
415
+ def reorder_topk_idx(
416
+ topk_idx: torch.Tensor,
417
+ cu_topk_q_count: torch.Tensor,
418
+ cu_seqlens: torch.Tensor,
419
+ cu_seqblocks: torch.Tensor,
420
+ block_size: int,
421
+ ):
422
+ num_kv_heads, total_len, topk = topk_idx.shape
423
+ batch_size = cu_seqlens.shape[0] - 1
424
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
425
+ max_seqlen = seq_lens.max().item()
426
+ # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk]
427
+ pad_topk_idx = torch.full(
428
+ (batch_size, num_kv_heads, max_seqlen, topk),
429
+ fill_value=-1,
430
+ device=topk_idx.device,
431
+ dtype=torch.int32,
432
+ )
433
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
434
+ BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T))
435
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N))
436
+ pad_topk_idx_kernel[grid](
437
+ topk_idx,
438
+ pad_topk_idx,
439
+ cu_seqlens,
440
+ topk,
441
+ topk_idx.stride(0),
442
+ topk_idx.stride(1),
443
+ topk_idx.stride(2),
444
+ pad_topk_idx.stride(0),
445
+ pad_topk_idx.stride(1),
446
+ pad_topk_idx.stride(2),
447
+ pad_topk_idx.stride(3),
448
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
449
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
450
+ )
451
+ # argsort
452
+ pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk
453
+ pad_topk_q_idx = pad_topk_q_idx.to(torch.int32)
454
+ # save as remove pad version
455
+ topk_q_idx = torch.full(
456
+ (num_kv_heads, cu_topk_q_count[:, -1].max().item()),
457
+ fill_value=-1,
458
+ device=topk_idx.device,
459
+ dtype=torch.int32,
460
+ )
461
+ max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item()
462
+ BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192)
463
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N))
464
+ save_topk_idx_kernel[grid](
465
+ pad_topk_q_idx,
466
+ topk_q_idx,
467
+ cu_seqblocks,
468
+ cu_topk_q_count,
469
+ pad_topk_q_idx.shape[-1],
470
+ pad_topk_q_idx.stride(0),
471
+ pad_topk_q_idx.stride(1),
472
+ pad_topk_q_idx.stride(2),
473
+ topk_q_idx.stride(0),
474
+ topk_q_idx.stride(1),
475
+ cu_topk_q_count.stride(0),
476
+ cu_topk_q_count.stride(1),
477
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
478
+ )
479
+ return topk_q_idx
480
+
481
+
482
+ @triton.jit
483
+ def backward_dkdv(
484
+ q_ptr, # Q: n x qh x d
485
+ k_ptr, # K: n x kh x d
486
+ v_ptr, # V: n x kh x d
487
+ tq_ptr, # topk_q_idx: kh x N
488
+ lse_ptr, # LSE: qh x n
489
+ d_ptr, # Delta: qh x n
490
+ do_ptr,
491
+ dk_ptr, # DK: sh x n x kh x d
492
+ dv_ptr, # DK: sh x n x kh x d
493
+ # seqlens
494
+ cu_seqlens_q, # [batch_size + 1]
495
+ cu_seqlens_k, # [batch_size + 1]
496
+ cu_seqblocks, # [batch_size + 1]
497
+ cu_topk_q_count, # [kh, total_blocks]
498
+ # shape
499
+ NUM_KV_HEADS,
500
+ NUM_SHARE_Q_HEADS,
501
+ HEAD_DIM,
502
+ TOPK,
503
+ # sm_scale
504
+ sm_scale,
505
+ # stride
506
+ stride_qn,
507
+ stride_qh,
508
+ stride_qd,
509
+ stride_kn,
510
+ stride_kh,
511
+ stride_kd,
512
+ stride_vn,
513
+ stride_vh,
514
+ stride_vd,
515
+ stride_tqh,
516
+ stride_tqn,
517
+ stride_ctqh,
518
+ stride_ctqn,
519
+ stride_lh,
520
+ stride_ln,
521
+ stride_dh,
522
+ stride_dn,
523
+ stride_don,
524
+ stride_doh,
525
+ stride_dod,
526
+ stride_dks,
527
+ stride_dkn,
528
+ stride_dkh,
529
+ stride_dkd,
530
+ stride_dvs,
531
+ stride_dvn,
532
+ stride_dvh,
533
+ stride_dvd,
534
+ # META parameters
535
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
536
+ BLOCK_SIZE_K: tl.constexpr, # k block size
537
+ BLOCK_SIZE_D: tl.constexpr,
538
+ ):
539
+ qk_scale = sm_scale * 1.44269504
540
+ # get batch id and head id
541
+ pid_b = tl.program_id(0)
542
+ pid_h = tl.program_id(1)
543
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
544
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
545
+ pid_k = tl.program_id(2)
546
+ # get q k start and len after rmpad
547
+ q_start = tl.load(cu_seqlens_q + pid_b)
548
+ tl.load(cu_seqlens_q + pid_b + 1) - q_start
549
+ k_start = tl.load(cu_seqlens_k + pid_b)
550
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
551
+ if BLOCK_SIZE_K * pid_k >= k_len:
552
+ return
553
+ # get topk_q_idx
554
+ b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence
555
+ act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn)
556
+ act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn)
557
+ act_q_len = act_q_end - act_q_start
558
+ tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn
559
+ # init pointers
560
+ k_ptrs = tl.make_block_ptr(
561
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
562
+ shape=(k_len, HEAD_DIM),
563
+ strides=(stride_kn, stride_kd),
564
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
565
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
566
+ order=(1, 0),
567
+ )
568
+ dk_ptrs = tl.make_block_ptr(
569
+ base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
570
+ shape=(k_len, HEAD_DIM),
571
+ strides=(stride_dkn, stride_dkd),
572
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
573
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
574
+ order=(1, 0),
575
+ )
576
+ v_ptrs = tl.make_block_ptr(
577
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
578
+ shape=(k_len, HEAD_DIM),
579
+ strides=(stride_vn, stride_vd),
580
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
581
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
582
+ order=(1, 0),
583
+ )
584
+ dv_ptrs = tl.make_block_ptr(
585
+ base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
586
+ shape=(k_len, HEAD_DIM),
587
+ strides=(stride_dvn, stride_dvd),
588
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
589
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
590
+ order=(1, 0),
591
+ )
592
+ # offsets
593
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
594
+ off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
595
+ off_d = tl.arange(0, BLOCK_SIZE_D)
596
+ # load k v and keep in SRAM
597
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
598
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
599
+ # init dk dv
600
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
601
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
602
+ # init ptrs
603
+ q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd
604
+ do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod
605
+ d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh
606
+ lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh
607
+ # loop for q blocks
608
+ for i in range(0, act_q_len, BLOCK_SIZE_Q):
609
+ # load
610
+ idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32)
611
+ q = tl.load(
612
+ q_ptrs + idx_q[:, None] * stride_qn,
613
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
614
+ other=0,
615
+ )
616
+ do = tl.load(
617
+ do_ptrs + idx_q[:, None] * stride_don,
618
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
619
+ other=0,
620
+ )
621
+ lse = tl.load(
622
+ lse_ptrs + idx_q[:, None] * stride_ln,
623
+ mask=(off_q < act_q_len - i)[:, None],
624
+ other=0,
625
+ )
626
+ d = tl.load(
627
+ d_ptrs + idx_q[:, None] * stride_dn,
628
+ mask=(off_q < act_q_len - i)[:, None],
629
+ other=0,
630
+ )
631
+ # compute qk
632
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
633
+ qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf"))
634
+ qk += tl.dot(q, k.T) * qk_scale
635
+ # compute p, ds
636
+ p = tl.exp2(qk - lse)
637
+ dp = tl.dot(do, v.T)
638
+ ds = sm_scale * p * (dp - d)
639
+ # cast dtype
640
+ p = p.to(do.dtype)
641
+ ds = ds.to(q.dtype)
642
+ # update dk and dv
643
+ dk += tl.dot(ds.T, q)
644
+ dv += tl.dot(p.T, do)
645
+ # save dk dv
646
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
647
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
648
+
649
+
650
+ @triton.jit
651
+ def backward_dq(
652
+ q_ptr, # Q: n x qh x d
653
+ k_ptr, # K: n x kh x d
654
+ v_ptr, # V: n x kh x d
655
+ t_ptr, # topk_idx: kh x n x k
656
+ lse_ptr, # LSE: qh x n
657
+ d_ptr, # Delta: qh x n
658
+ do_ptr,
659
+ dq_ptr,
660
+ # seqlens
661
+ cu_seqlens_q,
662
+ cu_seqlens_k,
663
+ # shape
664
+ NUM_KV_HEADS,
665
+ NUM_SHARE_Q_HEADS,
666
+ HEAD_DIM,
667
+ TOPK,
668
+ # q loop num
669
+ num_q_loop,
670
+ # sm_scale
671
+ sm_scale,
672
+ # stride
673
+ stride_qn,
674
+ stride_qh,
675
+ stride_qd,
676
+ stride_kn,
677
+ stride_kh,
678
+ stride_kd,
679
+ stride_vn,
680
+ stride_vh,
681
+ stride_vd,
682
+ stride_th,
683
+ stride_tn,
684
+ stride_tk,
685
+ stride_lh,
686
+ stride_ln,
687
+ stride_dh,
688
+ stride_dn,
689
+ stride_don,
690
+ stride_doh,
691
+ stride_dod,
692
+ stride_dqn,
693
+ stride_dqh,
694
+ stride_dqd,
695
+ # META parameters
696
+ BLOCK_SIZE_K: tl.constexpr, # k block size
697
+ BLOCK_SIZE_D: tl.constexpr,
698
+ BLOCK_SIZE_H: tl.constexpr,
699
+ BLOCK_SIZE_T: tl.constexpr,
700
+ ):
701
+ qk_scale = sm_scale * 1.44269504
702
+ # get batch id and head id
703
+ pid_b = tl.program_id(0)
704
+ pid_kh = tl.program_id(1)
705
+ pid_q = tl.program_id(2)
706
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
707
+ # get q k start and len after rmpad
708
+ q_start = tl.load(cu_seqlens_q + pid_b)
709
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
710
+ k_start = tl.load(cu_seqlens_k + pid_b)
711
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
712
+ if pid_q * num_q_loop >= q_len:
713
+ return
714
+ real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
715
+ for j in range(real_q_loop):
716
+ pid_q_j = pid_q * num_q_loop + j
717
+ # init topk idx pointer
718
+ off_t = tl.arange(0, BLOCK_SIZE_T)
719
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
720
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
721
+
722
+ real_topk = tl.sum(
723
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),
724
+ axis=0,
725
+ )
726
+ # init pointers
727
+ q_ptrs = tl.make_block_ptr(
728
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
729
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
730
+ strides=(stride_qh, stride_qd),
731
+ offsets=(0, 0),
732
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
733
+ order=(1, 0),
734
+ )
735
+ dq_ptrs = tl.make_block_ptr(
736
+ base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh,
737
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
738
+ strides=(stride_dqh, stride_dqd),
739
+ offsets=(0, 0),
740
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
741
+ order=(1, 0),
742
+ )
743
+ k_ptrs = tl.make_block_ptr(
744
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
745
+ shape=(k_len, HEAD_DIM),
746
+ strides=(stride_kn, stride_kd),
747
+ offsets=(0, 0),
748
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
749
+ order=(1, 0),
750
+ )
751
+ v_ptrs = tl.make_block_ptr(
752
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
753
+ shape=(HEAD_DIM, k_len),
754
+ strides=(stride_vd, stride_vn),
755
+ offsets=(0, 0),
756
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
757
+ order=(0, 1),
758
+ )
759
+ do_ptrs = tl.make_block_ptr(
760
+ base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh,
761
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
762
+ strides=(stride_doh, stride_dod),
763
+ offsets=(0, 0),
764
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
765
+ order=(1, 0),
766
+ )
767
+ d_ptrs = tl.make_block_ptr(
768
+ base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh,
769
+ shape=(NUM_SHARE_Q_HEADS, 1),
770
+ strides=(stride_dh, stride_dn),
771
+ offsets=(0, 0),
772
+ block_shape=(BLOCK_SIZE_H, 1),
773
+ order=(1, 0),
774
+ )
775
+ lse_ptrs = tl.make_block_ptr(
776
+ base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh,
777
+ shape=(NUM_SHARE_Q_HEADS, 1),
778
+ strides=(stride_lh, stride_ln),
779
+ offsets=(0, 0),
780
+ block_shape=(BLOCK_SIZE_H, 1),
781
+ order=(1, 0),
782
+ )
783
+ # offsets
784
+ off_k = tl.arange(0, BLOCK_SIZE_K)
785
+ # load q, do, lse, delta, and keep in SRAM
786
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
787
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
788
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
789
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
790
+ # init dq
791
+ dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32)
792
+ # sparse
793
+ for i in range(real_topk):
794
+ # get current block start index
795
+ c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
796
+ t_ptr_j = t_ptr_j + stride_tk
797
+ # load
798
+ k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero")
799
+ v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero")
800
+ # compute qk
801
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
802
+ qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
803
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
804
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
805
+ # compute p, ds
806
+ p = tl.exp2(qk - lse)
807
+ dp = tl.dot(do, v)
808
+ ds = sm_scale * p * (dp - d)
809
+ # cast dtype
810
+ ds = ds.to(q.dtype)
811
+ # update dq
812
+ dq += tl.dot(ds, k)
813
+ # save dq
814
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
815
+
816
+
817
+ def _topk_sparse_attention_fwd(
818
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
819
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
820
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
821
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
822
+ block_size: int,
823
+ cu_seqlens_q: torch.Tensor,
824
+ cu_seqlens_k: torch.Tensor,
825
+ max_seqlen_q: int,
826
+ max_seqlen_k: int,
827
+ sm_scale: float,
828
+ ):
829
+ # dtype check
830
+ assert k.dtype == q.dtype and v.dtype == q.dtype
831
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
832
+ assert block_size in {32, 64, 128, 256}
833
+ # shape
834
+ q_len, num_q_heads, head_dim = q.shape
835
+ k_len, num_k_heads, head_dim = k.shape
836
+ v_len, num_v_heads, head_dim = v.shape
837
+ batch_size = cu_seqlens_q.shape[0] - 1
838
+ # assert q_len == k_len and k_len == v_len
839
+ topk = topk_idx.shape[-1]
840
+ assert topk_idx.shape[0] == num_k_heads
841
+ assert topk_idx.shape[1] == q_len
842
+ # gqa
843
+ assert num_k_heads == num_v_heads
844
+ assert num_q_heads % num_k_heads == 0
845
+ num_share_q_heads = num_q_heads // num_k_heads
846
+ # output tensor
847
+ o = torch.zeros_like(q)
848
+
849
+ lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device)
850
+
851
+ # launch kernel
852
+ num_q_loop = num_k_loop = 1
853
+ BLOCK_SIZE_K = triton.next_power_of_2(block_size)
854
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
855
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
856
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
857
+
858
+ def grid(meta):
859
+ grid = (
860
+ batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop),
861
+ )
862
+ return grid
863
+
864
+ num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU)
865
+ forward_kernel_orig[grid](
866
+ q,
867
+ k,
868
+ v,
869
+ topk_idx,
870
+ o,
871
+ lse,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ num_k_heads,
875
+ num_share_q_heads,
876
+ head_dim,
877
+ topk,
878
+ block_size,
879
+ # num_q_loop,
880
+ sm_scale,
881
+ q.stride(0),
882
+ q.stride(1),
883
+ q.stride(2),
884
+ k.stride(0),
885
+ k.stride(1),
886
+ k.stride(2),
887
+ v.stride(0),
888
+ v.stride(1),
889
+ v.stride(2),
890
+ topk_idx.stride(0),
891
+ topk_idx.stride(1),
892
+ topk_idx.stride(2),
893
+ o.stride(0),
894
+ o.stride(1),
895
+ o.stride(2),
896
+ lse.stride(0),
897
+ lse.stride(1),
898
+ num_q_loop=num_q_loop,
899
+ num_k_loop=num_k_loop,
900
+ MAX_SEQ_LEN=max_seqlen_q,
901
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
902
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
903
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
904
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
905
+ num_warps=num_warps,
906
+ num_stages=num_stages,
907
+ )
908
+ return o, lse
909
+
910
+
911
+ def _topk_sparse_attention_bwd(
912
+ o: torch.Tensor,
913
+ do: torch.Tensor,
914
+ lse: torch.Tensor,
915
+ q: torch.Tensor,
916
+ k: torch.Tensor,
917
+ v: torch.Tensor,
918
+ topk_idx: torch.Tensor,
919
+ block_size: int,
920
+ cu_seqlens_q: torch.Tensor,
921
+ cu_seqlens_k: torch.Tensor,
922
+ max_seqlen_q: int,
923
+ max_seqlen_k: int,
924
+ sm_scale: float,
925
+ ):
926
+
927
+ assert block_size in {32, 64, 128, 256}
928
+ q_len, num_q_heads, head_dim = q.shape
929
+ k_len, num_k_heads, head_dim = k.shape
930
+ v_len, num_v_heads, head_dim = v.shape
931
+ o_len, num_o_heads, head_dim = o.shape
932
+ num_share_q_heads = num_q_heads // num_k_heads
933
+ topk = topk_idx.shape[-1]
934
+ # compute D
935
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
936
+ BLOCK_SIZE_O = 256
937
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
938
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
939
+ grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads)
940
+
941
+ backward_sum_o_do[grid](
942
+ o,
943
+ do,
944
+ delta,
945
+ o_len,
946
+ head_dim,
947
+ o.stride(0),
948
+ o.stride(1),
949
+ o.stride(2),
950
+ do.stride(0),
951
+ do.stride(1),
952
+ do.stride(2),
953
+ delta.stride(0),
954
+ delta.stride(1),
955
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
956
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
957
+ num_warps=num_warps,
958
+ num_stages=num_stages,
959
+ )
960
+ # count active querys for each key block, shape: (num_k_heads, total_k_blocks)
961
+ seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
962
+ seqblocks = torch.ceil(seqlens / block_size).to(torch.int32)
963
+ cu_seqblocks = torch.cat(
964
+ [
965
+ torch.zeros(1, dtype=torch.int32, device=topk_idx.device),
966
+ torch.cumsum(seqblocks, dim=0),
967
+ ]
968
+ ).to(torch.int32)
969
+
970
+ topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size)
971
+
972
+ cu_topk_q_count = torch.cat(
973
+ [
974
+ torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device),
975
+ torch.cumsum(topk_q_count, dim=-1),
976
+ ],
977
+ dim=-1,
978
+ ).to(torch.int32)
979
+ # active query idx for each key block
980
+ # how to get active query idx for sequence b, head h, kv block i?
981
+ topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size)
982
+ # compute dk dv
983
+ dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
984
+ dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
985
+ batch_size = cu_seqlens_q.shape[0] - 1
986
+ BLOCK_SIZE_K = triton.next_power_of_2(block_size)
987
+ BLOCK_SIZE_Q = 64
988
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
989
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
990
+ grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K))
991
+ backward_dkdv[grid](
992
+ q,
993
+ k,
994
+ v,
995
+ topk_q_idx,
996
+ lse,
997
+ delta,
998
+ do,
999
+ dk,
1000
+ dv,
1001
+ cu_seqlens_q,
1002
+ cu_seqlens_k,
1003
+ cu_seqblocks,
1004
+ cu_topk_q_count,
1005
+ num_k_heads,
1006
+ num_share_q_heads,
1007
+ head_dim,
1008
+ topk,
1009
+ sm_scale,
1010
+ q.stride(0),
1011
+ q.stride(1),
1012
+ q.stride(2),
1013
+ k.stride(0),
1014
+ k.stride(1),
1015
+ k.stride(2),
1016
+ v.stride(0),
1017
+ v.stride(1),
1018
+ v.stride(2),
1019
+ topk_q_idx.stride(0),
1020
+ topk_q_idx.stride(1),
1021
+ cu_topk_q_count.stride(0),
1022
+ cu_topk_q_count.stride(1),
1023
+ lse.stride(0),
1024
+ lse.stride(1),
1025
+ delta.stride(0),
1026
+ delta.stride(1),
1027
+ do.stride(0),
1028
+ do.stride(1),
1029
+ do.stride(2),
1030
+ dk.stride(0),
1031
+ dk.stride(1),
1032
+ dk.stride(2),
1033
+ dk.stride(3),
1034
+ dv.stride(0),
1035
+ dv.stride(1),
1036
+ dv.stride(2),
1037
+ dv.stride(3),
1038
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1039
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1040
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1041
+ num_warps=num_warps,
1042
+ num_stages=num_stages,
1043
+ )
1044
+ dk = dk.sum(0)
1045
+ dv = dv.sum(0)
1046
+ # compute dq
1047
+ dq = torch.zeros_like(q)
1048
+ num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long
1049
+ grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
1050
+ BLOCK_SIZE_K = block_size
1051
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
1052
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
1053
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
1054
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
1055
+
1056
+ backward_dq[grid](
1057
+ q,
1058
+ k,
1059
+ v,
1060
+ topk_idx,
1061
+ lse,
1062
+ delta,
1063
+ do,
1064
+ dq,
1065
+ cu_seqlens_q,
1066
+ cu_seqlens_k,
1067
+ num_k_heads,
1068
+ num_share_q_heads,
1069
+ head_dim,
1070
+ topk,
1071
+ num_q_loop,
1072
+ sm_scale,
1073
+ q.stride(0),
1074
+ q.stride(1),
1075
+ q.stride(2),
1076
+ k.stride(0),
1077
+ k.stride(1),
1078
+ k.stride(2),
1079
+ v.stride(0),
1080
+ v.stride(1),
1081
+ v.stride(2),
1082
+ topk_idx.stride(0),
1083
+ topk_idx.stride(1),
1084
+ topk_idx.stride(2),
1085
+ lse.stride(0),
1086
+ lse.stride(1),
1087
+ delta.stride(0),
1088
+ delta.stride(1),
1089
+ do.stride(0),
1090
+ do.stride(1),
1091
+ do.stride(2),
1092
+ dq.stride(0),
1093
+ dq.stride(1),
1094
+ dq.stride(2),
1095
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1096
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1097
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
1098
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
1099
+ num_warps=num_warps,
1100
+ num_stages=num_stages,
1101
+ )
1102
+ return dq, dk, dv
1103
+
1104
+
1105
+ class TopkSparseAttention(torch.autograd.Function):
1106
+ @staticmethod
1107
+ def forward(
1108
+ ctx,
1109
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
1110
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
1111
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
1112
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
1113
+ block_size: int,
1114
+ cu_seqlens_q: torch.Tensor,
1115
+ cu_seqlens_k: torch.Tensor,
1116
+ max_seqlen_q: torch.Tensor,
1117
+ max_seqlen_k: torch.Tensor,
1118
+ sm_scale=None,
1119
+ ):
1120
+ # dtype check
1121
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
1122
+ assert q.dtype == k.dtype and k.dtype == v.dtype
1123
+ assert topk_idx.dtype == torch.int32
1124
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
1125
+ # softmax scale
1126
+ if sm_scale is None:
1127
+ sm_scale = 1 / math.sqrt(q.shape[-1])
1128
+
1129
+ o, lse = _topk_sparse_attention_fwd(
1130
+ q,
1131
+ k,
1132
+ v,
1133
+ topk_idx,
1134
+ block_size,
1135
+ cu_seqlens_q,
1136
+ cu_seqlens_k,
1137
+ max_seqlen_q,
1138
+ max_seqlen_k,
1139
+ sm_scale,
1140
+ )
1141
+
1142
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx)
1143
+ ctx.sm_scale = sm_scale
1144
+ ctx.max_seqlen_q = max_seqlen_q
1145
+ ctx.max_seqlen_k = max_seqlen_k
1146
+ ctx.block_size = block_size
1147
+ return o
1148
+
1149
+ @staticmethod
1150
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
1151
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors
1152
+
1153
+ max_seqlen_q = ctx.max_seqlen_q
1154
+ max_seqlen_k = ctx.max_seqlen_k
1155
+ sm_scale = ctx.sm_scale
1156
+ block_size = ctx.block_size
1157
+ assert block_size in {32, 64, 128, 256}
1158
+
1159
+ dq, dk, dv = _topk_sparse_attention_bwd(
1160
+ o,
1161
+ do,
1162
+ lse,
1163
+ q,
1164
+ k,
1165
+ v,
1166
+ topk_idx,
1167
+ block_size,
1168
+ cu_seqlens_q,
1169
+ cu_seqlens_k,
1170
+ max_seqlen_q,
1171
+ max_seqlen_k,
1172
+ sm_scale,
1173
+ )
1174
+ return dq, dk, dv, None, None, None, None, None, None, None, None
1175
+
1176
+
1177
+ def topk_sparse_attention(
1178
+ q: torch.Tensor,
1179
+ k: torch.Tensor,
1180
+ v: torch.Tensor,
1181
+ topk_idx: torch.Tensor,
1182
+ block_size: int,
1183
+ cu_seqlens: torch.Tensor,
1184
+ softmax_scale: Optional[float] = None,
1185
+ ) -> torch.Tensor:
1186
+ """Topk sparse attention varlen version implemented in triton.
1187
+
1188
+ Args:
1189
+ q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
1190
+ k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
1191
+ v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
1192
+ topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
1193
+ block_size (int): key value block size.
1194
+ cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
1195
+ softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
1196
+
1197
+ Returns:
1198
+ torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
1199
+ """
1200
+
1201
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
1202
+ return TopkSparseAttention.apply(
1203
+ q,
1204
+ k,
1205
+ v,
1206
+ topk_idx,
1207
+ block_size,
1208
+ cu_seqlens,
1209
+ cu_seqlens,
1210
+ max_seqlen,
1211
+ max_seqlen,
1212
+ softmax_scale,
1213
+ )
ops/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def is_hopper_gpu():
5
+ if torch.cuda.is_available():
6
+ device_capability = torch.cuda.get_device_capability(0)
7
+ major, minor = device_capability
8
+ return major == 9
9
+ return False
10
+
11
+
12
+ def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
13
+ """
14
+ Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
15
+
16
+ Args:
17
+ head_dim (int): Size of the head dimension.
18
+ block_size (int): Size of the block in the attention matrix.
19
+ is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
20
+
21
+ Returns:
22
+ tuple: (num_warps, num_stages) recommended values.
23
+ """
24
+ # Determine if head_dim and block_size exceed 64
25
+ head_large = head_dim > 64
26
+ block_large = block_size > 64
27
+
28
+ if is_hopper_gpu:
29
+ # Hopper GPU recommendations
30
+ if head_large and block_large:
31
+ num_warps = 8
32
+ num_stages = 3
33
+ elif head_large or block_large:
34
+ num_warps = 4
35
+ num_stages = 3
36
+ else:
37
+ num_warps = 2
38
+ num_stages = 2
39
+ else:
40
+ # Ampere GPU recommendations
41
+ if head_large and block_large:
42
+ num_warps = 8
43
+ num_stages = 3
44
+ elif head_large or block_large:
45
+ num_warps = 8
46
+ num_stages = 3
47
+ else:
48
+ num_warps = 2
49
+ num_stages = 2
50
+ return num_warps, num_stages