Maxtimer97 commited on
Commit
4ee9d9e
·
1 Parent(s): 22ba83b

Changed to autotune triton for 48G GPU deployment

Browse files
Files changed (2) hide show
  1. compressed_attention.py +40 -15
  2. topk_sparse_attention.py +38 -15
compressed_attention.py CHANGED
@@ -26,7 +26,13 @@ except ImportError:
26
 
27
  IS_HOPPER_GPU = is_hopper_gpu()
28
 
29
-
 
 
 
 
 
 
30
  @triton.jit
31
  def forward_kernel(
32
  q_ptr, # Q: n x h x d
@@ -159,6 +165,13 @@ def forward_kernel(
159
  tl.store(l_ptrs, lse_i, mask=off_q < q_len)
160
 
161
 
 
 
 
 
 
 
 
162
  @triton.jit
163
  def backward_sum_o_do(
164
  o_ptr, # O: n x h x d
@@ -194,7 +207,13 @@ def backward_sum_o_do(
194
  delta = tl.sum(o * do, axis=1)
195
  tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
196
 
197
-
 
 
 
 
 
 
198
  @triton.jit
199
  def backward_dkdv(
200
  q_ptr, # Q: n x qh x d
@@ -368,7 +387,13 @@ def backward_dkdv(
368
  tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
369
  tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
370
 
371
-
 
 
 
 
 
 
372
  @triton.jit
373
  def backward_dq(
374
  q_ptr, # Q: n x qh x d
@@ -595,8 +620,8 @@ def _compressed_attention_fwd(
595
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
596
  BLOCK_SIZE_K=BLOCK_SIZE_K,
597
  BLOCK_SIZE_D=BLOCK_SIZE_D,
598
- num_warps=num_warps,
599
- num_stages=num_stages,
600
  )
601
  return o, lse
602
 
@@ -643,8 +668,8 @@ def _compressed_attention_bwd(
643
  delta.stride(1),
644
  BLOCK_SIZE_O=BLOCK_SIZE_O,
645
  BLOCK_SIZE_D=BLOCK_SIZE_D,
646
- num_warps=num_warps,
647
- num_stages=num_stages,
648
  )
649
  # compute dk dv
650
  dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
@@ -703,8 +728,8 @@ def _compressed_attention_bwd(
703
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
704
  BLOCK_SIZE_K=BLOCK_SIZE_K,
705
  BLOCK_SIZE_D=BLOCK_SIZE_D,
706
- num_warps=num_warps,
707
- num_stages=num_stages,
708
  )
709
  dk = dk.sum(0)
710
  dv = dv.sum(0)
@@ -756,8 +781,8 @@ def _compressed_attention_bwd(
756
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
757
  BLOCK_SIZE_K=BLOCK_SIZE_K,
758
  BLOCK_SIZE_D=BLOCK_SIZE_D,
759
- num_warps=num_warps,
760
- num_stages=num_stages,
761
  )
762
  return dq, dk, dv
763
 
@@ -1000,8 +1025,8 @@ def _get_attention_score(
1000
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1001
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1002
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1003
- num_warps=8,
1004
- num_stages=3,
1005
  )
1006
  return score
1007
 
@@ -1155,8 +1180,8 @@ def transform_score(
1155
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1156
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1157
  BLOCK_SIZE_O=BLOCK_SIZE_O,
1158
- num_warps=4,
1159
- num_stages=3,
1160
  )
1161
  return block_score
1162
 
 
26
 
27
  IS_HOPPER_GPU = is_hopper_gpu()
28
 
29
+ @triton.autotune(
30
+ configs=[
31
+ triton.Config({}, num_warps=num_warps)
32
+ for num_warps in [1, 2, 4, 8]
33
+ ],
34
+ key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_V'],
35
+ )
36
  @triton.jit
37
  def forward_kernel(
38
  q_ptr, # Q: n x h x d
 
165
  tl.store(l_ptrs, lse_i, mask=off_q < q_len)
166
 
167
 
168
+ @triton.autotune(
169
+ configs=[
170
+ triton.Config({}, num_warps=num_warps)
171
+ for num_warps in [1, 2, 4, 8]
172
+ ],
173
+ key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'],
174
+ )
175
  @triton.jit
176
  def backward_sum_o_do(
177
  o_ptr, # O: n x h x d
 
207
  delta = tl.sum(o * do, axis=1)
208
  tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len)
209
 
210
+ @triton.autotune(
211
+ configs=[
212
+ triton.Config({}, num_warps=num_warps)
213
+ for num_warps in [1, 2, 4, 8]
214
+ ],
215
+ key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
216
+ )
217
  @triton.jit
218
  def backward_dkdv(
219
  q_ptr, # Q: n x qh x d
 
387
  tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
388
  tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
389
 
390
+ @triton.autotune(
391
+ configs=[
392
+ triton.Config({}, num_warps=num_warps)
393
+ for num_warps in [1, 2, 4, 8]
394
+ ],
395
+ key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
396
+ )
397
  @triton.jit
398
  def backward_dq(
399
  q_ptr, # Q: n x qh x d
 
620
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
621
  BLOCK_SIZE_K=BLOCK_SIZE_K,
622
  BLOCK_SIZE_D=BLOCK_SIZE_D,
623
+ # num_warps=num_warps,
624
+ # num_stages=num_stages,
625
  )
626
  return o, lse
627
 
 
668
  delta.stride(1),
669
  BLOCK_SIZE_O=BLOCK_SIZE_O,
670
  BLOCK_SIZE_D=BLOCK_SIZE_D,
671
+ # num_warps=num_warps,
672
+ # num_stages=num_stages,
673
  )
674
  # compute dk dv
675
  dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype)
 
728
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
729
  BLOCK_SIZE_K=BLOCK_SIZE_K,
730
  BLOCK_SIZE_D=BLOCK_SIZE_D,
731
+ # num_warps=num_warps,
732
+ # num_stages=num_stages,
733
  )
734
  dk = dk.sum(0)
735
  dv = dv.sum(0)
 
781
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
782
  BLOCK_SIZE_K=BLOCK_SIZE_K,
783
  BLOCK_SIZE_D=BLOCK_SIZE_D,
784
+ # num_warps=num_warps,
785
+ # num_stages=num_stages,
786
  )
787
  return dq, dk, dv
788
 
 
1025
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1026
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1027
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1028
+ # num_warps=8,
1029
+ # num_stages=3,
1030
  )
1031
  return score
1032
 
 
1180
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1181
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1182
  BLOCK_SIZE_O=BLOCK_SIZE_O,
1183
+ # num_warps=4,
1184
+ # num_stages=3,
1185
  )
1186
  return block_score
1187
 
topk_sparse_attention.py CHANGED
@@ -25,7 +25,10 @@ except ImportError:
25
 
26
  IS_HOPPER_GPU = is_hopper_gpu()
27
 
28
-
 
 
 
29
  @triton.jit
30
  def forward_kernel_orig(
31
  q_ptr, # Q: n x h x d
@@ -196,7 +199,10 @@ def forward_kernel_orig(
196
  lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
197
  tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
198
 
199
-
 
 
 
200
  @triton.jit
201
  def backward_sum_o_do(
202
  o_ptr, # O: n x h x d
@@ -233,6 +239,10 @@ def backward_sum_o_do(
233
  tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
234
 
235
 
 
 
 
 
236
  @triton.jit
237
  def count_kernel(
238
  x_ptr, # [num_kv_heads, total_len, topk]
@@ -309,12 +319,16 @@ def count_query(
309
  BLOCK_SIZE_N=BLOCK_SIZE_N,
310
  BLOCK_SIZE_K=BLOCK_SIZE_K,
311
  BLOCK_SIZE_R=BLOCK_SIZE_R,
312
- num_warps=4,
313
- num_stages=3,
314
  )
315
  return active_query_count
316
 
317
 
 
 
 
 
318
  @triton.jit
319
  def pad_topk_idx_kernel(
320
  t_ptr,
@@ -360,7 +374,10 @@ def pad_topk_idx_kernel(
360
  idxs = tl.load(t_ptrs, boundary_check=(0, 1))
361
  tl.store(p_ptrs, idxs, boundary_check=(0, 1))
362
 
363
-
 
 
 
364
  @triton.jit
365
  def save_topk_idx_kernel(
366
  p_ptr,
@@ -478,7 +495,10 @@ def reorder_topk_idx(
478
  )
479
  return topk_q_idx
480
 
481
-
 
 
 
482
  @triton.jit
483
  def backward_dkdv(
484
  q_ptr, # Q: n x qh x d
@@ -646,7 +666,10 @@ def backward_dkdv(
646
  tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
647
  tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
648
 
649
-
 
 
 
650
  @triton.jit
651
  def backward_dq(
652
  q_ptr, # Q: n x qh x d
@@ -902,8 +925,8 @@ def _topk_sparse_attention_fwd(
902
  BLOCK_SIZE_D=BLOCK_SIZE_D,
903
  BLOCK_SIZE_H=BLOCK_SIZE_H,
904
  BLOCK_SIZE_T=BLOCK_SIZE_T,
905
- num_warps=num_warps,
906
- num_stages=num_stages,
907
  )
908
  return o, lse
909
 
@@ -954,8 +977,8 @@ def _topk_sparse_attention_bwd(
954
  delta.stride(1),
955
  BLOCK_SIZE_O=BLOCK_SIZE_O,
956
  BLOCK_SIZE_D=BLOCK_SIZE_D,
957
- num_warps=num_warps,
958
- num_stages=num_stages,
959
  )
960
  # count active querys for each key block, shape: (num_k_heads, total_k_blocks)
961
  seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
@@ -1038,8 +1061,8 @@ def _topk_sparse_attention_bwd(
1038
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1039
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1040
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1041
- num_warps=num_warps,
1042
- num_stages=num_stages,
1043
  )
1044
  dk = dk.sum(0)
1045
  dv = dv.sum(0)
@@ -1096,8 +1119,8 @@ def _topk_sparse_attention_bwd(
1096
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1097
  BLOCK_SIZE_H=BLOCK_SIZE_H,
1098
  BLOCK_SIZE_T=BLOCK_SIZE_T,
1099
- num_warps=num_warps,
1100
- num_stages=num_stages,
1101
  )
1102
  return dq, dk, dv
1103
 
 
25
 
26
  IS_HOPPER_GPU = is_hopper_gpu()
27
 
28
+ @triton.autotune(
29
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
30
+ key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'],
31
+ )
32
  @triton.jit
33
  def forward_kernel_orig(
34
  q_ptr, # Q: n x h x d
 
199
  lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
200
  tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
201
 
202
+ @triton.autotune(
203
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
204
+ key=['HEAD_DIM', 'BLOCK_SIZE_O', 'BLOCK_SIZE_D'],
205
+ )
206
  @triton.jit
207
  def backward_sum_o_do(
208
  o_ptr, # O: n x h x d
 
239
  tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len)
240
 
241
 
242
+ @triton.autotune(
243
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
244
+ key=['BLOCK_SIZE_N', 'BLOCK_SIZE_K', 'BLOCK_SIZE_R'],
245
+ )
246
  @triton.jit
247
  def count_kernel(
248
  x_ptr, # [num_kv_heads, total_len, topk]
 
319
  BLOCK_SIZE_N=BLOCK_SIZE_N,
320
  BLOCK_SIZE_K=BLOCK_SIZE_K,
321
  BLOCK_SIZE_R=BLOCK_SIZE_R,
322
+ # num_warps=4,
323
+ # num_stages=3,
324
  )
325
  return active_query_count
326
 
327
 
328
+ @triton.autotune(
329
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
330
+ key=['topk', 'BLOCK_SIZE_N', 'BLOCK_SIZE_T'],
331
+ )
332
  @triton.jit
333
  def pad_topk_idx_kernel(
334
  t_ptr,
 
374
  idxs = tl.load(t_ptrs, boundary_check=(0, 1))
375
  tl.store(p_ptrs, idxs, boundary_check=(0, 1))
376
 
377
+ @triton.autotune(
378
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
379
+ key=['BLOCK_SIZE_N'],
380
+ )
381
  @triton.jit
382
  def save_topk_idx_kernel(
383
  p_ptr,
 
495
  )
496
  return topk_q_idx
497
 
498
+ @triton.autotune(
499
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
500
+ key=['HEAD_DIM', 'BLOCK_SIZE_Q', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D'],
501
+ )
502
  @triton.jit
503
  def backward_dkdv(
504
  q_ptr, # Q: n x qh x d
 
666
  tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
667
  tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
668
 
669
+ @triton.autotune(
670
+ configs=[triton.Config({}, num_warps=nw) for nw in [1, 2, 4, 8]],
671
+ key=['HEAD_DIM', 'BLOCK_SIZE_K', 'BLOCK_SIZE_D', 'BLOCK_SIZE_H', 'BLOCK_SIZE_T'],
672
+ )
673
  @triton.jit
674
  def backward_dq(
675
  q_ptr, # Q: n x qh x d
 
925
  BLOCK_SIZE_D=BLOCK_SIZE_D,
926
  BLOCK_SIZE_H=BLOCK_SIZE_H,
927
  BLOCK_SIZE_T=BLOCK_SIZE_T,
928
+ # num_warps=num_warps,
929
+ # num_stages=num_stages,
930
  )
931
  return o, lse
932
 
 
977
  delta.stride(1),
978
  BLOCK_SIZE_O=BLOCK_SIZE_O,
979
  BLOCK_SIZE_D=BLOCK_SIZE_D,
980
+ # num_warps=num_warps,
981
+ # num_stages=num_stages,
982
  )
983
  # count active querys for each key block, shape: (num_k_heads, total_k_blocks)
984
  seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
 
1061
  BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1062
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1063
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1064
+ # num_warps=num_warps,
1065
+ # num_stages=num_stages,
1066
  )
1067
  dk = dk.sum(0)
1068
  dv = dv.sum(0)
 
1119
  BLOCK_SIZE_D=BLOCK_SIZE_D,
1120
  BLOCK_SIZE_H=BLOCK_SIZE_H,
1121
  BLOCK_SIZE_T=BLOCK_SIZE_T,
1122
+ # num_warps=num_warps,
1123
+ # num_stages=num_stages,
1124
  )
1125
  return dq, dk, dv
1126