Maxtimer97 commited on
Commit
a2f57c7
·
1 Parent(s): 28885e1

Flattened repo

Browse files
Files changed (5) hide show
  1. compressed_attention.py +1320 -0
  2. modeling_chatglm.py +8 -3
  3. pooling.py +207 -0
  4. topk_sparse_attention.py +1213 -0
  5. utils.py +50 -0
compressed_attention.py ADDED
@@ -0,0 +1,1320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import warnings
16
+ from typing import Any, Tuple, Union
17
+
18
+ import torch
19
+ import triton
20
+ import triton.language as tl
21
+
22
+ try:
23
+ from .utils import get_num_warps_stages, is_hopper_gpu
24
+ except ImportError:
25
+ from ops.utils import get_num_warps_stages, is_hopper_gpu
26
+
27
+ IS_HOPPER_GPU = is_hopper_gpu()
28
+
29
+
30
+ @triton.jit
31
+ def forward_kernel(
32
+ q_ptr, # Q: n x h x d
33
+ k_ptr, # K: n x h x d
34
+ v_ptr, # V: n x h x d
35
+ o_ptr, # O: n x h x d
36
+ lse_ptr, # LSE: h x n
37
+ # size and stride at compresstion
38
+ kernel_size,
39
+ kernel_stride,
40
+ # seqlens
41
+ cu_seqlens_q,
42
+ cu_seqlens_k,
43
+ # shape
44
+ NUM_KV_HEADS,
45
+ NUM_SHARE_Q_HEADS,
46
+ HEAD_DIM,
47
+ # sm_scale
48
+ sm_scale,
49
+ # stride
50
+ stride_qn,
51
+ stride_qh,
52
+ stride_qd,
53
+ stride_kn,
54
+ stride_kh,
55
+ stride_kd,
56
+ stride_vn,
57
+ stride_vh,
58
+ stride_vd,
59
+ stride_on,
60
+ stride_oh,
61
+ stride_od,
62
+ stride_lh,
63
+ stride_ln,
64
+ # META parameters
65
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
66
+ BLOCK_SIZE_K: tl.constexpr, # k block size
67
+ BLOCK_SIZE_D: tl.constexpr,
68
+ ):
69
+ qk_scale = sm_scale * 1.44269504
70
+ # get batch id and head id
71
+ pid_b = tl.program_id(0)
72
+ pid_h = tl.program_id(1)
73
+ pid_q = tl.program_id(2)
74
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
75
+ # get q k start and len after rmpad
76
+ q_start = tl.load(cu_seqlens_q + pid_b)
77
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
78
+ k_start = tl.load(cu_seqlens_k + pid_b)
79
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
80
+ # skip first kernel_size query block, because they do no attend to any keys
81
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
82
+ if q_start_in_seq >= q_len:
83
+ return
84
+ # init qkv pointer
85
+ q_ptrs = tl.make_block_ptr(
86
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
87
+ shape=(q_len, HEAD_DIM),
88
+ strides=(stride_qn, stride_qd),
89
+ offsets=(q_start_in_seq, 0),
90
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
91
+ order=(1, 0),
92
+ )
93
+ k_ptrs = tl.make_block_ptr(
94
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
95
+ shape=(HEAD_DIM, k_len),
96
+ strides=(stride_kd, stride_kn),
97
+ offsets=(0, 0),
98
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
99
+ order=(0, 1),
100
+ )
101
+ v_ptrs = tl.make_block_ptr(
102
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
103
+ shape=(k_len, HEAD_DIM),
104
+ strides=(stride_vn, stride_vd),
105
+ offsets=(0, 0),
106
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
107
+ order=(1, 0),
108
+ )
109
+ # load q
110
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
111
+ # init statistics
112
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
113
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
114
+ m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
115
+ lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
116
+ acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
117
+ # attention
118
+ lo = 0
119
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
120
+ for i in range(lo, hi, BLOCK_SIZE_K):
121
+ i = tl.multiple_of(i, BLOCK_SIZE_K)
122
+ # load k
123
+ k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
124
+ # compute qk
125
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
126
+ qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
127
+ qk += tl.dot(q, k) * qk_scale
128
+ # compute m_ij and l_ij
129
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
130
+ p = tl.exp2(qk - m_ij[:, None])
131
+ l_ij = tl.sum(p, axis=1)
132
+ # scale acc_o
133
+ acc_o_scale = tl.exp2(m_i - m_ij)
134
+ acc_o = acc_o * acc_o_scale[:, None]
135
+ # load v and update acc_o
136
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
137
+ p = p.to(v.dtype)
138
+ acc_o += tl.dot(p, v)
139
+ # update statistics
140
+ m_i = m_ij
141
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
142
+ # update ptrs
143
+ k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
144
+ v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
145
+ # final scale
146
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
147
+ # save output
148
+ o_ptrs = tl.make_block_ptr(
149
+ base=o_ptr + q_start * stride_on + pid_h * stride_oh,
150
+ shape=(q_len, HEAD_DIM),
151
+ strides=(stride_on, stride_od),
152
+ offsets=(q_start_in_seq, 0),
153
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
154
+ order=(1, 0),
155
+ )
156
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
157
+ # save lse
158
+ l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
159
+ tl.store(l_ptrs, lse_i, mask=off_q < q_len)
160
+
161
+
162
+ @triton.jit
163
+ def backward_sum_o_do(
164
+ o_ptr, # O: n x h x d
165
+ do_ptr, # dO: n x h x d
166
+ delta_ptr, # D: h x n
167
+ o_len,
168
+ HEAD_DIM,
169
+ stride_on,
170
+ stride_oh,
171
+ stride_od,
172
+ stride_don,
173
+ stride_doh,
174
+ stride_dod,
175
+ stride_dh,
176
+ stride_dn,
177
+ BLOCK_SIZE_O: tl.constexpr,
178
+ BLOCK_SIZE_D: tl.constexpr,
179
+ ):
180
+ pid_n = tl.program_id(0)
181
+ pid_h = tl.program_id(1)
182
+ off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
183
+ off_d = tl.arange(0, BLOCK_SIZE_D)
184
+ o = tl.load(
185
+ o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od,
186
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
187
+ other=0,
188
+ ).to(tl.float32)
189
+ do = tl.load(
190
+ do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod,
191
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
192
+ other=0,
193
+ ).to(tl.float32)
194
+ delta = tl.sum(o * do, axis=1)
195
+ tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
196
+
197
+
198
+ @triton.jit
199
+ def backward_dkdv(
200
+ q_ptr, # Q: n x qh x d
201
+ k_ptr, # K: n x kh x d
202
+ v_ptr, # V: n x kh x d
203
+ lse_ptr, # LSE: qh x n
204
+ d_ptr, # Delta: qh x n
205
+ do_ptr,
206
+ dk_ptr, # DK: sh x n x kh x d
207
+ dv_ptr, # DV: sh x n x kh x d
208
+ kernel_size,
209
+ kernel_stride,
210
+ # seqlens
211
+ cu_seqlens_q,
212
+ cu_seqlens_k,
213
+ # shape
214
+ NUM_KV_HEADS,
215
+ NUM_SHARE_Q_HEADS,
216
+ HEAD_DIM,
217
+ # sm_scale
218
+ sm_scale,
219
+ # stride
220
+ stride_qn,
221
+ stride_qh,
222
+ stride_qd,
223
+ stride_kn,
224
+ stride_kh,
225
+ stride_kd,
226
+ stride_vn,
227
+ stride_vh,
228
+ stride_vd,
229
+ stride_lh,
230
+ stride_ln,
231
+ stride_dh,
232
+ stride_dn,
233
+ stride_don,
234
+ stride_doh,
235
+ stride_dod,
236
+ stride_dks,
237
+ stride_dkn,
238
+ stride_dkh,
239
+ stride_dkd,
240
+ stride_dvs,
241
+ stride_dvn,
242
+ stride_dvh,
243
+ stride_dvd,
244
+ # META parameters
245
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
246
+ BLOCK_SIZE_K: tl.constexpr, # k block size
247
+ BLOCK_SIZE_D: tl.constexpr,
248
+ ):
249
+ qk_scale = sm_scale * 1.44269504
250
+ # get batch id and head id
251
+ pid_b = tl.program_id(0)
252
+ pid_h = tl.program_id(1)
253
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
254
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
255
+ pid_k = tl.program_id(2)
256
+ # get q k start and len after rmpad
257
+ q_start = tl.load(cu_seqlens_q + pid_b)
258
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
259
+ k_start = tl.load(cu_seqlens_k + pid_b)
260
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
261
+ if BLOCK_SIZE_K * pid_k >= k_len:
262
+ return
263
+ # init pointers
264
+ k_ptrs = tl.make_block_ptr(
265
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
266
+ shape=(k_len, HEAD_DIM),
267
+ strides=(stride_kn, stride_kd),
268
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
269
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
270
+ order=(1, 0),
271
+ )
272
+ dk_ptrs = tl.make_block_ptr(
273
+ base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
274
+ shape=(k_len, HEAD_DIM),
275
+ strides=(stride_dkn, stride_dkd),
276
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
277
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
278
+ order=(1, 0),
279
+ )
280
+ v_ptrs = tl.make_block_ptr(
281
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
282
+ shape=(k_len, HEAD_DIM),
283
+ strides=(stride_vn, stride_vd),
284
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
285
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
286
+ order=(1, 0),
287
+ )
288
+ dv_ptrs = tl.make_block_ptr(
289
+ base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
290
+ shape=(k_len, HEAD_DIM),
291
+ strides=(stride_dvn, stride_dvd),
292
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
293
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
294
+ order=(1, 0),
295
+ )
296
+ # offsets
297
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
298
+ off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
299
+ # load k v and keep in SRAM
300
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
301
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
302
+ # init dk dv
303
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
304
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
305
+ q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
306
+ q_ptrs = tl.make_block_ptr(
307
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
308
+ shape=(HEAD_DIM, q_len),
309
+ strides=(stride_qd, stride_qn),
310
+ offsets=(0, q_lo),
311
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
312
+ order=(0, 1),
313
+ )
314
+ do_ptrs = tl.make_block_ptr(
315
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
316
+ shape=(HEAD_DIM, q_len),
317
+ strides=(stride_dod, stride_don),
318
+ offsets=(0, q_lo),
319
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
320
+ order=(0, 1),
321
+ )
322
+ d_ptrs = tl.make_block_ptr(
323
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
324
+ shape=(1, q_len),
325
+ strides=(0, stride_dn),
326
+ offsets=(0, q_lo),
327
+ block_shape=(1, BLOCK_SIZE_Q),
328
+ order=(1, 0),
329
+ )
330
+ lse_ptrs = tl.make_block_ptr(
331
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
332
+ shape=(1, q_len),
333
+ strides=(0, stride_ln),
334
+ offsets=(0, q_lo),
335
+ block_shape=(1, BLOCK_SIZE_Q),
336
+ order=(0, 1),
337
+ )
338
+ # loop for q blocks
339
+ for i in range(q_lo, q_len, BLOCK_SIZE_Q):
340
+ # load
341
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
342
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
343
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
344
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
345
+ # compute qk
346
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
347
+ qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
348
+ qk += tl.dot(k, q) * qk_scale
349
+ # compute p, ds
350
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
351
+ p = tl.exp2(qk - lse)
352
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
353
+ dp = tl.dot(v, do)
354
+ ds = sm_scale * p * (dp - d)
355
+ # cast dtype
356
+ p = p.to(do.dtype)
357
+ ds = ds.to(q.dtype)
358
+ # update dk and dv
359
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
360
+ dk += tl.dot(ds, tl.trans(q))
361
+ dv += tl.dot(p, tl.trans(do))
362
+ # increment pointers
363
+ q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
364
+ do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
365
+ lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
366
+ d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
367
+ # save dk dv
368
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
369
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
370
+
371
+
372
+ @triton.jit
373
+ def backward_dq(
374
+ q_ptr, # Q: n x qh x d
375
+ k_ptr, # K: n x kh x d
376
+ v_ptr, # V: n x kh x d
377
+ lse_ptr, # LSE: qh x n
378
+ d_ptr, # Delta: qh x n
379
+ do_ptr,
380
+ dq_ptr,
381
+ kernel_size,
382
+ kernel_stride,
383
+ # seqlens
384
+ cu_seqlens_q,
385
+ cu_seqlens_k,
386
+ # shape
387
+ NUM_KV_HEADS,
388
+ NUM_SHARE_Q_HEADS,
389
+ HEAD_DIM,
390
+ # sm_scale
391
+ sm_scale,
392
+ # stride
393
+ stride_qn,
394
+ stride_qh,
395
+ stride_qd,
396
+ stride_kn,
397
+ stride_kh,
398
+ stride_kd,
399
+ stride_vn,
400
+ stride_vh,
401
+ stride_vd,
402
+ stride_lh,
403
+ stride_ln,
404
+ stride_dh,
405
+ stride_dn,
406
+ stride_don,
407
+ stride_doh,
408
+ stride_dod,
409
+ stride_dqn,
410
+ stride_dqh,
411
+ stride_dqd,
412
+ # META parameters
413
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
414
+ BLOCK_SIZE_K: tl.constexpr, # k block size
415
+ BLOCK_SIZE_D: tl.constexpr,
416
+ ):
417
+ qk_scale = sm_scale * 1.44269504
418
+ # get batch id and head id
419
+ pid_b = tl.program_id(0)
420
+ pid_h = tl.program_id(1)
421
+ pid_q = tl.program_id(2)
422
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
423
+ # get q k start and len after rmpad
424
+ q_start = tl.load(cu_seqlens_q + pid_b)
425
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
426
+ k_start = tl.load(cu_seqlens_k + pid_b)
427
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
428
+ # skip first kernel_size query block, because they do no attend to any keys
429
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
430
+ if q_start_in_seq >= q_len:
431
+ return
432
+ # init pointers
433
+ q_ptrs = tl.make_block_ptr(
434
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
435
+ shape=(q_len, HEAD_DIM),
436
+ strides=(stride_qn, stride_qd),
437
+ offsets=(q_start_in_seq, 0),
438
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
439
+ order=(1, 0),
440
+ )
441
+ dq_ptrs = tl.make_block_ptr(
442
+ base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
443
+ shape=(q_len, HEAD_DIM),
444
+ strides=(stride_dqn, stride_dqd),
445
+ offsets=(q_start_in_seq, 0),
446
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
447
+ order=(1, 0),
448
+ )
449
+ k_ptrs = tl.make_block_ptr(
450
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
451
+ shape=(k_len, HEAD_DIM),
452
+ strides=(stride_kn, stride_kd),
453
+ offsets=(0, 0),
454
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
455
+ order=(1, 0),
456
+ )
457
+ v_ptrs = tl.make_block_ptr(
458
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
459
+ shape=(HEAD_DIM, k_len),
460
+ strides=(stride_vd, stride_vn),
461
+ offsets=(0, 0),
462
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
463
+ order=(0, 1),
464
+ )
465
+ do_ptrs = tl.make_block_ptr(
466
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
467
+ shape=(q_len, HEAD_DIM),
468
+ strides=(stride_don, stride_dod),
469
+ offsets=(q_start_in_seq, 0),
470
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
471
+ order=(1, 0),
472
+ )
473
+ d_ptrs = tl.make_block_ptr(
474
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
475
+ shape=(q_len, 1),
476
+ strides=(stride_dn, stride_dh),
477
+ offsets=(q_start_in_seq, 0),
478
+ block_shape=(BLOCK_SIZE_Q, 1),
479
+ order=(0, 1),
480
+ )
481
+ lse_ptrs = tl.make_block_ptr(
482
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
483
+ shape=(q_len, 1),
484
+ strides=(stride_ln, stride_lh),
485
+ offsets=(q_start_in_seq, 0),
486
+ block_shape=(BLOCK_SIZE_Q, 1),
487
+ order=(0, 1),
488
+ )
489
+ # offsets
490
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
491
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
492
+ # load q, do, lse, delta, and keep in SRAM
493
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
494
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
495
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
496
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
497
+ # init dq
498
+ dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
499
+ lo = 0
500
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
501
+ for i in range(lo, hi, BLOCK_SIZE_K):
502
+ # load
503
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
504
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
505
+ # compute qk
506
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
507
+ qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf"))
508
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
509
+ # compute p, ds
510
+ p = tl.exp2(qk - lse)
511
+ dp = tl.dot(do, v)
512
+ ds = sm_scale * p * (dp - d)
513
+ # cast dtype
514
+ ds = ds.to(q.dtype)
515
+ # update dq
516
+ dq += tl.dot(ds, k)
517
+ # increment pointers
518
+ k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
519
+ v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
520
+ # save dq
521
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
522
+
523
+
524
+ def _compressed_attention_fwd(
525
+ q: torch.Tensor,
526
+ k: torch.Tensor,
527
+ v: torch.Tensor,
528
+ kernel_size: int,
529
+ kernel_stride: int,
530
+ cu_seqlens_q: torch.Tensor,
531
+ cu_seqlens_k: torch.Tensor,
532
+ max_seqlen_q: torch.Tensor,
533
+ max_seqlen_k: torch.Tensor,
534
+ sm_scale: float,
535
+ ):
536
+ # dtype check
537
+ assert k.dtype == q.dtype and v.dtype == q.dtype
538
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
539
+ # shape
540
+ q_len, num_q_heads, head_dim = q.shape
541
+ k_len, num_k_heads, head_dim = k.shape
542
+ v_len, num_v_heads, head_dim = v.shape
543
+ batch_size = cu_seqlens_q.shape[0] - 1
544
+ assert k_len == v_len and q_len > k_len
545
+ # gqa
546
+ assert num_k_heads == num_v_heads
547
+ assert num_q_heads % num_k_heads == 0
548
+ num_share_q_heads = num_q_heads // num_k_heads
549
+ # output tensor
550
+ o = torch.zeros_like(q)
551
+ lse = torch.full(
552
+ (num_q_heads, q_len),
553
+ fill_value=-torch.inf,
554
+ dtype=torch.float32,
555
+ device=q.device,
556
+ )
557
+ # launch kernel
558
+ grid = lambda META: (
559
+ batch_size,
560
+ num_q_heads,
561
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
562
+ )
563
+ BLOCK_SIZE_Q = 128
564
+ BLOCK_SIZE_K = 128
565
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
566
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
567
+ forward_kernel[grid](
568
+ q,
569
+ k,
570
+ v,
571
+ o,
572
+ lse,
573
+ kernel_size,
574
+ kernel_stride,
575
+ cu_seqlens_q,
576
+ cu_seqlens_k,
577
+ num_k_heads,
578
+ num_share_q_heads,
579
+ head_dim,
580
+ sm_scale,
581
+ q.stride(0),
582
+ q.stride(1),
583
+ q.stride(2),
584
+ k.stride(0),
585
+ k.stride(1),
586
+ k.stride(2),
587
+ v.stride(0),
588
+ v.stride(1),
589
+ v.stride(2),
590
+ o.stride(0),
591
+ o.stride(1),
592
+ o.stride(2),
593
+ lse.stride(0),
594
+ lse.stride(1),
595
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
596
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
597
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
598
+ num_warps=num_warps,
599
+ num_stages=num_stages,
600
+ )
601
+ return o, lse
602
+
603
+
604
+ def _compressed_attention_bwd(
605
+ o: torch.Tensor,
606
+ do: torch.Tensor,
607
+ lse: torch.Tensor,
608
+ q: torch.Tensor,
609
+ k: torch.Tensor,
610
+ v: torch.Tensor,
611
+ kernel_size: int,
612
+ kernel_stride: int,
613
+ cu_seqlens_q: torch.Tensor,
614
+ cu_seqlens_k: torch.Tensor,
615
+ max_seqlen_q: torch.Tensor,
616
+ max_seqlen_k: torch.Tensor,
617
+ sm_scale: float,
618
+ ):
619
+ q_len, num_q_heads, head_dim = q.shape
620
+ k_len, num_k_heads, head_dim = k.shape
621
+ v_len, num_v_heads, head_dim = v.shape
622
+ o_len, num_o_heads, head_dim = o.shape
623
+ num_share_q_heads = num_q_heads // num_k_heads
624
+ # compute D
625
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
626
+ grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
627
+ BLOCK_SIZE_O = 256
628
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
629
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
630
+ backward_sum_o_do[grid](
631
+ o,
632
+ do,
633
+ delta,
634
+ o_len,
635
+ head_dim,
636
+ o.stride(0),
637
+ o.stride(1),
638
+ o.stride(2),
639
+ do.stride(0),
640
+ do.stride(1),
641
+ do.stride(2),
642
+ delta.stride(0),
643
+ delta.stride(1),
644
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
645
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
646
+ num_warps=num_warps,
647
+ num_stages=num_stages,
648
+ )
649
+ # compute dk dv
650
+ dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
651
+ dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
652
+ batch_size = cu_seqlens_q.shape[0] - 1
653
+ grid = lambda META: (
654
+ batch_size,
655
+ num_q_heads,
656
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
657
+ )
658
+ BLOCK_SIZE_Q = 64
659
+ BLOCK_SIZE_K = 128
660
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
661
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
662
+ backward_dkdv[grid](
663
+ q,
664
+ k,
665
+ v,
666
+ lse,
667
+ delta,
668
+ do,
669
+ dk,
670
+ dv,
671
+ kernel_size,
672
+ kernel_stride,
673
+ cu_seqlens_q,
674
+ cu_seqlens_k,
675
+ num_k_heads,
676
+ num_share_q_heads,
677
+ head_dim,
678
+ sm_scale,
679
+ q.stride(0),
680
+ q.stride(1),
681
+ q.stride(2),
682
+ k.stride(0),
683
+ k.stride(1),
684
+ k.stride(2),
685
+ v.stride(0),
686
+ v.stride(1),
687
+ v.stride(2),
688
+ lse.stride(0),
689
+ lse.stride(1),
690
+ delta.stride(0),
691
+ delta.stride(1),
692
+ do.stride(0),
693
+ do.stride(1),
694
+ do.stride(2),
695
+ dk.stride(0),
696
+ dk.stride(1),
697
+ dk.stride(2),
698
+ dk.stride(3),
699
+ dv.stride(0),
700
+ dv.stride(1),
701
+ dv.stride(2),
702
+ dv.stride(3),
703
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
704
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
705
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
706
+ num_warps=num_warps,
707
+ num_stages=num_stages,
708
+ )
709
+ dk = dk.sum(0)
710
+ dv = dv.sum(0)
711
+ # compute dq
712
+ dq = torch.zeros_like(q)
713
+ grid = lambda META: (
714
+ batch_size,
715
+ num_q_heads,
716
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
717
+ )
718
+ BLOCK_SIZE_Q = 128
719
+ BLOCK_SIZE_K = 64
720
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
721
+ backward_dq[grid](
722
+ q,
723
+ k,
724
+ v,
725
+ lse,
726
+ delta,
727
+ do,
728
+ dq,
729
+ kernel_size,
730
+ kernel_stride,
731
+ cu_seqlens_q,
732
+ cu_seqlens_k,
733
+ num_k_heads,
734
+ num_share_q_heads,
735
+ head_dim,
736
+ sm_scale,
737
+ q.stride(0),
738
+ q.stride(1),
739
+ q.stride(2),
740
+ k.stride(0),
741
+ k.stride(1),
742
+ k.stride(2),
743
+ v.stride(0),
744
+ v.stride(1),
745
+ v.stride(2),
746
+ lse.stride(0),
747
+ lse.stride(1),
748
+ delta.stride(0),
749
+ delta.stride(1),
750
+ do.stride(0),
751
+ do.stride(1),
752
+ do.stride(2),
753
+ dq.stride(0),
754
+ dq.stride(1),
755
+ dq.stride(2),
756
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
757
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
758
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
759
+ num_warps=num_warps,
760
+ num_stages=num_stages,
761
+ )
762
+ return dq, dk, dv
763
+
764
+
765
+ class CompressedAttention(torch.autograd.Function):
766
+ @staticmethod
767
+ def forward(
768
+ ctx,
769
+ q: torch.Tensor,
770
+ k: torch.Tensor,
771
+ v: torch.Tensor,
772
+ kernel_size: int,
773
+ kernel_stride: int,
774
+ cu_seqlens_q: torch.Tensor,
775
+ cu_seqlens_k: torch.Tensor,
776
+ max_seqlen_q: torch.Tensor,
777
+ max_seqlen_k: torch.Tensor,
778
+ sm_scale=None,
779
+ ):
780
+ # dtype check
781
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
782
+ assert q.dtype == k.dtype and k.dtype == v.dtype
783
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
784
+ # softmax scale
785
+ if sm_scale is None:
786
+ sm_scale = 1 / math.sqrt(q.shape[-1])
787
+
788
+ o, lse = _compressed_attention_fwd(
789
+ q,
790
+ k,
791
+ v,
792
+ kernel_size,
793
+ kernel_stride,
794
+ cu_seqlens_q,
795
+ cu_seqlens_k,
796
+ max_seqlen_q,
797
+ max_seqlen_k,
798
+ sm_scale,
799
+ )
800
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
801
+ ctx.sm_scale = sm_scale
802
+ ctx.max_seqlen_q = max_seqlen_q
803
+ ctx.max_seqlen_k = max_seqlen_k
804
+ ctx.kernel_size = kernel_size
805
+ ctx.kernel_stride = kernel_stride
806
+ return o, lse
807
+
808
+ @staticmethod
809
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
810
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
811
+ max_seqlen_q = ctx.max_seqlen_q
812
+ max_seqlen_k = ctx.max_seqlen_k
813
+ sm_scale = ctx.sm_scale
814
+ kernel_size = ctx.kernel_size
815
+ kernel_stride = ctx.kernel_stride
816
+
817
+ dq, dk, dv = _compressed_attention_bwd(
818
+ o,
819
+ do,
820
+ lse,
821
+ q,
822
+ k,
823
+ v,
824
+ kernel_size,
825
+ kernel_stride,
826
+ cu_seqlens_q,
827
+ cu_seqlens_k,
828
+ max_seqlen_q,
829
+ max_seqlen_k,
830
+ sm_scale,
831
+ )
832
+ return dq, dk, dv, None, None, None, None, None, None, None
833
+
834
+
835
+ @triton.jit
836
+ def score_kernel(
837
+ q_ptr,
838
+ k_ptr,
839
+ lse_ptr,
840
+ s_ptr,
841
+ kernel_size,
842
+ kernel_stride,
843
+ # seqlens
844
+ cu_seqlens_q,
845
+ cu_seqlens_k,
846
+ # shape
847
+ NUM_KV_HEADS,
848
+ NUM_SHARE_Q_HEADS,
849
+ HEAD_DIM,
850
+ # sm_scale
851
+ sm_scale,
852
+ # stride
853
+ stride_qn,
854
+ stride_qh,
855
+ stride_qd,
856
+ stride_kn,
857
+ stride_kh,
858
+ stride_kd,
859
+ stride_lh,
860
+ stride_ln,
861
+ stride_sh,
862
+ stride_sq,
863
+ stride_sk,
864
+ # META parameters
865
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
866
+ BLOCK_SIZE_K: tl.constexpr, # k block size
867
+ BLOCK_SIZE_D: tl.constexpr,
868
+ ):
869
+ qk_scale = sm_scale * 1.44269504
870
+ # get batch id and head id
871
+ pid_bkh = tl.program_id(0)
872
+ pid_b = pid_bkh // NUM_KV_HEADS
873
+ pid_kh = pid_bkh % NUM_KV_HEADS
874
+ pid_q = tl.program_id(1)
875
+ pid_k = tl.program_id(2)
876
+ # get q k start and len after rmpad
877
+ q_start = tl.load(cu_seqlens_q + pid_b)
878
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
879
+ k_start = tl.load(cu_seqlens_k + pid_b)
880
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
881
+ if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
882
+ return
883
+ # init k pointer and load k
884
+ k_ptrs = tl.make_block_ptr(
885
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
886
+ shape=(HEAD_DIM, k_len),
887
+ strides=(stride_kd, stride_kn),
888
+ offsets=(0, pid_k * BLOCK_SIZE_K),
889
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
890
+ order=(0, 1),
891
+ )
892
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
893
+ # offsets
894
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
895
+ off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
896
+ causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
897
+ # init score
898
+ s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
899
+
900
+ q_ptrs = tl.make_block_ptr(
901
+ base=q_ptr + q_start * stride_qn + pid_kh * stride_qh,
902
+ shape=(q_len, HEAD_DIM),
903
+ strides=(stride_qn, stride_qd),
904
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
905
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
906
+ order=(1, 0),
907
+ )
908
+ lse_ptrs = tl.make_block_ptr(
909
+ base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh,
910
+ shape=(q_len, 1),
911
+ strides=(stride_ln, stride_lh),
912
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
913
+ block_shape=(BLOCK_SIZE_Q, 1),
914
+ order=(0, 1),
915
+ )
916
+ # load q and lse
917
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
918
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
919
+ # compute qk
920
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
921
+ qk += tl.dot(q, k) * qk_scale
922
+ # compute score
923
+ s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
924
+ # save output
925
+ s_ptrs = tl.make_block_ptr(
926
+ base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
927
+ shape=(q_len, k_len),
928
+ strides=(stride_sq, stride_sk),
929
+ offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
930
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
931
+ order=(1, 0),
932
+ )
933
+ tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
934
+
935
+
936
+ def _get_attention_score(
937
+ q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
938
+ k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
939
+ lse: torch.Tensor, # [num_q_heads, total_query_len]
940
+ kernel_size: int,
941
+ kernel_stride: int,
942
+ cu_seqlens_q: torch.Tensor,
943
+ cu_seqlens_k: torch.Tensor,
944
+ max_seqlen_q: int,
945
+ max_seqlen_k: int,
946
+ sm_scale: float,
947
+ ) -> torch.Tensor:
948
+ # dtype check
949
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
950
+ assert q.dtype == k.dtype
951
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
952
+ assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
953
+ # shape
954
+ q_len, num_q_heads, head_dim = q.shape
955
+ k_len, num_k_heads, head_dim = k.shape
956
+ batch_size = cu_seqlens_q.shape[0] - 1
957
+ assert q_len > k_len
958
+ if sm_scale is None:
959
+ sm_scale = 1 / math.sqrt(head_dim)
960
+ # gqa
961
+ assert num_q_heads % num_k_heads == 0
962
+ num_share_q_heads = num_q_heads // num_k_heads
963
+ # init score
964
+ score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device)
965
+
966
+ # launch kernel
967
+ grid = lambda META: (
968
+ batch_size * num_k_heads,
969
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
970
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
971
+ )
972
+ BLOCK_SIZE_Q = 128
973
+ BLOCK_SIZE_K = 128
974
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
975
+
976
+ score_kernel[grid](
977
+ q,
978
+ k,
979
+ lse,
980
+ score,
981
+ kernel_size,
982
+ kernel_stride,
983
+ cu_seqlens_q,
984
+ cu_seqlens_k,
985
+ num_k_heads,
986
+ num_share_q_heads,
987
+ head_dim,
988
+ sm_scale,
989
+ q.stride(0),
990
+ q.stride(1),
991
+ q.stride(2),
992
+ k.stride(0),
993
+ k.stride(1),
994
+ k.stride(2),
995
+ lse.stride(0),
996
+ lse.stride(1),
997
+ score.stride(0),
998
+ score.stride(1),
999
+ score.stride(2),
1000
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1001
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1002
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1003
+ num_warps=8,
1004
+ num_stages=3,
1005
+ )
1006
+ return score
1007
+
1008
+
1009
+ @triton.jit
1010
+ def _transform_score_kernel(
1011
+ s_ptr, # score, shape: [num_heads, q_len, k_len]
1012
+ bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
1013
+ offs,
1014
+ cu_seqlens_q,
1015
+ # shape
1016
+ num_heads,
1017
+ num_offs,
1018
+ max_k_len,
1019
+ max_blocks,
1020
+ pad_len,
1021
+ # kernel & block size
1022
+ block_size,
1023
+ block_stride, # block_size // kernel_stride
1024
+ init_blocks,
1025
+ local_blocks,
1026
+ # stride
1027
+ stride_sh,
1028
+ stride_sq,
1029
+ stride_sk,
1030
+ stride_bsh,
1031
+ stride_bsq,
1032
+ stride_bsk,
1033
+ TOTAL_QUERY_LEN: tl.constexpr,
1034
+ BLOCK_SIZE_Q: tl.constexpr,
1035
+ BLOCK_SIZE_K: tl.constexpr,
1036
+ BLOCK_SIZE_O: tl.constexpr,
1037
+ ):
1038
+ pid_bh = tl.program_id(0)
1039
+ pid_b = pid_bh // num_heads
1040
+ pid_h = pid_bh % num_heads
1041
+ pid_q = tl.program_id(1)
1042
+ pid_k = tl.program_id(2)
1043
+ q_start = tl.load(cu_seqlens_q + pid_b)
1044
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
1045
+ k_start = pid_k * BLOCK_SIZE_K
1046
+ if pid_q * BLOCK_SIZE_Q >= q_len:
1047
+ return
1048
+ # load weight
1049
+ off_o = tl.arange(0, BLOCK_SIZE_O)
1050
+ w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
1051
+ # load score
1052
+ off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
1053
+ off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
1054
+ off_k = off_k[None, :] + off_o[:, None]
1055
+ s_ptrs = (
1056
+ s_ptr
1057
+ + q_start * stride_sq
1058
+ + pid_h * stride_sh
1059
+ + off_q[:, None, None] * stride_sq
1060
+ + off_k[None, :, :] * stride_sk
1061
+ )
1062
+ # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
1063
+ s = tl.load(
1064
+ s_ptrs,
1065
+ mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
1066
+ other=0,
1067
+ )
1068
+ s = s * w[None, :, None]
1069
+ s = tl.sum(s, axis=1)
1070
+ # init mask and local mask
1071
+ off_bq = off_q // block_size
1072
+ off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
1073
+ s = tl.where(
1074
+ ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks))
1075
+ | (off_bk[None, :] < init_blocks - k_start),
1076
+ float("inf"),
1077
+ s,
1078
+ )
1079
+ # store block wise score
1080
+ bs_ptrs = (
1081
+ bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk
1082
+ )
1083
+ tl.store(
1084
+ bs_ptrs,
1085
+ s,
1086
+ mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
1087
+ )
1088
+
1089
+
1090
+ def transform_score(
1091
+ score: torch.Tensor,
1092
+ kernel_size: int,
1093
+ kernel_stride: int,
1094
+ block_size: int,
1095
+ cu_seqlens_q: torch.Tensor,
1096
+ cu_seqlens_k: torch.Tensor,
1097
+ max_seqlen_q: int,
1098
+ max_seqlen_k: int,
1099
+ init_blocks: int = 1,
1100
+ local_blocks: int = 2,
1101
+ ) -> torch.Tensor:
1102
+ num_k_heads, total_query_len, max_key_len = score.shape
1103
+ batch_size = cu_seqlens_q.shape[0] - 1
1104
+ pad_len = kernel_size // kernel_stride - 1
1105
+ max_blocks = math.ceil(max_seqlen_q / block_size)
1106
+ block_score = torch.zeros(
1107
+ num_k_heads,
1108
+ total_query_len,
1109
+ max_blocks,
1110
+ dtype=torch.float32,
1111
+ device=score.device,
1112
+ )
1113
+ offs = (
1114
+ torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
1115
+ + torch.arange(block_size // kernel_stride, device=score.device)[None, :]
1116
+ ).view(-1)
1117
+
1118
+ offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
1119
+
1120
+ num_offs = int(offs.shape[0])
1121
+
1122
+ BLOCK_SIZE_Q = 16
1123
+ BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
1124
+ BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
1125
+
1126
+ def grid(meta):
1127
+ grid = (
1128
+ num_k_heads * batch_size,
1129
+ triton.cdiv(total_query_len, BLOCK_SIZE_Q),
1130
+ triton.cdiv(max_blocks, BLOCK_SIZE_K),
1131
+ )
1132
+ return grid
1133
+
1134
+ _transform_score_kernel[grid](
1135
+ score,
1136
+ block_score,
1137
+ offs,
1138
+ cu_seqlens_q,
1139
+ num_k_heads,
1140
+ offs.shape[0],
1141
+ max_key_len,
1142
+ max_blocks,
1143
+ pad_len,
1144
+ block_size,
1145
+ block_size // kernel_stride,
1146
+ init_blocks,
1147
+ local_blocks,
1148
+ score.stride(0),
1149
+ score.stride(1),
1150
+ score.stride(2),
1151
+ block_score.stride(0),
1152
+ block_score.stride(1),
1153
+ block_score.stride(2),
1154
+ TOTAL_QUERY_LEN=total_query_len,
1155
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1156
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1157
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
1158
+ num_warps=4,
1159
+ num_stages=3,
1160
+ )
1161
+ return block_score
1162
+
1163
+
1164
+ def compressed_attention(
1165
+ q: torch.Tensor,
1166
+ k: torch.Tensor,
1167
+ v: torch.Tensor,
1168
+ kernel_size: int,
1169
+ kernel_stride: int,
1170
+ block_size: int,
1171
+ topk: int,
1172
+ cu_seqlens_q: torch.Tensor,
1173
+ cu_seqlens_k: torch.Tensor,
1174
+ max_seqlen_q: int,
1175
+ max_seqlen_k: int,
1176
+ sm_scale: float = None,
1177
+ init_blocks: int = 1,
1178
+ local_blocks: int = 2,
1179
+ parallel_topk_compute: Union[str, bool] = False,
1180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1181
+ """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
1182
+
1183
+ Args:
1184
+ q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
1185
+ k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1186
+ v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1187
+ kernel_size (int): kernel size in compress_key_value
1188
+ kernel_stride (int): stride of compress_key_value
1189
+ block_size (int): key value block size for topk sparse attention.
1190
+ topk (int): number of blocks for each query.
1191
+ cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
1192
+ cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
1193
+ max_seqlen_q (int): max q len of the batch.
1194
+ max_seqlen_k (int): max k len of the batch.
1195
+ sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
1196
+ init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
1197
+ local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
1198
+ parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
1199
+ We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
1200
+
1201
+ Returns:
1202
+ Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
1203
+ """
1204
+
1205
+ if max_seqlen_q is None:
1206
+ max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
1207
+ if max_seqlen_k is None:
1208
+ max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
1209
+
1210
+ attn_output, lse = CompressedAttention.apply(
1211
+ q,
1212
+ k,
1213
+ v,
1214
+ kernel_size,
1215
+ kernel_stride,
1216
+ cu_seqlens_q,
1217
+ cu_seqlens_k,
1218
+ max_seqlen_q,
1219
+ max_seqlen_k,
1220
+ sm_scale,
1221
+ )
1222
+
1223
+ # do not select topk index
1224
+ if topk <= 0:
1225
+ warnings.warn("topk <= 0, returned topk_idx will be None")
1226
+ return attn_output, None
1227
+
1228
+ assert topk >= init_blocks + local_blocks
1229
+ with torch.no_grad():
1230
+ num_k_heads, num_q_heads = k.shape[1], q.shape[1]
1231
+ num_shared_q_heads = num_q_heads // num_k_heads
1232
+ batch_size = cu_seqlens_q.shape[0] - 1
1233
+ q_idx = torch.cat(
1234
+ [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)],
1235
+ dim=0,
1236
+ )
1237
+ q_idx = q_idx // block_size
1238
+
1239
+ # whether to use parallel version
1240
+ if parallel_topk_compute == "auto":
1241
+ parallel_topk_compute = cu_seqlens_q[-1] <= 32768
1242
+ # parallel version
1243
+ if parallel_topk_compute:
1244
+ # recompute score
1245
+ score = _get_attention_score(
1246
+ q,
1247
+ k,
1248
+ lse,
1249
+ kernel_size,
1250
+ kernel_stride,
1251
+ cu_seqlens_q,
1252
+ cu_seqlens_k,
1253
+ max_seqlen_q,
1254
+ max_seqlen_k,
1255
+ sm_scale,
1256
+ )
1257
+ # transform score to block-wise score
1258
+ score = transform_score(
1259
+ score,
1260
+ kernel_size,
1261
+ kernel_stride,
1262
+ block_size,
1263
+ cu_seqlens_q,
1264
+ cu_seqlens_k,
1265
+ max_seqlen_q,
1266
+ max_seqlen_k,
1267
+ init_blocks,
1268
+ local_blocks,
1269
+ )
1270
+ # get topk
1271
+ topk = min(topk, score.shape[-1])
1272
+ topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
1273
+ topk_idx[topk_idx > q_idx[None, :, None]] = -1
1274
+ topk_idx = topk_idx.to(torch.int32)
1275
+ # non parallel version, avoid some current bugs when sequence length is too long
1276
+ # FIXME: need to fix later
1277
+ else:
1278
+ topk_idx_list = []
1279
+ head_tile = 1
1280
+ assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}"
1281
+ for h in range(num_k_heads // head_tile):
1282
+ # recompute score
1283
+ score = _get_attention_score(
1284
+ q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
1285
+ k[:, h * head_tile: (h + 1) * head_tile],
1286
+ lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile],
1287
+ kernel_size,
1288
+ kernel_stride,
1289
+ cu_seqlens_q,
1290
+ cu_seqlens_k,
1291
+ max_seqlen_q,
1292
+ max_seqlen_k,
1293
+ sm_scale,
1294
+ )
1295
+ # transform score to block-wise score
1296
+ score = transform_score(
1297
+ score,
1298
+ kernel_size,
1299
+ kernel_stride,
1300
+ block_size,
1301
+ cu_seqlens_q,
1302
+ cu_seqlens_k,
1303
+ max_seqlen_q,
1304
+ max_seqlen_k,
1305
+ init_blocks,
1306
+ local_blocks,
1307
+ )
1308
+ # get topk
1309
+ topk = min(topk, score.shape[-1])
1310
+ if score.dtype == torch.float32:
1311
+ score = score.to(torch.bfloat16)
1312
+ topk_idx = score.topk(topk, dim=-1, sorted=False).indices
1313
+ topk_idx = topk_idx.sort(-1).values
1314
+
1315
+ topk_idx[topk_idx > q_idx[None, :, None]] = -1
1316
+ topk_idx = topk_idx.to(torch.int32)
1317
+ topk_idx_list.append(topk_idx)
1318
+ topk_idx = torch.cat(topk_idx_list, dim=0)
1319
+
1320
+ return attn_output, topk_idx
modeling_chatglm.py CHANGED
@@ -25,9 +25,14 @@ from transformers.generation.utils import GenerationMixin
25
 
