Commit
·
4ee9d9e
1
Parent(s):
22ba83b
Changed to autotune triton for 48G GPU deployment
Browse files- compressed_attention.py +40 -15
- topk_sparse_attention.py +38 -15
compressed_attention.py
CHANGED
|
@@ -26,7 +26,13 @@ except ImportError:
|
|
| 26 |
|
| 27 |
IS_HOPPER_GPU = is_hopper_gpu()
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
@triton.jit
|
| 31 |
def forward_kernel(
|
| 32 |
q_ptr, # Q: n x h x d
|
|
@@ -159,6 +165,13 @@ def forward_kernel(
|
|
| 159 |
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
|
| 160 |
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
@triton.jit
|
| 163 |
def backward_sum_o_do(
|
| 164 |
o_ptr, # O: n x h x d
|
|
@@ -194,7 +207,13 @@ def backward_sum_o_do(
|
|
| 194 |
delta = tl.sum(o * do, axis=1)
|
| 195 |
tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
|
| 196 |
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
@triton.jit
|
| 199 |
def backward_dkdv(
|
| 200 |
q_ptr, # Q: n x qh x d
|
|
@@ -368,7 +387,13 @@ def backward_dkdv(
|
|
| 368 |
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 369 |
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 370 |
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
@triton.jit
|
| 373 |
def backward_dq(
|
| 374 |
q_ptr, # Q: n x qh x d
|
|
@@ -595,8 +620,8 @@ def _compressed_attention_fwd(
|
|
| 595 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 596 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 597 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 598 |
-
num_warps=num_warps,
|
| 599 |
-
num_stages=num_stages,
|
| 600 |
)
|
| 601 |
return o, lse
|
| 602 |
|
|
@@ -643,8 +668,8 @@ def _compressed_attention_bwd(
|
|
| 643 |
delta.stride(1),
|
| 644 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 645 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 646 |
-
num_warps=num_warps,
|
| 647 |
-
num_stages=num_stages,
|
| 648 |
)
|
| 649 |
# compute dk dv
|
| 650 |
dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
|
|
@@ -703,8 +728,8 @@ def _compressed_attention_bwd(
|
|
| 703 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 704 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 705 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 706 |
-
num_warps=num_warps,
|
| 707 |
-
num_stages=num_stages,
|
| 708 |
)
|
| 709 |
dk = dk.sum(0)
|
| 710 |
dv = dv.sum(0)
|
|
@@ -756,8 +781,8 @@ def _compressed_attention_bwd(
|
|
| 756 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 757 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 758 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 759 |
-
num_warps=num_warps,
|
| 760 |
-
num_stages=num_stages,
|
| 761 |
)
|
| 762 |
return dq, dk, dv
|
| 763 |
|
|
@@ -1000,8 +1025,8 @@ def _get_attention_score(
|
|
| 1000 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1001 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1002 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1003 |
-
num_warps=8,
|
| 1004 |
-
num_stages=3,
|
| 1005 |
)
|
| 1006 |
return score
|
| 1007 |
|
|
@@ -1155,8 +1180,8 @@ def transform_score(
|
|
| 1155 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1156 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1157 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 1158 |
-
num_warps=4,
|
| 1159 |
-
num_stages=3,
|
| 1160 |
)
|
| 1161 |
return block_score
|
| 1162 |
|
|
|
|
| 26 |
|
| 27 |
IS_HOPPER_GPU = is_hopper_gpu()
|
| 28 |
|
| 29 |
+
@triton.autotune(
|
| 30 |
+
configs=[
|
| 31 |
+
triton.Config({}, num_warps=num_warps)
|
| 32 |
+
for num_warps in [1, 2, 4, 8]
|
| 33 |
+
],
|
| 34 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_V'],
|
| 35 |
+
)
|
| 36 |
@triton.jit
|
| 37 |
def forward_kernel(
|
| 38 |
q_ptr, # Q: n x h x d
|
|
|
|
| 165 |
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
|
| 166 |
|
| 167 |
|
| 168 |
+
@triton.autotune(
|
| 169 |
+
configs=[
|
| 170 |
+
triton.Config({}, num_warps=num_warps)
|
| 171 |
+
for num_warps in [1, 2, 4, 8]
|
| 172 |
+
],
|
| 173 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'],
|
| 174 |
+
)
|
| 175 |
@triton.jit
|
| 176 |
def backward_sum_o_do(
|
| 177 |
o_ptr, # O: n x h x d
|
|
|
|
| 207 |
delta = tl.sum(o * do, axis=1)
|
| 208 |
tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
|
| 209 |
|
| 210 |
+
@triton.autotune(
|
| 211 |
+
configs=[
|
| 212 |
+
triton.Config({}, num_warps=num_warps)
|
| 213 |
+
for num_warps in [1, 2, 4, 8]
|
| 214 |
+
],
|
| 215 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
|
| 216 |
+
)
|
| 217 |
@triton.jit
|
| 218 |
def backward_dkdv(
|
| 219 |
q_ptr, # Q: n x qh x d
|
|
|
|
| 387 |
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 388 |
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 389 |
|
| 390 |
+
@triton.autotune(
|
| 391 |
+
configs=[
|
| 392 |
+
triton.Config({}, num_warps=num_warps)
|
| 393 |
+
for num_warps in [1, 2, 4, 8]
|
| 394 |
+
],
|
| 395 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
|
| 396 |
+
)
|
| 397 |
@triton.jit
|
| 398 |
def backward_dq(
|
| 399 |
q_ptr, # Q: n x qh x d
|
|
|
|
| 620 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 621 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 622 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 623 |
+
# num_warps=num_warps,
|
| 624 |
+
# num_stages=num_stages,
|
| 625 |
)
|
| 626 |
return o, lse
|
| 627 |
|
|
|
|
| 668 |
delta.stride(1),
|
| 669 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 670 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 671 |
+
# num_warps=num_warps,
|
| 672 |
+
# num_stages=num_stages,
|
| 673 |
)
|
| 674 |
# compute dk dv
|
| 675 |
dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
|
|
|
|
| 728 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 729 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 730 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 731 |
+
# num_warps=num_warps,
|
| 732 |
+
# num_stages=num_stages,
|
| 733 |
)
|
| 734 |
dk = dk.sum(0)
|
| 735 |
dv = dv.sum(0)
|
|
|
|
| 781 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 782 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 783 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 784 |
+
# num_warps=num_warps,
|
| 785 |
+
# num_stages=num_stages,
|
| 786 |
)
|
| 787 |
return dq, dk, dv
|
| 788 |
|
|
|
|
| 1025 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1026 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1027 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1028 |
+
# num_warps=8,
|
| 1029 |
+
# num_stages=3,
|
| 1030 |
)
|
| 1031 |
return score
|
| 1032 |
|
|
|
|
| 1180 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1181 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1182 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 1183 |
+
# num_warps=4,
|
| 1184 |
+
# num_stages=3,
|
| 1185 |
)
|
| 1186 |
return block_score
|
| 1187 |
|
topk_sparse_attention.py
CHANGED
|
@@ -25,7 +25,10 @@ except ImportError:
|
|
| 25 |
|
| 26 |
IS_HOPPER_GPU = is_hopper_gpu()
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
@triton.jit
|
| 30 |
def forward_kernel_orig(
|
| 31 |
q_ptr, # Q: n x h x d
|
|
@@ -196,7 +199,10 @@ def forward_kernel_orig(
|
|
| 196 |
lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
|
| 197 |
tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
|
| 198 |
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
@triton.jit
|
| 201 |
def backward_sum_o_do(
|
| 202 |
o_ptr, # O: n x h x d
|
|
@@ -233,6 +239,10 @@ def backward_sum_o_do(
|
|
| 233 |
tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
@triton.jit
|
| 237 |
def count_kernel(
|
| 238 |
x_ptr, # [num_kv_heads, total_len, topk]
|
|
@@ -309,12 +319,16 @@ def count_query(
|
|
| 309 |
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 310 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 311 |
BLOCK_SIZE_R=BLOCK_SIZE_R,
|
| 312 |
-
num_warps=4,
|
| 313 |
-
num_stages=3,
|
| 314 |
)
|
| 315 |
return active_query_count
|
| 316 |
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
@triton.jit
|
| 319 |
def pad_topk_idx_kernel(
|
| 320 |
t_ptr,
|
|
@@ -360,7 +374,10 @@ def pad_topk_idx_kernel(
|
|
| 360 |
idxs = tl.load(t_ptrs, boundary_check=(0, 1))
|
| 361 |
tl.store(p_ptrs, idxs, boundary_check=(0, 1))
|
| 362 |
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
| 364 |
@triton.jit
|
| 365 |
def save_topk_idx_kernel(
|
| 366 |
p_ptr,
|
|
@@ -478,7 +495,10 @@ def reorder_topk_idx(
|
|
| 478 |
)
|
| 479 |
return topk_q_idx
|
| 480 |
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
| 482 |
@triton.jit
|
| 483 |
def backward_dkdv(
|
| 484 |
q_ptr, # Q: n x qh x d
|
|
@@ -646,7 +666,10 @@ def backward_dkdv(
|
|
| 646 |
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 647 |
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 648 |
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
| 650 |
@triton.jit
|
| 651 |
def backward_dq(
|
| 652 |
q_ptr, # Q: n x qh x d
|
|
@@ -902,8 +925,8 @@ def _topk_sparse_attention_fwd(
|
|
| 902 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 903 |
BLOCK_SIZE_H=BLOCK_SIZE_H,
|
| 904 |
BLOCK_SIZE_T=BLOCK_SIZE_T,
|
| 905 |
-
num_warps=num_warps,
|
| 906 |
-
num_stages=num_stages,
|
| 907 |
)
|
| 908 |
return o, lse
|
| 909 |
|
|
@@ -954,8 +977,8 @@ def _topk_sparse_attention_bwd(
|
|
| 954 |
delta.stride(1),
|
| 955 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 956 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 957 |
-
num_warps=num_warps,
|
| 958 |
-
num_stages=num_stages,
|
| 959 |
)
|
| 960 |
# count active querys for each key block, shape: (num_k_heads, total_k_blocks)
|
| 961 |
seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
|
@@ -1038,8 +1061,8 @@ def _topk_sparse_attention_bwd(
|
|
| 1038 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1039 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1040 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1041 |
-
num_warps=num_warps,
|
| 1042 |
-
num_stages=num_stages,
|
| 1043 |
)
|
| 1044 |
dk = dk.sum(0)
|
| 1045 |
dv = dv.sum(0)
|
|
@@ -1096,8 +1119,8 @@ def _topk_sparse_attention_bwd(
|
|
| 1096 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1097 |
BLOCK_SIZE_H=BLOCK_SIZE_H,
|
| 1098 |
BLOCK_SIZE_T=BLOCK_SIZE_T,
|
| 1099 |
-
num_warps=num_warps,
|
| 1100 |
-
num_stages=num_stages,
|
| 1101 |
)
|
| 1102 |
return dq, dk, dv
|
| 1103 |
|
|
|
|
| 25 |
|
| 26 |
IS_HOPPER_GPU = is_hopper_gpu()
|
| 27 |
|
| 28 |
+
@triton.autotune(
|
| 29 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 30 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'],
|
| 31 |
+
)
|
| 32 |
@triton.jit
|
| 33 |
def forward_kernel_orig(
|
| 34 |
q_ptr, # Q: n x h x d
|
|
|
|
| 199 |
lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
|
| 200 |
tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
|
| 201 |
|
| 202 |
+
@triton.autotune(
|
| 203 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 204 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'],
|
| 205 |
+
)
|
| 206 |
@triton.jit
|
| 207 |
def backward_sum_o_do(
|
| 208 |
o_ptr, # O: n x h x d
|
|
|
|
| 239 |
tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
|
| 240 |
|
| 241 |
|
| 242 |
+
@triton.autotune(
|
| 243 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 244 |
+
key=['BLOCK_SIZE_N', 'BLOCK_SIZE_K', 'BLOCK_SIZE_R'],
|
| 245 |
+
)
|
| 246 |
@triton.jit
|
| 247 |
def count_kernel(
|
| 248 |
x_ptr, # [num_kv_heads, total_len, topk]
|
|
|
|
| 319 |
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 320 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 321 |
BLOCK_SIZE_R=BLOCK_SIZE_R,
|
| 322 |
+
# num_warps=4,
|
| 323 |
+
# num_stages=3,
|
| 324 |
)
|
| 325 |
return active_query_count
|
| 326 |
|
| 327 |
|
| 328 |
+
@triton.autotune(
|
| 329 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 330 |
+
key=['topk', 'BLOCK_SIZE_N', 'BLOCK_SIZE_T'],
|
| 331 |
+
)
|
| 332 |
@triton.jit
|
| 333 |
def pad_topk_idx_kernel(
|
| 334 |
t_ptr,
|
|
|
|
| 374 |
idxs = tl.load(t_ptrs, boundary_check=(0, 1))
|
| 375 |
tl.store(p_ptrs, idxs, boundary_check=(0, 1))
|
| 376 |
|
| 377 |
+
@triton.autotune(
|
| 378 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 379 |
+
key=['BLOCK_SIZE_N'],
|
| 380 |
+
)
|
| 381 |
@triton.jit
|
| 382 |
def save_topk_idx_kernel(
|
| 383 |
p_ptr,
|
|
|
|
| 495 |
)
|
| 496 |
return topk_q_idx
|
| 497 |
|
| 498 |
+
@triton.autotune(
|
| 499 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 500 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
|
| 501 |
+
)
|
| 502 |
@triton.jit
|
| 503 |
def backward_dkdv(
|
| 504 |
q_ptr, # Q: n x qh x d
|
|
|
|
| 666 |
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 667 |
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 668 |
|
| 669 |
+
@triton.autotune(
|
| 670 |
+
configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
|
| 671 |
+
key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'],
|
| 672 |
+
)
|
| 673 |
@triton.jit
|
| 674 |
def backward_dq(
|
| 675 |
q_ptr, # Q: n x qh x d
|
|
|
|
| 925 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 926 |
BLOCK_SIZE_H=BLOCK_SIZE_H,
|
| 927 |
BLOCK_SIZE_T=BLOCK_SIZE_T,
|
| 928 |
+
# num_warps=num_warps,
|
| 929 |
+
# num_stages=num_stages,
|
| 930 |
)
|
| 931 |
return o, lse
|
| 932 |
|
|
|
|
| 977 |
delta.stride(1),
|
| 978 |
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
| 979 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 980 |
+
# num_warps=num_warps,
|
| 981 |
+
# num_stages=num_stages,
|
| 982 |
)
|
| 983 |
# count active querys for each key block, shape: (num_k_heads, total_k_blocks)
|
| 984 |
seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
|
|
|
| 1061 |
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
| 1062 |
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 1063 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1064 |
+
# num_warps=num_warps,
|
| 1065 |
+
# num_stages=num_stages,
|
| 1066 |
)
|
| 1067 |
dk = dk.sum(0)
|
| 1068 |
dv = dv.sum(0)
|
|
|
|
| 1119 |
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
| 1120 |
BLOCK_SIZE_H=BLOCK_SIZE_H,
|
| 1121 |
BLOCK_SIZE_T=BLOCK_SIZE_T,
|
| 1122 |
+
# num_warps=num_warps,
|
| 1123 |
+
# num_stages=num_stages,
|
| 1124 |
)
|
| 1125 |
return dq, dk, dv
|
| 1126 |
|