Update triton_flash_blocksparse_attn.py
Browse filesAdding with ```torch.cuda.device(q.device.index)``` at all applicable sections to support multi gpu.
- triton_flash_blocksparse_attn.py +64 -62
triton_flash_blocksparse_attn.py
CHANGED
|
@@ -992,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
|
|
| 992 |
|
| 993 |
grid = (len(q_start_sids), n_heads)
|
| 994 |
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
|
|
|
| 1026 |
|
| 1027 |
return out
|
| 1028 |
|
|
@@ -1094,37 +1095,38 @@ def blocksparse_flash_attn_varlen_fwd(
|
|
| 1094 |
|
| 1095 |
grid = (len(q_start_sids), n_heads)
|
| 1096 |
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
|
|
|
| 1128 |
|
| 1129 |
return out
|
| 1130 |
|
|
|
|
| 992 |
|
| 993 |
grid = (len(q_start_sids), n_heads)
|
| 994 |
|
| 995 |
+
with torch.cuda.device(q.device.index):
|
| 996 |
+
_fwd_kernel_batch_inference[grid](
|
| 997 |
+
q, k, v, out,
|
| 998 |
+
sm_scale,
|
| 999 |
+
q_batch_starts,
|
| 1000 |
+
q_batch_ends,
|
| 1001 |
+
k_batch_starts,
|
| 1002 |
+
k_batch_ends,
|
| 1003 |
+
q_batch_ids,
|
| 1004 |
+
q_start_sids,
|
| 1005 |
+
|
| 1006 |
+
*q.stride(),
|
| 1007 |
+
*k.stride(),
|
| 1008 |
+
*v.stride(),
|
| 1009 |
+
*out.stride(),
|
| 1010 |
+
|
| 1011 |
+
layout_crow_indices,
|
| 1012 |
+
layout_col_indices,
|
| 1013 |
+
*layout_crow_indices.stride(),
|
| 1014 |
+
*layout_col_indices.stride(),
|
| 1015 |
+
|
| 1016 |
+
q_k_ratio,
|
| 1017 |
+
HAS_BATCH_DIM = True,
|
| 1018 |
+
D_HEAD = head_size,
|
| 1019 |
+
BLOCK_M = block_size,
|
| 1020 |
+
BLOCK_N = block_size,
|
| 1021 |
+
BLOCK_D = block_d,
|
| 1022 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
| 1023 |
+
EVEN_D = block_d == head_size,
|
| 1024 |
+
num_warps = 1 if q_len == 1 else 4,
|
| 1025 |
+
num_stages = 3
|
| 1026 |
+
)
|
| 1027 |
|
| 1028 |
return out
|
| 1029 |
|
|
|
|
| 1095 |
|
| 1096 |
grid = (len(q_start_sids), n_heads)
|
| 1097 |
|
| 1098 |
+
with torch.cuda.device(q.device.index):
|
| 1099 |
+
_fwd_kernel_batch_inference[grid](
|
| 1100 |
+
q, k, v, out,
|
| 1101 |
+
sm_scale,
|
| 1102 |
+
cu_seqlens_q[:-1],
|
| 1103 |
+
cu_seqlens_q[1:],
|
| 1104 |
+
cu_seqlens_k[:-1],
|
| 1105 |
+
cu_seqlens_k[1:],
|
| 1106 |
+
q_batch_ids,
|
| 1107 |
+
q_start_sids,
|
| 1108 |
+
|
| 1109 |
+
0, *q.stride(),
|
| 1110 |
+
0, *k.stride(),
|
| 1111 |
+
0, *v.stride(),
|
| 1112 |
+
0, *out.stride(),
|
| 1113 |
+
|
| 1114 |
+
layout_crow_indices,
|
| 1115 |
+
layout_col_indices,
|
| 1116 |
+
*layout_crow_indices.stride(),
|
| 1117 |
+
*layout_col_indices.stride(),
|
| 1118 |
+
|
| 1119 |
+
q_k_ratio,
|
| 1120 |
+
HAS_BATCH_DIM = False,
|
| 1121 |
+
D_HEAD = head_size,
|
| 1122 |
+
BLOCK_M = block_size,
|
| 1123 |
+
BLOCK_N = block_size,
|
| 1124 |
+
BLOCK_D = block_d,
|
| 1125 |
+
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
|
| 1126 |
+
EVEN_D = block_d == head_size,
|
| 1127 |
+
num_warps = 1 if decoding_only else 4,
|
| 1128 |
+
num_stages = 3
|
| 1129 |
+
)
|
| 1130 |
|
| 1131 |
return out
|
| 1132 |
|