26
  try:
27
  from .configuration_chatglm import ChatGLMConfig
28
- from .ops.pooling import mean_pooling
29
- from .ops.compressed_attention import compressed_attention
30
- from .ops.topk_sparse_attention import topk_sparse_attention
 
 
 
 
 
31
  except ImportError:
32
  from configuration_chatglm import ChatGLMConfig
33
  from ops.pooling import mean_pooling
 
25
 
26
  try:
27
  from .configuration_chatglm import ChatGLMConfig
28
+ from .pooling import mean_pooling
29
+ from .compressed_attention import compressed_attention
30
+ from .topk_sparse_attention import topk_sparse_attention
31
+ # try:
32
+ # from .configuration_chatglm import ChatGLMConfig
33
+ # from .ops.pooling import mean_pooling
34
+ # from .ops.compressed_attention import compressed_attention
35
+ # from .ops.topk_sparse_attention import topk_sparse_attention
36
  except ImportError:
37
  from configuration_chatglm import ChatGLMConfig
38
  from ops.pooling import mean_pooling
pooling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.index import prepare_chunk_indices
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BD': BD}, num_warps=num_warps)
20
+ for BD in [16, 32, 64, 128]
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def mean_pooling_fwd_kernel(
27
+ x,
28
+ o,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ D: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BD: tl.constexpr,
36
+ IS_VARLEN: tl.constexpr
37
+ ):
38
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_tg = i_t
42
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
43
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
44
+ T = eos - bos
45
+ NT = tl.cdiv(T, BT)
46
+ else:
47
+ NT = tl.cdiv(T, BT)
48
+ i_tg = i_b * NT + i_t
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
52
+ p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
53
+ # [BT, BD]
54
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
55
+ # [BD]
56
+ b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
57
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
58
+
59
+
60
+ @triton.heuristics({
61
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
62
+ })
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BD': BD}, num_warps=num_warps)
66
+ for BD in [16, 32, 64, 128]
67
+ for num_warps in [1, 2, 4, 8]
68
+ ],
69
+ key=['BT']
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def mean_pooling_bwd_kernel(
73
+ do,
74
+ dx,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ D: tl.constexpr,
80
+ BT: tl.constexpr,
81
+ BD: tl.constexpr,
82
+ IS_VARLEN: tl.constexpr
83
+ ):
84
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_tg = i_t
88
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
89
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
90
+ T = eos - bos
91
+ NT = tl.cdiv(T, BT)
92
+ else:
93
+ NT = tl.cdiv(T, BT)
94
+ i_tg = i_b * NT + i_t
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
98
+ p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
99
+ # [BD]
100
+ b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
101
+ # [BT, BD]
102
+ b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
103
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+
106
+ def mean_pooling_fwd(
107
+ x: torch.Tensor,
108
+ chunk_size: int,
109
+ cu_seqlens: Optional[torch.LongTensor] = None
110
+ ) -> torch.Tensor:
111
+ B, T, H, D = x.shape
112
+ BT = chunk_size
113
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
114
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
115
+
116
+ o = x.new_empty(B, NT, H, D)
117
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
118
+ mean_pooling_fwd_kernel[grid](
119
+ x,
120
+ o,
121
+ cu_seqlens,
122
+ chunk_indices,
123
+ T=T,
124
+ H=H,
125
+ D=D,
126
+ BT=BT,
127
+ )
128
+ return o
129
+
130
+
131
+ def mean_pooling_bwd(
132
+ do: torch.Tensor,
133
+ batch_size: int,
134
+ seq_len: int,
135
+ chunk_size: int,
136
+ cu_seqlens: Optional[torch.LongTensor] = None
137
+ ) -> torch.Tensor:
138
+ B, T, H, D = batch_size, seq_len, *do.shape[-2:]
139
+ BT = chunk_size
140
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
141
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
142
+
143
+ dx = do.new_empty(B, T, H, D)
144
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
145
+ mean_pooling_bwd_kernel[grid](
146
+ do,
147
+ dx,
148
+ cu_seqlens,
149
+ chunk_indices,
150
+ T=T,
151
+ H=H,
152
+ D=D,
153
+ BT=BT,
154
+ )
155
+ return dx
156
+
157
+
158
+ class MeanPoolingFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ @input_guard
162
+ @autocast_custom_fwd
163
+ def forward(
164
+ ctx,
165
+ x: torch.Tensor,
166
+ chunk_size: int,
167
+ cu_seqlens: Optional[torch.LongTensor] = None
168
+ ) -> torch.Tensor:
169
+ o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
170
+ ctx.batch_size = x.shape[0]
171
+ ctx.seq_len = x.shape[1]
172
+ ctx.chunk_size = chunk_size
173
+ ctx.cu_seqlens = cu_seqlens
174
+ return o
175
+
176
+ @staticmethod
177
+ @input_guard
178
+ @autocast_custom_bwd
179
+ def backward(
180
+ ctx, do
181
+ ) -> Tuple[torch.Tensor, None, None]:
182
+ batch_size = ctx.batch_size
183
+ seq_len = ctx.seq_len
184
+ chunk_size = ctx.chunk_size
185
+ cu_seqlens = ctx.cu_seqlens
186
+ dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
187
+ return dx, None, None
188
+
189
+
190
+ def mean_pooling(
191
+ x: torch.Tensor,
192
+ chunk_size: int,
193
+ cu_seqlens: Optional[torch.LongTensor] = None,
194
+ head_first: bool = False
195
+ ) -> torch.Tensor:
196
+ if head_first:
197
+ x = x.transpose(1, 2)
198
+ if cu_seqlens is not None:
199
+ if x.shape[0] != 1:
200
+ raise ValueError(
201
+ f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
202
+ f"Please flatten variable-length inputs before processing."
203
+ )
204
+ o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
205
+ if head_first:
206
+ o = o.transpose(1, 2)
207
+ return o
topk_sparse_attention.py ADDED
@@ -0,0 +1,1213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Any, Optional
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ try:
22
+ from .utils import get_num_warps_stages, is_hopper_gpu
23
+ except ImportError:
24
+ from ops.utils import get_num_warps_stages, is_hopper_gpu
25
+
26
+ IS_HOPPER_GPU = is_hopper_gpu()
27
+
28
+
29
+ @triton.jit
30
+ def forward_kernel_orig(
31
+ q_ptr, # Q: n x h x d
32
+ k_ptr, # K: n x kh x d
33
+ v_ptr, # V: n x kh x d
34
+ t_ptr, # topk_idx: kh x n x k
35
+ o_ptr, # O: n x h x d
36
+ lse_ptr, # LSE: h x n
37
+ # seqlens
38
+ cu_seqlens_q,
39
+ cu_seqlens_k,
40
+ # shape
41
+ NUM_KV_HEADS,
42
+ NUM_SHARE_Q_HEADS,
43
+ HEAD_DIM,
44
+ TOPK,
45
+ block_size,
46
+ # sm_scale
47
+ sm_scale,
48
+ # stride
49
+ stride_qn,
50
+ stride_qh,
51
+ stride_qd,
52
+ stride_kn,
53
+ stride_kh,
54
+ stride_kd,
55
+ stride_vn,
56
+ stride_vh,
57
+ stride_vd,
58
+ stride_th,
59
+ stride_tn,
60
+ stride_tk,
61
+ stride_on,
62
+ stride_oh,
63
+ stride_od,
64
+ stride_lh,
65
+ stride_ln,
66
+ # META parameters
67
+ # q loop num
68
+ num_q_loop: tl.constexpr,
69
+ num_k_loop: tl.constexpr,
70
+ MAX_SEQ_LEN: tl.constexpr,
71
+ BLOCK_SIZE_K: tl.constexpr, # k block size
72
+ BLOCK_SIZE_D: tl.constexpr,
73
+ BLOCK_SIZE_H: tl.constexpr,
74
+ BLOCK_SIZE_T: tl.constexpr,
75
+ ):
76
+ qk_scale = sm_scale * 1.44269504
77
+ # get batch id and head id
78
+ pid = tl.program_id(0)
79
+
80
+ Q = MAX_SEQ_LEN // num_q_loop
81
+ HK = NUM_KV_HEADS // num_k_loop
82
+
83
+ # 第几个 (b, kh_chunk, q_chunk)
84
+ pid_b = pid // (HK * Q)
85
+ pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head
86
+ pid_q = pid % Q
87
+
88
+ # get q k start and len after rmpad
89
+ q_start = tl.load(cu_seqlens_q + pid_b)
90
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
91
+ k_start = tl.load(cu_seqlens_k + pid_b)
92
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
93
+
94
+ if pid_q * num_q_loop >= q_len:
95
+ return
96
+ real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
97
+
98
+ for kh_offset in range(num_k_loop):
99
+ pid_kh = pid_kh_chunk * num_k_loop + kh_offset
100
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
101
+
102
+ for j in range(real_q_loop):
103
+ pid_q_j = pid_q * num_q_loop + j
104
+ # init topk idx pointer
105
+ off_t = tl.arange(0, BLOCK_SIZE_T)
106
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
107
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
108
+
109
+ """Removed causal attention, which should be:
110
+ real_topk = tl.sum(
111
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0),
112
+ axis=0,
113
+ )
114
+ """
115
+ # real_topk = tl.sum(
116
+ # tl.where((topk_idx >= 0), 1, 0),
117
+ # axis=0,
118
+ # )
119
+ real_topk = tl.sum(
120
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0),
121
+ axis=0,
122
+ )
123
+ # init qkv pointer
124
+ q_ptrs = tl.make_block_ptr(
125
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
126
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
127
+ strides=(stride_qh, stride_qd),
128
+ offsets=(0, 0),
129
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
130
+ order=(1, 0),
131
+ )
132
+ k_ptrs = tl.make_block_ptr(
133
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
134
+ shape=(HEAD_DIM, k_len),
135
+ strides=(stride_kd, stride_kn),
136
+ offsets=(0, 0),
137
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
138
+ order=(0, 1),
139
+ )
140
+ v_ptrs = tl.make_block_ptr(
141
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
142
+ shape=(k_len, HEAD_DIM),
143
+ strides=(stride_vn, stride_vd),
144
+ offsets=(0, 0),
145
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
146
+ order=(1, 0),
147
+ )
148
+ # load q
149
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
150
+ # init statistics
151
+ off_h = tl.arange(0, BLOCK_SIZE_H)
152
+ off_k = tl.arange(0, BLOCK_SIZE_K)
153
+ m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
154
+ lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
155
+ acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)
156
+ # sparse attention
157
+ for i in range(real_topk):
158
+ # get current block start index
159
+ c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
160
+ t_ptr_j = t_ptr_j + stride_tk
161
+ # load k
162
+ k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero")
163
+ # compute qk
164
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
165
+ qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
166
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
167
+ qk += tl.dot(q, k) * qk_scale
168
+ # compute m_ij and l_ij
169
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
170
+ p = tl.exp2(qk - m_ij[:, None])
171
+ l_ij = tl.sum(p, axis=1)
172
+ # scale acc_o
173
+ acc_o_scale = tl.exp2(m_i - m_ij)
174
+ acc_o = acc_o * acc_o_scale[:, None]
175
+ # load v and update acc_o
176
+ v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero")
177
+ p = p.to(v.dtype)
178
+ acc_o += tl.dot(p, v)
179
+ # update statistics
180
+ m_i = m_ij
181
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
182
+
183
+ # final scale
184
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
185
+ # save output
186
+ o_ptrs = tl.make_block_ptr(
187
+ base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh,
188
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
189
+ strides=(stride_oh, stride_od),
190
+ offsets=(0, 0),
191
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
192
+ order=(1, 0),
193
+ )
194
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
195
+ # save lse
196
+ lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
197
+ tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
198
+
199
+
200
+ @triton.jit
201
+ def backward_sum_o_do(
202
+ o_ptr, # O: n x h x d
203
+ do_ptr, # dO: n x h x d
204
+ delta_ptr, # D: h x n
205
+ o_len,
206
+ HEAD_DIM,
207
+ stride_on,
208
+ stride_oh,
209
+ stride_od,
210
+ stride_don,
211
+ stride_doh,
212
+ stride_dod,
213
+ stride_dh,
214
+ stride_dn,
215
+ BLOCK_SIZE_O: tl.constexpr,
216
+ BLOCK_SIZE_D: tl.constexpr,
217
+ ):
218
+ pid_n = tl.program_id(0)
219
+ pid_h = tl.program_id(1)
220
+ off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
221
+ off_d = tl.arange(0, BLOCK_SIZE_D)
222
+ o = tl.load(
223
+ o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od,
224
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
225
+ other=0,
226
+ ).to(tl.float32)
227
+ do = tl.load(
228
+ do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod,
229
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
230
+ other=0,
231
+ ).to(tl.float32)
232
+ delta = tl.sum(o * do, axis=1)
233
+ tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
234
+
235
+
236
+ @triton.jit
237
+ def count_kernel(
238
+ x_ptr, # [num_kv_heads, total_len, topk]
239
+ y_ptr, # [num_kv_heads, total_blocks]
240
+ cu_seqlens, # [batch_size + 1]
241
+ cu_seqblocks, # [batch_size + 1]
242
+ topk,
243
+ stride_xh,
244
+ stride_xn,
245
+ stride_xk,
246
+ stride_yh,
247
+ stride_yn,
248
+ BLOCK_SIZE_N: tl.constexpr,
249
+ BLOCK_SIZE_K: tl.constexpr,
250
+ BLOCK_SIZE_R: tl.constexpr,
251
+ ):
252
+ pid_h = tl.program_id(0)
253
+ pid_b = tl.program_id(1)
254
+ # get start and len after rmpad
255
+ seq_start = tl.load(cu_seqlens + pid_b)
256
+ seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start
257
+ blocks_start = tl.load(cu_seqblocks + pid_b)
258
+ num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start
259
+ # load x
260
+ off_k = tl.arange(0, BLOCK_SIZE_K)
261
+ off_n = tl.arange(0, BLOCK_SIZE_N)
262
+ x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn
263
+ x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk
264
+ # init y
265
+ y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32)
266
+ # loop
267
+ for i in range(0, seq_len, BLOCK_SIZE_N):
268
+ x = tl.load(
269
+ x_ptrs,
270
+ mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :],
271
+ other=-1,
272
+ )
273
+ x = tl.ravel(x)
274
+ y += tl.histogram(x, BLOCK_SIZE_R)
275
+ x_ptrs += BLOCK_SIZE_N * stride_xn
276
+ # store result
277
+ off_r = tl.arange(0, BLOCK_SIZE_R)
278
+ y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn
279
+ y_ptrs = y_ptr + off_r * stride_yn
280
+ tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks)
281
+
282
+
283
+ def count_query(
284
+ topk_idx: torch.Tensor,
285
+ cu_seqlens: torch.Tensor,
286
+ cu_seqblocks: torch.Tensor,
287
+ block_size: int,
288
+ ):
289
+ num_kv_heads, total_len, topk = topk_idx.shape
290
+ seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
291
+ seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1]
292
+ batch_size = seqlens.shape[0]
293
+ BLOCK_SIZE_K = triton.next_power_of_2(topk)
294
+ BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K)
295
+ BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2)
296
+ active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device)
297
+ grid = (num_kv_heads, batch_size)
298
+ count_kernel[grid](
299
+ topk_idx,
300
+ active_query_count,
301
+ cu_seqlens,
302
+ cu_seqblocks,
303
+ topk,
304
+ topk_idx.stride(0),
305
+ topk_idx.stride(1),
306
+ topk_idx.stride(2),
307
+ active_query_count.stride(0),
308
+ active_query_count.stride(1),
309
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
310
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
311
+ BLOCK_SIZE_R=BLOCK_SIZE_R,
312
+ num_warps=4,
313
+ num_stages=3,
314
+ )
315
+ return active_query_count
316
+
317
+
318
+ @triton.jit
319
+ def pad_topk_idx_kernel(
320
+ t_ptr,
321
+ p_ptr,
322
+ cu_seqlens,
323
+ topk,
324
+ stride_th,
325
+ stride_tn,
326
+ stride_tk,
327
+ stride_pb,
328
+ stride_ph,
329
+ stride_pn,
330
+ stride_pk,
331
+ BLOCK_SIZE_N: tl.constexpr,
332
+ BLOCK_SIZE_T: tl.constexpr,
333
+ ):
334
+ pid_b = tl.program_id(0)
335
+ pid_h = tl.program_id(1)
336
+ pid_n = tl.program_id(2)
337
+ # get q start and len after rmpad
338
+ q_start = tl.load(cu_seqlens + pid_b)
339
+ q_len = tl.load(cu_seqlens + pid_b + 1) - q_start
340
+ if BLOCK_SIZE_N * pid_n >= q_len:
341
+ return
342
+ # init prts
343
+ t_ptrs = tl.make_block_ptr(
344
+ base=t_ptr + pid_h * stride_th + q_start * stride_tn,
345
+ shape=(q_len, topk),
346
+ strides=(stride_tn, stride_tk),
347
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
348
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
349
+ order=(1, 0),
350
+ )
351
+ p_ptrs = tl.make_block_ptr(
352
+ base=p_ptr + pid_b * stride_pb + pid_h * stride_ph,
353
+ shape=(q_len, topk),
354
+ strides=(stride_pn, stride_pk),
355
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
356
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
357
+ order=(1, 0),
358
+ )
359
+ # load and save
360
+ idxs = tl.load(t_ptrs, boundary_check=(0, 1))
361
+ tl.store(p_ptrs, idxs, boundary_check=(0, 1))
362
+
363
+
364
+ @triton.jit
365
+ def save_topk_idx_kernel(
366
+ p_ptr,
367
+ t_ptr,
368
+ cu_seqblocks,
369
+ cu_topk_q_count,
370
+ n_len,
371
+ stride_pb,
372
+ stride_ph,
373
+ stride_pn,
374
+ stride_th,
375
+ stride_tn,
376
+ stride_ch,
377
+ stride_cn,
378
+ BLOCK_SIZE_N: tl.constexpr,
379
+ ):
380
+ pid_b = tl.program_id(0)
381
+ pid_h = tl.program_id(1)
382
+ pid_n = tl.program_id(2)
383
+ # get q start and len after rmpad
384
+ q_block_start = tl.load(cu_seqblocks + pid_b)
385
+ q_block_end = tl.load(cu_seqblocks + pid_b + 1)
386
+ c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn)
387
+ c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn)
388
+ c_len = c_end - c_start
389
+ if c_len <= 0:
390
+ return
391
+ if pid_n * BLOCK_SIZE_N >= c_len:
392
+ return
393
+ # init ptrs
394
+ p_ptrs = tl.make_block_ptr(
395
+ base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn,
396
+ shape=(c_len,),
397
+ strides=(stride_pn,),
398
+ offsets=(pid_n * BLOCK_SIZE_N,),
399
+ block_shape=(BLOCK_SIZE_N,),
400
+ order=(0,),
401
+ )
402
+ t_ptrs = tl.make_block_ptr(
403
+ base=t_ptr + pid_h * stride_th + c_start * stride_tn,
404
+ shape=(c_len,),
405
+ strides=(stride_tn,),
406
+ offsets=(pid_n * BLOCK_SIZE_N,),
407
+ block_shape=(BLOCK_SIZE_N,),
408
+ order=(0,),
409
+ )
410
+ # load and save
411
+ idxs = tl.load(p_ptrs, boundary_check=(0,))
412
+ tl.store(t_ptrs, idxs, boundary_check=(0,))
413
+
414
+
415
+ def reorder_topk_idx(
416
+ topk_idx: torch.Tensor,
417
+ cu_topk_q_count: torch.Tensor,
418
+ cu_seqlens: torch.Tensor,
419
+ cu_seqblocks: torch.Tensor,
420
+ block_size: int,
421
+ ):
422
+ num_kv_heads, total_len, topk = topk_idx.shape
423
+ batch_size = cu_seqlens.shape[0] - 1
424
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
425
+ max_seqlen = seq_lens.max().item()
426
+ # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk]
427
+ pad_topk_idx = torch.full(
428
+ (batch_size, num_kv_heads, max_seqlen, topk),
429
+ fill_value=-1,
430
+ device=topk_idx.device,
431
+ dtype=torch.int32,
432
+ )
433
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
434
+ BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T))
435
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N))
436
+ pad_topk_idx_kernel[grid](
437
+ topk_idx,
438
+ pad_topk_idx,
439
+ cu_seqlens,
440
+ topk,
441
+ topk_idx.stride(0),
442
+ topk_idx.stride(1),
443
+ topk_idx.stride(2),
444
+ pad_topk_idx.stride(0),
445
+ pad_topk_idx.stride(1),
446
+ pad_topk_idx.stride(2),
447
+ pad_topk_idx.stride(3),
448
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
449
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
450
+ )
451
+ # argsort
452
+ pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk
453
+ pad_topk_q_idx = pad_topk_q_idx.to(torch.int32)
454
+ # save as remove pad version
455
+ topk_q_idx = torch.full(
456
+ (num_kv_heads, cu_topk_q_count[:, -1].max().item()),
457
+ fill_value=-1,
458
+ device=topk_idx.device,
459
+ dtype=torch.int32,
460
+ )
461
+ max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item()
462
+ BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192)
463
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N))
464
+ save_topk_idx_kernel[grid](
465
+ pad_topk_q_idx,
466
+ topk_q_idx,
467
+ cu_seqblocks,
468
+ cu_topk_q_count,
469
+ pad_topk_q_idx.shape[-1],
470
+ pad_topk_q_idx.stride(0),
471
+ pad_topk_q_idx.stride(1),
472
+ pad_topk_q_idx.stride(2),
473
+ topk_q_idx.stride(0),
474
+ topk_q_idx.stride(1),
475
+ cu_topk_q_count.stride(0),
476
+ cu_topk_q_count.stride(1),
477
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
478
+ )
479
+ return topk_q_idx
480
+
481
+
482
+ @triton.jit
483
+ def backward_dkdv(
484
+ q_ptr, # Q: n x qh x d
485
+ k_ptr, # K: n x kh x d
486
+ v_ptr, # V: n x kh x d
487
+ tq_ptr, # topk_q_idx: kh x N
488
+ lse_ptr, # LSE: qh x n
489
+ d_ptr, # Delta: qh x n
490
+ do_ptr,
491
+ dk_ptr, # DK: sh x n x kh x d
492
+ dv_ptr, # DK: sh x n x kh x d
493
+ # seqlens
494
+ cu_seqlens_q, # [batch_size + 1]
495
+ cu_seqlens_k, # [batch_size + 1]
496
+ cu_seqblocks, # [batch_size + 1]
497
+ cu_topk_q_count, # [kh, total_blocks]
498
+ # shape
499
+ NUM_KV_HEADS,
500
+ NUM_SHARE_Q_HEADS,
501
+ HEAD_DIM,
502
+ TOPK,
503
+ # sm_scale
504
+ sm_scale,
505
+ # stride
506
+ stride_qn,
507
+ stride_qh,
508
+ stride_qd,
509
+ stride_kn,
510
+ stride_kh,
511
+ stride_kd,
512
+ stride_vn,
513
+ stride_vh,
514
+ stride_vd,
515
+ stride_tqh,
516
+ stride_tqn,
517
+ stride_ctqh,
518
+ stride_ctqn,
519
+ stride_lh,
520
+ stride_ln,
521
+ stride_dh,
522
+ stride_dn,
523
+ stride_don,
524
+ stride_doh,
525
+ stride_dod,
526
+ stride_dks,
527
+ stride_dkn,
528
+ stride_dkh,
529
+ stride_dkd,
530
+ stride_dvs,
531
+ stride_dvn,
532
+ stride_dvh,
533
+ stride_dvd,
534
+ # META parameters
535
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
536
+ BLOCK_SIZE_K: tl.constexpr, # k block size
537
+ BLOCK_SIZE_D: tl.constexpr,
538
+ ):
539
+ qk_scale = sm_scale * 1.44269504
540
+ # get batch id and head id
541
+ pid_b = tl.program_id(0)
542
+ pid_h = tl.program_id(1)
543
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
544
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
545
+ pid_k = tl.program_id(2)
546
+ # get q k start and len after rmpad
547
+ q_start = tl.load(cu_seqlens_q + pid_b)
548
+ tl.load(cu_seqlens_q + pid_b + 1) - q_start
549
+ k_start = tl.load(cu_seqlens_k + pid_b)
550
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
551
+ if BLOCK_SIZE_K * pid_k >= k_len:
552
+ return
553
+ # get topk_q_idx
554
+ b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence
555
+ act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn)
556
+ act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn)
557
+ act_q_len = act_q_end - act_q_start
558
+ tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn
559
+ # init pointers
560
+ k_ptrs = tl.make_block_ptr(
561
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
562
+ shape=(k_len, HEAD_DIM),
563
+ strides=(stride_kn, stride_kd),
564
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
565
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
566
+ order=(1, 0),
567
+ )
568
+ dk_ptrs = tl.make_block_ptr(
569
+ base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
570
+ shape=(k_len, HEAD_DIM),
571
+ strides=(stride_dkn, stride_dkd),
572
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
573
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
574
+ order=(1, 0),
575
+ )
576
+ v_ptrs = tl.make_block_ptr(
577
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
578
+ shape=(k_len, HEAD_DIM),
579
+ strides=(stride_vn, stride_vd),
580
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
581
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
582
+ order=(1, 0),
583
+ )
584
+ dv_ptrs = tl.make_block_ptr(
585
+ base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
586
+ shape=(k_len, HEAD_DIM),
587
+ strides=(stride_dvn, stride_dvd),
588
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
589
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
590
+ order=(1, 0),
591
+ )
592
+ # offsets
593
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
594
+ off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
595
+ off_d = tl.arange(0, BLOCK_SIZE_D)
596
+ # load k v and keep in SRAM
597
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
598
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
599
+ # init dk dv
600
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
601
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
602
+ # init ptrs
603
+ q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd
604
+ do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod
605
+ d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh
606
+ lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh
607
+ # loop for q blocks
608
+ for i in range(0, act_q_len, BLOCK_SIZE_Q):
609
+ # load
610
+ idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32)
611
+ q = tl.load(
612
+ q_ptrs + idx_q[:, None] * stride_qn,
613
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
614
+ other=0,
615
+ )
616
+ do = tl.load(
617
+ do_ptrs + idx_q[:, None] * stride_don,
618
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
619
+ other=0,
620
+ )
621
+ lse = tl.load(
622
+ lse_ptrs + idx_q[:, None] * stride_ln,
623
+ mask=(off_q < act_q_len - i)[:, None],
624
+ other=0,
625
+ )
626
+ d = tl.load(
627
+ d_ptrs + idx_q[:, None] * stride_dn,
628
+ mask=(off_q < act_q_len - i)[:, None],
629
+ other=0,
630
+ )
631
+ # compute qk
632
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
633
+ qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf"))
634
+ qk += tl.dot(q, k.T) * qk_scale
635
+ # compute p, ds
636
+ p = tl.exp2(qk - lse)
637
+ dp = tl.dot(do, v.T)
638
+ ds = sm_scale * p * (dp - d)
639
+ # cast dtype
640
+ p = p.to(do.dtype)
641
+ ds = ds.to(q.dtype)
642
+ # update dk and dv
643
+ dk += tl.dot(ds.T, q)
644
+ dv += tl.dot(p.T, do)
645
+ # save dk dv
646
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
647
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
648
+
649
+
650
+ @triton.jit
651
+ def backward_dq(
652
+ q_ptr, # Q: n x qh x d
653
+ k_ptr, # K: n x kh x d
654
+ v_ptr, # V: n x kh x d
655
+ t_ptr, # topk_idx: kh x n x k
656
+ lse_ptr, # LSE: qh x n
657
+ d_ptr, # Delta: qh x n
658
+ do_ptr,
659
+ dq_ptr,
660
+ # seqlens
661
+ cu_seqlens_q,
662
+ cu_seqlens_k,
663
+ # shape
664
+ NUM_KV_HEADS,
665
+ NUM_SHARE_Q_HEADS,
666
+ HEAD_DIM,
667
+ TOPK,
668
+ # q loop num
669
+ num_q_loop,
670
+ # sm_scale
671
+ sm_scale,
672
+ # stride
673
+ stride_qn,
674
+ stride_qh,
675
+ stride_qd,
676
+ stride_kn,
677
+ stride_kh,
678
+ stride_kd,
679
+ stride_vn,
680
+ stride_vh,
681
+ stride_vd,
682
+ stride_th,
683
+ stride_tn,
684
+ stride_tk,
685
+ stride_lh,
686
+ stride_ln,
687
+ stride_dh,
688
+ stride_dn,
689
+ stride_don,
690
+ stride_doh,
691
+ stride_dod,
692
+ stride_dqn,
693
+ stride_dqh,
694
+ stride_dqd,
695
+ # META parameters
696
+ BLOCK_SIZE_K: tl.constexpr, # k block size
697
+ BLOCK_SIZE_D: tl.constexpr,
698
+ BLOCK_SIZE_H: tl.constexpr,
699
+ BLOCK_SIZE_T: tl.constexpr,
700
+ ):
701
+ qk_scale = sm_scale * 1.44269504
702
+ # get batch id and head id
703
+ pid_b = tl.program_id(0)
704
+ pid_kh = tl.program_id(1)
705
+ pid_q = tl.program_id(2)
706
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
707
+ # get q k start and len after rmpad
708
+ q_start = tl.load(cu_seqlens_q + pid_b)
709
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
710
+ k_start = tl.load(cu_seqlens_k + pid_b)
711
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
712
+ if pid_q * num_q_loop >= q_len:
713
+ return
714
+ real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
715
+ for j in range(real_q_loop):
716
+ pid_q_j = pid_q * num_q_loop + j
717
+ # init topk idx pointer
718
+ off_t = tl.arange(0, BLOCK_SIZE_T)
719
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
720
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
721
+
722
+ real_topk = tl.sum(
723
+ tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),
724
+ axis=0,
725
+ )
726
+ # init pointers
727
+ q_ptrs = tl.make_block_ptr(
728
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
729
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
730
+ strides=(stride_qh, stride_qd),
731
+ offsets=(0, 0),
732
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
733
+ order=(1, 0),
734
+ )
735
+ dq_ptrs = tl.make_block_ptr(
736
+ base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh,
737
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
738
+ strides=(stride_dqh, stride_dqd),
739
+ offsets=(0, 0),
740
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
741
+ order=(1, 0),
742
+ )
743
+ k_ptrs = tl.make_block_ptr(
744
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
745
+ shape=(k_len, HEAD_DIM),
746
+ strides=(stride_kn, stride_kd),
747
+ offsets=(0, 0),
748
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
749
+ order=(1, 0),
750
+ )
751
+ v_ptrs = tl.make_block_ptr(
752
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
753
+ shape=(HEAD_DIM, k_len),
754
+ strides=(stride_vd, stride_vn),
755
+ offsets=(0, 0),
756
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
757
+ order=(0, 1),
758
+ )
759
+ do_ptrs = tl.make_block_ptr(
760
+ base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh,
761
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
762
+ strides=(stride_doh, stride_dod),
763
+ offsets=(0, 0),
764
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
765
+ order=(1, 0),
766
+ )
767
+ d_ptrs = tl.make_block_ptr(
768
+ base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh,
769
+ shape=(NUM_SHARE_Q_HEADS, 1),
770
+ strides=(stride_dh, stride_dn),
771
+ offsets=(0, 0),
772
+ block_shape=(BLOCK_SIZE_H, 1),
773
+ order=(1, 0),
774
+ )
775
+ lse_ptrs = tl.make_block_ptr(
776
+ base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh,
777
+ shape=(NUM_SHARE_Q_HEADS, 1),
778
+ strides=(stride_lh, stride_ln),
779
+ offsets=(0, 0),
780
+ block_shape=(BLOCK_SIZE_H, 1),
781
+ order=(1, 0),
782
+ )
783
+ # offsets
784
+ off_k = tl.arange(0, BLOCK_SIZE_K)
785
+ # load q, do, lse, delta, and keep in SRAM
786
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
787
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
788
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
789
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
790
+ # init dq
791
+ dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32)
792
+ # sparse
793
+ for i in range(real_topk):
794
+ # get current block start index
795
+ c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
796
+ t_ptr_j = t_ptr_j + stride_tk
797
+ # load
798
+ k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero")
799
+ v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero")
800
+ # compute qk
801
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
802
+ qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
803
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
804
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
805
+ # compute p, ds
806
+ p = tl.exp2(qk - lse)
807
+ dp = tl.dot(do, v)
808
+ ds = sm_scale * p * (dp - d)
809
+ # cast dtype
810
+ ds = ds.to(q.dtype)
811
+ # update dq
812
+ dq += tl.dot(ds, k)
813
+ # save dq
814
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
815
+
816
+
817
+ def _topk_sparse_attention_fwd(
818
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
819
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
820
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
821
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
822
+ block_size: int,
823
+ cu_seqlens_q: torch.Tensor,
824
+ cu_seqlens_k: torch.Tensor,
825
+ max_seqlen_q: int,
826
+ max_seqlen_k: int,
827
+ sm_scale: float,
828
+ ):
829
+ # dtype check
830
+ assert k.dtype == q.dtype and v.dtype == q.dtype
831
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
832
+ assert block_size in {32, 64, 128, 256}
833
+ # shape
834
+ q_len, num_q_heads, head_dim = q.shape
835
+ k_len, num_k_heads, head_dim = k.shape
836
+ v_len, num_v_heads, head_dim = v.shape
837
+ batch_size = cu_seqlens_q.shape[0] - 1
838
+ # assert q_len == k_len and k_len == v_len
839
+ topk = topk_idx.shape[-1]
840
+ assert topk_idx.shape[0] == num_k_heads
841
+ assert topk_idx.shape[1] == q_len
842
+ # gqa
843
+ assert num_k_heads == num_v_heads
844
+ assert num_q_heads % num_k_heads == 0
845
+ num_share_q_heads = num_q_heads // num_k_heads
846
+ # output tensor
847
+ o = torch.zeros_like(q)
848
+
849
+ lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device)
850
+
851
+ # launch kernel
852
+ num_q_loop = num_k_loop = 1
853
+ BLOCK_SIZE_K = triton.next_power_of_2(block_size)
854
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
855
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
856
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
857
+
858
+ def grid(meta):
859
+ grid = (
860
+ batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop),
861
+ )
862
+ return grid
863
+
864
+ num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU)
865
+ forward_kernel_orig[grid](
866
+ q,
867
+ k,
868
+ v,
869
+ topk_idx,
870
+ o,
871
+ lse,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ num_k_heads,
875
+ num_share_q_heads,
876
+ head_dim,
877
+ topk,
878
+ block_size,
879
+ # num_q_loop,
880
+ sm_scale,
881
+ q.stride(0),
882
+ q.stride(1),
883
+ q.stride(2),
884
+ k.stride(0),
885
+ k.stride(1),
886
+ k.stride(2),
887
+ v.stride(0),
888
+ v.stride(1),
889
+ v.stride(2),
890
+ topk_idx.stride(0),
891
+ topk_idx.stride(1),
892
+ topk_idx.stride(2),
893
+ o.stride(0),
894
+ o.stride(1),
895
+ o.stride(2),
896
+ lse.stride(0),
897
+ lse.stride(1),
898
+ num_q_loop=num_q_loop,
899
+ num_k_loop=num_k_loop,
900
+ MAX_SEQ_LEN=max_seqlen_q,
901
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
902
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
903
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
904
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
905
+ num_warps=num_warps,
906
+ num_stages=num_stages,
907
+ )
908
+ return o, lse
909
+
910
+
911
+ def _topk_sparse_attention_bwd(
912
+ o: torch.Tensor,
913
+ do: torch.Tensor,
914
+ lse: torch.Tensor,
915
+ q: torch.Tensor,
916
+ k: torch.Tensor,
917
+ v: torch.Tensor,
918
+ topk_idx: torch.Tensor,
919
+ block_size: int,
920
+ cu_seqlens_q: torch.Tensor,
921
+ cu_seqlens_k: torch.Tensor,
922
+ max_seqlen_q: int,
923
+ max_seqlen_k: int,
924
+ sm_scale: float,
925
+ ):
926
+
927
+ assert block_size in {32, 64, 128, 256}
928
+ q_len, num_q_heads, head_dim = q.shape
929
+ k_len, num_k_heads, head_dim = k.shape
930
+ v_len, num_v_heads, head_dim = v.shape
931
+ o_len, num_o_heads, head_dim = o.shape
932
+ num_share_q_heads = num_q_heads // num_k_heads
933
+ topk = topk_idx.shape[-1]
934
+ # compute D
935
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
936
+ BLOCK_SIZE_O = 256
937
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
938
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
939
+ grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads)
940
+
941
+ backward_sum_o_do[grid](
942
+ o,
943
+ do,
944
+ delta,
945
+ o_len,
946
+ head_dim,
947
+ o.stride(0),
948
+ o.stride(1),
949
+ o.stride(2),
950
+ do.stride(0),
951
+ do.stride(1),
952
+ do.stride(2),
953
+ delta.stride(0),
954
+ delta.stride(1),
955
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
956
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
957
+ num_warps=num_warps,
958
+ num_stages=num_stages,
959
+ )
960
+ # count active querys for each key block, shape: (num_k_heads, total_k_blocks)
961
+ seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
962
+ seqblocks = torch.ceil(seqlens / block_size).to(torch.int32)
963
+ cu_seqblocks = torch.cat(
964
+ [
965
+ torch.zeros(1, dtype=torch.int32, device=topk_idx.device),
966
+ torch.cumsum(seqblocks, dim=0),
967
+ ]
968
+ ).to(torch.int32)
969
+
970
+ topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size)
971
+
972
+ cu_topk_q_count = torch.cat(
973
+ [
974
+ torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device),
975
+ torch.cumsum(topk_q_count, dim=-1),
976
+ ],
977
+ dim=-1,
978
+ ).to(torch.int32)
979
+ # active query idx for each key block
980
+ # how to get active query idx for sequence b, head h, kv block i?
981
+ topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size)
982
+ # compute dk dv
983
+ dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
984
+ dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
985
+ batch_size = cu_seqlens_q.shape[0] - 1
986
+ BLOCK_SIZE_K = triton.next_power_of_2(block_size)
987
+ BLOCK_SIZE_Q = 64
988
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
989
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
990
+ grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K))
991
+ backward_dkdv[grid](
992
+ q,
993
+ k,
994
+ v,
995
+ topk_q_idx,
996
+ lse,
997
+ delta,
998
+ do,
999
+ dk,
1000
+ dv,
1001
+ cu_seqlens_q,
1002
+ cu_seqlens_k,
1003
+ cu_seqblocks,
1004
+ cu_topk_q_count,
1005
+ num_k_heads,
1006
+ num_share_q_heads,
1007
+ head_dim,
1008
+ topk,
1009
+ sm_scale,
1010
+ q.stride(0),
1011
+ q.stride(1),
1012
+ q.stride(2),
1013
+ k.stride(0),
1014
+ k.stride(1),
1015
+ k.stride(2),
1016
+ v.stride(0),
1017
+ v.stride(1),
1018
+ v.stride(2),
1019
+ topk_q_idx.stride(0),
1020
+ topk_q_idx.stride(1),
1021
+ cu_topk_q_count.stride(0),
1022
+ cu_topk_q_count.stride(1),
1023
+ lse.stride(0),
1024
+ lse.stride(1),
1025
+ delta.stride(0),
1026
+ delta.stride(1),
1027
+ do.stride(0),
1028
+ do.stride(1),
1029
+ do.stride(2),
1030
+ dk.stride(0),
1031
+ dk.stride(1),
1032
+ dk.stride(2),
1033
+ dk.stride(3),
1034
+ dv.stride(0),
1035
+ dv.stride(1),
1036
+ dv.stride(2),
1037
+ dv.stride(3),
1038
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1039
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1040
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1041
+ num_warps=num_warps,
1042
+ num_stages=num_stages,
1043
+ )
1044
+ dk = dk.sum(0)
1045
+ dv = dv.sum(0)
1046
+ # compute dq
1047
+ dq = torch.zeros_like(q)
1048
+ num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long
1049
+ grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
1050
+ BLOCK_SIZE_K = block_size
1051
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
1052
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
1053
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
1054
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
1055
+
1056
+ backward_dq[grid](
1057
+ q,
1058
+ k,
1059
+ v,
1060
+ topk_idx,
1061
+ lse,
1062
+ delta,
1063
+ do,
1064
+ dq,
1065
+ cu_seqlens_q,
1066
+ cu_seqlens_k,
1067
+ num_k_heads,
1068
+ num_share_q_heads,
1069
+ head_dim,
1070
+ topk,
1071
+ num_q_loop,
1072
+ sm_scale,
1073
+ q.stride(0),
1074
+ q.stride(1),
1075
+ q.stride(2),
1076
+ k.stride(0),
1077
+ k.stride(1),
1078
+ k.stride(2),
1079
+ v.stride(0),
1080
+ v.stride(1),
1081
+ v.stride(2),
1082
+ topk_idx.stride(0),
1083
+ topk_idx.stride(1),
1084
+ topk_idx.stride(2),
1085
+ lse.stride(0),
1086
+ lse.stride(1),
1087
+ delta.stride(0),
1088
+ delta.stride(1),
1089
+ do.stride(0),
1090
+ do.stride(1),
1091
+ do.stride(2),
1092
+ dq.stride(0),
1093
+ dq.stride(1),
1094
+ dq.stride(2),
1095
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1096
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1097
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
1098
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
1099
+ num_warps=num_warps,
1100
+ num_stages=num_stages,
1101
+ )
1102
+ return dq, dk, dv
1103
+
1104
+
1105
+ class TopkSparseAttention(torch.autograd.Function):
1106
+ @staticmethod
1107
+ def forward(
1108
+ ctx,
1109
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
1110
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
1111
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
1112
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
1113
+ block_size: int,
1114
+ cu_seqlens_q: torch.Tensor,
1115
+ cu_seqlens_k: torch.Tensor,
1116
+ max_seqlen_q: torch.Tensor,
1117
+ max_seqlen_k: torch.Tensor,
1118
+ sm_scale=None,
1119
+ ):
1120
+ # dtype check
1121
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
1122
+ assert q.dtype == k.dtype and k.dtype == v.dtype
1123
+ assert topk_idx.dtype == torch.int32
1124
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
1125
+ # softmax scale
1126
+ if sm_scale is None:
1127
+ sm_scale = 1 / math.sqrt(q.shape[-1])
1128
+
1129
+ o, lse = _topk_sparse_attention_fwd(
1130
+ q,
1131
+ k,
1132
+ v,
1133
+ topk_idx,
1134
+ block_size,
1135
+ cu_seqlens_q,
1136
+ cu_seqlens_k,
1137
+ max_seqlen_q,
1138
+ max_seqlen_k,
1139
+ sm_scale,
1140
+ )
1141
+
1142
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx)
1143
+ ctx.sm_scale = sm_scale
1144
+ ctx.max_seqlen_q = max_seqlen_q
1145
+ ctx.max_seqlen_k = max_seqlen_k
1146
+ ctx.block_size = block_size
1147
+ return o
1148
+
1149
+ @staticmethod
1150
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
1151
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors
1152
+
1153
+ max_seqlen_q = ctx.max_seqlen_q
1154
+ max_seqlen_k = ctx.max_seqlen_k
1155
+ sm_scale = ctx.sm_scale
1156
+ block_size = ctx.block_size
1157
+ assert block_size in {32, 64, 128, 256}
1158
+
1159
+ dq, dk, dv = _topk_sparse_attention_bwd(
1160
+ o,
1161
+ do,
1162
+ lse,
1163
+ q,
1164
+ k,
1165
+ v,
1166
+ topk_idx,
1167
+ block_size,
1168
+ cu_seqlens_q,
1169
+ cu_seqlens_k,
1170
+ max_seqlen_q,
1171
+ max_seqlen_k,
1172
+ sm_scale,
1173
+ )
1174
+ return dq, dk, dv, None, None, None, None, None, None, None, None
1175
+
1176
+
1177
+ def topk_sparse_attention(
1178
+ q: torch.Tensor,
1179
+ k: torch.Tensor,
1180
+ v: torch.Tensor,
1181
+ topk_idx: torch.Tensor,
1182
+ block_size: int,
1183
+ cu_seqlens: torch.Tensor,
1184
+ softmax_scale: Optional[float] = None,
1185
+ ) -> torch.Tensor:
1186
+ """Topk sparse attention varlen version implemented in triton.
1187
+
1188
+ Args:
1189
+ q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
1190
+ k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
1191
+ v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
1192
+ topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
1193
+ block_size (int): key value block size.
1194
+ cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
1195
+ softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
1196
+
1197
+ Returns:
1198
+ torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
1199
+ """
1200
+
1201
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
1202
+ return TopkSparseAttention.apply(
1203
+ q,
1204
+ k,
1205
+ v,
1206
+ topk_idx,
1207
+ block_size,
1208
+ cu_seqlens,
1209
+ cu_seqlens,
1210
+ max_seqlen,
1211
+ max_seqlen,
1212
+ softmax_scale,
1213
+ )
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def is_hopper_gpu():
5
+ if torch.cuda.is_available():
6
+ device_capability = torch.cuda.get_device_capability(0)
7
+ major, minor = device_capability
8
+ return major == 9
9
+ return False
10
+
11
+
12
+ def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
13
+ """
14
+ Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
15
+
16
+ Args:
17
+ head_dim (int): Size of the head dimension.
18
+ block_size (int): Size of the block in the attention matrix.
19
+ is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
20
+
21
+ Returns:
22
+ tuple: (num_warps, num_stages) recommended values.
23
+ """
24
+ # Determine if head_dim and block_size exceed 64
25
+ head_large = head_dim > 64
26
+ block_large = block_size > 64
27
+
28
+ if is_hopper_gpu:
29
+ # Hopper GPU recommendations
30
+ if head_large and block_large:
31
+ num_warps = 8
32
+ num_stages = 3
33
+ elif head_large or block_large:
34
+ num_warps = 4
35
+ num_stages = 3
36
+ else:
37
+ num_warps = 2
38
+ num_stages = 2
39
+ else:
40
+ # Ampere GPU recommendations
41
+ if head_large and block_large:
42
+ num_warps = 8
43
+ num_stages = 3
44
+ elif head_large or block_large:
45
+ num_warps = 8
46
+ num_stages = 3
47
+ else:
48
+ num_warps = 2
49
+ num_stages = 2
50
+ return num_warps, num_stages