Build uploaded using `kernels`.
Browse files- build/torch-cuda/__init__.py +24 -0
- build/torch-cuda/_ops.py +8 -0
- build/torch-cuda/ampere_helpers.py +103 -0
- build/torch-cuda/barrier.py +71 -0
- build/torch-cuda/benchmark.py +268 -0
- build/torch-cuda/blackwell_helpers.py +1089 -0
- build/torch-cuda/block_info.py +108 -0
- build/torch-cuda/block_sparse_utils.py +1476 -0
- build/torch-cuda/block_sparsity.py +440 -0
- build/torch-cuda/cache_utils.py +307 -0
- build/torch-cuda/compute_block_sparsity.py +378 -0
- build/torch-cuda/copy_utils.py +372 -0
- build/torch-cuda/cute_dsl_ptxas.py +151 -0
- build/torch-cuda/cute_dsl_utils.py +167 -0
- build/torch-cuda/fast_math.py +21 -0
- build/torch-cuda/flash_attn4/__init__.py +26 -0
- build/torch-cuda/flash_bwd.py +1264 -0
- build/torch-cuda/flash_bwd_postprocess.py +585 -0
- build/torch-cuda/flash_bwd_preprocess.py +361 -0
- build/torch-cuda/flash_bwd_sm100.py +0 -0
- build/torch-cuda/flash_bwd_sm90.py +1591 -0
- build/torch-cuda/flash_fwd.py +0 -0
- build/torch-cuda/flash_fwd_combine.py +692 -0
- build/torch-cuda/flash_fwd_sm100.py +0 -0
- build/torch-cuda/interface.py +1855 -0
- build/torch-cuda/mask.py +653 -0
- build/torch-cuda/metadata.json +8 -0
- build/torch-cuda/mma_sm100_desc.py +296 -0
- build/torch-cuda/named_barrier.py +32 -0
- build/torch-cuda/pack_gqa.py +165 -0
- build/torch-cuda/paged_kv.py +214 -0
- build/torch-cuda/pipeline.py +440 -0
- build/torch-cuda/quack/__init__.py +0 -0
- build/torch-cuda/quack/activation.py +568 -0
- build/torch-cuda/quack/compile_utils.py +19 -0
- build/torch-cuda/quack/copy_utils.py +1007 -0
- build/torch-cuda/quack/cute_dsl_utils.py +165 -0
- build/torch-cuda/quack/layout_utils.py +297 -0
- build/torch-cuda/quack/sm90_utils.py +161 -0
- build/torch-cuda/seqlen_info.py +138 -0
- build/torch-cuda/softmax.py +592 -0
- build/torch-cuda/testing.py +456 -0
- build/torch-cuda/tile_scheduler.py +727 -0
- build/torch-cuda/utils.py +698 -0
build/torch-cuda/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flash Attention CUTE (CUDA Template Engine) implementation."""
|
| 2 |
+
|
| 3 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 4 |
+
|
| 5 |
+
# Update when syncing again.
|
| 6 |
+
__version__ = "4.0.0.beta4"
|
| 7 |
+
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
|
| 10 |
+
from .interface import (
|
| 11 |
+
flash_attn_func,
|
| 12 |
+
flash_attn_varlen_func,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from .cute_dsl_utils import cute_compile_patched
|
| 16 |
+
|
| 17 |
+
# Patch cute.compile to optionally dump SASS
|
| 18 |
+
cute.compile = cute_compile_patched
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"flash_attn_func",
|
| 23 |
+
"flash_attn_varlen_func",
|
| 24 |
+
]
|
build/torch-cuda/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._flash_attn4_c07a63b
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_flash_attn4_c07a63b::{op_name}"
|
build/torch-cuda/ampere_helpers.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
from typing import Type, Callable, Optional
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
|
| 9 |
+
dtype_byte = cutlass.const_expr(dtype.width // 8)
|
| 10 |
+
bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
|
| 11 |
+
smem_k_block_size = (
|
| 12 |
+
cutlass.const_expr(
|
| 13 |
+
128
|
| 14 |
+
if bytes_per_row % 128 == 0
|
| 15 |
+
else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
|
| 16 |
+
)
|
| 17 |
+
// dtype_byte
|
| 18 |
+
)
|
| 19 |
+
swizzle_bits = (
|
| 20 |
+
4
|
| 21 |
+
if smem_k_block_size == 128
|
| 22 |
+
else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
|
| 23 |
+
)
|
| 24 |
+
swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
|
| 25 |
+
return cute.make_composed_layout(
|
| 26 |
+
cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
|
| 27 |
+
0,
|
| 28 |
+
cute.make_ordered_layout(
|
| 29 |
+
(8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
|
| 30 |
+
),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@cute.jit
|
| 35 |
+
def gemm(
|
| 36 |
+
tiled_mma: cute.TiledMma,
|
| 37 |
+
acc: cute.Tensor,
|
| 38 |
+
tCrA: cute.Tensor,
|
| 39 |
+
tCrB: cute.Tensor,
|
| 40 |
+
tCsA: cute.Tensor,
|
| 41 |
+
tCsB: cute.Tensor,
|
| 42 |
+
smem_thr_copy_A: cute.TiledCopy,
|
| 43 |
+
smem_thr_copy_B: cute.TiledCopy,
|
| 44 |
+
hook_fn: Optional[Callable] = None,
|
| 45 |
+
A_in_regs: cutlass.Constexpr[bool] = False,
|
| 46 |
+
B_in_regs: cutlass.Constexpr[bool] = False,
|
| 47 |
+
swap_AB: cutlass.Constexpr[bool] = False,
|
| 48 |
+
) -> None:
|
| 49 |
+
if cutlass.const_expr(swap_AB):
|
| 50 |
+
gemm(
|
| 51 |
+
tiled_mma,
|
| 52 |
+
acc,
|
| 53 |
+
tCrB,
|
| 54 |
+
tCrA,
|
| 55 |
+
tCsB,
|
| 56 |
+
tCsA,
|
| 57 |
+
smem_thr_copy_B,
|
| 58 |
+
smem_thr_copy_A,
|
| 59 |
+
hook_fn,
|
| 60 |
+
A_in_regs=B_in_regs,
|
| 61 |
+
B_in_regs=A_in_regs,
|
| 62 |
+
swap_AB=False,
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
|
| 66 |
+
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
|
| 67 |
+
if cutlass.const_expr(not A_in_regs):
|
| 68 |
+
cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
|
| 69 |
+
if cutlass.const_expr(not B_in_regs):
|
| 70 |
+
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
|
| 71 |
+
for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
|
| 72 |
+
if k < cute.size(tCsA.shape[2]) - 1:
|
| 73 |
+
if cutlass.const_expr(not A_in_regs):
|
| 74 |
+
cute.copy(
|
| 75 |
+
smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
|
| 76 |
+
)
|
| 77 |
+
if cutlass.const_expr(not B_in_regs):
|
| 78 |
+
cute.copy(
|
| 79 |
+
smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
|
| 80 |
+
)
|
| 81 |
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 82 |
+
if cutlass.const_expr(k == 0 and hook_fn is not None):
|
| 83 |
+
hook_fn()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@cute.jit
|
| 87 |
+
def gemm_rs(
|
| 88 |
+
tiled_mma: cute.TiledMma,
|
| 89 |
+
acc: cute.Tensor,
|
| 90 |
+
tCrA: cute.Tensor,
|
| 91 |
+
tCrB: cute.Tensor,
|
| 92 |
+
tCsB: cute.Tensor,
|
| 93 |
+
smem_thr_copy_B: cute.TiledCopy,
|
| 94 |
+
hook_fn: Optional[Callable] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
|
| 97 |
+
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
|
| 98 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 99 |
+
if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
|
| 100 |
+
cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
|
| 101 |
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 102 |
+
if cutlass.const_expr(k == 0 and hook_fn is not None):
|
| 103 |
+
hook_fn()
|
build/torch-cuda/barrier.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cutlass
|
| 2 |
+
import cutlass.cute as cute
|
| 3 |
+
from cutlass import Int32
|
| 4 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 5 |
+
from cutlass._mlir.dialects import llvm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dsl_user_op
|
| 9 |
+
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
|
| 10 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 11 |
+
state = llvm.inline_asm(
|
| 12 |
+
T.i32(),
|
| 13 |
+
[lock_ptr_i64],
|
| 14 |
+
"ld.global.acquire.gpu.b32 $0, [$1];",
|
| 15 |
+
"=r,l",
|
| 16 |
+
has_side_effects=True,
|
| 17 |
+
is_align_stack=False,
|
| 18 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 19 |
+
)
|
| 20 |
+
return cutlass.Int32(state)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def red_relaxed(
|
| 25 |
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
| 26 |
+
) -> None:
|
| 27 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 28 |
+
llvm.inline_asm(
|
| 29 |
+
None,
|
| 30 |
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
| 31 |
+
"red.relaxed.gpu.global.add.s32 [$0], $1;",
|
| 32 |
+
"l,r",
|
| 33 |
+
has_side_effects=True,
|
| 34 |
+
is_align_stack=False,
|
| 35 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dsl_user_op
|
| 40 |
+
def red_release(
|
| 41 |
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
| 42 |
+
) -> None:
|
| 43 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 44 |
+
llvm.inline_asm(
|
| 45 |
+
None,
|
| 46 |
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
| 47 |
+
"red.release.gpu.global.add.s32 [$0], $1;",
|
| 48 |
+
"l,r",
|
| 49 |
+
has_side_effects=True,
|
| 50 |
+
is_align_stack=False,
|
| 51 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@cute.jit
|
| 56 |
+
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
|
| 57 |
+
flag_ptr = lock_ptr + flag_offset
|
| 58 |
+
if thread_idx == 0:
|
| 59 |
+
read_val = Int32(0)
|
| 60 |
+
while read_val != val:
|
| 61 |
+
read_val = ld_acquire(flag_ptr)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@cute.jit
|
| 65 |
+
def arrive_inc(
|
| 66 |
+
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
|
| 67 |
+
) -> None:
|
| 68 |
+
flag_ptr = lock_ptr + flag_offset
|
| 69 |
+
if thread_idx == 0:
|
| 70 |
+
red_release(flag_ptr, val)
|
| 71 |
+
# red_relaxed(flag_ptr, val)
|
build/torch-cuda/benchmark.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
"""Useful functions for writing test code."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.benchmark as benchmark
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def benchmark_forward(
|
| 9 |
+
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
|
| 10 |
+
):
|
| 11 |
+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
| 12 |
+
if verbose:
|
| 13 |
+
print(desc, "- Forward pass")
|
| 14 |
+
|
| 15 |
+
def amp_wrapper(*inputs, **kwinputs):
|
| 16 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 17 |
+
fn(*inputs, **kwinputs)
|
| 18 |
+
|
| 19 |
+
t = benchmark.Timer(
|
| 20 |
+
stmt="fn_amp(*inputs, **kwinputs)",
|
| 21 |
+
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
| 22 |
+
num_threads=torch.get_num_threads(),
|
| 23 |
+
)
|
| 24 |
+
m = t.timeit(repeats)
|
| 25 |
+
if verbose:
|
| 26 |
+
print(m)
|
| 27 |
+
return t, m
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def benchmark_backward(
|
| 31 |
+
fn,
|
| 32 |
+
*inputs,
|
| 33 |
+
grad=None,
|
| 34 |
+
repeats=10,
|
| 35 |
+
desc="",
|
| 36 |
+
verbose=True,
|
| 37 |
+
amp=False,
|
| 38 |
+
amp_dtype=torch.float16,
|
| 39 |
+
**kwinputs,
|
| 40 |
+
):
|
| 41 |
+
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
|
| 42 |
+
if verbose:
|
| 43 |
+
print(desc, "- Backward pass")
|
| 44 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 45 |
+
y = fn(*inputs, **kwinputs)
|
| 46 |
+
if type(y) is tuple:
|
| 47 |
+
y = y[0]
|
| 48 |
+
if grad is None:
|
| 49 |
+
grad = torch.randn_like(y)
|
| 50 |
+
else:
|
| 51 |
+
if grad.shape != y.shape:
|
| 52 |
+
raise RuntimeError("Grad shape does not match output shape")
|
| 53 |
+
|
| 54 |
+
def f(*inputs, y, grad):
|
| 55 |
+
# Set .grad to None to avoid extra operation of gradient accumulation
|
| 56 |
+
for x in inputs:
|
| 57 |
+
if isinstance(x, torch.Tensor):
|
| 58 |
+
x.grad = None
|
| 59 |
+
y.backward(grad, retain_graph=True)
|
| 60 |
+
|
| 61 |
+
t = benchmark.Timer(
|
| 62 |
+
stmt="f(*inputs, y=y, grad=grad)",
|
| 63 |
+
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
|
| 64 |
+
num_threads=torch.get_num_threads(),
|
| 65 |
+
)
|
| 66 |
+
m = t.timeit(repeats)
|
| 67 |
+
if verbose:
|
| 68 |
+
print(m)
|
| 69 |
+
return t, m
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def benchmark_combined(
|
| 73 |
+
fn,
|
| 74 |
+
*inputs,
|
| 75 |
+
grad=None,
|
| 76 |
+
repeats=10,
|
| 77 |
+
desc="",
|
| 78 |
+
verbose=True,
|
| 79 |
+
amp=False,
|
| 80 |
+
amp_dtype=torch.float16,
|
| 81 |
+
**kwinputs,
|
| 82 |
+
):
|
| 83 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
| 84 |
+
if verbose:
|
| 85 |
+
print(desc, "- Forward + Backward pass")
|
| 86 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 87 |
+
y = fn(*inputs, **kwinputs)
|
| 88 |
+
if type(y) is tuple:
|
| 89 |
+
y = y[0]
|
| 90 |
+
if grad is None:
|
| 91 |
+
grad = torch.randn_like(y)
|
| 92 |
+
else:
|
| 93 |
+
if grad.shape != y.shape:
|
| 94 |
+
raise RuntimeError("Grad shape does not match output shape")
|
| 95 |
+
|
| 96 |
+
def f(grad, *inputs, **kwinputs):
|
| 97 |
+
for x in inputs:
|
| 98 |
+
if isinstance(x, torch.Tensor):
|
| 99 |
+
x.grad = None
|
| 100 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 101 |
+
y = fn(*inputs, **kwinputs)
|
| 102 |
+
if type(y) is tuple:
|
| 103 |
+
y = y[0]
|
| 104 |
+
y.backward(grad, retain_graph=True)
|
| 105 |
+
|
| 106 |
+
t = benchmark.Timer(
|
| 107 |
+
stmt="f(grad, *inputs, **kwinputs)",
|
| 108 |
+
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
|
| 109 |
+
num_threads=torch.get_num_threads(),
|
| 110 |
+
)
|
| 111 |
+
m = t.timeit(repeats)
|
| 112 |
+
if verbose:
|
| 113 |
+
print(m)
|
| 114 |
+
return t, m
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def benchmark_fwd_bwd(
|
| 118 |
+
fn,
|
| 119 |
+
*inputs,
|
| 120 |
+
grad=None,
|
| 121 |
+
repeats=10,
|
| 122 |
+
desc="",
|
| 123 |
+
verbose=True,
|
| 124 |
+
amp=False,
|
| 125 |
+
amp_dtype=torch.float16,
|
| 126 |
+
**kwinputs,
|
| 127 |
+
):
|
| 128 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
| 129 |
+
return (
|
| 130 |
+
benchmark_forward(
|
| 131 |
+
fn,
|
| 132 |
+
*inputs,
|
| 133 |
+
repeats=repeats,
|
| 134 |
+
desc=desc,
|
| 135 |
+
verbose=verbose,
|
| 136 |
+
amp=amp,
|
| 137 |
+
amp_dtype=amp_dtype,
|
| 138 |
+
**kwinputs,
|
| 139 |
+
),
|
| 140 |
+
benchmark_backward(
|
| 141 |
+
fn,
|
| 142 |
+
*inputs,
|
| 143 |
+
grad=grad,
|
| 144 |
+
repeats=repeats,
|
| 145 |
+
desc=desc,
|
| 146 |
+
verbose=verbose,
|
| 147 |
+
amp=amp,
|
| 148 |
+
amp_dtype=amp_dtype,
|
| 149 |
+
**kwinputs,
|
| 150 |
+
),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def benchmark_all(
|
| 155 |
+
fn,
|
| 156 |
+
*inputs,
|
| 157 |
+
grad=None,
|
| 158 |
+
repeats=10,
|
| 159 |
+
desc="",
|
| 160 |
+
verbose=True,
|
| 161 |
+
amp=False,
|
| 162 |
+
amp_dtype=torch.float16,
|
| 163 |
+
**kwinputs,
|
| 164 |
+
):
|
| 165 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
| 166 |
+
return (
|
| 167 |
+
benchmark_forward(
|
| 168 |
+
fn,
|
| 169 |
+
*inputs,
|
| 170 |
+
repeats=repeats,
|
| 171 |
+
desc=desc,
|
| 172 |
+
verbose=verbose,
|
| 173 |
+
amp=amp,
|
| 174 |
+
amp_dtype=amp_dtype,
|
| 175 |
+
**kwinputs,
|
| 176 |
+
),
|
| 177 |
+
benchmark_backward(
|
| 178 |
+
fn,
|
| 179 |
+
*inputs,
|
| 180 |
+
grad=grad,
|
| 181 |
+
repeats=repeats,
|
| 182 |
+
desc=desc,
|
| 183 |
+
verbose=verbose,
|
| 184 |
+
amp=amp,
|
| 185 |
+
amp_dtype=amp_dtype,
|
| 186 |
+
**kwinputs,
|
| 187 |
+
),
|
| 188 |
+
benchmark_combined(
|
| 189 |
+
fn,
|
| 190 |
+
*inputs,
|
| 191 |
+
grad=grad,
|
| 192 |
+
repeats=repeats,
|
| 193 |
+
desc=desc,
|
| 194 |
+
verbose=verbose,
|
| 195 |
+
amp=amp,
|
| 196 |
+
amp_dtype=amp_dtype,
|
| 197 |
+
**kwinputs,
|
| 198 |
+
),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def pytorch_profiler(
|
| 203 |
+
fn,
|
| 204 |
+
*inputs,
|
| 205 |
+
trace_filename=None,
|
| 206 |
+
backward=False,
|
| 207 |
+
amp=False,
|
| 208 |
+
amp_dtype=torch.float16,
|
| 209 |
+
cpu=False,
|
| 210 |
+
verbose=True,
|
| 211 |
+
**kwinputs,
|
| 212 |
+
):
|
| 213 |
+
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
|
| 214 |
+
if backward:
|
| 215 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 216 |
+
out = fn(*inputs, **kwinputs)
|
| 217 |
+
if type(out) is tuple:
|
| 218 |
+
out = out[0]
|
| 219 |
+
g = torch.randn_like(out)
|
| 220 |
+
for _ in range(30): # Warm up
|
| 221 |
+
if backward:
|
| 222 |
+
for x in inputs:
|
| 223 |
+
if isinstance(x, torch.Tensor):
|
| 224 |
+
x.grad = None
|
| 225 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 226 |
+
out = fn(*inputs, **kwinputs)
|
| 227 |
+
if type(out) is tuple:
|
| 228 |
+
out = out[0]
|
| 229 |
+
# Backward should be done outside autocast
|
| 230 |
+
if backward:
|
| 231 |
+
out.backward(g, retain_graph=True)
|
| 232 |
+
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
|
| 233 |
+
torch.profiler.ProfilerActivity.CUDA
|
| 234 |
+
]
|
| 235 |
+
with torch.profiler.profile(
|
| 236 |
+
activities=activities,
|
| 237 |
+
record_shapes=True,
|
| 238 |
+
# profile_memory=True,
|
| 239 |
+
with_stack=True,
|
| 240 |
+
) as prof:
|
| 241 |
+
if backward:
|
| 242 |
+
for x in inputs:
|
| 243 |
+
if isinstance(x, torch.Tensor):
|
| 244 |
+
x.grad = None
|
| 245 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
| 246 |
+
out = fn(*inputs, **kwinputs)
|
| 247 |
+
if type(out) is tuple:
|
| 248 |
+
out = out[0]
|
| 249 |
+
if backward:
|
| 250 |
+
out.backward(g, retain_graph=True)
|
| 251 |
+
if verbose:
|
| 252 |
+
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
| 253 |
+
print(prof.key_averages().table(row_limit=50))
|
| 254 |
+
if trace_filename is not None:
|
| 255 |
+
prof.export_chrome_trace(trace_filename)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
|
| 259 |
+
torch.cuda.empty_cache()
|
| 260 |
+
torch.cuda.reset_peak_memory_stats()
|
| 261 |
+
torch.cuda.synchronize()
|
| 262 |
+
fn(*inputs, **kwinputs)
|
| 263 |
+
torch.cuda.synchronize()
|
| 264 |
+
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
|
| 265 |
+
if verbose:
|
| 266 |
+
print(f"{desc} max memory: {mem}GB")
|
| 267 |
+
torch.cuda.empty_cache()
|
| 268 |
+
return mem
|
build/torch-cuda/blackwell_helpers.py
ADDED
|
@@ -0,0 +1,1089 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Int32, Boolean, const_expr
|
| 7 |
+
from cutlass.cute.nvgpu import tcgen05
|
| 8 |
+
from cutlass._mlir.dialects import llvm
|
| 9 |
+
|
| 10 |
+
from . import mma_sm100_desc as sm100_desc
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@cute.jit
|
| 14 |
+
def gemm_w_idx(
|
| 15 |
+
tiled_mma: cute.TiledMma,
|
| 16 |
+
acc: cute.Tensor,
|
| 17 |
+
tCrA: cute.Tensor,
|
| 18 |
+
tCrB: cute.Tensor,
|
| 19 |
+
A_idx: Optional[Int32] = None,
|
| 20 |
+
B_idx: Optional[Int32] = None,
|
| 21 |
+
zero_init: bool | Boolean = False,
|
| 22 |
+
swap_AB: bool = False,
|
| 23 |
+
num_unroll_groups: int = 1,
|
| 24 |
+
) -> None:
|
| 25 |
+
if const_expr(swap_AB):
|
| 26 |
+
return gemm_w_idx(
|
| 27 |
+
tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 31 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 32 |
+
|
| 33 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 34 |
+
for k in cutlass.range(
|
| 35 |
+
cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups
|
| 36 |
+
):
|
| 37 |
+
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
| 38 |
+
cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@cute.jit
|
| 42 |
+
def gemm_ptx_w_idx(
|
| 43 |
+
tiled_mma: cute.TiledMma,
|
| 44 |
+
acc: cute.Tensor,
|
| 45 |
+
tCrA: cute.Tensor,
|
| 46 |
+
tCrB: cute.Tensor,
|
| 47 |
+
sA: Optional[cute.Tensor],
|
| 48 |
+
sB: cute.Tensor,
|
| 49 |
+
A_idx: Optional[Int32] = None,
|
| 50 |
+
B_idx: Optional[Int32] = None,
|
| 51 |
+
zero_init: bool | Boolean = False,
|
| 52 |
+
cta_group: int = 1,
|
| 53 |
+
**kwargs,
|
| 54 |
+
) -> None:
|
| 55 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 56 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 57 |
+
sA_cur = None
|
| 58 |
+
if const_expr(sA is not None):
|
| 59 |
+
sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
|
| 60 |
+
sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
|
| 61 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 62 |
+
acc_tmem_addr = acc.iterator.toint()
|
| 63 |
+
gemm_ptx_partial(
|
| 64 |
+
mma_atom.op,
|
| 65 |
+
acc_tmem_addr,
|
| 66 |
+
rA,
|
| 67 |
+
rB,
|
| 68 |
+
sA_cur,
|
| 69 |
+
sB_cur,
|
| 70 |
+
zero_init=zero_init,
|
| 71 |
+
cta_group=cta_group,
|
| 72 |
+
**kwargs,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@cute.jit
|
| 77 |
+
def gemm(
|
| 78 |
+
tiled_mma: cute.TiledMma,
|
| 79 |
+
acc: cute.Tensor,
|
| 80 |
+
tCrA: cute.Tensor,
|
| 81 |
+
tCrB: cute.Tensor,
|
| 82 |
+
zero_init: bool | Boolean = False,
|
| 83 |
+
) -> None:
|
| 84 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 85 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 86 |
+
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
| 87 |
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def i64_to_i32x2(i: int) -> Tuple[int, int]:
|
| 91 |
+
"""Convert a 64-bit integer to a tuple of two 32-bit integers."""
|
| 92 |
+
return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@cute.jit
|
| 96 |
+
def gemm_ptx(
|
| 97 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 98 |
+
acc: cute.Tensor,
|
| 99 |
+
tCrA: cute.Tensor,
|
| 100 |
+
tCrB: cute.Tensor,
|
| 101 |
+
sA: Optional[cute.Tensor],
|
| 102 |
+
sB: cute.Tensor,
|
| 103 |
+
zero_init: bool | Boolean = False,
|
| 104 |
+
) -> None:
|
| 105 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 106 |
+
if const_expr(not is_ts):
|
| 107 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 108 |
+
sA_layout = sA.layout if sA is not None else None
|
| 109 |
+
sB_layout = sB.layout
|
| 110 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 111 |
+
if const_expr(not is_ts):
|
| 112 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 113 |
+
smem_desc_base_a: int = const_expr(
|
| 114 |
+
sm100_desc.make_smem_desc_base(
|
| 115 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 116 |
+
sA_swizzle,
|
| 117 |
+
sm100_desc.Major.K
|
| 118 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 119 |
+
else sm100_desc.Major.MN,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 123 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 124 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 125 |
+
else:
|
| 126 |
+
smem_desc_base_a = None
|
| 127 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 128 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 129 |
+
smem_desc_base_b: int = const_expr(
|
| 130 |
+
sm100_desc.make_smem_desc_base(
|
| 131 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 132 |
+
sB_swizzle,
|
| 133 |
+
sm100_desc.Major.K
|
| 134 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 135 |
+
else sm100_desc.Major.MN,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 139 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 140 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 141 |
+
|
| 142 |
+
if const_expr(not is_ts):
|
| 143 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
|
| 144 |
+
sA[None, None, 0].iterator
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
smem_desc_start_a_lo = None
|
| 148 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
|
| 149 |
+
sB[None, None, 0].iterator
|
| 150 |
+
)
|
| 151 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 152 |
+
if const_expr(not is_ts):
|
| 153 |
+
smem_desc_a_lo = smem_desc_start_a_lo + (
|
| 154 |
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
| 155 |
+
)
|
| 156 |
+
smem_desc_b_lo = smem_desc_start_b_lo + (
|
| 157 |
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
| 158 |
+
)
|
| 159 |
+
# with cute.arch.elect_one():
|
| 160 |
+
# cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
|
| 161 |
+
# cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
|
| 162 |
+
with cute.arch.elect_one():
|
| 163 |
+
if const_expr(not is_ts):
|
| 164 |
+
llvm.inline_asm(
|
| 165 |
+
None,
|
| 166 |
+
[
|
| 167 |
+
acc.iterator.toint().ir_value(),
|
| 168 |
+
smem_desc_a_lo.ir_value(),
|
| 169 |
+
smem_desc_b_lo.ir_value(),
|
| 170 |
+
Int32(not zero_init or k != 0).ir_value(),
|
| 171 |
+
],
|
| 172 |
+
"{\n\t"
|
| 173 |
+
".reg .pred p;\n\t"
|
| 174 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 175 |
+
".reg .b32 idesc;\n\t"
|
| 176 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 177 |
+
f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
|
| 178 |
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
| 179 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 180 |
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
|
| 181 |
+
"}\n",
|
| 182 |
+
"r,r,r,r",
|
| 183 |
+
has_side_effects=True,
|
| 184 |
+
is_align_stack=False,
|
| 185 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
llvm.inline_asm(
|
| 189 |
+
None,
|
| 190 |
+
[
|
| 191 |
+
acc.iterator.toint().ir_value(),
|
| 192 |
+
tCrA[None, None, k].iterator.toint().ir_value(),
|
| 193 |
+
smem_desc_b_lo.ir_value(),
|
| 194 |
+
Int32(not zero_init or k != 0).ir_value(),
|
| 195 |
+
],
|
| 196 |
+
"{\n\t"
|
| 197 |
+
".reg .pred p;\n\t"
|
| 198 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 199 |
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
| 200 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 201 |
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
|
| 202 |
+
"}\n",
|
| 203 |
+
"r,r,r,r",
|
| 204 |
+
has_side_effects=True,
|
| 205 |
+
is_align_stack=False,
|
| 206 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@cute.jit
|
| 211 |
+
def gemm_ptx_loop(
|
| 212 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 213 |
+
acc: cute.Tensor,
|
| 214 |
+
tCrA: cute.Tensor,
|
| 215 |
+
tCrB: cute.Tensor,
|
| 216 |
+
sA: Optional[cute.Tensor],
|
| 217 |
+
sB: cute.Tensor,
|
| 218 |
+
zero_init: bool | Boolean = False,
|
| 219 |
+
) -> None:
|
| 220 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 221 |
+
if const_expr(not is_ts):
|
| 222 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 223 |
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
| 224 |
+
sB_layout = sB.layout
|
| 225 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 226 |
+
if const_expr(not is_ts):
|
| 227 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 228 |
+
smem_desc_base_a: int = const_expr(
|
| 229 |
+
sm100_desc.make_smem_desc_base(
|
| 230 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 231 |
+
sA_swizzle,
|
| 232 |
+
sm100_desc.Major.K
|
| 233 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 234 |
+
else sm100_desc.Major.MN,
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 238 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 239 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 240 |
+
else:
|
| 241 |
+
smem_desc_base_a = None
|
| 242 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 243 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 244 |
+
smem_desc_base_b: int = const_expr(
|
| 245 |
+
sm100_desc.make_smem_desc_base(
|
| 246 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 247 |
+
sB_swizzle,
|
| 248 |
+
sm100_desc.Major.K
|
| 249 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 250 |
+
else sm100_desc.Major.MN,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 254 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 255 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 256 |
+
|
| 257 |
+
if const_expr(not is_ts):
|
| 258 |
+
offset_a = [
|
| 259 |
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
| 260 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
| 261 |
+
]
|
| 262 |
+
else:
|
| 263 |
+
offset_a = [
|
| 264 |
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
| 265 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
| 266 |
+
]
|
| 267 |
+
offset_a_diff = [
|
| 268 |
+
offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 269 |
+
]
|
| 270 |
+
offset_b = [
|
| 271 |
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
| 272 |
+
for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
|
| 273 |
+
]
|
| 274 |
+
offset_b_diff = [
|
| 275 |
+
offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
if const_expr(not is_ts):
|
| 279 |
+
smem_desc_start_a_lo = Int32(
|
| 280 |
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
smem_desc_start_a_lo = None
|
| 284 |
+
smem_desc_start_b_lo = Int32(
|
| 285 |
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
| 286 |
+
)
|
| 287 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 288 |
+
if const_expr(not is_ts):
|
| 289 |
+
llvm.inline_asm(
|
| 290 |
+
None,
|
| 291 |
+
[
|
| 292 |
+
acc.iterator.toint().ir_value(),
|
| 293 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 294 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 295 |
+
Int32(not zero_init).ir_value(),
|
| 296 |
+
],
|
| 297 |
+
"{\n\t"
|
| 298 |
+
".reg .pred leader_thread;\n\t"
|
| 299 |
+
".reg .pred p;\n\t"
|
| 300 |
+
".reg .b32 idesc;\n\t"
|
| 301 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 302 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 303 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 304 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 305 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 306 |
+
"mov.b32 smem_desc_a_lo, $1;\n\t"
|
| 307 |
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 308 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 309 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 310 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 311 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 312 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 313 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 314 |
+
+ "".join(
|
| 315 |
+
(
|
| 316 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 317 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 318 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 319 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 320 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 321 |
+
)
|
| 322 |
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 323 |
+
)
|
| 324 |
+
+ "}\n",
|
| 325 |
+
"r,r,r,r",
|
| 326 |
+
has_side_effects=True,
|
| 327 |
+
is_align_stack=False,
|
| 328 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
llvm.inline_asm(
|
| 332 |
+
None,
|
| 333 |
+
[
|
| 334 |
+
acc.iterator.toint().ir_value(),
|
| 335 |
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 336 |
+
Int32(smem_desc_start_b_lo).ir_value(),
|
| 337 |
+
Int32(not zero_init).ir_value(),
|
| 338 |
+
],
|
| 339 |
+
"{\n\t"
|
| 340 |
+
".reg .pred leader_thread;\n\t"
|
| 341 |
+
".reg .pred p;\n\t"
|
| 342 |
+
".reg .b32 idesc;\n\t"
|
| 343 |
+
".reg .b32 tmem_a;\n\t"
|
| 344 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 345 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 346 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 347 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 348 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 349 |
+
"mov.b32 tmem_a, $1;\n\t"
|
| 350 |
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 351 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 352 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 353 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 354 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 355 |
+
+ "".join(
|
| 356 |
+
(
|
| 357 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 358 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 359 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 360 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 361 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 362 |
+
)
|
| 363 |
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 364 |
+
)
|
| 365 |
+
+ "}\n",
|
| 366 |
+
"r,r,r,r",
|
| 367 |
+
has_side_effects=True,
|
| 368 |
+
is_align_stack=False,
|
| 369 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@cute.jit
|
| 374 |
+
def gemm_ptx_partial(
|
| 375 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 376 |
+
acc_tmem_addr: Int32,
|
| 377 |
+
tCrA: cute.Tensor,
|
| 378 |
+
tCrB: cute.Tensor,
|
| 379 |
+
sA: Optional[cute.Tensor],
|
| 380 |
+
sB: cute.Tensor,
|
| 381 |
+
mbar_ptr: Optional[cutlass.Pointer] = None,
|
| 382 |
+
mbar_phase: Optional[Int32] = None,
|
| 383 |
+
split_arrive: Optional[int] = None,
|
| 384 |
+
zero_init: bool | Boolean = False,
|
| 385 |
+
# sA_offset: Int32 = 0,
|
| 386 |
+
# acc_offset: Int32 = 0,
|
| 387 |
+
tA_addr: Optional[Int32] = None,
|
| 388 |
+
cta_group: int = 1,
|
| 389 |
+
) -> None:
|
| 390 |
+
# acc_tmem_addr += acc_offset
|
| 391 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 392 |
+
if const_expr(not is_ts):
|
| 393 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 394 |
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
| 395 |
+
sB_layout = sB.layout
|
| 396 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 397 |
+
if const_expr(not is_ts):
|
| 398 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 399 |
+
smem_desc_base_a: int = const_expr(
|
| 400 |
+
sm100_desc.make_smem_desc_base(
|
| 401 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 402 |
+
sA_swizzle,
|
| 403 |
+
sm100_desc.Major.K
|
| 404 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 405 |
+
else sm100_desc.Major.MN,
|
| 406 |
+
)
|
| 407 |
+
)
|
| 408 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 409 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 410 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 411 |
+
else:
|
| 412 |
+
smem_desc_base_a = None
|
| 413 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 414 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 415 |
+
smem_desc_base_b: int = const_expr(
|
| 416 |
+
sm100_desc.make_smem_desc_base(
|
| 417 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 418 |
+
sB_swizzle,
|
| 419 |
+
sm100_desc.Major.K
|
| 420 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 421 |
+
else sm100_desc.Major.MN,
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 425 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 426 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 427 |
+
|
| 428 |
+
tCrA_layout = (
|
| 429 |
+
tCrA.layout
|
| 430 |
+
if const_expr(not is_ts)
|
| 431 |
+
else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
|
| 432 |
+
)
|
| 433 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
|
| 434 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
| 435 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
|
| 436 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
| 437 |
+
|
| 438 |
+
if const_expr(not is_ts):
|
| 439 |
+
smem_desc_start_a_lo = Int32(
|
| 440 |
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
| 441 |
+
)
|
| 442 |
+
# ) + sA_offset
|
| 443 |
+
else:
|
| 444 |
+
smem_desc_start_a_lo = None
|
| 445 |
+
smem_desc_start_b_lo = Int32(
|
| 446 |
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
| 447 |
+
)
|
| 448 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 449 |
+
if const_expr(not is_ts):
|
| 450 |
+
assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
|
| 451 |
+
llvm.inline_asm(
|
| 452 |
+
None,
|
| 453 |
+
[
|
| 454 |
+
# acc.iterator.toint().ir_value(),
|
| 455 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 456 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 457 |
+
Int32(not zero_init).ir_value(),
|
| 458 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 459 |
+
],
|
| 460 |
+
"{\n\t"
|
| 461 |
+
".reg .pred leader_thread;\n\t"
|
| 462 |
+
".reg .pred p;\n\t"
|
| 463 |
+
".reg .b32 idesc;\n\t"
|
| 464 |
+
".reg .b32 tmem_acc;\n\t"
|
| 465 |
+
".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
|
| 466 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 467 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 468 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 469 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 470 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 471 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 472 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 473 |
+
"mov.b32 smem_desc_a_lo_start, $0;\n\t"
|
| 474 |
+
"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 475 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 476 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 477 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
|
| 478 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 479 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 480 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 481 |
+
+ "".join(
|
| 482 |
+
(
|
| 483 |
+
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 484 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 485 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
|
| 486 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 487 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 488 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 489 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 490 |
+
)
|
| 491 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 492 |
+
)
|
| 493 |
+
+ "}\n",
|
| 494 |
+
# "r,r,r",
|
| 495 |
+
"r,r,r,r",
|
| 496 |
+
has_side_effects=True,
|
| 497 |
+
is_align_stack=False,
|
| 498 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
# For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
|
| 502 |
+
# explicitly pass in the tA_addr for correctness.
|
| 503 |
+
tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
|
| 504 |
+
input_args = [
|
| 505 |
+
# Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
|
| 506 |
+
Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
|
| 507 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 508 |
+
Int32(not zero_init).ir_value(),
|
| 509 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 510 |
+
]
|
| 511 |
+
if const_expr(mbar_ptr is not None):
|
| 512 |
+
assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
|
| 513 |
+
assert split_arrive is not None, (
|
| 514 |
+
"split_arrive must be provided when mbar_ptr is not None"
|
| 515 |
+
)
|
| 516 |
+
split_arrive_idx = split_arrive // op.shape_mnk[2]
|
| 517 |
+
input_args.append(mbar_ptr.toint().ir_value())
|
| 518 |
+
input_args.append(Int32(mbar_phase).ir_value())
|
| 519 |
+
mbar_wait_str = (
|
| 520 |
+
".reg .pred P1; \n\t"
|
| 521 |
+
"LAB_WAIT: \n\t"
|
| 522 |
+
"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
|
| 523 |
+
"@P1 bra DONE; \n\t"
|
| 524 |
+
"bra LAB_WAIT; \n\t"
|
| 525 |
+
"DONE: \n\t"
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
mbar_wait_str = ""
|
| 529 |
+
llvm.inline_asm(
|
| 530 |
+
None,
|
| 531 |
+
# [
|
| 532 |
+
# # acc.iterator.toint().ir_value(),
|
| 533 |
+
# Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 534 |
+
# Int32(smem_desc_start_b_lo).ir_value(),
|
| 535 |
+
# Int32(not zero_init).ir_value(),
|
| 536 |
+
# ],
|
| 537 |
+
input_args,
|
| 538 |
+
"{\n\t"
|
| 539 |
+
".reg .pred leader_thread;\n\t"
|
| 540 |
+
".reg .pred p;\n\t"
|
| 541 |
+
".reg .b32 idesc;\n\t"
|
| 542 |
+
".reg .b32 tmem_acc;\n\t"
|
| 543 |
+
".reg .b32 tmem_a;\n\t"
|
| 544 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 545 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 546 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 547 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 548 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 549 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 550 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 551 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 552 |
+
f"mov.b32 tmem_a, $0;\n\t"
|
| 553 |
+
f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 554 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 555 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 556 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 557 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 558 |
+
+ "".join(
|
| 559 |
+
(
|
| 560 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 561 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 562 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 563 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 564 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 565 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 566 |
+
)
|
| 567 |
+
for k in range(
|
| 568 |
+
1,
|
| 569 |
+
cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx,
|
| 570 |
+
)
|
| 571 |
+
)
|
| 572 |
+
+ mbar_wait_str
|
| 573 |
+
+ (
|
| 574 |
+
"".join(
|
| 575 |
+
(
|
| 576 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 577 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 578 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 579 |
+
)
|
| 580 |
+
for k in range(split_arrive_idx, cute.size(tCrA.shape[2]))
|
| 581 |
+
)
|
| 582 |
+
if const_expr(mbar_ptr is not None)
|
| 583 |
+
else ""
|
| 584 |
+
)
|
| 585 |
+
+ "}\n",
|
| 586 |
+
"r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
|
| 587 |
+
has_side_effects=True,
|
| 588 |
+
is_align_stack=False,
|
| 589 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
@cute.jit
|
| 594 |
+
def gemm_ptx_partial1(
|
| 595 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 596 |
+
acc_tmem_addr: cutlass.Constexpr[int],
|
| 597 |
+
tCrA: cute.Tensor,
|
| 598 |
+
tCrB: cute.Tensor,
|
| 599 |
+
sA_base_addr_for_desc: Int32,
|
| 600 |
+
sA_addr_offset_for_desc: cutlass.Constexpr[int],
|
| 601 |
+
sA_stage: Int32,
|
| 602 |
+
sB_base_addr_for_desc: Int32,
|
| 603 |
+
sB_addr_offset_for_desc: cutlass.Constexpr[int],
|
| 604 |
+
sB_stage: Int32,
|
| 605 |
+
sA_layout: Optional[cute.Layout],
|
| 606 |
+
sB_layout: Optional[cute.Layout],
|
| 607 |
+
sA_swizzle: Optional[cute.Swizzle],
|
| 608 |
+
sB_swizzle: cute.Swizzle,
|
| 609 |
+
zero_init: bool | Boolean = False,
|
| 610 |
+
) -> None:
|
| 611 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 612 |
+
if const_expr(not is_ts):
|
| 613 |
+
assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
|
| 614 |
+
assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
|
| 615 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 616 |
+
if const_expr(not is_ts):
|
| 617 |
+
smem_desc_base_a: int = const_expr(
|
| 618 |
+
sm100_desc.make_smem_desc_base(
|
| 619 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 620 |
+
sA_swizzle,
|
| 621 |
+
sm100_desc.Major.K
|
| 622 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 623 |
+
else sm100_desc.Major.MN,
|
| 624 |
+
)
|
| 625 |
+
)
|
| 626 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 627 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 628 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 629 |
+
else:
|
| 630 |
+
smem_desc_base_a = None
|
| 631 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 632 |
+
smem_desc_base_b: int = const_expr(
|
| 633 |
+
sm100_desc.make_smem_desc_base(
|
| 634 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 635 |
+
sB_swizzle,
|
| 636 |
+
sm100_desc.Major.K
|
| 637 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 638 |
+
else sm100_desc.Major.MN,
|
| 639 |
+
)
|
| 640 |
+
)
|
| 641 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 642 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 643 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 644 |
+
mask = [Int32(0)] * 4
|
| 645 |
+
|
| 646 |
+
if const_expr(not is_ts):
|
| 647 |
+
offset_a = [
|
| 648 |
+
(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
|
| 649 |
+
for k in range(cute.size(tCrA.shape[2]))
|
| 650 |
+
]
|
| 651 |
+
else:
|
| 652 |
+
offset_a = [
|
| 653 |
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
| 654 |
+
for k in range(cute.size(tCrA.shape[2]))
|
| 655 |
+
]
|
| 656 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
| 657 |
+
offset_b = [
|
| 658 |
+
(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
|
| 659 |
+
for k in range(cute.size(tCrB.shape[2]))
|
| 660 |
+
]
|
| 661 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
| 662 |
+
|
| 663 |
+
if const_expr(not is_ts):
|
| 664 |
+
# smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
|
| 665 |
+
smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
|
| 666 |
+
else:
|
| 667 |
+
smem_desc_start_a_lo = None
|
| 668 |
+
# smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
|
| 669 |
+
smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
|
| 670 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 671 |
+
if const_expr(not is_ts):
|
| 672 |
+
llvm.inline_asm(
|
| 673 |
+
None,
|
| 674 |
+
[
|
| 675 |
+
# acc.iterator.toint().ir_value(),
|
| 676 |
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 677 |
+
Int32(sA_base_addr_for_desc).ir_value(),
|
| 678 |
+
Int32(sA_stage).ir_value(),
|
| 679 |
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 680 |
+
Int32(sB_base_addr_for_desc).ir_value(),
|
| 681 |
+
Int32(sB_stage).ir_value(),
|
| 682 |
+
Int32(not zero_init).ir_value(),
|
| 683 |
+
mask[0].ir_value(),
|
| 684 |
+
mask[1].ir_value(),
|
| 685 |
+
mask[2].ir_value(),
|
| 686 |
+
mask[3].ir_value(),
|
| 687 |
+
],
|
| 688 |
+
"{\n\t"
|
| 689 |
+
".reg .pred leader_thread;\n\t"
|
| 690 |
+
".reg .pred p;\n\t"
|
| 691 |
+
".reg .b32 idesc;\n\t"
|
| 692 |
+
".reg .b32 tmem_acc;\n\t"
|
| 693 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 694 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 695 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 696 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 697 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 698 |
+
f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 699 |
+
# "mov.b32 smem_desc_a_lo, $0;\n\t"
|
| 700 |
+
# f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
|
| 701 |
+
f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
|
| 702 |
+
# "mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 703 |
+
f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
|
| 704 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 705 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 706 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 707 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 708 |
+
"setp.ne.b32 p, $4, 0;\n\t"
|
| 709 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
|
| 710 |
+
+ "".join(
|
| 711 |
+
(
|
| 712 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 713 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 714 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 715 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 716 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
|
| 717 |
+
)
|
| 718 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 719 |
+
)
|
| 720 |
+
+ "}\n",
|
| 721 |
+
"r,r,r,r,r,r,r,r,r",
|
| 722 |
+
has_side_effects=True,
|
| 723 |
+
is_align_stack=False,
|
| 724 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 725 |
+
)
|
| 726 |
+
else:
|
| 727 |
+
llvm.inline_asm(
|
| 728 |
+
None,
|
| 729 |
+
[
|
| 730 |
+
# acc.iterator.toint().ir_value(),
|
| 731 |
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 732 |
+
Int32(smem_desc_start_b_lo).ir_value(),
|
| 733 |
+
Int32(not zero_init).ir_value(),
|
| 734 |
+
mask[0].ir_value(),
|
| 735 |
+
mask[1].ir_value(),
|
| 736 |
+
mask[2].ir_value(),
|
| 737 |
+
mask[3].ir_value(),
|
| 738 |
+
],
|
| 739 |
+
"{\n\t"
|
| 740 |
+
".reg .pred leader_thread;\n\t"
|
| 741 |
+
".reg .pred p;\n\t"
|
| 742 |
+
".reg .b32 idesc;\n\t"
|
| 743 |
+
".reg .b32 tmem_a;\n\t"
|
| 744 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 745 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 746 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 747 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 748 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 749 |
+
f"mov.b32 tmem_a, $1;\n\t"
|
| 750 |
+
f"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 751 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 752 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 753 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 754 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
|
| 755 |
+
+ "".join(
|
| 756 |
+
(
|
| 757 |
+
f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 758 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 759 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 760 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
|
| 761 |
+
)
|
| 762 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 763 |
+
)
|
| 764 |
+
+ "}\n",
|
| 765 |
+
"r,r,r,r,r,r,r,r",
|
| 766 |
+
has_side_effects=True,
|
| 767 |
+
is_align_stack=False,
|
| 768 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
@cute.jit
|
| 773 |
+
def gemm_ptx_precomputed(
|
| 774 |
+
acc_tmem_addr: Int32,
|
| 775 |
+
smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
|
| 776 |
+
smem_desc_start_b: Int32,
|
| 777 |
+
idesc: int,
|
| 778 |
+
smem_desc_base_a: Optional[int],
|
| 779 |
+
smem_desc_base_b: int,
|
| 780 |
+
tCrA_layout: cute.Layout,
|
| 781 |
+
tCrB_layout: cute.Layout,
|
| 782 |
+
mbar_ptr: Optional[cutlass.Pointer] = None,
|
| 783 |
+
mbar_phase: Optional[Int32] = None,
|
| 784 |
+
zero_init: bool | Boolean = False,
|
| 785 |
+
cta_group: int = 1,
|
| 786 |
+
) -> None:
|
| 787 |
+
# acc_tmem_addr += acc_offset
|
| 788 |
+
is_ts = const_expr(smem_desc_base_a is None)
|
| 789 |
+
num_k_tile = cute.size(tCrA_layout.shape[2])
|
| 790 |
+
if const_expr(not is_ts):
|
| 791 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 792 |
+
else:
|
| 793 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 794 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 795 |
+
|
| 796 |
+
tCrA_layout = (
|
| 797 |
+
tCrA_layout
|
| 798 |
+
if const_expr(not is_ts)
|
| 799 |
+
# else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
|
| 800 |
+
# currently hard-coding the width to 16
|
| 801 |
+
else cute.recast_layout(32, 16, tCrA_layout)
|
| 802 |
+
)
|
| 803 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
|
| 804 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)]
|
| 805 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
|
| 806 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)]
|
| 807 |
+
|
| 808 |
+
smem_desc_start_a_lo = None
|
| 809 |
+
if const_expr(not is_ts):
|
| 810 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
|
| 811 |
+
# smem_desc_start_a_lo = smem_desc_start_a
|
| 812 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
|
| 813 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 814 |
+
if const_expr(not is_ts):
|
| 815 |
+
assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
|
| 816 |
+
llvm.inline_asm(
|
| 817 |
+
None,
|
| 818 |
+
[
|
| 819 |
+
# acc.iterator.toint().ir_value(),
|
| 820 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 821 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 822 |
+
Int32(not zero_init).ir_value(),
|
| 823 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 824 |
+
],
|
| 825 |
+
"{\n\t"
|
| 826 |
+
".reg .pred leader_thread;\n\t"
|
| 827 |
+
".reg .pred p;\n\t"
|
| 828 |
+
".reg .b32 idesc;\n\t"
|
| 829 |
+
".reg .b32 tmem_acc;\n\t"
|
| 830 |
+
".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
|
| 831 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 832 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 833 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 834 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 835 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 836 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 837 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 838 |
+
"mov.b32 smem_desc_a_lo_start, $0;\n\t"
|
| 839 |
+
"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 840 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 841 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 842 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
|
| 843 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 844 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 845 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 846 |
+
+ "".join(
|
| 847 |
+
(
|
| 848 |
+
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 849 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 850 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
|
| 851 |
+
f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 852 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 853 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 854 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 855 |
+
)
|
| 856 |
+
for k in range(1, num_k_tile)
|
| 857 |
+
)
|
| 858 |
+
+ "}\n",
|
| 859 |
+
# "r,r,r",
|
| 860 |
+
"r,r,r,r",
|
| 861 |
+
has_side_effects=True,
|
| 862 |
+
is_align_stack=False,
|
| 863 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 864 |
+
)
|
| 865 |
+
else:
|
| 866 |
+
input_args = [
|
| 867 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(),
|
| 868 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 869 |
+
Int32(not zero_init).ir_value(),
|
| 870 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 871 |
+
]
|
| 872 |
+
if const_expr(mbar_ptr is not None):
|
| 873 |
+
assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
|
| 874 |
+
input_args.append(mbar_ptr.toint().ir_value())
|
| 875 |
+
input_args.append(Int32(mbar_phase).ir_value())
|
| 876 |
+
mbar_wait_str = (
|
| 877 |
+
".reg .pred P1; \n\t"
|
| 878 |
+
"LAB_WAIT: \n\t"
|
| 879 |
+
"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
|
| 880 |
+
"@P1 bra DONE; \n\t"
|
| 881 |
+
"bra LAB_WAIT; \n\t"
|
| 882 |
+
"DONE: \n\t"
|
| 883 |
+
)
|
| 884 |
+
else:
|
| 885 |
+
mbar_wait_str = ""
|
| 886 |
+
llvm.inline_asm(
|
| 887 |
+
None,
|
| 888 |
+
# [
|
| 889 |
+
# # acc.iterator.toint().ir_value(),
|
| 890 |
+
# Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(),
|
| 891 |
+
# Int32(smem_desc_start_b_lo).ir_value(),
|
| 892 |
+
# Int32(not zero_init).ir_value(),
|
| 893 |
+
# ],
|
| 894 |
+
input_args,
|
| 895 |
+
"{\n\t"
|
| 896 |
+
".reg .pred leader_thread;\n\t"
|
| 897 |
+
".reg .pred p;\n\t"
|
| 898 |
+
".reg .b32 idesc;\n\t"
|
| 899 |
+
".reg .b32 tmem_acc;\n\t"
|
| 900 |
+
".reg .b32 tmem_a;\n\t"
|
| 901 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 902 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 903 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 904 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 905 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 906 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 907 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 908 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 909 |
+
f"mov.b32 tmem_a, $0;\n\t"
|
| 910 |
+
f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 911 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 912 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 913 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 914 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 915 |
+
+ "".join(
|
| 916 |
+
(
|
| 917 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 918 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 919 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 920 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 921 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 922 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 923 |
+
)
|
| 924 |
+
for k in range(
|
| 925 |
+
1,
|
| 926 |
+
num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3,
|
| 927 |
+
)
|
| 928 |
+
)
|
| 929 |
+
+ mbar_wait_str
|
| 930 |
+
+ (
|
| 931 |
+
"".join(
|
| 932 |
+
(
|
| 933 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 934 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 935 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 936 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 937 |
+
)
|
| 938 |
+
for k in range(num_k_tile // 4 * 3, num_k_tile)
|
| 939 |
+
)
|
| 940 |
+
if const_expr(mbar_ptr is not None)
|
| 941 |
+
else ""
|
| 942 |
+
)
|
| 943 |
+
+ "}\n",
|
| 944 |
+
"r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
|
| 945 |
+
has_side_effects=True,
|
| 946 |
+
is_align_stack=False,
|
| 947 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
@cute.jit
|
| 952 |
+
def declare_ptx_smem_desc(
|
| 953 |
+
smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
|
| 954 |
+
smem_desc_base_a: Optional[int],
|
| 955 |
+
tCrA_layout: cute.Layout,
|
| 956 |
+
var_name_prefix: str = "smem_desc",
|
| 957 |
+
) -> None:
|
| 958 |
+
is_ts = const_expr(smem_desc_base_a is None)
|
| 959 |
+
num_k_tile = cute.size(tCrA_layout.shape[2])
|
| 960 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 961 |
+
if const_expr(not is_ts):
|
| 962 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 963 |
+
tCrA_layout = (
|
| 964 |
+
tCrA_layout
|
| 965 |
+
if const_expr(not is_ts)
|
| 966 |
+
# else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
|
| 967 |
+
# currently hard-coding the width to 16
|
| 968 |
+
else cute.recast_layout(32, 16, tCrA_layout)
|
| 969 |
+
)
|
| 970 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
|
| 971 |
+
smem_desc_start_a_lo = None
|
| 972 |
+
if const_expr(not is_ts):
|
| 973 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
|
| 974 |
+
if const_expr(not is_ts):
|
| 975 |
+
llvm.inline_asm(
|
| 976 |
+
None,
|
| 977 |
+
[Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()],
|
| 978 |
+
f".reg .b32 {var_name_prefix}_lo;\n\t"
|
| 979 |
+
f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t"
|
| 980 |
+
f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t"
|
| 981 |
+
+ "".join(
|
| 982 |
+
(
|
| 983 |
+
f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t"
|
| 984 |
+
f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t"
|
| 985 |
+
)
|
| 986 |
+
for k in range(1, num_k_tile)
|
| 987 |
+
),
|
| 988 |
+
"r",
|
| 989 |
+
has_side_effects=True,
|
| 990 |
+
is_align_stack=False,
|
| 991 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
@cute.jit
|
| 996 |
+
def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None:
|
| 997 |
+
idesc = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 998 |
+
llvm.inline_asm(
|
| 999 |
+
None,
|
| 1000 |
+
[],
|
| 1001 |
+
f".reg .b32 {var_name};\n\t" # noqa
|
| 1002 |
+
f"mov.b32 {var_name}, {hex(idesc)};\n\t",
|
| 1003 |
+
constraints="",
|
| 1004 |
+
has_side_effects=True,
|
| 1005 |
+
is_align_stack=False,
|
| 1006 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
@cute.jit
|
| 1011 |
+
def gemm_ptx_precomputed_varname(
|
| 1012 |
+
acc_tmem_addr: Int32,
|
| 1013 |
+
smem_desc_start_b: Int32,
|
| 1014 |
+
# idesc: int,
|
| 1015 |
+
smem_desc_base_b: int,
|
| 1016 |
+
tCrB_layout: cute.Layout,
|
| 1017 |
+
smem_var_name_prefix: str,
|
| 1018 |
+
idesc_var_name: str,
|
| 1019 |
+
smem_offset: int,
|
| 1020 |
+
zero_init: bool | Boolean = False,
|
| 1021 |
+
cta_group: int = 1,
|
| 1022 |
+
) -> None:
|
| 1023 |
+
is_ts = False
|
| 1024 |
+
num_k_tile = cute.size(tCrB_layout.shape[2])
|
| 1025 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 1026 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
|
| 1027 |
+
|
| 1028 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
|
| 1029 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 1030 |
+
if const_expr(not is_ts):
|
| 1031 |
+
llvm.inline_asm(
|
| 1032 |
+
None,
|
| 1033 |
+
[
|
| 1034 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 1035 |
+
Int32(not zero_init).ir_value(),
|
| 1036 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 1037 |
+
],
|
| 1038 |
+
"{\n\t"
|
| 1039 |
+
".reg .pred leader_thread;\n\t"
|
| 1040 |
+
".reg .pred p;\n\t"
|
| 1041 |
+
# ".reg .b32 idesc;\n\t"
|
| 1042 |
+
".reg .b32 tmem_acc;\n\t"
|
| 1043 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 1044 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 1045 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 1046 |
+
# ".reg .b64 smem_desc_b;\n\t"
|
| 1047 |
+
f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t"
|
| 1048 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 1049 |
+
# f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 1050 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 1051 |
+
f"mov.b32 tmem_acc, $2;\n\t"
|
| 1052 |
+
"mov.b32 smem_desc_b_lo_start, $0;\n\t"
|
| 1053 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 1054 |
+
f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t"
|
| 1055 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1056 |
+
f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1057 |
+
f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 1058 |
+
+ "".join(
|
| 1059 |
+
(
|
| 1060 |
+
f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
|
| 1061 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1062 |
+
f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 1063 |
+
f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1064 |
+
f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 1065 |
+
)
|
| 1066 |
+
for k in range(1, num_k_tile)
|
| 1067 |
+
)
|
| 1068 |
+
+ "setp.ne.b32 p, $1, 0;\n\t"
|
| 1069 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t"
|
| 1070 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t"
|
| 1071 |
+
+ "".join(
|
| 1072 |
+
(
|
| 1073 |
+
# f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
|
| 1074 |
+
# f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1075 |
+
# f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 1076 |
+
# f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1077 |
+
# f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 1078 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t"
|
| 1079 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t"
|
| 1080 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t"
|
| 1081 |
+
)
|
| 1082 |
+
for k in range(1, num_k_tile)
|
| 1083 |
+
)
|
| 1084 |
+
+ "}\n",
|
| 1085 |
+
"r,r,r",
|
| 1086 |
+
has_side_effects=True,
|
| 1087 |
+
is_align_stack=False,
|
| 1088 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1089 |
+
)
|
build/torch-cuda/block_info.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
from typing import Tuple, Optional
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import cutlass
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
from cutlass import Int32, const_expr
|
| 8 |
+
|
| 9 |
+
from .seqlen_info import SeqlenInfoQK
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class BlockInfo:
|
| 14 |
+
tile_m: cutlass.Constexpr[int]
|
| 15 |
+
tile_n: cutlass.Constexpr[int]
|
| 16 |
+
is_causal: cutlass.Constexpr[bool]
|
| 17 |
+
is_local: cutlass.Constexpr[bool] = False
|
| 18 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 19 |
+
window_size_left: Optional[Int32] = None
|
| 20 |
+
window_size_right: Optional[Int32] = None
|
| 21 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 22 |
+
|
| 23 |
+
@cute.jit
|
| 24 |
+
def get_n_block_min_max(
|
| 25 |
+
self,
|
| 26 |
+
seqlen_info: SeqlenInfoQK,
|
| 27 |
+
m_block: Int32,
|
| 28 |
+
split_idx: cutlass.Int32 = 0,
|
| 29 |
+
num_splits: cutlass.Int32 = 1,
|
| 30 |
+
) -> Tuple[Int32, Int32]:
|
| 31 |
+
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
|
| 32 |
+
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
| 33 |
+
m_idx_max = (m_block + 1) * self.tile_m
|
| 34 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 35 |
+
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
|
| 36 |
+
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
| 37 |
+
n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right
|
| 38 |
+
n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
|
| 39 |
+
n_block_min = 0
|
| 40 |
+
if const_expr(self.is_local and self.window_size_left is not None):
|
| 41 |
+
m_idx_min = m_block * self.tile_m
|
| 42 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 43 |
+
m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
|
| 44 |
+
n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
| 45 |
+
n_idx_left = n_idx - self.window_size_left
|
| 46 |
+
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
|
| 47 |
+
if cutlass.const_expr(self.is_split_kv):
|
| 48 |
+
num_n_blocks_per_split = (
|
| 49 |
+
cutlass.Int32(0)
|
| 50 |
+
if n_block_max <= n_block_min
|
| 51 |
+
else (n_block_max - n_block_min + num_splits - 1) // num_splits
|
| 52 |
+
)
|
| 53 |
+
n_block_min = n_block_min + split_idx * num_n_blocks_per_split
|
| 54 |
+
n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
|
| 55 |
+
return n_block_min, n_block_max
|
| 56 |
+
|
| 57 |
+
@cute.jit
|
| 58 |
+
def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
|
| 59 |
+
m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
|
| 60 |
+
m_block_min = 0
|
| 61 |
+
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
| 62 |
+
n_idx_min = n_block * self.tile_n
|
| 63 |
+
m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
|
| 64 |
+
m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right
|
| 65 |
+
m_block_min = max(m_block_min, m_idx_right // self.tile_m)
|
| 66 |
+
if const_expr(self.is_local and self.window_size_left is not None):
|
| 67 |
+
n_idx_max = (n_block + 1) * self.tile_n
|
| 68 |
+
m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
|
| 69 |
+
m_idx_left = m_idx + self.window_size_left
|
| 70 |
+
m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
|
| 71 |
+
return m_block_min, m_block_max
|
| 72 |
+
|
| 73 |
+
@cute.jit
|
| 74 |
+
def get_n_block_min_causal_local_mask(
|
| 75 |
+
self,
|
| 76 |
+
seqlen_info: SeqlenInfoQK,
|
| 77 |
+
m_block: Int32,
|
| 78 |
+
n_block_min: Int32,
|
| 79 |
+
) -> Int32:
|
| 80 |
+
"""If we have separate iterations with causal or local masking at the start, where do we stop"""
|
| 81 |
+
m_idx_min = m_block * self.tile_m
|
| 82 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 83 |
+
m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
|
| 84 |
+
n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
| 85 |
+
n_idx_right = (
|
| 86 |
+
n_idx
|
| 87 |
+
if const_expr(not self.is_local or self.window_size_right is None)
|
| 88 |
+
else n_idx + self.window_size_right
|
| 89 |
+
)
|
| 90 |
+
return cutlass.max(n_block_min, n_idx_right // self.tile_n)
|
| 91 |
+
|
| 92 |
+
@cute.jit
|
| 93 |
+
def get_n_block_min_before_local_mask(
|
| 94 |
+
self,
|
| 95 |
+
seqlen_info: SeqlenInfoQK,
|
| 96 |
+
m_block: Int32,
|
| 97 |
+
n_block_min: Int32,
|
| 98 |
+
) -> Int32:
|
| 99 |
+
"""If we have separate iterations with local masking at the end, where do we stop the non-masked iterations"""
|
| 100 |
+
if const_expr(not self.is_local or self.window_size_left is None):
|
| 101 |
+
return n_block_min
|
| 102 |
+
else:
|
| 103 |
+
m_idx_max = (m_block + 1) * self.tile_m
|
| 104 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 105 |
+
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
|
| 106 |
+
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
| 107 |
+
n_idx_left = n_idx - self.window_size_left
|
| 108 |
+
return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))
|
build/torch-cuda/block_sparse_utils.py
ADDED
|
@@ -0,0 +1,1476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Block-sparse runtime utilities for CUTE DSL kernels.
|
| 3 |
+
|
| 4 |
+
This module contains runtime execution functions for block-sparse attention kernels.
|
| 5 |
+
These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
from functools import partial
|
| 10 |
+
import math
|
| 11 |
+
import cutlass
|
| 12 |
+
import cutlass.cute as cute
|
| 13 |
+
from cutlass import Float32, Int32, const_expr
|
| 14 |
+
|
| 15 |
+
from .quack import copy_utils
|
| 16 |
+
|
| 17 |
+
# Import data structures from block_sparsity
|
| 18 |
+
from .block_sparsity import BlockSparseTensors
|
| 19 |
+
from .named_barrier import NamedBarrierBwd
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# NOTE [SM100 block-sparse empty tiles: mbarrier contract]
|
| 23 |
+
#
|
| 24 |
+
# For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active
|
| 25 |
+
# KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so
|
| 26 |
+
# the softmax warp-group has no row stats to publish.
|
| 27 |
+
#
|
| 28 |
+
# The correction warp-group seeds fully-masked-row stats and runs the usual correction
|
| 29 |
+
# epilogue so output/LSE have well-defined values. Both warp-groups must still perform
|
| 30 |
+
# the softmax<->correction mbarrier handshake so phases advance correctly across
|
| 31 |
+
# empty->empty and empty->non-empty tile sequences.
|
| 32 |
+
#
|
| 33 |
+
# In the no-sink case, this corresponds to the usual fully-masked-row convention:
|
| 34 |
+
# output is zero and LSE is -inf.
|
| 35 |
+
#
|
| 36 |
+
# Barrier contract (each is `mbar_ptr + <offset> + stage`):
|
| 37 |
+
#
|
| 38 |
+
# Producer/consumer pairs:
|
| 39 |
+
# - `mbar_softmax_corr_full` : softmax arrive -> correction wait
|
| 40 |
+
# - `mbar_softmax_corr_empty` : correction arrive -> softmax wait
|
| 41 |
+
# - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait
|
| 42 |
+
# - `mbar_P_full_2` : softmax arrive -> MMA wait
|
| 43 |
+
# - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate)
|
| 44 |
+
#
|
| 45 |
+
# Empty tile (`total_block_cnt == 0`):
|
| 46 |
+
# - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`).
|
| 47 |
+
# It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal.
|
| 48 |
+
# At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty`
|
| 49 |
+
# before each tile (when block-sparse) to drain a prior correction arrival and keep
|
| 50 |
+
# phases aligned across non-empty -> empty transitions.
|
| 51 |
+
# - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`,
|
| 52 |
+
# and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable).
|
| 53 |
+
# - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction
|
| 54 |
+
# (and correction<->epilogue) handshakes advance phases.
|
| 55 |
+
#
|
| 56 |
+
# Non-empty tile:
|
| 57 |
+
# - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to
|
| 58 |
+
# publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases;
|
| 59 |
+
# arrives `mbar_P_full_*` when P is stored.
|
| 60 |
+
# - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty`
|
| 61 |
+
# to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed.
|
| 62 |
+
#
|
| 63 |
+
# Backward (SM100):
|
| 64 |
+
# - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute.
|
| 65 |
+
# - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles
|
| 66 |
+
# skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward).
|
| 67 |
+
# - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros
|
| 68 |
+
# even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`).
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@cute.jit
|
| 72 |
+
def load_block_list(
|
| 73 |
+
block_indices: cute.Tensor,
|
| 74 |
+
block_count,
|
| 75 |
+
load_q_with_first: cutlass.Constexpr,
|
| 76 |
+
first_block_preloaded: cutlass.Constexpr,
|
| 77 |
+
kv_producer_state,
|
| 78 |
+
load_Q,
|
| 79 |
+
load_K,
|
| 80 |
+
load_V,
|
| 81 |
+
pipeline_k,
|
| 82 |
+
pipeline_v,
|
| 83 |
+
use_tma_q: cutlass.Constexpr,
|
| 84 |
+
tma_q_bytes: cutlass.Constexpr,
|
| 85 |
+
intra_wg_overlap: cutlass.Constexpr,
|
| 86 |
+
):
|
| 87 |
+
"""Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
|
| 88 |
+
for the intra_wg_overlap case, we overlap the loads of K and V. And this
|
| 89 |
+
means we need to pipeline the last V load from the partial block case,
|
| 90 |
+
with the loads for the full blocks. Set first_block_preloaded when the
|
| 91 |
+
caller has already issued the first K load for the list.
|
| 92 |
+
|
| 93 |
+
Note:
|
| 94 |
+
we iterate along the block_n indices in reverse.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Updated kv_producer_state after processing the block list.
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
if block_count > 0:
|
| 101 |
+
if const_expr(not intra_wg_overlap):
|
| 102 |
+
# Peel first iteration: the first block may need to load Q alongside K,
|
| 103 |
+
# Parameters are already Constexpr, so no need to wrap in const_expr()
|
| 104 |
+
n_block_first = block_indices[block_count - 1]
|
| 105 |
+
extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
| 106 |
+
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
| 107 |
+
|
| 108 |
+
if const_expr(load_q_with_first and use_tma_q):
|
| 109 |
+
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
| 110 |
+
|
| 111 |
+
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 112 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 113 |
+
load_V(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 114 |
+
kv_producer_state.advance()
|
| 115 |
+
|
| 116 |
+
for offset in cutlass.range(1, block_count):
|
| 117 |
+
n_block = block_indices[block_count - 1 - offset]
|
| 118 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 119 |
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
| 120 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 121 |
+
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
| 122 |
+
kv_producer_state.advance()
|
| 123 |
+
else:
|
| 124 |
+
n_block_first = block_indices[block_count - 1]
|
| 125 |
+
if const_expr(not first_block_preloaded):
|
| 126 |
+
extra_tx = (
|
| 127 |
+
tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
| 128 |
+
)
|
| 129 |
+
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
| 130 |
+
|
| 131 |
+
if const_expr(load_q_with_first and use_tma_q):
|
| 132 |
+
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
| 133 |
+
|
| 134 |
+
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
| 135 |
+
|
| 136 |
+
for idx in cutlass.range(block_count - 1, unroll=1):
|
| 137 |
+
n_block_prev = block_indices[block_count - 1 - idx]
|
| 138 |
+
n_block = block_indices[block_count - 2 - idx]
|
| 139 |
+
kv_producer_state_prev = kv_producer_state.clone()
|
| 140 |
+
kv_producer_state.advance()
|
| 141 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 142 |
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
| 143 |
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
| 144 |
+
load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
|
| 145 |
+
|
| 146 |
+
return kv_producer_state
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@cute.jit
|
| 150 |
+
def finish_overlap_v_load(
|
| 151 |
+
block_indices: cute.Tensor,
|
| 152 |
+
block_count,
|
| 153 |
+
load_V,
|
| 154 |
+
pipeline_v,
|
| 155 |
+
kv_producer_state,
|
| 156 |
+
):
|
| 157 |
+
"""Load the final V block after overlapped K/V loads."""
|
| 158 |
+
if block_count > 0:
|
| 159 |
+
n_block_last = block_indices[0]
|
| 160 |
+
pipeline_v.producer_acquire(kv_producer_state)
|
| 161 |
+
load_V(src_idx=n_block_last, producer_state=kv_producer_state)
|
| 162 |
+
kv_producer_state.advance()
|
| 163 |
+
|
| 164 |
+
return kv_producer_state
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@cute.jit
|
| 168 |
+
def sparse_tensor_m_block(
|
| 169 |
+
m_block,
|
| 170 |
+
qhead_per_kvhead: cutlass.Constexpr[int],
|
| 171 |
+
q_subtile_factor: cutlass.Constexpr[int],
|
| 172 |
+
):
|
| 173 |
+
"""Map packed m_block indices to block-sparse tensor indices."""
|
| 174 |
+
block = m_block
|
| 175 |
+
if const_expr(qhead_per_kvhead != 1):
|
| 176 |
+
block = block // qhead_per_kvhead
|
| 177 |
+
if const_expr(q_subtile_factor != 1):
|
| 178 |
+
block = block // q_subtile_factor
|
| 179 |
+
return block
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@cute.jit
|
| 183 |
+
def produce_block_sparse_loads(
|
| 184 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 185 |
+
batch_idx,
|
| 186 |
+
head_idx,
|
| 187 |
+
m_block,
|
| 188 |
+
kv_producer_state,
|
| 189 |
+
load_Q,
|
| 190 |
+
load_K,
|
| 191 |
+
load_V,
|
| 192 |
+
pipeline_k,
|
| 193 |
+
pipeline_v,
|
| 194 |
+
use_tma_q: cutlass.Constexpr,
|
| 195 |
+
tma_q_bytes: cutlass.Constexpr,
|
| 196 |
+
intra_wg_overlap: cutlass.Constexpr,
|
| 197 |
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 198 |
+
q_subtile_factor: cutlass.Constexpr[int] = 1,
|
| 199 |
+
):
|
| 200 |
+
"""Iterate over the mask and full block lists for a single tile.
|
| 201 |
+
|
| 202 |
+
The masked (partial) list may leave the last V load pending when intra-warp-group
|
| 203 |
+
overlap is enabled. The first full block must consume that pending V while
|
| 204 |
+
issuing its own K load on the next pipeline stage.
|
| 205 |
+
|
| 206 |
+
In the intra-wg-overlap path, the last masked block leaves its V copy in flight
|
| 207 |
+
while we advance the producer state to start the next full K. Either the full list
|
| 208 |
+
overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
|
| 212 |
+
must be converted to unpacked for sparse tensor indexing.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
| 216 |
+
|
| 217 |
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
|
| 218 |
+
|
| 219 |
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 220 |
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 221 |
+
|
| 222 |
+
if const_expr(full_block_cnt is not None):
|
| 223 |
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 224 |
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 225 |
+
else:
|
| 226 |
+
curr_full_block_cnt = Int32(0)
|
| 227 |
+
curr_full_block_idx = None
|
| 228 |
+
|
| 229 |
+
mask_empty = curr_mask_block_cnt == 0
|
| 230 |
+
full_empty = curr_full_block_cnt == 0
|
| 231 |
+
|
| 232 |
+
if mask_empty:
|
| 233 |
+
# No masked blocks: the full list owns the initial Q+K load.
|
| 234 |
+
kv_producer_state = load_block_list(
|
| 235 |
+
curr_full_block_idx,
|
| 236 |
+
curr_full_block_cnt,
|
| 237 |
+
load_q_with_first=True,
|
| 238 |
+
first_block_preloaded=False,
|
| 239 |
+
kv_producer_state=kv_producer_state,
|
| 240 |
+
load_Q=load_Q,
|
| 241 |
+
load_K=load_K,
|
| 242 |
+
load_V=load_V,
|
| 243 |
+
pipeline_k=pipeline_k,
|
| 244 |
+
pipeline_v=pipeline_v,
|
| 245 |
+
use_tma_q=use_tma_q,
|
| 246 |
+
tma_q_bytes=tma_q_bytes,
|
| 247 |
+
intra_wg_overlap=intra_wg_overlap,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:
|
| 251 |
+
kv_producer_state = finish_overlap_v_load(
|
| 252 |
+
curr_full_block_idx,
|
| 253 |
+
curr_full_block_cnt,
|
| 254 |
+
load_V,
|
| 255 |
+
pipeline_v,
|
| 256 |
+
kv_producer_state,
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
# Masked blocks present: load Q together with the first masked K so consumers can
|
| 260 |
+
# start immediately. When overlap is disabled this fully drains the list.
|
| 261 |
+
kv_producer_state = load_block_list(
|
| 262 |
+
curr_mask_block_idx,
|
| 263 |
+
curr_mask_block_cnt,
|
| 264 |
+
load_q_with_first=True,
|
| 265 |
+
first_block_preloaded=False,
|
| 266 |
+
kv_producer_state=kv_producer_state,
|
| 267 |
+
load_Q=load_Q,
|
| 268 |
+
load_K=load_K,
|
| 269 |
+
load_V=load_V,
|
| 270 |
+
pipeline_k=pipeline_k,
|
| 271 |
+
pipeline_v=pipeline_v,
|
| 272 |
+
use_tma_q=use_tma_q,
|
| 273 |
+
tma_q_bytes=tma_q_bytes,
|
| 274 |
+
intra_wg_overlap=intra_wg_overlap,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if full_empty:
|
| 278 |
+
if const_expr(intra_wg_overlap):
|
| 279 |
+
kv_producer_state = finish_overlap_v_load(
|
| 280 |
+
curr_mask_block_idx,
|
| 281 |
+
curr_mask_block_cnt,
|
| 282 |
+
load_V,
|
| 283 |
+
pipeline_v,
|
| 284 |
+
kv_producer_state,
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
if const_expr(intra_wg_overlap):
|
| 288 |
+
# Bridge the masked list to the full list by overlapping the pending masked V
|
| 289 |
+
# with the first full K load.
|
| 290 |
+
n_block_mask_last = curr_mask_block_idx[0]
|
| 291 |
+
n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]
|
| 292 |
+
kv_producer_state_prev = kv_producer_state.clone()
|
| 293 |
+
kv_producer_state.advance()
|
| 294 |
+
pipeline_k.producer_acquire(kv_producer_state)
|
| 295 |
+
load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)
|
| 296 |
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
| 297 |
+
load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)
|
| 298 |
+
|
| 299 |
+
kv_producer_state = load_block_list(
|
| 300 |
+
curr_full_block_idx,
|
| 301 |
+
curr_full_block_cnt,
|
| 302 |
+
load_q_with_first=False,
|
| 303 |
+
first_block_preloaded=True,
|
| 304 |
+
kv_producer_state=kv_producer_state,
|
| 305 |
+
load_Q=load_Q,
|
| 306 |
+
load_K=load_K,
|
| 307 |
+
load_V=load_V,
|
| 308 |
+
pipeline_k=pipeline_k,
|
| 309 |
+
pipeline_v=pipeline_v,
|
| 310 |
+
use_tma_q=use_tma_q,
|
| 311 |
+
tma_q_bytes=tma_q_bytes,
|
| 312 |
+
intra_wg_overlap=intra_wg_overlap,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
kv_producer_state = finish_overlap_v_load(
|
| 316 |
+
curr_full_block_idx,
|
| 317 |
+
curr_full_block_cnt,
|
| 318 |
+
load_V,
|
| 319 |
+
pipeline_v,
|
| 320 |
+
kv_producer_state,
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
# Non-overlap path with both lists: run the full list normally (skipping the Q
|
| 324 |
+
# reload because the masked list already issued it).
|
| 325 |
+
kv_producer_state = load_block_list(
|
| 326 |
+
curr_full_block_idx,
|
| 327 |
+
curr_full_block_cnt,
|
| 328 |
+
load_q_with_first=False,
|
| 329 |
+
first_block_preloaded=False,
|
| 330 |
+
kv_producer_state=kv_producer_state,
|
| 331 |
+
load_Q=load_Q,
|
| 332 |
+
load_K=load_K,
|
| 333 |
+
load_V=load_V,
|
| 334 |
+
pipeline_k=pipeline_k,
|
| 335 |
+
pipeline_v=pipeline_v,
|
| 336 |
+
use_tma_q=use_tma_q,
|
| 337 |
+
tma_q_bytes=tma_q_bytes,
|
| 338 |
+
intra_wg_overlap=intra_wg_overlap,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return kv_producer_state
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@cute.jit
|
| 345 |
+
def consume_block_sparse_loads(
|
| 346 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 347 |
+
batch_idx,
|
| 348 |
+
head_idx,
|
| 349 |
+
m_block,
|
| 350 |
+
seqlen,
|
| 351 |
+
kv_consumer_state,
|
| 352 |
+
mma_pv_fn,
|
| 353 |
+
mma_one_n_block,
|
| 354 |
+
process_first_half_block,
|
| 355 |
+
process_last_half_block,
|
| 356 |
+
mask_fn,
|
| 357 |
+
score_mod_fn,
|
| 358 |
+
O_should_accumulate,
|
| 359 |
+
mask_mod,
|
| 360 |
+
fastdiv_mods,
|
| 361 |
+
intra_wg_overlap: cutlass.Constexpr,
|
| 362 |
+
warp_scheduler_barrier_sync: Callable,
|
| 363 |
+
warp_scheduler_barrier_arrive: Callable,
|
| 364 |
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 365 |
+
q_subtile_factor: cutlass.Constexpr[int] = 1,
|
| 366 |
+
):
|
| 367 |
+
"""Consume the mask and full block lists for a single tile on the consumer side.
|
| 368 |
+
|
| 369 |
+
Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
|
| 370 |
+
the same sparse tensor indexing.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
|
| 374 |
+
must be converted to unpacked for sparse tensor indexing.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
| 378 |
+
|
| 379 |
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
|
| 380 |
+
|
| 381 |
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 382 |
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 383 |
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 384 |
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 385 |
+
|
| 386 |
+
processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0
|
| 387 |
+
|
| 388 |
+
if const_expr(not intra_wg_overlap):
|
| 389 |
+
if curr_mask_block_cnt > 0:
|
| 390 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
| 391 |
+
warp_scheduler_barrier_sync()
|
| 392 |
+
kv_consumer_state = mma_one_n_block(
|
| 393 |
+
kv_consumer_state,
|
| 394 |
+
n_block=mask_n_block,
|
| 395 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 396 |
+
mask_fn=partial(
|
| 397 |
+
mask_fn,
|
| 398 |
+
mask_mod=mask_mod,
|
| 399 |
+
mask_seqlen=True,
|
| 400 |
+
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
|
| 401 |
+
),
|
| 402 |
+
is_first_n_block=True,
|
| 403 |
+
)
|
| 404 |
+
O_should_accumulate = True
|
| 405 |
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
| 406 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
| 407 |
+
kv_consumer_state = mma_one_n_block(
|
| 408 |
+
kv_consumer_state,
|
| 409 |
+
n_block=mask_n_block,
|
| 410 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 411 |
+
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
|
| 412 |
+
is_first_n_block=False,
|
| 413 |
+
)
|
| 414 |
+
O_should_accumulate = True
|
| 415 |
+
if curr_full_block_cnt == 0:
|
| 416 |
+
warp_scheduler_barrier_arrive()
|
| 417 |
+
|
| 418 |
+
if curr_full_block_cnt > 0:
|
| 419 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
| 420 |
+
if curr_mask_block_cnt == 0:
|
| 421 |
+
warp_scheduler_barrier_sync()
|
| 422 |
+
kv_consumer_state = mma_one_n_block(
|
| 423 |
+
kv_consumer_state,
|
| 424 |
+
n_block=full_n_block,
|
| 425 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 426 |
+
mask_fn=partial(mask_fn, mask_seqlen=True),
|
| 427 |
+
is_first_n_block=True,
|
| 428 |
+
)
|
| 429 |
+
O_should_accumulate = True
|
| 430 |
+
for i in cutlass.range(1, curr_full_block_cnt):
|
| 431 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
| 432 |
+
kv_consumer_state = mma_one_n_block(
|
| 433 |
+
kv_consumer_state,
|
| 434 |
+
n_block=full_n_block,
|
| 435 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 436 |
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
| 437 |
+
is_first_n_block=False,
|
| 438 |
+
)
|
| 439 |
+
O_should_accumulate = True
|
| 440 |
+
else:
|
| 441 |
+
kv_consumer_state = mma_one_n_block(
|
| 442 |
+
kv_consumer_state,
|
| 443 |
+
n_block=full_n_block,
|
| 444 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 445 |
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
| 446 |
+
is_first_n_block=False,
|
| 447 |
+
)
|
| 448 |
+
O_should_accumulate = True
|
| 449 |
+
for i in cutlass.range(1, curr_full_block_cnt):
|
| 450 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
| 451 |
+
kv_consumer_state = mma_one_n_block(
|
| 452 |
+
kv_consumer_state,
|
| 453 |
+
n_block=full_n_block,
|
| 454 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 455 |
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
|
| 456 |
+
is_first_n_block=False,
|
| 457 |
+
)
|
| 458 |
+
O_should_accumulate = True
|
| 459 |
+
warp_scheduler_barrier_arrive()
|
| 460 |
+
else:
|
| 461 |
+
if curr_mask_block_cnt > 0:
|
| 462 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
| 463 |
+
kv_consumer_state = process_first_half_block(
|
| 464 |
+
n_block=mask_n_block,
|
| 465 |
+
seqlen=seqlen,
|
| 466 |
+
kv_consumer_state=kv_consumer_state,
|
| 467 |
+
mask_fn=partial(
|
| 468 |
+
mask_fn,
|
| 469 |
+
mask_mod=mask_mod,
|
| 470 |
+
mask_seqlen=True,
|
| 471 |
+
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
|
| 472 |
+
),
|
| 473 |
+
score_mod_fn=score_mod_fn,
|
| 474 |
+
is_first_block=True,
|
| 475 |
+
)
|
| 476 |
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
| 477 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
| 478 |
+
kv_consumer_state = mma_one_n_block(
|
| 479 |
+
kv_consumer_state,
|
| 480 |
+
n_block=mask_n_block,
|
| 481 |
+
seqlen=seqlen,
|
| 482 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 483 |
+
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
|
| 484 |
+
)
|
| 485 |
+
O_should_accumulate = True
|
| 486 |
+
|
| 487 |
+
if curr_full_block_cnt > 0:
|
| 488 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
| 489 |
+
if curr_mask_block_cnt == 0:
|
| 490 |
+
kv_consumer_state = process_first_half_block(
|
| 491 |
+
n_block=full_n_block,
|
| 492 |
+
seqlen=seqlen,
|
| 493 |
+
kv_consumer_state=kv_consumer_state,
|
| 494 |
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
| 495 |
+
score_mod_fn=score_mod_fn,
|
| 496 |
+
is_first_block=True,
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
kv_consumer_state = mma_one_n_block(
|
| 500 |
+
kv_consumer_state,
|
| 501 |
+
n_block=full_n_block,
|
| 502 |
+
seqlen=seqlen,
|
| 503 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 504 |
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
| 505 |
+
)
|
| 506 |
+
O_should_accumulate = True
|
| 507 |
+
for i in cutlass.range(1, curr_full_block_cnt):
|
| 508 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
| 509 |
+
kv_consumer_state = mma_one_n_block(
|
| 510 |
+
kv_consumer_state,
|
| 511 |
+
n_block=full_n_block,
|
| 512 |
+
seqlen=seqlen,
|
| 513 |
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
| 514 |
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
|
| 515 |
+
)
|
| 516 |
+
O_should_accumulate = True
|
| 517 |
+
|
| 518 |
+
if curr_mask_block_cnt + curr_full_block_cnt > 0:
|
| 519 |
+
kv_consumer_state = process_last_half_block(
|
| 520 |
+
kv_consumer_state=kv_consumer_state,
|
| 521 |
+
zero_init=not O_should_accumulate,
|
| 522 |
+
)
|
| 523 |
+
O_should_accumulate = True
|
| 524 |
+
|
| 525 |
+
return kv_consumer_state, O_should_accumulate, processed_any
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
@cute.jit
|
| 529 |
+
def load_block_list_sm100(
|
| 530 |
+
block_indices: cute.Tensor,
|
| 531 |
+
block_count,
|
| 532 |
+
load_q_with_first: cutlass.Constexpr,
|
| 533 |
+
q_stage: cutlass.Constexpr,
|
| 534 |
+
kv_producer_state,
|
| 535 |
+
load_Q,
|
| 536 |
+
load_K,
|
| 537 |
+
load_V,
|
| 538 |
+
pipeline_kv,
|
| 539 |
+
):
|
| 540 |
+
"""SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
|
| 541 |
+
if block_count > 0:
|
| 542 |
+
# First iteration: load Q alongside K if requested
|
| 543 |
+
n_block_first = block_indices[block_count - 1]
|
| 544 |
+
|
| 545 |
+
if const_expr(load_q_with_first):
|
| 546 |
+
# SM100 loads Q0 and optionally Q1
|
| 547 |
+
load_Q(block=0, stage=0)
|
| 548 |
+
if const_expr(q_stage == 2):
|
| 549 |
+
load_Q(block=1, stage=1)
|
| 550 |
+
|
| 551 |
+
# SM100 doesn't use producer_acquire for pipeline_kv in load path
|
| 552 |
+
# The pipeline barriers are handled inside load_KV
|
| 553 |
+
load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
|
| 554 |
+
kv_producer_state.advance()
|
| 555 |
+
load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
|
| 556 |
+
kv_producer_state.advance()
|
| 557 |
+
|
| 558 |
+
# Remaining blocks
|
| 559 |
+
for offset in cutlass.range(1, block_count):
|
| 560 |
+
n_block = block_indices[block_count - 1 - offset]
|
| 561 |
+
load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
|
| 562 |
+
kv_producer_state.advance()
|
| 563 |
+
load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
|
| 564 |
+
kv_producer_state.advance()
|
| 565 |
+
|
| 566 |
+
return kv_producer_state
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# SM100-specific tile processor using SM100 helpers
|
| 570 |
+
@cute.jit
|
| 571 |
+
def produce_block_sparse_loads_sm100(
|
| 572 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 573 |
+
batch_idx,
|
| 574 |
+
head_idx,
|
| 575 |
+
m_block,
|
| 576 |
+
kv_producer_state,
|
| 577 |
+
load_Q,
|
| 578 |
+
load_K,
|
| 579 |
+
load_V,
|
| 580 |
+
pipeline_kv,
|
| 581 |
+
q_stage: cutlass.Constexpr,
|
| 582 |
+
q_producer_phase: Int32,
|
| 583 |
+
qhead_per_kvhead: cutlass.Constexpr,
|
| 584 |
+
q_subtile_factor: cutlass.Constexpr,
|
| 585 |
+
):
|
| 586 |
+
"""SM100 entry point for sparse block iteration.
|
| 587 |
+
|
| 588 |
+
SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
|
| 589 |
+
simplified block processing that just calls producer_acquire without extras.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
m_block: which tile of m we are processing
|
| 593 |
+
qhead_per_kvhead: Constexpr pack factor
|
| 594 |
+
"""
|
| 595 |
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
|
| 596 |
+
|
| 597 |
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
| 598 |
+
|
| 599 |
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 600 |
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 601 |
+
|
| 602 |
+
if const_expr(full_block_cnt is not None):
|
| 603 |
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 604 |
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 605 |
+
else:
|
| 606 |
+
curr_full_block_cnt = Int32(0)
|
| 607 |
+
curr_full_block_idx = None
|
| 608 |
+
|
| 609 |
+
mask_empty = curr_mask_block_cnt == 0
|
| 610 |
+
full_empty = curr_full_block_cnt == 0
|
| 611 |
+
|
| 612 |
+
q_phase_flipped = False
|
| 613 |
+
|
| 614 |
+
if mask_empty:
|
| 615 |
+
# No masked blocks: process full list with Q loading
|
| 616 |
+
kv_producer_state = load_block_list_sm100(
|
| 617 |
+
curr_full_block_idx,
|
| 618 |
+
curr_full_block_cnt,
|
| 619 |
+
load_q_with_first=True,
|
| 620 |
+
q_stage=q_stage,
|
| 621 |
+
kv_producer_state=kv_producer_state,
|
| 622 |
+
load_Q=load_Q,
|
| 623 |
+
load_K=load_K,
|
| 624 |
+
load_V=load_V,
|
| 625 |
+
pipeline_kv=pipeline_kv,
|
| 626 |
+
)
|
| 627 |
+
q_phase_flipped = not full_empty
|
| 628 |
+
else:
|
| 629 |
+
# Process masked blocks with Q loading
|
| 630 |
+
kv_producer_state = load_block_list_sm100(
|
| 631 |
+
curr_mask_block_idx,
|
| 632 |
+
curr_mask_block_cnt,
|
| 633 |
+
load_q_with_first=True,
|
| 634 |
+
q_stage=q_stage,
|
| 635 |
+
kv_producer_state=kv_producer_state,
|
| 636 |
+
load_Q=load_Q,
|
| 637 |
+
load_K=load_K,
|
| 638 |
+
load_V=load_V,
|
| 639 |
+
pipeline_kv=pipeline_kv,
|
| 640 |
+
)
|
| 641 |
+
q_phase_flipped = True
|
| 642 |
+
|
| 643 |
+
if not full_empty:
|
| 644 |
+
# Process full blocks without Q loading
|
| 645 |
+
kv_producer_state = load_block_list_sm100(
|
| 646 |
+
curr_full_block_idx,
|
| 647 |
+
curr_full_block_cnt,
|
| 648 |
+
load_q_with_first=False,
|
| 649 |
+
q_stage=q_stage,
|
| 650 |
+
kv_producer_state=kv_producer_state,
|
| 651 |
+
load_Q=load_Q,
|
| 652 |
+
load_K=load_K,
|
| 653 |
+
load_V=load_V,
|
| 654 |
+
pipeline_kv=pipeline_kv,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
if q_phase_flipped:
|
| 658 |
+
q_producer_phase ^= 1
|
| 659 |
+
|
| 660 |
+
return kv_producer_state, q_producer_phase
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@cute.jit
|
| 664 |
+
def get_total_block_count(
|
| 665 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 666 |
+
batch_idx,
|
| 667 |
+
head_idx,
|
| 668 |
+
m_block,
|
| 669 |
+
qhead_per_kvhead: cutlass.Constexpr,
|
| 670 |
+
q_subtile_factor: cutlass.Constexpr,
|
| 671 |
+
):
|
| 672 |
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
|
| 673 |
+
|
| 674 |
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
| 675 |
+
if const_expr(full_block_cnt is not None):
|
| 676 |
+
return (
|
| 677 |
+
mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 678 |
+
+ full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 679 |
+
)
|
| 680 |
+
else:
|
| 681 |
+
return mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
@cute.jit
|
| 685 |
+
def handle_block_sparse_empty_tile_correction_sm100(
|
| 686 |
+
tidx: Int32,
|
| 687 |
+
q_stage: cutlass.Constexpr,
|
| 688 |
+
m_block_size: cutlass.Constexpr,
|
| 689 |
+
qhead_per_kvhead,
|
| 690 |
+
pack_gqa: cutlass.Constexpr,
|
| 691 |
+
is_split_kv: cutlass.Constexpr,
|
| 692 |
+
learnable_sink,
|
| 693 |
+
mLSE,
|
| 694 |
+
seqlen,
|
| 695 |
+
m_block: Int32,
|
| 696 |
+
head_idx: Int32,
|
| 697 |
+
batch_idx: Int32,
|
| 698 |
+
split_idx: Int32,
|
| 699 |
+
sScale: cute.Tensor,
|
| 700 |
+
stats: list,
|
| 701 |
+
correction_epilogue: Callable,
|
| 702 |
+
thr_mma_pv: cute.core.ThrMma,
|
| 703 |
+
tOtO: cute.Tensor,
|
| 704 |
+
sO: cute.Tensor,
|
| 705 |
+
pipeline_sm_stats: cutlass.pipeline.PipelineAsync,
|
| 706 |
+
sm_stats_barrier: cutlass.pipeline.NamedBarrier,
|
| 707 |
+
pipeline_o_epi: cutlass.pipeline.PipelineAsync,
|
| 708 |
+
sm_stats_consumer_phase: Int32,
|
| 709 |
+
o_corr_consumer_phase: Int32,
|
| 710 |
+
corr_epi_producer_phase: Int32,
|
| 711 |
+
softmax_scale_log2: Float32,
|
| 712 |
+
mO_cur: Optional[cute.Tensor] = None,
|
| 713 |
+
gO: Optional[cute.Tensor] = None,
|
| 714 |
+
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
|
| 715 |
+
):
|
| 716 |
+
"""Handle SM100 forward block-sparse tiles with no active KV blocks.
|
| 717 |
+
|
| 718 |
+
This path is taken when `total_block_cnt == 0`. The softmax warp-group still
|
| 719 |
+
arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction
|
| 720 |
+
warp-group can:
|
| 721 |
+
|
| 722 |
+
- seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE
|
| 723 |
+
- run `correction_epilogue` with `scale=0` so the output tile is written as zeros
|
| 724 |
+
(independent of any prior tmem contents)
|
| 725 |
+
- wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty`
|
| 726 |
+
(and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles
|
| 727 |
+
|
| 728 |
+
This helper intentionally does not touch `mbar_P_full_*` since no P is produced.
|
| 729 |
+
See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
|
| 730 |
+
"""
|
| 731 |
+
LOG2_E = Float32(math.log2(math.e))
|
| 732 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 733 |
+
|
| 734 |
+
for stage in cutlass.range_constexpr(q_stage):
|
| 735 |
+
row_sum_value = Float32(1.0)
|
| 736 |
+
row_max_value = (
|
| 737 |
+
-Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None
|
| 738 |
+
)
|
| 739 |
+
if const_expr(learnable_sink is not None):
|
| 740 |
+
sink_val = -Float32.inf
|
| 741 |
+
if const_expr(not pack_gqa):
|
| 742 |
+
sink_val = Float32(learnable_sink[head_idx])
|
| 743 |
+
elif tidx < m_block_size:
|
| 744 |
+
q_head_idx = (
|
| 745 |
+
(q_stage * m_block + stage) * m_block_size + tidx
|
| 746 |
+
) % qhead_per_kvhead + head_idx * qhead_per_kvhead
|
| 747 |
+
sink_val = Float32(learnable_sink[q_head_idx])
|
| 748 |
+
if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
|
| 749 |
+
if row_max_value == -Float32.inf:
|
| 750 |
+
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
|
| 751 |
+
row_sum_value = Float32(1.0)
|
| 752 |
+
else:
|
| 753 |
+
row_sum_value = row_sum_value + cute.math.exp2(
|
| 754 |
+
sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
|
| 755 |
+
)
|
| 756 |
+
if tidx < m_block_size:
|
| 757 |
+
scale_row_idx = tidx + stage * m_block_size
|
| 758 |
+
sScale[scale_row_idx] = row_sum_value
|
| 759 |
+
if const_expr(mLSE is not None or learnable_sink is not None):
|
| 760 |
+
sScale[scale_row_idx + q_stage * m_block_size] = row_max_value
|
| 761 |
+
acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
|
| 762 |
+
stats[stage] = (row_sum_value, row_max_value, acc_flag)
|
| 763 |
+
|
| 764 |
+
# See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
|
| 765 |
+
# pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase)
|
| 766 |
+
sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx)
|
| 767 |
+
pipeline_sm_stats.consumer_release_w_index(stage)
|
| 768 |
+
|
| 769 |
+
if const_expr(gmem_tiled_copy_O is None):
|
| 770 |
+
pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
|
| 771 |
+
correction_epilogue(
|
| 772 |
+
thr_mma_pv,
|
| 773 |
+
tOtO[None, None, None, stage],
|
| 774 |
+
tidx,
|
| 775 |
+
stage,
|
| 776 |
+
m_block,
|
| 777 |
+
seqlen.seqlen_q,
|
| 778 |
+
Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
|
| 779 |
+
sO[None, None, stage],
|
| 780 |
+
mO_cur,
|
| 781 |
+
gO[None, None, stage],
|
| 782 |
+
gmem_tiled_copy_O,
|
| 783 |
+
)
|
| 784 |
+
if const_expr(gmem_tiled_copy_O is None):
|
| 785 |
+
pipeline_o_epi.producer_commit_w_index(stage)
|
| 786 |
+
|
| 787 |
+
sm_stats_consumer_phase ^= 1
|
| 788 |
+
corr_epi_producer_phase ^= 1
|
| 789 |
+
|
| 790 |
+
return (
|
| 791 |
+
sm_stats_consumer_phase,
|
| 792 |
+
o_corr_consumer_phase,
|
| 793 |
+
corr_epi_producer_phase,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@cute.jit
|
| 798 |
+
def softmax_block_sparse_sm100(
|
| 799 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 800 |
+
batch_idx,
|
| 801 |
+
head_idx,
|
| 802 |
+
m_block,
|
| 803 |
+
softmax_step: Callable,
|
| 804 |
+
mask_fn: Callable,
|
| 805 |
+
mask_fn_none: Callable,
|
| 806 |
+
mma_si_consumer_phase: Int32,
|
| 807 |
+
si_corr_producer_phase: Int32,
|
| 808 |
+
s0_s1_sequence_phase: Int32,
|
| 809 |
+
pipeline_sm_stats: cutlass.pipeline.PipelineAsync,
|
| 810 |
+
sm_stats_barrier: cutlass.pipeline.NamedBarrier,
|
| 811 |
+
q_stage: cutlass.Constexpr,
|
| 812 |
+
stage_idx: Int32,
|
| 813 |
+
check_m_boundary: bool,
|
| 814 |
+
qhead_per_kvhead: cutlass.Constexpr,
|
| 815 |
+
q_subtile_factor: cutlass.Constexpr[int] = 1,
|
| 816 |
+
):
|
| 817 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 818 |
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
|
| 819 |
+
|
| 820 |
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
| 821 |
+
|
| 822 |
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 823 |
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 824 |
+
|
| 825 |
+
if const_expr(full_block_cnt is not None):
|
| 826 |
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
| 827 |
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
| 828 |
+
else:
|
| 829 |
+
curr_full_block_cnt = Int32(0)
|
| 830 |
+
curr_full_block_idx = None
|
| 831 |
+
|
| 832 |
+
total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
|
| 833 |
+
|
| 834 |
+
if total_block_cnt == 0:
|
| 835 |
+
# See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
|
| 836 |
+
# pipeline_sm_stats.producer_commit_w_index(stage_idx)
|
| 837 |
+
sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx)
|
| 838 |
+
else:
|
| 839 |
+
if curr_mask_block_cnt > 0:
|
| 840 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
| 841 |
+
(
|
| 842 |
+
mma_si_consumer_phase,
|
| 843 |
+
si_corr_producer_phase,
|
| 844 |
+
s0_s1_sequence_phase,
|
| 845 |
+
) = softmax_step(
|
| 846 |
+
mma_si_consumer_phase,
|
| 847 |
+
si_corr_producer_phase,
|
| 848 |
+
s0_s1_sequence_phase,
|
| 849 |
+
mask_n_block,
|
| 850 |
+
is_first=True,
|
| 851 |
+
mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
|
| 852 |
+
)
|
| 853 |
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
| 854 |
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
| 855 |
+
(
|
| 856 |
+
mma_si_consumer_phase,
|
| 857 |
+
si_corr_producer_phase,
|
| 858 |
+
s0_s1_sequence_phase,
|
| 859 |
+
) = softmax_step(
|
| 860 |
+
mma_si_consumer_phase,
|
| 861 |
+
si_corr_producer_phase,
|
| 862 |
+
s0_s1_sequence_phase,
|
| 863 |
+
mask_n_block,
|
| 864 |
+
mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
if curr_full_block_cnt > 0:
|
| 868 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
| 869 |
+
if curr_mask_block_cnt == 0:
|
| 870 |
+
(
|
| 871 |
+
mma_si_consumer_phase,
|
| 872 |
+
si_corr_producer_phase,
|
| 873 |
+
s0_s1_sequence_phase,
|
| 874 |
+
) = softmax_step(
|
| 875 |
+
mma_si_consumer_phase,
|
| 876 |
+
si_corr_producer_phase,
|
| 877 |
+
s0_s1_sequence_phase,
|
| 878 |
+
full_n_block,
|
| 879 |
+
is_first=True,
|
| 880 |
+
mask_fn=partial(
|
| 881 |
+
mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
|
| 882 |
+
),
|
| 883 |
+
)
|
| 884 |
+
else:
|
| 885 |
+
(
|
| 886 |
+
mma_si_consumer_phase,
|
| 887 |
+
si_corr_producer_phase,
|
| 888 |
+
s0_s1_sequence_phase,
|
| 889 |
+
) = softmax_step(
|
| 890 |
+
mma_si_consumer_phase,
|
| 891 |
+
si_corr_producer_phase,
|
| 892 |
+
s0_s1_sequence_phase,
|
| 893 |
+
full_n_block,
|
| 894 |
+
is_first=False,
|
| 895 |
+
mask_fn=partial(
|
| 896 |
+
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
|
| 897 |
+
),
|
| 898 |
+
)
|
| 899 |
+
for i in cutlass.range(1, curr_full_block_cnt):
|
| 900 |
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
| 901 |
+
(
|
| 902 |
+
mma_si_consumer_phase,
|
| 903 |
+
si_corr_producer_phase,
|
| 904 |
+
s0_s1_sequence_phase,
|
| 905 |
+
) = softmax_step(
|
| 906 |
+
mma_si_consumer_phase,
|
| 907 |
+
si_corr_producer_phase,
|
| 908 |
+
s0_s1_sequence_phase,
|
| 909 |
+
full_n_block,
|
| 910 |
+
mask_fn=partial(
|
| 911 |
+
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
|
| 912 |
+
),
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
return (
|
| 916 |
+
mma_si_consumer_phase,
|
| 917 |
+
si_corr_producer_phase,
|
| 918 |
+
s0_s1_sequence_phase,
|
| 919 |
+
total_block_cnt == 0,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
# =============================================================================
|
| 924 |
+
# Backward-specific block-sparse helpers (SM100)
|
| 925 |
+
# =============================================================================
|
| 926 |
+
#
|
| 927 |
+
# In backward, iteration is transposed compared to forward:
|
| 928 |
+
# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
|
| 929 |
+
# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
|
| 930 |
+
#
|
| 931 |
+
# The backward block-sparse tensors use "Q direction" indexing:
|
| 932 |
+
# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
|
| 933 |
+
# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
|
| 934 |
+
#
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
@cute.jit
|
| 938 |
+
def get_total_q_block_count_bwd(
|
| 939 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 940 |
+
batch_idx,
|
| 941 |
+
head_idx,
|
| 942 |
+
n_block,
|
| 943 |
+
subtile_factor: cutlass.Constexpr = 1,
|
| 944 |
+
m_block_max: int = 0,
|
| 945 |
+
):
|
| 946 |
+
"""Count total tile iterations for given n_block (KV tile) in backward."""
|
| 947 |
+
q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
|
| 948 |
+
total = q_block_cnt[batch_idx, head_idx, n_block]
|
| 949 |
+
if const_expr(full_block_cnt is not None):
|
| 950 |
+
total = total + full_block_cnt[batch_idx, head_idx, n_block]
|
| 951 |
+
return total * subtile_factor
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
@cute.jit
|
| 955 |
+
def produce_block_sparse_q_loads_bwd_sm100(
|
| 956 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 957 |
+
batch_idx,
|
| 958 |
+
head_idx,
|
| 959 |
+
n_block,
|
| 960 |
+
# Pipeline states (will be returned after advancing)
|
| 961 |
+
producer_state_Q_LSE,
|
| 962 |
+
producer_state_dO_dPsum,
|
| 963 |
+
# Pipelines
|
| 964 |
+
pipeline_Q,
|
| 965 |
+
pipeline_LSE,
|
| 966 |
+
pipeline_dO,
|
| 967 |
+
pipeline_dPsum,
|
| 968 |
+
# Load functions
|
| 969 |
+
load_K,
|
| 970 |
+
load_V,
|
| 971 |
+
load_Q,
|
| 972 |
+
load_dO,
|
| 973 |
+
copy_stats,
|
| 974 |
+
# Global tensors for LSE/dPsum
|
| 975 |
+
gLSE,
|
| 976 |
+
sLSE,
|
| 977 |
+
gdPsum,
|
| 978 |
+
sdPsum,
|
| 979 |
+
# TMA copy bytes for extra_tx_count
|
| 980 |
+
tma_copy_bytes_K,
|
| 981 |
+
tma_copy_bytes_V,
|
| 982 |
+
# Flags for which loads to perform
|
| 983 |
+
should_load_Q: cutlass.Constexpr,
|
| 984 |
+
should_load_dO: cutlass.Constexpr,
|
| 985 |
+
# Subtiling factor and bounds
|
| 986 |
+
subtile_factor: cutlass.Constexpr = 1,
|
| 987 |
+
m_block_max: int = 0,
|
| 988 |
+
):
|
| 989 |
+
"""SM100 backward block sparse loading with subtiling.
|
| 990 |
+
|
| 991 |
+
Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
|
| 992 |
+
First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
|
| 993 |
+
"""
|
| 994 |
+
(
|
| 995 |
+
curr_q_cnt,
|
| 996 |
+
curr_q_idx,
|
| 997 |
+
curr_full_cnt,
|
| 998 |
+
curr_full_idx,
|
| 999 |
+
loop_count,
|
| 1000 |
+
) = get_block_sparse_iteration_info_bwd(
|
| 1001 |
+
blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
for iter_idx in cutlass.range(loop_count, unroll=1):
|
| 1005 |
+
m_block, _ = get_m_block_from_iter_bwd(
|
| 1006 |
+
iter_idx,
|
| 1007 |
+
curr_q_cnt,
|
| 1008 |
+
curr_q_idx,
|
| 1009 |
+
curr_full_cnt,
|
| 1010 |
+
curr_full_idx,
|
| 1011 |
+
subtile_factor,
|
| 1012 |
+
m_block_max,
|
| 1013 |
+
)
|
| 1014 |
+
m_block_safe = m_block
|
| 1015 |
+
if m_block_max > 0:
|
| 1016 |
+
m_block_safe = cutlass.min(m_block, m_block_max - 1)
|
| 1017 |
+
|
| 1018 |
+
if iter_idx == 0:
|
| 1019 |
+
# First block: load K/V alongside Q/dO
|
| 1020 |
+
if const_expr(should_load_Q):
|
| 1021 |
+
pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
|
| 1022 |
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
|
| 1023 |
+
load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
|
| 1024 |
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
| 1025 |
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
| 1026 |
+
with cute.arch.elect_one():
|
| 1027 |
+
copy_stats(
|
| 1028 |
+
gLSE[None, m_block_safe],
|
| 1029 |
+
sLSE[None, producer_state_Q_LSE.index],
|
| 1030 |
+
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
|
| 1031 |
+
)
|
| 1032 |
+
producer_state_Q_LSE.advance()
|
| 1033 |
+
if const_expr(should_load_dO):
|
| 1034 |
+
pipeline_dO.producer_acquire(
|
| 1035 |
+
producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
|
| 1036 |
+
)
|
| 1037 |
+
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))
|
| 1038 |
+
load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
|
| 1039 |
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
| 1040 |
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
| 1041 |
+
with cute.arch.elect_one():
|
| 1042 |
+
copy_stats(
|
| 1043 |
+
gdPsum[None, m_block_safe],
|
| 1044 |
+
sdPsum[None, producer_state_dO_dPsum.index],
|
| 1045 |
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
|
| 1046 |
+
)
|
| 1047 |
+
producer_state_dO_dPsum.advance()
|
| 1048 |
+
else:
|
| 1049 |
+
# Subsequent blocks: just load Q/dO (K/V already loaded)
|
| 1050 |
+
if const_expr(should_load_Q):
|
| 1051 |
+
pipeline_Q.producer_acquire(producer_state_Q_LSE)
|
| 1052 |
+
load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
|
| 1053 |
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
| 1054 |
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
| 1055 |
+
with cute.arch.elect_one():
|
| 1056 |
+
copy_stats(
|
| 1057 |
+
gLSE[None, m_block_safe],
|
| 1058 |
+
sLSE[None, producer_state_Q_LSE.index],
|
| 1059 |
+
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
|
| 1060 |
+
)
|
| 1061 |
+
producer_state_Q_LSE.advance()
|
| 1062 |
+
if const_expr(should_load_dO):
|
| 1063 |
+
pipeline_dO.producer_acquire(producer_state_dO_dPsum)
|
| 1064 |
+
load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
|
| 1065 |
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
| 1066 |
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
| 1067 |
+
with cute.arch.elect_one():
|
| 1068 |
+
copy_stats(
|
| 1069 |
+
gdPsum[None, m_block_safe],
|
| 1070 |
+
sdPsum[None, producer_state_dO_dPsum.index],
|
| 1071 |
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
|
| 1072 |
+
)
|
| 1073 |
+
producer_state_dO_dPsum.advance()
|
| 1074 |
+
|
| 1075 |
+
return producer_state_Q_LSE, producer_state_dO_dPsum
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
@cute.jit
|
| 1079 |
+
def get_block_sparse_iteration_info_bwd(
|
| 1080 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 1081 |
+
batch_idx,
|
| 1082 |
+
head_idx,
|
| 1083 |
+
n_block,
|
| 1084 |
+
subtile_factor: cutlass.Constexpr = 1,
|
| 1085 |
+
m_block_max: int = 0,
|
| 1086 |
+
):
|
| 1087 |
+
"""Extract block-sparse iteration info for backward pass.
|
| 1088 |
+
|
| 1089 |
+
Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
|
| 1090 |
+
"""
|
| 1091 |
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
| 1092 |
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
| 1093 |
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
| 1094 |
+
|
| 1095 |
+
if const_expr(full_cnt is not None):
|
| 1096 |
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
| 1097 |
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
| 1098 |
+
else:
|
| 1099 |
+
curr_full_cnt = Int32(0)
|
| 1100 |
+
curr_full_idx = None
|
| 1101 |
+
|
| 1102 |
+
sparse_block_count = curr_q_cnt
|
| 1103 |
+
if const_expr(full_cnt is not None):
|
| 1104 |
+
sparse_block_count = sparse_block_count + curr_full_cnt
|
| 1105 |
+
total_count = sparse_block_count * subtile_factor
|
| 1106 |
+
|
| 1107 |
+
return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
@cute.jit
|
| 1111 |
+
def get_m_block_from_iter_bwd(
|
| 1112 |
+
iter_idx,
|
| 1113 |
+
curr_q_cnt,
|
| 1114 |
+
curr_q_idx: cute.Tensor,
|
| 1115 |
+
curr_full_cnt,
|
| 1116 |
+
curr_full_idx: Optional[cute.Tensor],
|
| 1117 |
+
subtile_factor: cutlass.Constexpr = 1,
|
| 1118 |
+
m_block_max: int = 0,
|
| 1119 |
+
):
|
| 1120 |
+
"""Derive m_block index and is_full_block flag from iteration index.
|
| 1121 |
+
|
| 1122 |
+
Returns (m_block, is_full_block):
|
| 1123 |
+
- m_block: The actual Q-tile block index
|
| 1124 |
+
- is_full_block: True if this is a full block (no mask_mod needed)
|
| 1125 |
+
"""
|
| 1126 |
+
sparse_iter_idx = iter_idx // subtile_factor
|
| 1127 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1128 |
+
|
| 1129 |
+
sparse_m_block = Int32(0)
|
| 1130 |
+
is_full_block = False
|
| 1131 |
+
if const_expr(curr_full_idx is not None):
|
| 1132 |
+
if sparse_iter_idx < curr_q_cnt:
|
| 1133 |
+
sparse_m_block = curr_q_idx[sparse_iter_idx]
|
| 1134 |
+
else:
|
| 1135 |
+
sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt]
|
| 1136 |
+
is_full_block = True
|
| 1137 |
+
else:
|
| 1138 |
+
sparse_m_block = curr_q_idx[sparse_iter_idx]
|
| 1139 |
+
|
| 1140 |
+
return sparse_m_block * subtile_factor + subtile_offset, is_full_block
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
@cute.jit
|
| 1144 |
+
def _load_q_do_block_sm90(
|
| 1145 |
+
m_block,
|
| 1146 |
+
producer_state_Q,
|
| 1147 |
+
producer_state_dO,
|
| 1148 |
+
pipeline_Q,
|
| 1149 |
+
pipeline_dO,
|
| 1150 |
+
load_K,
|
| 1151 |
+
load_V,
|
| 1152 |
+
load_Q,
|
| 1153 |
+
load_dO,
|
| 1154 |
+
load_LSE,
|
| 1155 |
+
load_dPsum,
|
| 1156 |
+
tma_copy_bytes_K,
|
| 1157 |
+
tma_copy_bytes_V,
|
| 1158 |
+
Q_stage_eq_dO_stage: cutlass.Constexpr,
|
| 1159 |
+
load_kv: bool,
|
| 1160 |
+
):
|
| 1161 |
+
"""Load one Q/dO block, optionally loading K/V on first iteration."""
|
| 1162 |
+
if load_kv:
|
| 1163 |
+
pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K)
|
| 1164 |
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
|
| 1165 |
+
else:
|
| 1166 |
+
pipeline_Q.producer_acquire(producer_state_Q)
|
| 1167 |
+
load_Q(m_block, producer_state=producer_state_Q)
|
| 1168 |
+
load_LSE(m_block, producer_state=producer_state_Q)
|
| 1169 |
+
|
| 1170 |
+
producer_state_dO_cur = (
|
| 1171 |
+
producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
|
| 1172 |
+
)
|
| 1173 |
+
if load_kv:
|
| 1174 |
+
pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V)
|
| 1175 |
+
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
|
| 1176 |
+
else:
|
| 1177 |
+
pipeline_dO.producer_acquire(producer_state_dO_cur)
|
| 1178 |
+
load_dO(m_block, producer_state=producer_state_dO_cur)
|
| 1179 |
+
load_dPsum(m_block, producer_state=producer_state_dO_cur)
|
| 1180 |
+
|
| 1181 |
+
producer_state_Q.advance()
|
| 1182 |
+
producer_state_dO.advance()
|
| 1183 |
+
return producer_state_Q, producer_state_dO
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
@cute.jit
|
| 1187 |
+
def produce_block_sparse_q_loads_bwd_sm90(
|
| 1188 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 1189 |
+
batch_idx,
|
| 1190 |
+
head_idx,
|
| 1191 |
+
n_block,
|
| 1192 |
+
producer_state_Q,
|
| 1193 |
+
producer_state_dO,
|
| 1194 |
+
pipeline_Q,
|
| 1195 |
+
pipeline_dO,
|
| 1196 |
+
load_K,
|
| 1197 |
+
load_V,
|
| 1198 |
+
load_Q,
|
| 1199 |
+
load_dO,
|
| 1200 |
+
load_LSE,
|
| 1201 |
+
load_dPsum,
|
| 1202 |
+
tma_copy_bytes_K,
|
| 1203 |
+
tma_copy_bytes_V,
|
| 1204 |
+
Q_stage_eq_dO_stage: cutlass.Constexpr,
|
| 1205 |
+
subtile_factor: cutlass.Constexpr,
|
| 1206 |
+
m_block_max: int,
|
| 1207 |
+
):
|
| 1208 |
+
"""SM90 backward block sparse loading with separate partial/full loops.
|
| 1209 |
+
|
| 1210 |
+
K/V are loaded with the first valid block. Iterates partial blocks first,
|
| 1211 |
+
then full blocks, matching consumer order.
|
| 1212 |
+
|
| 1213 |
+
Returns updated (producer_state_Q, producer_state_dO).
|
| 1214 |
+
"""
|
| 1215 |
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
| 1216 |
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
| 1217 |
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
| 1218 |
+
|
| 1219 |
+
if const_expr(full_cnt is not None):
|
| 1220 |
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
| 1221 |
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
| 1222 |
+
else:
|
| 1223 |
+
curr_full_cnt = Int32(0)
|
| 1224 |
+
curr_full_idx = None
|
| 1225 |
+
|
| 1226 |
+
kv_loaded = False
|
| 1227 |
+
|
| 1228 |
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
| 1229 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1230 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1231 |
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1232 |
+
|
| 1233 |
+
if m_block < m_block_max:
|
| 1234 |
+
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
|
| 1235 |
+
m_block,
|
| 1236 |
+
producer_state_Q,
|
| 1237 |
+
producer_state_dO,
|
| 1238 |
+
pipeline_Q,
|
| 1239 |
+
pipeline_dO,
|
| 1240 |
+
load_K,
|
| 1241 |
+
load_V,
|
| 1242 |
+
load_Q,
|
| 1243 |
+
load_dO,
|
| 1244 |
+
load_LSE,
|
| 1245 |
+
load_dPsum,
|
| 1246 |
+
tma_copy_bytes_K,
|
| 1247 |
+
tma_copy_bytes_V,
|
| 1248 |
+
Q_stage_eq_dO_stage,
|
| 1249 |
+
load_kv=not kv_loaded,
|
| 1250 |
+
)
|
| 1251 |
+
kv_loaded = True
|
| 1252 |
+
|
| 1253 |
+
if const_expr(full_cnt is not None):
|
| 1254 |
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
| 1255 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1256 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1257 |
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1258 |
+
|
| 1259 |
+
if m_block < m_block_max:
|
| 1260 |
+
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
|
| 1261 |
+
m_block,
|
| 1262 |
+
producer_state_Q,
|
| 1263 |
+
producer_state_dO,
|
| 1264 |
+
pipeline_Q,
|
| 1265 |
+
pipeline_dO,
|
| 1266 |
+
load_K,
|
| 1267 |
+
load_V,
|
| 1268 |
+
load_Q,
|
| 1269 |
+
load_dO,
|
| 1270 |
+
load_LSE,
|
| 1271 |
+
load_dPsum,
|
| 1272 |
+
tma_copy_bytes_K,
|
| 1273 |
+
tma_copy_bytes_V,
|
| 1274 |
+
Q_stage_eq_dO_stage,
|
| 1275 |
+
load_kv=not kv_loaded,
|
| 1276 |
+
)
|
| 1277 |
+
kv_loaded = True
|
| 1278 |
+
|
| 1279 |
+
return producer_state_Q, producer_state_dO
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
@cute.jit
|
| 1283 |
+
def consume_block_sparse_mma_bwd_sm90(
|
| 1284 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 1285 |
+
batch_idx,
|
| 1286 |
+
head_idx,
|
| 1287 |
+
n_block,
|
| 1288 |
+
consumer_state_Q,
|
| 1289 |
+
consumer_state_dO,
|
| 1290 |
+
mma_one_m_block_fn,
|
| 1291 |
+
mask,
|
| 1292 |
+
mask_mod,
|
| 1293 |
+
is_causal: cutlass.Constexpr,
|
| 1294 |
+
is_local: cutlass.Constexpr,
|
| 1295 |
+
thr_mma_SdP,
|
| 1296 |
+
score_mod_fn=None,
|
| 1297 |
+
score_mod_bwd_fn=None,
|
| 1298 |
+
subtile_factor: cutlass.Constexpr = 1,
|
| 1299 |
+
m_block_max: int = 0,
|
| 1300 |
+
aux_tensors=None,
|
| 1301 |
+
fastdiv_mods=(None, None),
|
| 1302 |
+
):
|
| 1303 |
+
"""SM90 backward block sparse MMA consumption with separate partial/full loops.
|
| 1304 |
+
|
| 1305 |
+
Partial blocks are processed first (with mask_mod applied), then full blocks
|
| 1306 |
+
(without mask_mod). This ensures mask_mod is only applied where needed.
|
| 1307 |
+
|
| 1308 |
+
Returns updated (consumer_state_Q, consumer_state_dO).
|
| 1309 |
+
"""
|
| 1310 |
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
| 1311 |
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
| 1312 |
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
| 1313 |
+
|
| 1314 |
+
if const_expr(full_cnt is not None):
|
| 1315 |
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
| 1316 |
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
| 1317 |
+
else:
|
| 1318 |
+
curr_full_cnt = Int32(0)
|
| 1319 |
+
curr_full_idx = None
|
| 1320 |
+
|
| 1321 |
+
dKV_accumulate = False
|
| 1322 |
+
|
| 1323 |
+
mask_fn_partial = partial(
|
| 1324 |
+
mask.apply_mask,
|
| 1325 |
+
batch_idx=batch_idx,
|
| 1326 |
+
head_idx=head_idx,
|
| 1327 |
+
n_block=n_block,
|
| 1328 |
+
thr_mma=thr_mma_SdP,
|
| 1329 |
+
mask_seqlen=True,
|
| 1330 |
+
mask_causal=is_causal,
|
| 1331 |
+
mask_local=is_local,
|
| 1332 |
+
mask_mod=mask_mod,
|
| 1333 |
+
aux_tensors=aux_tensors,
|
| 1334 |
+
fastdiv_mods=fastdiv_mods,
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
mask_fn_full = partial(
|
| 1338 |
+
mask.apply_mask,
|
| 1339 |
+
batch_idx=batch_idx,
|
| 1340 |
+
head_idx=head_idx,
|
| 1341 |
+
n_block=n_block,
|
| 1342 |
+
thr_mma=thr_mma_SdP,
|
| 1343 |
+
mask_seqlen=True,
|
| 1344 |
+
mask_causal=is_causal,
|
| 1345 |
+
mask_local=is_local,
|
| 1346 |
+
aux_tensors=aux_tensors,
|
| 1347 |
+
fastdiv_mods=fastdiv_mods,
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
| 1351 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1352 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1353 |
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1354 |
+
|
| 1355 |
+
if m_block < m_block_max:
|
| 1356 |
+
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
|
| 1357 |
+
m_block,
|
| 1358 |
+
consumer_state_Q,
|
| 1359 |
+
consumer_state_dO,
|
| 1360 |
+
mask_fn=mask_fn_partial,
|
| 1361 |
+
score_mod_fn=score_mod_fn,
|
| 1362 |
+
score_mod_bwd_fn=score_mod_bwd_fn,
|
| 1363 |
+
dKV_accumulate=dKV_accumulate,
|
| 1364 |
+
)
|
| 1365 |
+
dKV_accumulate = True
|
| 1366 |
+
|
| 1367 |
+
if const_expr(full_cnt is not None):
|
| 1368 |
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
| 1369 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1370 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1371 |
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1372 |
+
|
| 1373 |
+
if m_block < m_block_max:
|
| 1374 |
+
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
|
| 1375 |
+
m_block,
|
| 1376 |
+
consumer_state_Q,
|
| 1377 |
+
consumer_state_dO,
|
| 1378 |
+
mask_fn=mask_fn_full,
|
| 1379 |
+
score_mod_fn=score_mod_fn,
|
| 1380 |
+
score_mod_bwd_fn=score_mod_bwd_fn,
|
| 1381 |
+
dKV_accumulate=dKV_accumulate,
|
| 1382 |
+
)
|
| 1383 |
+
dKV_accumulate = True
|
| 1384 |
+
|
| 1385 |
+
return consumer_state_Q, consumer_state_dO
|
| 1386 |
+
|
| 1387 |
+
|
| 1388 |
+
@cute.jit
|
| 1389 |
+
def _store_one_dQaccum_sm90(
|
| 1390 |
+
m_block,
|
| 1391 |
+
sdQaccum: cute.Tensor,
|
| 1392 |
+
gdQaccum: cute.Tensor,
|
| 1393 |
+
num_mma_warp_groups: cutlass.Constexpr,
|
| 1394 |
+
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1395 |
+
tma_copy_bytes_dQ,
|
| 1396 |
+
):
|
| 1397 |
+
"""Store dQaccum for a single m_block."""
|
| 1398 |
+
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
|
| 1399 |
+
cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
|
| 1400 |
+
cute.arch.barrier_arrive(
|
| 1401 |
+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1402 |
+
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1403 |
+
)
|
| 1404 |
+
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
|
| 1405 |
+
cute.arch.barrier(
|
| 1406 |
+
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1407 |
+
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1408 |
+
)
|
| 1409 |
+
with cute.arch.elect_one():
|
| 1410 |
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1411 |
+
sdQaccum[None, warp_group_idx].iterator,
|
| 1412 |
+
gdQaccum[None, warp_group_idx, m_block].iterator,
|
| 1413 |
+
tma_copy_bytes_dQ,
|
| 1414 |
+
)
|
| 1415 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
@cute.jit
|
| 1419 |
+
def dQaccum_store_block_sparse_bwd_sm90(
|
| 1420 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 1421 |
+
batch_idx,
|
| 1422 |
+
head_idx,
|
| 1423 |
+
n_block,
|
| 1424 |
+
sdQaccum: cute.Tensor,
|
| 1425 |
+
gdQaccum: cute.Tensor,
|
| 1426 |
+
subtile_factor: cutlass.Constexpr,
|
| 1427 |
+
m_block_max: int,
|
| 1428 |
+
num_mma_warp_groups: cutlass.Constexpr,
|
| 1429 |
+
num_threads_per_warp_group: cutlass.Constexpr,
|
| 1430 |
+
tma_copy_bytes_dQ,
|
| 1431 |
+
):
|
| 1432 |
+
"""SM90 backward block sparse dQaccum store with separate partial/full loops.
|
| 1433 |
+
|
| 1434 |
+
Iterates partial blocks first, then full blocks, matching producer/consumer order.
|
| 1435 |
+
"""
|
| 1436 |
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
| 1437 |
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
| 1438 |
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
| 1439 |
+
|
| 1440 |
+
if const_expr(full_cnt is not None):
|
| 1441 |
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
| 1442 |
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
| 1443 |
+
else:
|
| 1444 |
+
curr_full_cnt = Int32(0)
|
| 1445 |
+
curr_full_idx = None
|
| 1446 |
+
|
| 1447 |
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
| 1448 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1449 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1450 |
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1451 |
+
|
| 1452 |
+
if m_block < m_block_max:
|
| 1453 |
+
_store_one_dQaccum_sm90(
|
| 1454 |
+
m_block,
|
| 1455 |
+
sdQaccum,
|
| 1456 |
+
gdQaccum,
|
| 1457 |
+
num_mma_warp_groups,
|
| 1458 |
+
num_threads_per_warp_group,
|
| 1459 |
+
tma_copy_bytes_dQ,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
if const_expr(full_cnt is not None):
|
| 1463 |
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
| 1464 |
+
sparse_idx = iter_idx // subtile_factor
|
| 1465 |
+
subtile_offset = iter_idx % subtile_factor
|
| 1466 |
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
| 1467 |
+
|
| 1468 |
+
if m_block < m_block_max:
|
| 1469 |
+
_store_one_dQaccum_sm90(
|
| 1470 |
+
m_block,
|
| 1471 |
+
sdQaccum,
|
| 1472 |
+
gdQaccum,
|
| 1473 |
+
num_mma_warp_groups,
|
| 1474 |
+
num_threads_per_warp_group,
|
| 1475 |
+
tma_copy_bytes_dQ,
|
| 1476 |
+
)
|
build/torch-cuda/block_sparsity.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Block-sparsity utilities for FlexAttention
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Callable, NamedTuple, Tuple
|
| 6 |
+
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .cute_dsl_utils import get_broadcast_dims, to_cute_tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ceildiv(a: int, b: int) -> int:
|
| 14 |
+
return (a + b - 1) // b
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BlockSparseTensors(NamedTuple):
|
| 18 |
+
mask_block_cnt: cute.Tensor
|
| 19 |
+
mask_block_idx: cute.Tensor
|
| 20 |
+
full_block_cnt: cute.Tensor | None
|
| 21 |
+
full_block_idx: cute.Tensor | None
|
| 22 |
+
|
| 23 |
+
def __new_from_mlir_values__(self, values):
|
| 24 |
+
if len(values) == 2:
|
| 25 |
+
values = (*values, None, None)
|
| 26 |
+
return BlockSparseTensors(*values)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BlockSparseTensorsTorch(NamedTuple):
|
| 30 |
+
mask_block_cnt: torch.Tensor
|
| 31 |
+
mask_block_idx: torch.Tensor
|
| 32 |
+
full_block_cnt: torch.Tensor | None = None
|
| 33 |
+
full_block_idx: torch.Tensor | None = None
|
| 34 |
+
block_size: tuple[int, int] | None = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _expand_sparsity_tensor(
|
| 38 |
+
tensor: torch.Tensor,
|
| 39 |
+
expected_shape: Tuple[int, ...],
|
| 40 |
+
tensor_name: str,
|
| 41 |
+
context: str | None,
|
| 42 |
+
hint: str | Callable[[], str] | None,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""Check if we need to expand the tensor to expected shape, and do so if possible."""
|
| 45 |
+
needs_expand = tensor.shape != expected_shape
|
| 46 |
+
if not needs_expand:
|
| 47 |
+
return tensor
|
| 48 |
+
can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
|
| 49 |
+
if not can_expand:
|
| 50 |
+
context_clause = f" ({context})" if context else ""
|
| 51 |
+
resolved_hint = hint() if callable(hint) else hint
|
| 52 |
+
hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
|
| 55 |
+
f"{hint_clause}"
|
| 56 |
+
)
|
| 57 |
+
return tensor.expand(*expected_shape)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _check_and_expand_block(
|
| 61 |
+
name: str,
|
| 62 |
+
cnt: torch.Tensor | None,
|
| 63 |
+
idx: torch.Tensor | None,
|
| 64 |
+
expected_count_shape: Tuple[int, int, int],
|
| 65 |
+
expected_index_shape: Tuple[int, int, int, int],
|
| 66 |
+
context: str | None,
|
| 67 |
+
hint: str | Callable[[], str] | None,
|
| 68 |
+
) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 69 |
+
if (cnt is None) != (idx is None):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
|
| 72 |
+
)
|
| 73 |
+
if cnt is None or idx is None:
|
| 74 |
+
return None, None
|
| 75 |
+
if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
|
| 76 |
+
raise ValueError(f"{name}_block tensors must have dtype torch.int32")
|
| 77 |
+
if cnt.device != idx.device:
|
| 78 |
+
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
|
| 79 |
+
if not cnt.is_cuda or not idx.is_cuda:
|
| 80 |
+
raise ValueError(f"{name}_block tensors must live on CUDA")
|
| 81 |
+
expanded_cnt = _expand_sparsity_tensor(
|
| 82 |
+
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
|
| 83 |
+
)
|
| 84 |
+
expanded_idx = _expand_sparsity_tensor(
|
| 85 |
+
idx, expected_index_shape, f"{name}_block_idx", context, hint
|
| 86 |
+
)
|
| 87 |
+
return expanded_cnt, expanded_idx
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_block_sparse_expected_shapes(
|
| 91 |
+
batch_size: int,
|
| 92 |
+
num_head: int,
|
| 93 |
+
seqlen_q: int,
|
| 94 |
+
seqlen_k: int,
|
| 95 |
+
m_block_size: int,
|
| 96 |
+
n_block_size: int,
|
| 97 |
+
q_stage: int,
|
| 98 |
+
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
|
| 99 |
+
"""Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
|
| 100 |
+
m_block_size_effective = q_stage * m_block_size
|
| 101 |
+
expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
|
| 102 |
+
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
|
| 103 |
+
expected_count_shape = (batch_size, num_head, expected_m_blocks)
|
| 104 |
+
expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
|
| 105 |
+
return expected_count_shape, expected_index_shape
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def infer_block_sparse_expected_shapes(
|
| 109 |
+
tensors: BlockSparseTensorsTorch,
|
| 110 |
+
*,
|
| 111 |
+
batch_size: int,
|
| 112 |
+
num_head: int,
|
| 113 |
+
seqlen_q: int,
|
| 114 |
+
seqlen_k: int,
|
| 115 |
+
m_block_size: int,
|
| 116 |
+
n_block_size: int,
|
| 117 |
+
q_stage: int,
|
| 118 |
+
context: str,
|
| 119 |
+
sparse_block_size_q: int | None = None,
|
| 120 |
+
sparse_block_size_kv: int | None = None,
|
| 121 |
+
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int], int]:
|
| 122 |
+
"""Infer shapes and scaling for block-sparse tensors.
|
| 123 |
+
|
| 124 |
+
Expectations:
|
| 125 |
+
- mask_block_cnt is (B, H, M) and mask_block_idx is (B, H, M, N).
|
| 126 |
+
- Batch/head dims may be 1 for broadcast, or match the requested sizes.
|
| 127 |
+
- sparse_block_size_kv must match tile_n.
|
| 128 |
+
- sparse_block_size_q must be a multiple of q_stage * tile_m.
|
| 129 |
+
- If sparse_block_size_q is omitted and seqlen_q/num_m_blocks is ambiguous,
|
| 130 |
+
the caller must provide block_size to disambiguate. TODO will make this required in a future PR.
|
| 131 |
+
"""
|
| 132 |
+
base_m_block = q_stage * m_block_size
|
| 133 |
+
base_n_block = n_block_size
|
| 134 |
+
if sparse_block_size_kv is None:
|
| 135 |
+
sparse_block_size_kv = base_n_block
|
| 136 |
+
if sparse_block_size_kv != base_n_block:
|
| 137 |
+
raise ValueError(f"Block sparse tensors{context} require BLOCK_SIZE_KV={base_n_block}.")
|
| 138 |
+
if tensors.mask_block_idx is None:
|
| 139 |
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
| 140 |
+
num_m_blocks = tensors.mask_block_idx.shape[2]
|
| 141 |
+
|
| 142 |
+
if sparse_block_size_q is None:
|
| 143 |
+
min_block_size = ceildiv(seqlen_q, num_m_blocks)
|
| 144 |
+
if num_m_blocks == 1:
|
| 145 |
+
max_block_size = seqlen_q
|
| 146 |
+
else:
|
| 147 |
+
max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)
|
| 148 |
+
if max_block_size != min_block_size and base_m_block != 1:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"Block sparse tensors{context} require explicit sparse_block_size[0] "
|
| 151 |
+
f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
|
| 152 |
+
)
|
| 153 |
+
sparse_block_size_q = min_block_size
|
| 154 |
+
|
| 155 |
+
if sparse_block_size_q % base_m_block != 0:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"Block sparse tensors{context} have block size {sparse_block_size_q}, "
|
| 158 |
+
f"which must be a multiple of {base_m_block}."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
|
| 162 |
+
expected_n_blocks = ceildiv(seqlen_k, sparse_block_size_kv)
|
| 163 |
+
q_subtile_factor = sparse_block_size_q // base_m_block
|
| 164 |
+
expected_count_shape = (batch_size, num_head, expected_m_blocks)
|
| 165 |
+
expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
|
| 166 |
+
|
| 167 |
+
mask_block_cnt = tensors.mask_block_cnt
|
| 168 |
+
mask_block_idx = tensors.mask_block_idx
|
| 169 |
+
if mask_block_cnt is None or mask_block_idx is None:
|
| 170 |
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
| 171 |
+
if mask_block_cnt.ndim != 3 or mask_block_idx.ndim != 4:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Block sparse tensors{context} must have shapes (B, H, M) and (B, H, M, N)."
|
| 174 |
+
)
|
| 175 |
+
for dim_name, cur, tgt in (
|
| 176 |
+
("batch", mask_block_cnt.shape[0], expected_count_shape[0]),
|
| 177 |
+
("head", mask_block_cnt.shape[1], expected_count_shape[1]),
|
| 178 |
+
):
|
| 179 |
+
if cur != tgt and cur != 1:
|
| 180 |
+
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
|
| 181 |
+
for dim_name, cur, tgt in (
|
| 182 |
+
("batch", mask_block_idx.shape[0], expected_index_shape[0]),
|
| 183 |
+
("head", mask_block_idx.shape[1], expected_index_shape[1]),
|
| 184 |
+
):
|
| 185 |
+
if cur != tgt and cur != 1:
|
| 186 |
+
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
|
| 187 |
+
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
|
| 188 |
+
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
|
| 189 |
+
if mask_block_idx.shape[3] != expected_n_blocks:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
|
| 192 |
+
)
|
| 193 |
+
if expected_m_blocks != num_m_blocks:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"Block sparse tensors{context} m-block dimension {num_m_blocks} does not match "
|
| 196 |
+
f"sparse_block_size_q={sparse_block_size_q}. "
|
| 197 |
+
f"Set BlockSparseTensorsTorch.block_size to match the BlockMask BLOCK_SIZE."
|
| 198 |
+
)
|
| 199 |
+
return expected_count_shape, expected_index_shape, q_subtile_factor
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_block_sparse_expected_shapes_bwd(
|
| 203 |
+
batch_size: int,
|
| 204 |
+
num_head: int,
|
| 205 |
+
seqlen_q: int,
|
| 206 |
+
seqlen_k: int,
|
| 207 |
+
m_block_size: int,
|
| 208 |
+
n_block_size: int,
|
| 209 |
+
subtile_factor: int,
|
| 210 |
+
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
|
| 211 |
+
"""Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.
|
| 212 |
+
|
| 213 |
+
Backward uses Q-direction indexing (transposed from forward), where shapes are
|
| 214 |
+
indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
|
| 215 |
+
by subtile_factor * m_block_size.
|
| 216 |
+
"""
|
| 217 |
+
sparse_block_size_q = subtile_factor * m_block_size
|
| 218 |
+
expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
|
| 219 |
+
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
|
| 220 |
+
expected_count_shape = (batch_size, num_head, expected_n_blocks)
|
| 221 |
+
expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)
|
| 222 |
+
return expected_count_shape, expected_index_shape
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def normalize_block_sparse_tensors(
|
| 226 |
+
tensors: BlockSparseTensorsTorch,
|
| 227 |
+
*,
|
| 228 |
+
expected_count_shape: Tuple[int, int, int],
|
| 229 |
+
expected_index_shape: Tuple[int, int, int, int],
|
| 230 |
+
context: str | None = None,
|
| 231 |
+
hint: str | Callable[[], str] | None = None,
|
| 232 |
+
) -> BlockSparseTensorsTorch:
|
| 233 |
+
if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
|
| 234 |
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
| 235 |
+
|
| 236 |
+
mask_cnt, mask_idx = _check_and_expand_block(
|
| 237 |
+
"mask",
|
| 238 |
+
tensors.mask_block_cnt,
|
| 239 |
+
tensors.mask_block_idx,
|
| 240 |
+
expected_count_shape,
|
| 241 |
+
expected_index_shape,
|
| 242 |
+
context,
|
| 243 |
+
hint,
|
| 244 |
+
)
|
| 245 |
+
if mask_cnt is None or mask_idx is None:
|
| 246 |
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
| 247 |
+
|
| 248 |
+
full_cnt, full_idx = _check_and_expand_block(
|
| 249 |
+
"full",
|
| 250 |
+
tensors.full_block_cnt,
|
| 251 |
+
tensors.full_block_idx,
|
| 252 |
+
expected_count_shape,
|
| 253 |
+
expected_index_shape,
|
| 254 |
+
context,
|
| 255 |
+
hint,
|
| 256 |
+
)
|
| 257 |
+
if full_cnt is not None and mask_cnt.device != full_cnt.device:
|
| 258 |
+
raise ValueError("All block sparse tensors must be on the same device")
|
| 259 |
+
|
| 260 |
+
return BlockSparseTensorsTorch(
|
| 261 |
+
mask_block_cnt=mask_cnt,
|
| 262 |
+
mask_block_idx=mask_idx,
|
| 263 |
+
full_block_cnt=full_cnt,
|
| 264 |
+
full_block_idx=full_idx,
|
| 265 |
+
block_size=tensors.block_size,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
|
| 270 |
+
return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_block_sparse_broadcast_pattern(
|
| 274 |
+
tensors: BlockSparseTensorsTorch,
|
| 275 |
+
) -> Tuple[Tuple[bool, ...], ...] | None:
|
| 276 |
+
"""Return broadcast pattern for block sparse tensors by checking actual strides.
|
| 277 |
+
|
| 278 |
+
Returns a tuple of broadcast patterns (one per tensor) where each pattern
|
| 279 |
+
is a tuple of bools indicating which dims have stride=0.
|
| 280 |
+
This is used in compile keys to ensure kernels are recompiled when
|
| 281 |
+
broadcast patterns change, since CuTe's mark_layout_dynamic() keeps
|
| 282 |
+
stride=0 as static.
|
| 283 |
+
|
| 284 |
+
The tensors should already be expanded/normalized before calling this function.
|
| 285 |
+
|
| 286 |
+
Returns None if block sparsity is not enabled.
|
| 287 |
+
"""
|
| 288 |
+
if not is_block_sparsity_enabled(tensors):
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
patterns = []
|
| 292 |
+
for tensor in (
|
| 293 |
+
tensors.mask_block_cnt,
|
| 294 |
+
tensors.mask_block_idx,
|
| 295 |
+
tensors.full_block_cnt,
|
| 296 |
+
tensors.full_block_idx,
|
| 297 |
+
):
|
| 298 |
+
if tensor is not None:
|
| 299 |
+
patterns.append(get_broadcast_dims(tensor))
|
| 300 |
+
else:
|
| 301 |
+
patterns.append(None)
|
| 302 |
+
return tuple(patterns)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def normalize_block_sparse_config(
|
| 306 |
+
tensors: BlockSparseTensorsTorch,
|
| 307 |
+
*,
|
| 308 |
+
batch_size: int,
|
| 309 |
+
num_head: int,
|
| 310 |
+
seqlen_q: int,
|
| 311 |
+
seqlen_k: int,
|
| 312 |
+
block_size: tuple[int, int],
|
| 313 |
+
q_stage: int,
|
| 314 |
+
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
|
| 315 |
+
m_block_size, n_block_size = block_size
|
| 316 |
+
if tensors.block_size is None:
|
| 317 |
+
sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size
|
| 318 |
+
else:
|
| 319 |
+
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
|
| 320 |
+
if sparse_block_size_kv != n_block_size:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
f"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n."
|
| 323 |
+
)
|
| 324 |
+
expected_count_shape, expected_index_shape, q_subtile_factor = (
|
| 325 |
+
infer_block_sparse_expected_shapes(
|
| 326 |
+
tensors,
|
| 327 |
+
batch_size=batch_size,
|
| 328 |
+
num_head=num_head,
|
| 329 |
+
seqlen_q=seqlen_q,
|
| 330 |
+
seqlen_k=seqlen_k,
|
| 331 |
+
m_block_size=m_block_size,
|
| 332 |
+
n_block_size=n_block_size,
|
| 333 |
+
q_stage=q_stage,
|
| 334 |
+
context="forward",
|
| 335 |
+
sparse_block_size_q=sparse_block_size_q,
|
| 336 |
+
sparse_block_size_kv=sparse_block_size_kv,
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
normalized_tensors = normalize_block_sparse_tensors(
|
| 340 |
+
tensors,
|
| 341 |
+
expected_count_shape=expected_count_shape,
|
| 342 |
+
expected_index_shape=expected_index_shape,
|
| 343 |
+
)
|
| 344 |
+
return (
|
| 345 |
+
normalized_tensors,
|
| 346 |
+
get_block_sparse_broadcast_pattern(normalized_tensors),
|
| 347 |
+
q_subtile_factor,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def normalize_block_sparse_config_bwd(
|
| 352 |
+
tensors: BlockSparseTensorsTorch,
|
| 353 |
+
*,
|
| 354 |
+
batch_size: int,
|
| 355 |
+
num_head: int,
|
| 356 |
+
seqlen_q: int,
|
| 357 |
+
seqlen_k: int,
|
| 358 |
+
block_size: tuple[int, int],
|
| 359 |
+
subtile_factor: int,
|
| 360 |
+
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None]:
|
| 361 |
+
m_block_size, n_block_size = block_size
|
| 362 |
+
if tensors.block_size is None:
|
| 363 |
+
sparse_block_size_q, sparse_block_size_kv = subtile_factor * m_block_size, n_block_size
|
| 364 |
+
else:
|
| 365 |
+
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
|
| 366 |
+
if sparse_block_size_q != subtile_factor * m_block_size:
|
| 367 |
+
raise ValueError(
|
| 368 |
+
f"Block sparsity expects sparse_block_size_q={subtile_factor * m_block_size} "
|
| 369 |
+
f"for subtile_factor={subtile_factor}."
|
| 370 |
+
)
|
| 371 |
+
if sparse_block_size_kv != n_block_size:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"Block sparsity expects sparse_block_size[1]={n_block_size} to match tile_n."
|
| 374 |
+
)
|
| 375 |
+
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
|
| 376 |
+
batch_size,
|
| 377 |
+
num_head,
|
| 378 |
+
seqlen_q,
|
| 379 |
+
seqlen_k,
|
| 380 |
+
m_block_size,
|
| 381 |
+
n_block_size,
|
| 382 |
+
subtile_factor,
|
| 383 |
+
)
|
| 384 |
+
normalized_tensors = normalize_block_sparse_tensors(
|
| 385 |
+
tensors,
|
| 386 |
+
expected_count_shape=expected_count_shape,
|
| 387 |
+
expected_index_shape=expected_index_shape,
|
| 388 |
+
context="_flash_attn_bwd",
|
| 389 |
+
hint=lambda: (
|
| 390 |
+
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, "
|
| 391 |
+
f"and optionally full_q_cnt/full_q_idx). Regenerate the backward BlockMask with "
|
| 392 |
+
f"BLOCK_SIZE=({subtile_factor * m_block_size}, {n_block_size})."
|
| 393 |
+
),
|
| 394 |
+
)
|
| 395 |
+
return normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def to_cute_block_sparse_tensors(
|
| 399 |
+
tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
|
| 400 |
+
) -> BlockSparseTensors | None:
|
| 401 |
+
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
|
| 402 |
+
if not is_block_sparsity_enabled(tensors):
|
| 403 |
+
return None
|
| 404 |
+
(
|
| 405 |
+
mask_block_cnt,
|
| 406 |
+
mask_block_idx,
|
| 407 |
+
full_block_cnt,
|
| 408 |
+
full_block_idx,
|
| 409 |
+
*_,
|
| 410 |
+
) = tensors
|
| 411 |
+
|
| 412 |
+
(
|
| 413 |
+
mask_block_cnt_tensor,
|
| 414 |
+
mask_block_idx_tensor,
|
| 415 |
+
) = [
|
| 416 |
+
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
|
| 417 |
+
for t in (mask_block_cnt, mask_block_idx)
|
| 418 |
+
]
|
| 419 |
+
(
|
| 420 |
+
full_block_cnt_tensor,
|
| 421 |
+
full_block_idx_tensor,
|
| 422 |
+
) = [
|
| 423 |
+
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
|
| 424 |
+
if t is not None
|
| 425 |
+
else None
|
| 426 |
+
for t in (full_block_cnt, full_block_idx)
|
| 427 |
+
]
|
| 428 |
+
|
| 429 |
+
return BlockSparseTensors(
|
| 430 |
+
mask_block_cnt_tensor,
|
| 431 |
+
mask_block_idx_tensor,
|
| 432 |
+
full_block_cnt_tensor,
|
| 433 |
+
full_block_idx_tensor,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def fast_sampling(mask_mod):
|
| 438 |
+
"""Convenience decorator to mark mask_mod as safe for 5-point fast sampling"""
|
| 439 |
+
mask_mod.use_fast_sampling = True
|
| 440 |
+
return mask_mod
|
build/torch-cuda/cache_utils.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Manage Ahead-of-Time (AOT) compiled kernels
|
| 2 |
+
import fcntl
|
| 3 |
+
import hashlib
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import sys
|
| 8 |
+
import tempfile
|
| 9 |
+
import time
|
| 10 |
+
from distutils.ccompiler import CCompiler, new_compiler
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
from getpass import getuser
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Hashable, TypeAlias
|
| 15 |
+
|
| 16 |
+
import cutlass
|
| 17 |
+
import cutlass.cute as cute
|
| 18 |
+
import tvm_ffi
|
| 19 |
+
from cutlass.cutlass_dsl import JitCompiledFunction
|
| 20 |
+
|
| 21 |
+
CompileKeyType: TypeAlias = tuple[Hashable, ...]
|
| 22 |
+
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
logger.addHandler(logging.StreamHandler())
|
| 26 |
+
logger.setLevel(logging.WARNING)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
|
| 30 |
+
CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is
|
| 34 |
+
# `/tmp/${USER}/flash_attention_cute_dsl_cache``
|
| 35 |
+
CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_cache_path() -> Path:
|
| 39 |
+
if CUTE_DSL_CACHE_DIR is not None:
|
| 40 |
+
cache_dir = Path(CUTE_DSL_CACHE_DIR)
|
| 41 |
+
else:
|
| 42 |
+
cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache"
|
| 43 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 44 |
+
return cache_dir
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@lru_cache(maxsize=1)
|
| 48 |
+
def _compute_source_fingerprint() -> str:
|
| 49 |
+
"""
|
| 50 |
+
Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint.
|
| 51 |
+
|
| 52 |
+
The fingerprint changes whenever:
|
| 53 |
+
- Any .py file under flash_attn/cute is added, removed, renamed, or modified.
|
| 54 |
+
- The Python minor version changes (e.g. 3.13 -> 3.14).
|
| 55 |
+
- The cutlass or tvm_ffi package version changes.
|
| 56 |
+
|
| 57 |
+
Computed once per process and cached.
|
| 58 |
+
"""
|
| 59 |
+
cute_root = Path(__file__).resolve().parent
|
| 60 |
+
h = hashlib.sha256()
|
| 61 |
+
|
| 62 |
+
h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode())
|
| 63 |
+
h.update(f"cutlass={cutlass.__version__}".encode())
|
| 64 |
+
h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
|
| 65 |
+
|
| 66 |
+
for src in sorted(cute_root.rglob("*.py")):
|
| 67 |
+
h.update(src.relative_to(cute_root).as_posix().encode())
|
| 68 |
+
content = src.read_bytes()
|
| 69 |
+
h.update(len(content).to_bytes(8, "little"))
|
| 70 |
+
h.update(content)
|
| 71 |
+
|
| 72 |
+
return h.hexdigest()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FileLock:
|
| 76 |
+
"""Context manager for advisory file locks using fcntl.flock.
|
| 77 |
+
|
| 78 |
+
Supports exclusive (write) and shared (read) locks.
|
| 79 |
+
Always blocks with polling until the lock is acquired or timeout is reached.
|
| 80 |
+
|
| 81 |
+
Usage:
|
| 82 |
+
with FileLock(lock_path, exclusive=True, timeout=15, label="abc"):
|
| 83 |
+
# do work under lock
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
lock_path: Path,
|
| 89 |
+
exclusive: bool,
|
| 90 |
+
timeout: float = 15,
|
| 91 |
+
label: str = "",
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Args:
|
| 95 |
+
lock_path: Path to the lock file on disk.
|
| 96 |
+
exclusive: True for exclusive (write) lock, False for shared (read) lock.
|
| 97 |
+
timeout: Max seconds to wait for lock acquisition before raising RuntimeError.
|
| 98 |
+
label: Optional human-readable label for error messages.
|
| 99 |
+
"""
|
| 100 |
+
self.lock_path: Path = lock_path
|
| 101 |
+
self.exclusive: bool = exclusive
|
| 102 |
+
self.timeout: float = timeout
|
| 103 |
+
self.label: str = label
|
| 104 |
+
self._fd: int = -1
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def _lock_label(self) -> str:
|
| 108 |
+
kind = "exclusive" if self.exclusive else "shared"
|
| 109 |
+
return f"{kind} {self.label}" if self.label else kind
|
| 110 |
+
|
| 111 |
+
def __enter__(self) -> "FileLock":
|
| 112 |
+
open_flags = (
|
| 113 |
+
os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
|
| 114 |
+
)
|
| 115 |
+
lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
|
| 116 |
+
|
| 117 |
+
self._fd = os.open(str(self.lock_path), open_flags)
|
| 118 |
+
|
| 119 |
+
deadline = time.monotonic() + self.timeout
|
| 120 |
+
acquired = False
|
| 121 |
+
while time.monotonic() < deadline:
|
| 122 |
+
try:
|
| 123 |
+
fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)
|
| 124 |
+
acquired = True
|
| 125 |
+
break
|
| 126 |
+
except OSError:
|
| 127 |
+
time.sleep(0.1)
|
| 128 |
+
if not acquired:
|
| 129 |
+
os.close(self._fd)
|
| 130 |
+
self._fd = None
|
| 131 |
+
raise RuntimeError(
|
| 132 |
+
f"Timed out after {self.timeout}s waiting for "
|
| 133 |
+
f"{self._lock_label} lock: {self.lock_path}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return self
|
| 137 |
+
|
| 138 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 139 |
+
if self._fd is not None:
|
| 140 |
+
fcntl.flock(self._fd, fcntl.LOCK_UN)
|
| 141 |
+
os.close(self._fd)
|
| 142 |
+
self._fd = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class JITCache:
|
| 146 |
+
"""
|
| 147 |
+
In-memory cache for compiled functions.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self):
|
| 151 |
+
self.cache: dict[CompileKeyType, CallableFunction] = {}
|
| 152 |
+
|
| 153 |
+
def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
|
| 154 |
+
self.cache[key] = fn
|
| 155 |
+
|
| 156 |
+
def __getitem__(self, key: CompileKeyType) -> CallableFunction:
|
| 157 |
+
return self.cache[key]
|
| 158 |
+
|
| 159 |
+
def __contains__(self, key: CompileKeyType) -> bool:
|
| 160 |
+
return key in self.cache
|
| 161 |
+
|
| 162 |
+
def clear(self) -> None:
|
| 163 |
+
"""
|
| 164 |
+
Clear in-memory cache of compiled functions
|
| 165 |
+
"""
|
| 166 |
+
self.cache.clear()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JITPersistentCache(JITCache):
|
| 170 |
+
"""
|
| 171 |
+
In-memory cache for compiled functions, which is also backed by persistent storage.
|
| 172 |
+
Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
EXPORT_FUNCTION_PREFIX = "func"
|
| 176 |
+
LOCK_TIMEOUT_SECONDS = 15
|
| 177 |
+
|
| 178 |
+
_compiler: CCompiler | None = None
|
| 179 |
+
|
| 180 |
+
def __init__(self, cache_path: Path):
|
| 181 |
+
super().__init__()
|
| 182 |
+
cache_path.mkdir(parents=True, exist_ok=True)
|
| 183 |
+
self.cache_path: Path = cache_path
|
| 184 |
+
|
| 185 |
+
def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
|
| 186 |
+
JITCache.__setitem__(self, key, fn)
|
| 187 |
+
self._try_export_to_storage(key, fn)
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, key: CompileKeyType) -> CallableFunction:
|
| 190 |
+
# Use __contains__ to try populating in-memory cache with persistent storage
|
| 191 |
+
self.__contains__(key)
|
| 192 |
+
return JITCache.__getitem__(self, key)
|
| 193 |
+
|
| 194 |
+
def __contains__(self, key: CompileKeyType) -> bool:
|
| 195 |
+
# Checks in-memory cache first, then tries loading from storage.
|
| 196 |
+
# When returning True, guarantees the in-memory cache is populated.
|
| 197 |
+
if JITCache.__contains__(self, key):
|
| 198 |
+
return True
|
| 199 |
+
return self._try_load_from_storage(key)
|
| 200 |
+
|
| 201 |
+
def _try_load_from_storage(self, key: CompileKeyType) -> bool:
|
| 202 |
+
"""
|
| 203 |
+
Try to load a function from persistent storage into in-memory cache.
|
| 204 |
+
Returns True if loaded successfully, False if not found on disk.
|
| 205 |
+
Holds a shared lock during loading to prevent concurrent writes.
|
| 206 |
+
"""
|
| 207 |
+
sha256_hex = self._key_to_hash(key)
|
| 208 |
+
so_path = self.cache_path / f"{sha256_hex}.so"
|
| 209 |
+
with FileLock(
|
| 210 |
+
self._lock_path(sha256_hex),
|
| 211 |
+
exclusive=False,
|
| 212 |
+
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 213 |
+
label=sha256_hex,
|
| 214 |
+
):
|
| 215 |
+
if so_path.exists():
|
| 216 |
+
logger.debug(
|
| 217 |
+
"Loading compiled function from disk: %s", so_path
|
| 218 |
+
)
|
| 219 |
+
m = cute.runtime.load_module(
|
| 220 |
+
str(so_path), enable_tvm_ffi=True
|
| 221 |
+
)
|
| 222 |
+
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
|
| 223 |
+
JITCache.__setitem__(self, key, fn)
|
| 224 |
+
return True
|
| 225 |
+
else:
|
| 226 |
+
logger.debug(
|
| 227 |
+
"Cache miss on disk for key hash %s", sha256_hex
|
| 228 |
+
)
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
def _try_export_to_storage(
|
| 232 |
+
self, key: CompileKeyType, fn: JitCompiledFunction
|
| 233 |
+
) -> None:
|
| 234 |
+
"""Export a compiled function to persistent storage under exclusive lock."""
|
| 235 |
+
sha256_hex = self._key_to_hash(key)
|
| 236 |
+
with FileLock(
|
| 237 |
+
self._lock_path(sha256_hex),
|
| 238 |
+
exclusive=True,
|
| 239 |
+
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 240 |
+
label=sha256_hex,
|
| 241 |
+
):
|
| 242 |
+
so_path = self.cache_path / f"{sha256_hex}.so"
|
| 243 |
+
if so_path.exists():
|
| 244 |
+
# Another process already exported.
|
| 245 |
+
logger.debug(
|
| 246 |
+
"Skipping export, already on disk: %s", so_path
|
| 247 |
+
)
|
| 248 |
+
return
|
| 249 |
+
obj_path = self.cache_path / f"{sha256_hex}.o"
|
| 250 |
+
logger.debug(
|
| 251 |
+
"Exporting compiled function to disk: %s", so_path
|
| 252 |
+
)
|
| 253 |
+
fn.export_to_c(
|
| 254 |
+
object_file_path=str(obj_path),
|
| 255 |
+
function_name=self.EXPORT_FUNCTION_PREFIX,
|
| 256 |
+
)
|
| 257 |
+
# TODO: as of cutedsl 4.4.0, `export_to_c` only supports exporting
|
| 258 |
+
# "relocatable" .o files. But tvm_ffi expects "shared library" .so
|
| 259 |
+
# files. Link ourselves to workaround.
|
| 260 |
+
if JITPersistentCache._compiler is None:
|
| 261 |
+
JITPersistentCache._compiler = new_compiler()
|
| 262 |
+
JITPersistentCache._compiler.link_shared_object(
|
| 263 |
+
[str(obj_path)], str(so_path)
|
| 264 |
+
)
|
| 265 |
+
obj_path.unlink()
|
| 266 |
+
logger.debug(
|
| 267 |
+
"Successfully exported compiled function to disk: %s", so_path
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def _key_to_hash(self, key: CompileKeyType) -> str:
|
| 271 |
+
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
| 272 |
+
|
| 273 |
+
def _lock_path(self, sha256_hex: str) -> Path:
|
| 274 |
+
return self.cache_path / f"{sha256_hex}.lock"
|
| 275 |
+
|
| 276 |
+
def clear(self) -> None:
|
| 277 |
+
"""
|
| 278 |
+
Not only clear the in-memory cache. Also purge persistent compilation cache.
|
| 279 |
+
"""
|
| 280 |
+
logger.debug(
|
| 281 |
+
"Clearing persistent cache at %s", self.cache_path
|
| 282 |
+
)
|
| 283 |
+
super().clear()
|
| 284 |
+
for child in self.cache_path.iterdir():
|
| 285 |
+
child.unlink()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_jit_cache(name: str | None = None) -> JITCache:
|
| 289 |
+
"""
|
| 290 |
+
JIT cache factory.
|
| 291 |
+
`name` is an optional identifier to create subdirectories to manage cache.
|
| 292 |
+
|
| 293 |
+
When persistent caching is enabled, artifacts are namespaced under a
|
| 294 |
+
source fingerprint directory so that code or dependency changes
|
| 295 |
+
automatically invalidate stale entries.
|
| 296 |
+
"""
|
| 297 |
+
if CUTE_DSL_CACHE_ENABLED:
|
| 298 |
+
path = get_cache_path() / _compute_source_fingerprint()
|
| 299 |
+
if name:
|
| 300 |
+
path = path / name
|
| 301 |
+
logger.debug(
|
| 302 |
+
"Creating persistent JIT cache at %s", path
|
| 303 |
+
)
|
| 304 |
+
return JITPersistentCache(path)
|
| 305 |
+
else:
|
| 306 |
+
logger.debug("Persistent cache disabled, using in-memory JIT cache")
|
| 307 |
+
return JITCache()
|
build/torch-cuda/compute_block_sparsity.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Callable, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
import torch
|
| 7 |
+
from cutlass import Boolean, Int8, Int32, const_expr
|
| 8 |
+
|
| 9 |
+
from .block_sparsity import (
|
| 10 |
+
BlockSparseTensors,
|
| 11 |
+
BlockSparseTensorsTorch,
|
| 12 |
+
to_cute_block_sparse_tensors,
|
| 13 |
+
)
|
| 14 |
+
from .utils import hash_callable, scalar_to_ssa, ssa_to_scalar
|
| 15 |
+
from .seqlen_info import SeqlenInfoQK
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BlockSparsityKernel:
|
| 19 |
+
"""Block sparsity kernel for FlexAttention.
|
| 20 |
+
|
| 21 |
+
This kernel computes `mask_mod` for every token of each block
|
| 22 |
+
to determine if an n block is full, masked, or neither.
|
| 23 |
+
|
| 24 |
+
Writes block counts and indices to a BlockSparseTensors object.
|
| 25 |
+
|
| 26 |
+
When use_fast_sampling=True, uses 5-point sampling (4 corners + center)
|
| 27 |
+
which is much faster but only suitable for masks where this is sufficient.
|
| 28 |
+
|
| 29 |
+
TODO:
|
| 30 |
+
- optimize mask_mod evaluation
|
| 31 |
+
- varlen support
|
| 32 |
+
- transposed tensors for bwd pass
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
mask_mod: Callable,
|
| 38 |
+
tile_mn: Tuple[int, int],
|
| 39 |
+
compute_full_blocks: bool = True,
|
| 40 |
+
use_aux_tensors: bool = False,
|
| 41 |
+
use_fast_sampling: bool = False,
|
| 42 |
+
):
|
| 43 |
+
self.mask_mod = mask_mod
|
| 44 |
+
self.tile_mn = tile_mn
|
| 45 |
+
self.compute_full_blocks = compute_full_blocks
|
| 46 |
+
self.use_aux_tensors = use_aux_tensors
|
| 47 |
+
self.use_fast_sampling = use_fast_sampling
|
| 48 |
+
|
| 49 |
+
@cute.jit
|
| 50 |
+
def __call__(
|
| 51 |
+
self,
|
| 52 |
+
blocksparse_tensors: BlockSparseTensors,
|
| 53 |
+
seqlen_q: Int32,
|
| 54 |
+
seqlen_k: Int32,
|
| 55 |
+
aux_tensors: Optional[list] = None,
|
| 56 |
+
):
|
| 57 |
+
self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors
|
| 58 |
+
|
| 59 |
+
if const_expr(self.compute_full_blocks):
|
| 60 |
+
assert self.full_cnt is not None and self.full_idx is not None, (
|
| 61 |
+
"full block tensors must be provided when computing full blocks"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape
|
| 65 |
+
# launch 1 CTA per m block
|
| 66 |
+
grid = [num_m_blocks, num_heads, batch_size]
|
| 67 |
+
|
| 68 |
+
if const_expr(self.use_fast_sampling):
|
| 69 |
+
num_threads = 5
|
| 70 |
+
self.num_warps = 1
|
| 71 |
+
else:
|
| 72 |
+
num_threads = self.tile_mn[0]
|
| 73 |
+
self.num_warps = (num_threads + 32 - 1) // 32
|
| 74 |
+
|
| 75 |
+
self.kernel(
|
| 76 |
+
self.mask_cnt,
|
| 77 |
+
self.mask_idx,
|
| 78 |
+
self.full_cnt,
|
| 79 |
+
self.full_idx,
|
| 80 |
+
num_n_blocks,
|
| 81 |
+
seqlen_q,
|
| 82 |
+
seqlen_k,
|
| 83 |
+
aux_tensors,
|
| 84 |
+
).launch(grid=grid, block=[num_threads, 1, 1])
|
| 85 |
+
|
| 86 |
+
@cute.kernel
|
| 87 |
+
def kernel(
|
| 88 |
+
self,
|
| 89 |
+
mask_cnt: cute.Tensor,
|
| 90 |
+
mask_idx: cute.Tensor,
|
| 91 |
+
full_cnt: cute.Tensor,
|
| 92 |
+
full_idx: cute.Tensor,
|
| 93 |
+
num_n_blocks: Int32,
|
| 94 |
+
seqlen_q: Int32,
|
| 95 |
+
seqlen_k: Int32,
|
| 96 |
+
aux_tensors: Optional[list] = None,
|
| 97 |
+
):
|
| 98 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 99 |
+
warp_idx = cute.arch.warp_idx()
|
| 100 |
+
lane_id = cute.arch.lane_idx()
|
| 101 |
+
m_block, head_idx, batch_idx = cute.arch.block_idx()
|
| 102 |
+
|
| 103 |
+
ssa = partial(scalar_to_ssa, dtype=Int32)
|
| 104 |
+
|
| 105 |
+
seqlen = SeqlenInfoQK.create(
|
| 106 |
+
batch_idx,
|
| 107 |
+
seqlen_q,
|
| 108 |
+
seqlen_k,
|
| 109 |
+
mCuSeqlensQ=None,
|
| 110 |
+
mCuSeqlensK=None,
|
| 111 |
+
mSeqUsedQ=None,
|
| 112 |
+
mSeqUsedK=None,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
@cute.struct
|
| 116 |
+
class SharedStorage:
|
| 117 |
+
reduction_buffer_smem: cute.struct.Align[
|
| 118 |
+
cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
smem = cutlass.utils.SmemAllocator()
|
| 122 |
+
storage = smem.allocate(SharedStorage, 16)
|
| 123 |
+
|
| 124 |
+
reduction_buffer = storage.reduction_buffer_smem.get_tensor(
|
| 125 |
+
cute.make_layout((self.num_warps, 2))
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
num_mask_blocks = Int32(0)
|
| 129 |
+
num_full_blocks = Int32(0)
|
| 130 |
+
|
| 131 |
+
for n_block in cutlass.range(num_n_blocks, unroll_full=True):
|
| 132 |
+
m_base = m_block * self.tile_mn[0]
|
| 133 |
+
n_base = n_block * self.tile_mn[1]
|
| 134 |
+
|
| 135 |
+
if const_expr(self.use_fast_sampling):
|
| 136 |
+
# Fast path: 5-point sampling (4 corners + center)
|
| 137 |
+
# Clamps OOB indices to nearest in bounds.
|
| 138 |
+
thread_result = Boolean(False)
|
| 139 |
+
thread_is_valid = Boolean(False)
|
| 140 |
+
q_idx = Int32(0)
|
| 141 |
+
kv_idx = Int32(0)
|
| 142 |
+
|
| 143 |
+
if tidx == 0:
|
| 144 |
+
# Top-left corner (0, 0); always in bounds
|
| 145 |
+
q_idx = m_base
|
| 146 |
+
kv_idx = n_base
|
| 147 |
+
elif tidx == 1:
|
| 148 |
+
# Top-right corner
|
| 149 |
+
q_idx = m_base
|
| 150 |
+
kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
|
| 151 |
+
elif tidx == 2:
|
| 152 |
+
# Bottom-left corner
|
| 153 |
+
q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
|
| 154 |
+
kv_idx = n_base
|
| 155 |
+
elif tidx == 3:
|
| 156 |
+
# Bottom-right corner
|
| 157 |
+
q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
|
| 158 |
+
kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
|
| 159 |
+
elif tidx == 4:
|
| 160 |
+
# Center point
|
| 161 |
+
q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2
|
| 162 |
+
kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2
|
| 163 |
+
else:
|
| 164 |
+
thread_is_valid = Boolean(False)
|
| 165 |
+
|
| 166 |
+
# Check bounds and determine if this thread has a valid index pair
|
| 167 |
+
if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k:
|
| 168 |
+
thread_is_valid = Boolean(True)
|
| 169 |
+
q_idx_ssa = ssa(q_idx)
|
| 170 |
+
kv_idx_ssa = ssa(kv_idx)
|
| 171 |
+
thread_result = ssa_to_scalar(
|
| 172 |
+
self.mask_mod(
|
| 173 |
+
ssa(batch_idx),
|
| 174 |
+
ssa(head_idx),
|
| 175 |
+
q_idx_ssa,
|
| 176 |
+
kv_idx_ssa,
|
| 177 |
+
seqlen,
|
| 178 |
+
aux_tensors,
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
thread_is_valid = Boolean(False)
|
| 183 |
+
|
| 184 |
+
# Use vote_any_sync to see if any valid thread found unmasked or masked
|
| 185 |
+
# Only count results from threads that checked valid indices
|
| 186 |
+
has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid)
|
| 187 |
+
has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid)
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
# Full path: check all elements in the block
|
| 191 |
+
# Track if this thread's row has any masked or unmasked elements
|
| 192 |
+
thread_has_unmasked = Boolean(False)
|
| 193 |
+
thread_has_masked = Boolean(False)
|
| 194 |
+
thread_is_valid = Boolean(False)
|
| 195 |
+
|
| 196 |
+
# Each thread handles 1 row
|
| 197 |
+
q_idx = m_base + tidx
|
| 198 |
+
kv_idx = Int32(0)
|
| 199 |
+
if tidx < self.tile_mn[0] and q_idx < seqlen_q:
|
| 200 |
+
thread_is_valid = Boolean(True)
|
| 201 |
+
q_idx_ssa = ssa(q_idx)
|
| 202 |
+
|
| 203 |
+
# Loop over all columns in this row
|
| 204 |
+
for c in cutlass.range(self.tile_mn[1], unroll_full=True):
|
| 205 |
+
kv_idx = n_base + c
|
| 206 |
+
kv_idx_ssa = ssa(kv_idx)
|
| 207 |
+
|
| 208 |
+
# Only check elements within valid sequence bounds
|
| 209 |
+
if kv_idx < seqlen_k:
|
| 210 |
+
# Direct scalar call
|
| 211 |
+
mask_val = ssa_to_scalar(
|
| 212 |
+
self.mask_mod(
|
| 213 |
+
ssa(batch_idx),
|
| 214 |
+
ssa(head_idx),
|
| 215 |
+
q_idx_ssa,
|
| 216 |
+
kv_idx_ssa,
|
| 217 |
+
seqlen,
|
| 218 |
+
aux_tensors,
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Update tracking flags
|
| 223 |
+
if mask_val:
|
| 224 |
+
thread_has_unmasked = Boolean(True)
|
| 225 |
+
else:
|
| 226 |
+
thread_has_masked = Boolean(True)
|
| 227 |
+
|
| 228 |
+
# Block-level reduction to combine results across all threads
|
| 229 |
+
# Only count votes from threads that checked valid indices
|
| 230 |
+
warp_has_unmasked_mask = cute.arch.vote_any_sync(
|
| 231 |
+
thread_has_unmasked & thread_is_valid
|
| 232 |
+
)
|
| 233 |
+
warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid)
|
| 234 |
+
|
| 235 |
+
# lane 0 writes the ballot mask to shared memory
|
| 236 |
+
lane_id = tidx % 32
|
| 237 |
+
if lane_id == 0:
|
| 238 |
+
# Store as Int8
|
| 239 |
+
reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0)
|
| 240 |
+
reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0)
|
| 241 |
+
|
| 242 |
+
cute.arch.sync_threads()
|
| 243 |
+
|
| 244 |
+
# Thread 0 ORs all warp results together
|
| 245 |
+
has_unmasked = Boolean(False)
|
| 246 |
+
has_masked = Boolean(False)
|
| 247 |
+
if tidx == 0:
|
| 248 |
+
for w in cutlass.range(self.num_warps):
|
| 249 |
+
if reduction_buffer[w, 0]:
|
| 250 |
+
has_unmasked = Boolean(True)
|
| 251 |
+
if reduction_buffer[w, 1]:
|
| 252 |
+
has_masked = Boolean(True)
|
| 253 |
+
|
| 254 |
+
# Only thread 0 updates the output arrays (common to both paths)
|
| 255 |
+
if tidx == 0:
|
| 256 |
+
# Block classification based on what we found:
|
| 257 |
+
# - If has_masked and has_unmasked: partial block (needs masking)
|
| 258 |
+
# - If only has_unmasked: full block (no masking needed)
|
| 259 |
+
# - If only has_masked: skip this block entirely
|
| 260 |
+
is_partial = Boolean(has_masked and has_unmasked)
|
| 261 |
+
is_full = Boolean(has_unmasked and (not has_masked))
|
| 262 |
+
|
| 263 |
+
if is_partial:
|
| 264 |
+
mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block
|
| 265 |
+
num_mask_blocks += 1
|
| 266 |
+
elif is_full and const_expr(self.compute_full_blocks):
|
| 267 |
+
full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block
|
| 268 |
+
num_full_blocks += 1
|
| 269 |
+
|
| 270 |
+
# Only thread 0 writes back the counts
|
| 271 |
+
if tidx == 0:
|
| 272 |
+
mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks
|
| 273 |
+
if const_expr(self.compute_full_blocks):
|
| 274 |
+
full_cnt[batch_idx, head_idx, m_block] = num_full_blocks
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_block_sparsity(
|
| 278 |
+
tile_m,
|
| 279 |
+
tile_n,
|
| 280 |
+
batch_size,
|
| 281 |
+
num_heads,
|
| 282 |
+
seqlen_q,
|
| 283 |
+
seqlen_k,
|
| 284 |
+
mask_mod: Callable,
|
| 285 |
+
aux_tensors: Optional[list], # list[cute.Tensor]
|
| 286 |
+
device,
|
| 287 |
+
compute_full_blocks: bool = True,
|
| 288 |
+
use_fast_sampling: bool = False,
|
| 289 |
+
) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]:
|
| 290 |
+
"""
|
| 291 |
+
Computes block sparsity for a given `mask_mod`.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
tile_m: The tile size for the m dimension.
|
| 295 |
+
tile_n: The tile size for the n dimension.
|
| 296 |
+
batch_size: The batch size.
|
| 297 |
+
num_heads: The number of heads.
|
| 298 |
+
seqlen_q: The sequence length for the query.
|
| 299 |
+
seqlen_k: The sequence length for the key.
|
| 300 |
+
mask_mod: The `mask_mod` callable to use.
|
| 301 |
+
aux_tensors: A list of auxiliary tensors.
|
| 302 |
+
device: The device to use.
|
| 303 |
+
compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed.
|
| 304 |
+
use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`.
|
| 308 |
+
"""
|
| 309 |
+
# Check if mask_mod is marked as suitable for 5-point fast sampling
|
| 310 |
+
use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling)
|
| 311 |
+
|
| 312 |
+
num_m_blocks = (seqlen_q + tile_m - 1) // tile_m
|
| 313 |
+
num_n_blocks = (seqlen_k + tile_n - 1) // tile_n
|
| 314 |
+
|
| 315 |
+
mask_block_cnt = torch.zeros(
|
| 316 |
+
(batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32
|
| 317 |
+
)
|
| 318 |
+
mask_block_idx = torch.zeros(
|
| 319 |
+
(batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
|
| 320 |
+
)
|
| 321 |
+
full_block_cnt = (
|
| 322 |
+
torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32)
|
| 323 |
+
if compute_full_blocks
|
| 324 |
+
else None
|
| 325 |
+
)
|
| 326 |
+
full_block_idx = (
|
| 327 |
+
torch.zeros(
|
| 328 |
+
(batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
|
| 329 |
+
)
|
| 330 |
+
if compute_full_blocks
|
| 331 |
+
else None
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
blocksparse_tensors_torch = BlockSparseTensorsTorch(
|
| 335 |
+
mask_block_cnt=mask_block_cnt,
|
| 336 |
+
mask_block_idx=mask_block_idx,
|
| 337 |
+
full_block_cnt=full_block_cnt,
|
| 338 |
+
full_block_idx=full_block_idx,
|
| 339 |
+
block_size=(tile_m, tile_n),
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
mask_mod_hash = hash_callable(mask_mod)
|
| 343 |
+
blocksparse_tensors = to_cute_block_sparse_tensors(
|
| 344 |
+
blocksparse_tensors_torch, enable_tvm_ffi=True
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
compile_key = (
|
| 348 |
+
tile_m,
|
| 349 |
+
tile_n,
|
| 350 |
+
mask_mod_hash,
|
| 351 |
+
compute_full_blocks,
|
| 352 |
+
aux_tensors is not None,
|
| 353 |
+
use_fast_sampling,
|
| 354 |
+
)
|
| 355 |
+
if compile_key not in compute_block_sparsity.compile_cache:
|
| 356 |
+
kernel = BlockSparsityKernel(
|
| 357 |
+
mask_mod,
|
| 358 |
+
tile_mn=(tile_m, tile_n),
|
| 359 |
+
compute_full_blocks=compute_full_blocks,
|
| 360 |
+
use_aux_tensors=aux_tensors is not None,
|
| 361 |
+
use_fast_sampling=use_fast_sampling,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
compute_block_sparsity.compile_cache[compile_key] = cute.compile(
|
| 365 |
+
kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
compute_block_sparsity.compile_cache[compile_key](
|
| 369 |
+
blocksparse_tensors_torch[:4],
|
| 370 |
+
seqlen_q,
|
| 371 |
+
seqlen_k,
|
| 372 |
+
aux_tensors,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return blocksparse_tensors, blocksparse_tensors_torch
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
compute_block_sparsity.compile_cache = {}
|
build/torch-cuda/copy_utils.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Type, Callable
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Float32, Int32, const_expr
|
| 9 |
+
from cutlass.cute.nvgpu import cpasync
|
| 10 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 11 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 12 |
+
from cutlass._mlir.dialects import llvm
|
| 13 |
+
import cutlass.pipeline
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dsl_user_op
|
| 17 |
+
def cvt_copy(
|
| 18 |
+
atom: cute.CopyAtom,
|
| 19 |
+
src: cute.Tensor,
|
| 20 |
+
dst: cute.Tensor,
|
| 21 |
+
*,
|
| 22 |
+
pred: Optional[cute.Tensor] = None,
|
| 23 |
+
loc=None,
|
| 24 |
+
ip=None,
|
| 25 |
+
**kwargs,
|
| 26 |
+
) -> None:
|
| 27 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 28 |
+
if const_expr(src.element_type != dst.element_type):
|
| 29 |
+
src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip)
|
| 30 |
+
src_cvt.store(src.load().to(dst.element_type))
|
| 31 |
+
src = src_cvt
|
| 32 |
+
cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dsl_user_op
|
| 36 |
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 37 |
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
| 38 |
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 39 |
+
return dst
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dsl_user_op
|
| 43 |
+
def get_copy_atom(
|
| 44 |
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
| 45 |
+
) -> cute.CopyAtom:
|
| 46 |
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
| 47 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 48 |
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dsl_user_op
|
| 52 |
+
def make_tmem_copy(
|
| 53 |
+
tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None
|
| 54 |
+
) -> cute.CopyAtom:
|
| 55 |
+
num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)
|
| 56 |
+
assert num_dp == 32
|
| 57 |
+
assert num_bits == 32
|
| 58 |
+
tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)
|
| 59 |
+
layout_tv = cute.make_layout(
|
| 60 |
+
((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))
|
| 61 |
+
)
|
| 62 |
+
return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dsl_user_op
|
| 66 |
+
def copy(
|
| 67 |
+
src: cute.Tensor,
|
| 68 |
+
dst: cute.Tensor,
|
| 69 |
+
*,
|
| 70 |
+
pred: Optional[cute.Tensor] = None,
|
| 71 |
+
num_copy_elems: int = 1,
|
| 72 |
+
is_async: bool = False,
|
| 73 |
+
loc=None,
|
| 74 |
+
ip=None,
|
| 75 |
+
**kwargs,
|
| 76 |
+
) -> None:
|
| 77 |
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
| 78 |
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def tiled_copy_1d(
|
| 82 |
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
| 83 |
+
) -> cute.TiledCopy:
|
| 84 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 85 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 86 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 87 |
+
thr_layout = cute.make_layout(num_threads)
|
| 88 |
+
val_layout = cute.make_layout(num_copy_elems)
|
| 89 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def tiled_copy_2d(
|
| 93 |
+
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
| 94 |
+
) -> cute.TiledCopy:
|
| 95 |
+
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
| 96 |
+
copy_elems = num_copy_bits // dtype.width
|
| 97 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 98 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 99 |
+
gmem_threads_per_row = major_mode_size // copy_elems
|
| 100 |
+
assert num_threads % gmem_threads_per_row == 0
|
| 101 |
+
thr_layout = cute.make_ordered_layout(
|
| 102 |
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 103 |
+
order=(1, 0),
|
| 104 |
+
)
|
| 105 |
+
val_layout = cute.make_layout((1, copy_elems))
|
| 106 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dsl_user_op
|
| 110 |
+
def atomic_add_fp32x4(
|
| 111 |
+
a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
|
| 112 |
+
) -> None:
|
| 113 |
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 114 |
+
# cache_hint = cutlass.Int64(0x12F0000000000000)
|
| 115 |
+
llvm.inline_asm(
|
| 116 |
+
None,
|
| 117 |
+
[
|
| 118 |
+
gmem_ptr_i64,
|
| 119 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 120 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 121 |
+
Float32(c).ir_value(loc=loc, ip=ip),
|
| 122 |
+
Float32(d).ir_value(loc=loc, ip=ip),
|
| 123 |
+
],
|
| 124 |
+
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
|
| 125 |
+
"{\n\t"
|
| 126 |
+
# ".reg .b128 abcd;\n\t"
|
| 127 |
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
| 128 |
+
".reg .v4 .f32 abcd;\n\t"
|
| 129 |
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
| 130 |
+
"mov.f32 abcd.x, $1;\n\t"
|
| 131 |
+
"mov.f32 abcd.y, $2;\n\t"
|
| 132 |
+
"mov.f32 abcd.z, $3;\n\t"
|
| 133 |
+
"mov.f32 abcd.w, $4;\n\t"
|
| 134 |
+
"red.global.add.v4.f32 [$0], abcd;\n\t"
|
| 135 |
+
# "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t"
|
| 136 |
+
"}\n",
|
| 137 |
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
|
| 138 |
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
|
| 139 |
+
"l,f,f,f,f",
|
| 140 |
+
# "l,f,l",
|
| 141 |
+
has_side_effects=True,
|
| 142 |
+
is_align_stack=False,
|
| 143 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dsl_user_op
|
| 148 |
+
def set_block_rank(
|
| 149 |
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
| 150 |
+
) -> Int32:
|
| 151 |
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
| 152 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 153 |
+
return Int32(
|
| 154 |
+
llvm.inline_asm(
|
| 155 |
+
T.i32(),
|
| 156 |
+
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
| 157 |
+
"mapa.shared::cluster.u32 $0, $1, $2;",
|
| 158 |
+
"=r,r,r",
|
| 159 |
+
has_side_effects=False,
|
| 160 |
+
is_align_stack=False,
|
| 161 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dsl_user_op
|
| 167 |
+
def store_shared_remote_fp32x4(
|
| 168 |
+
a: Float32,
|
| 169 |
+
b: Float32,
|
| 170 |
+
c: Float32,
|
| 171 |
+
d: Float32,
|
| 172 |
+
smem_ptr: cute.Pointer,
|
| 173 |
+
mbar_ptr: cute.Pointer,
|
| 174 |
+
peer_cta_rank_in_cluster: Int32,
|
| 175 |
+
*,
|
| 176 |
+
loc=None,
|
| 177 |
+
ip=None,
|
| 178 |
+
) -> None:
|
| 179 |
+
remote_smem_ptr_i32 = set_block_rank(
|
| 180 |
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 181 |
+
).ir_value()
|
| 182 |
+
remote_mbar_ptr_i32 = set_block_rank(
|
| 183 |
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 184 |
+
).ir_value()
|
| 185 |
+
llvm.inline_asm(
|
| 186 |
+
None,
|
| 187 |
+
[
|
| 188 |
+
remote_smem_ptr_i32,
|
| 189 |
+
remote_mbar_ptr_i32,
|
| 190 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 191 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 192 |
+
Float32(c).ir_value(loc=loc, ip=ip),
|
| 193 |
+
Float32(d).ir_value(loc=loc, ip=ip),
|
| 194 |
+
],
|
| 195 |
+
"{\n\t"
|
| 196 |
+
".reg .v4 .f32 abcd;\n\t"
|
| 197 |
+
"mov.f32 abcd.x, $2;\n\t"
|
| 198 |
+
"mov.f32 abcd.y, $3;\n\t"
|
| 199 |
+
"mov.f32 abcd.z, $4;\n\t"
|
| 200 |
+
"mov.f32 abcd.w, $5;\n\t"
|
| 201 |
+
"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t"
|
| 202 |
+
"}\n",
|
| 203 |
+
"r,r,f,f,f,f",
|
| 204 |
+
has_side_effects=True,
|
| 205 |
+
is_align_stack=False,
|
| 206 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dsl_user_op
|
| 211 |
+
def cpasync_bulk_s2cluster(
|
| 212 |
+
smem_src_ptr: cute.Pointer,
|
| 213 |
+
smem_dst_ptr: cute.Pointer,
|
| 214 |
+
mbar_ptr: cute.Pointer,
|
| 215 |
+
size: int | Int32,
|
| 216 |
+
peer_cta_rank_in_cluster: Int32,
|
| 217 |
+
*,
|
| 218 |
+
loc=None,
|
| 219 |
+
ip=None,
|
| 220 |
+
):
|
| 221 |
+
smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 222 |
+
smem_dst_ptr_i32 = set_block_rank(
|
| 223 |
+
smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 224 |
+
).ir_value()
|
| 225 |
+
mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
|
| 226 |
+
llvm.inline_asm(
|
| 227 |
+
None,
|
| 228 |
+
[
|
| 229 |
+
smem_dst_ptr_i32,
|
| 230 |
+
smem_src_ptr_i32,
|
| 231 |
+
mbar_ptr_i32,
|
| 232 |
+
Int32(size).ir_value(loc=loc, ip=ip),
|
| 233 |
+
],
|
| 234 |
+
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];",
|
| 235 |
+
"r,r,r,r",
|
| 236 |
+
has_side_effects=True,
|
| 237 |
+
is_align_stack=False,
|
| 238 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@dsl_user_op
|
| 243 |
+
def cpasync_bulk_g2s(
|
| 244 |
+
gmem_ptr: cute.Pointer,
|
| 245 |
+
smem_ptr: cute.Pointer,
|
| 246 |
+
tma_bar_ptr: cute.Pointer,
|
| 247 |
+
size: int | Int32,
|
| 248 |
+
*,
|
| 249 |
+
loc=None,
|
| 250 |
+
ip=None,
|
| 251 |
+
):
|
| 252 |
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 253 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 254 |
+
mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 255 |
+
llvm.inline_asm(
|
| 256 |
+
None,
|
| 257 |
+
[gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],
|
| 258 |
+
"cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];",
|
| 259 |
+
"l,r,r,r",
|
| 260 |
+
has_side_effects=True,
|
| 261 |
+
is_align_stack=False,
|
| 262 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@dsl_user_op
|
| 267 |
+
def cpasync_reduce_bulk_add_f32(
|
| 268 |
+
smem_ptr: cute.Pointer,
|
| 269 |
+
gmem_ptr: cute.Pointer,
|
| 270 |
+
store_bytes: int | Int32,
|
| 271 |
+
*,
|
| 272 |
+
loc=None,
|
| 273 |
+
ip=None,
|
| 274 |
+
):
|
| 275 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 276 |
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
| 277 |
+
llvm.inline_asm(
|
| 278 |
+
None,
|
| 279 |
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
| 280 |
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
| 281 |
+
"l,r,r",
|
| 282 |
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
| 283 |
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
| 284 |
+
# "l,r,r,l",
|
| 285 |
+
has_side_effects=True,
|
| 286 |
+
is_align_stack=False,
|
| 287 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def cpasync_bulk_get_copy_fn(
|
| 292 |
+
src_tensor: cute.Tensor,
|
| 293 |
+
dst_tensor: cute.Tensor,
|
| 294 |
+
single_stage: bool = False,
|
| 295 |
+
**kwargs,
|
| 296 |
+
) -> Callable:
|
| 297 |
+
# src_is_smem = const_expr(
|
| 298 |
+
# isinstance(src_tensor.iterator, cute.Pointer)
|
| 299 |
+
# and src_tensor.memspace == cute.AddressSpace.smem
|
| 300 |
+
# )
|
| 301 |
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
| 302 |
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
| 303 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 304 |
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
| 305 |
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
| 306 |
+
|
| 307 |
+
def copy_bulk(src_idx, dst_idx, **new_kwargs):
|
| 308 |
+
size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)
|
| 309 |
+
cpasync_bulk_g2s(
|
| 310 |
+
src[None, src_idx].iterator,
|
| 311 |
+
dst[None, dst_idx].iterator,
|
| 312 |
+
size=size,
|
| 313 |
+
**new_kwargs,
|
| 314 |
+
**kwargs,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def copy_bulk_single_stage(**new_kwargs):
|
| 318 |
+
size = const_expr(cute.size(src.shape) * src.element_type.width // 8)
|
| 319 |
+
cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)
|
| 320 |
+
|
| 321 |
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def tma_get_copy_fn(
|
| 325 |
+
atom: cute.CopyAtom,
|
| 326 |
+
cta_coord: cute.Coord,
|
| 327 |
+
cta_layout: cute.Layout,
|
| 328 |
+
src_tensor: cute.Tensor,
|
| 329 |
+
dst_tensor: cute.Tensor,
|
| 330 |
+
filter_zeros: bool = False,
|
| 331 |
+
single_stage: bool = False,
|
| 332 |
+
**kwargs,
|
| 333 |
+
) -> Callable:
|
| 334 |
+
src_is_smem = const_expr(
|
| 335 |
+
isinstance(src_tensor.iterator, cute.Pointer)
|
| 336 |
+
and src_tensor.memspace == cute.AddressSpace.smem
|
| 337 |
+
)
|
| 338 |
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
| 339 |
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
| 340 |
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
| 341 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 342 |
+
s, g = cpasync.tma_partition(
|
| 343 |
+
atom,
|
| 344 |
+
cta_coord,
|
| 345 |
+
cta_layout,
|
| 346 |
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 347 |
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 348 |
+
)
|
| 349 |
+
if const_expr(filter_zeros):
|
| 350 |
+
s = cute.filter_zeros(s)
|
| 351 |
+
g = cute.filter_zeros(g)
|
| 352 |
+
src, dst = (s, g) if src_is_smem else (g, s)
|
| 353 |
+
|
| 354 |
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
| 355 |
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
| 356 |
+
|
| 357 |
+
def copy_tma_single_stage(**new_kwargs):
|
| 358 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
| 359 |
+
|
| 360 |
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
| 364 |
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
| 365 |
+
copy(
|
| 366 |
+
src_idx=src_idx,
|
| 367 |
+
dst_idx=producer_state.index,
|
| 368 |
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
| 369 |
+
**new_kwargs,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
return copy_fn
|
build/torch-cuda/cute_dsl_ptxas.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System ptxas replacement for CUTLASS DSL.
|
| 3 |
+
Environment variables:
|
| 4 |
+
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
|
| 5 |
+
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import re
|
| 11 |
+
import ctypes
|
| 12 |
+
import subprocess
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import cutlass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
|
| 19 |
+
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
|
| 20 |
+
|
| 21 |
+
_original_load_cuda_library = None
|
| 22 |
+
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _log(msg):
|
| 26 |
+
if VERBOSE:
|
| 27 |
+
print(f"[ptxas] {msg}", file=sys.stderr)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
|
| 31 |
+
"""Find and read PTX file, stripping null bytes."""
|
| 32 |
+
func_name = getattr(compiled_func, "function_name", None)
|
| 33 |
+
if not func_name:
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
|
| 37 |
+
for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
|
| 38 |
+
content = ptx_path.read_text().rstrip("\x00")
|
| 39 |
+
if ".entry " in content and content.rstrip().endswith("}"):
|
| 40 |
+
_log(f"Found PTX: {ptx_path}")
|
| 41 |
+
return content, ptx_path
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
|
| 46 |
+
"""Compile PTX to cubin using system ptxas."""
|
| 47 |
+
# Extract arch from PTX
|
| 48 |
+
match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
|
| 49 |
+
arch = match.group(1) if match else "sm_90a"
|
| 50 |
+
|
| 51 |
+
# Write stripped content back if needed
|
| 52 |
+
if ptx_path.read_text() != ptx_content:
|
| 53 |
+
ptx_path.write_text(ptx_content)
|
| 54 |
+
|
| 55 |
+
# Compile
|
| 56 |
+
cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
|
| 57 |
+
try:
|
| 58 |
+
assert CUTE_DSL_PTXAS_PATH is not None
|
| 59 |
+
result = subprocess.run(
|
| 60 |
+
[CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
|
| 61 |
+
capture_output=True,
|
| 62 |
+
text=True,
|
| 63 |
+
)
|
| 64 |
+
if result.returncode != 0:
|
| 65 |
+
raise RuntimeError(f"ptxas failed: {result.stderr}")
|
| 66 |
+
|
| 67 |
+
cubin_data = cubin_tmp.read_bytes()
|
| 68 |
+
_log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
|
| 69 |
+
|
| 70 |
+
# Save cubin if CUTE_DSL_KEEP_CUBIN is set
|
| 71 |
+
if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
|
| 72 |
+
cubin_out = ptx_path.with_suffix(".cubin")
|
| 73 |
+
cubin_out.write_bytes(cubin_data)
|
| 74 |
+
_log(f"Saved: {cubin_out}")
|
| 75 |
+
|
| 76 |
+
return cubin_data
|
| 77 |
+
finally:
|
| 78 |
+
cubin_tmp.unlink(missing_ok=True)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _patched_load_cuda_library(self):
|
| 82 |
+
"""Replacement for _load_cuda_library that uses system ptxas."""
|
| 83 |
+
|
| 84 |
+
result = _get_ptx(self)
|
| 85 |
+
if not result:
|
| 86 |
+
_log("PTX not found, falling back to embedded ptxas")
|
| 87 |
+
return _original_load_cuda_library(self)
|
| 88 |
+
|
| 89 |
+
ptx_content, ptx_path = result
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
cubin = _compile_ptx(ptx_path, ptx_content)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
_log(f"Compilation failed ({e}), falling back to embedded ptxas")
|
| 95 |
+
return _original_load_cuda_library(self)
|
| 96 |
+
|
| 97 |
+
# Load cubin
|
| 98 |
+
import cuda.bindings.runtime as cuda_runtime
|
| 99 |
+
|
| 100 |
+
err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
|
| 101 |
+
if err != cuda_runtime.cudaError_t.cudaSuccess:
|
| 102 |
+
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
|
| 103 |
+
return _original_load_cuda_library(self)
|
| 104 |
+
|
| 105 |
+
# Register kernels on all devices
|
| 106 |
+
_, cuda_load_to_device = self._get_cuda_init_and_load()
|
| 107 |
+
lib_ptr = ctypes.c_void_p(int(library))
|
| 108 |
+
dev_id = ctypes.c_int32(0)
|
| 109 |
+
err_val = ctypes.c_int32(0)
|
| 110 |
+
args = (ctypes.c_void_p * 3)(
|
| 111 |
+
ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
|
| 112 |
+
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
|
| 113 |
+
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
for dev in range(self.num_devices):
|
| 117 |
+
dev_id.value = dev
|
| 118 |
+
cuda_load_to_device(args)
|
| 119 |
+
if err_val.value != 0:
|
| 120 |
+
_log("cuda_load_to_device failed, falling back to embedded ptxas")
|
| 121 |
+
return _original_load_cuda_library(self)
|
| 122 |
+
|
| 123 |
+
_log(f"Loaded kernel from {ptx_path.name}")
|
| 124 |
+
|
| 125 |
+
# Delete PTX if user didn't originally want it kept
|
| 126 |
+
if not _user_wanted_ptx:
|
| 127 |
+
ptx_path.unlink(missing_ok=True)
|
| 128 |
+
|
| 129 |
+
return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def patch():
|
| 133 |
+
"""Install system ptxas hook. Call before importing cutlass."""
|
| 134 |
+
global _original_load_cuda_library, _user_wanted_ptx
|
| 135 |
+
|
| 136 |
+
assert CUTE_DSL_PTXAS_PATH is not None
|
| 137 |
+
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
|
| 138 |
+
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
|
| 139 |
+
|
| 140 |
+
# Track if user originally wanted PTX kept
|
| 141 |
+
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
|
| 142 |
+
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
|
| 143 |
+
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
|
| 144 |
+
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
|
| 148 |
+
_original_load_cuda_library = cls._load_cuda_library
|
| 149 |
+
cls._load_cuda_library = _patched_load_cuda_library
|
| 150 |
+
_log("Patch applied")
|
| 151 |
+
return
|
build/torch-cuda/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
from functools import partial, lru_cache
|
| 7 |
+
from dataclasses import dataclass, fields
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from triton.tools.disasm import extract
|
| 13 |
+
except ImportError:
|
| 14 |
+
extract = None
|
| 15 |
+
|
| 16 |
+
import cutlass
|
| 17 |
+
import cutlass.cute as cute
|
| 18 |
+
from cutlass.base_dsl.typing import JitArgument
|
| 19 |
+
from cutlass.cutlass_dsl import NumericMeta
|
| 20 |
+
from cutlass.cute.runtime import from_dlpack
|
| 21 |
+
|
| 22 |
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
| 26 |
+
cute_compile_og = cute.compile
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
torch2cute_dtype_map = {
|
| 30 |
+
torch.float16: cutlass.Float16,
|
| 31 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 32 |
+
torch.float32: cutlass.Float32,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@lru_cache
|
| 37 |
+
def get_max_active_clusters(cluster_size):
|
| 38 |
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@lru_cache
|
| 42 |
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 43 |
+
return torch.cuda.get_device_capability(device)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class ArgumentsBase(JitArgument):
|
| 48 |
+
def __c_pointers__(self):
|
| 49 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 50 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 51 |
+
c_ptrs = []
|
| 52 |
+
for obj in non_constexpr_fields:
|
| 53 |
+
if hasattr(obj, "__c_pointers__"):
|
| 54 |
+
c_ptrs.extend(obj.__c_pointers__())
|
| 55 |
+
return c_ptrs
|
| 56 |
+
|
| 57 |
+
def __get_mlir_types__(self):
|
| 58 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 59 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 60 |
+
types, self._values_pos = [], []
|
| 61 |
+
for obj in non_constexpr_fields:
|
| 62 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 63 |
+
obj_types = obj.__get_mlir_types__()
|
| 64 |
+
types.extend(obj_types)
|
| 65 |
+
self._values_pos.append(len(obj_types))
|
| 66 |
+
else:
|
| 67 |
+
self._values_pos.append(0)
|
| 68 |
+
return types
|
| 69 |
+
|
| 70 |
+
def __new_from_mlir_values__(self, values):
|
| 71 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 72 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 73 |
+
non_constexpr_fields = {
|
| 74 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 75 |
+
}
|
| 76 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 77 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 78 |
+
values = values[n_items:]
|
| 79 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_cubin_module_data_patched(cubin_data, filepath):
|
| 83 |
+
pathlib.Path(filepath).write_bytes(cubin_data)
|
| 84 |
+
return load_cubin_module_data_og(cubin_data)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def cute_compile_patched(*args, **kwargs):
|
| 88 |
+
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
| 89 |
+
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
| 90 |
+
if cubin_path is not None:
|
| 91 |
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
| 92 |
+
load_cubin_module_data_patched, filepath=cubin_path
|
| 93 |
+
)
|
| 94 |
+
output = cute_compile_og(*args, **kwargs)
|
| 95 |
+
if cubin_path is not None:
|
| 96 |
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
| 97 |
+
if extract is not None:
|
| 98 |
+
sass = extract(cubin_path, None)
|
| 99 |
+
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
| 100 |
+
return output
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def assume_strides_aligned(t):
|
| 104 |
+
"""Assume all strides except the last are divisible by 128 bits.
|
| 105 |
+
|
| 106 |
+
Python int strides (e.g., stride=0 from GQA expand) are kept as-is
|
| 107 |
+
since they're static and don't need alignment assumptions.
|
| 108 |
+
"""
|
| 109 |
+
divby = 128 // t.element_type.width
|
| 110 |
+
strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
|
| 111 |
+
return (*strides, t.stride[-1])
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def assume_tensor_aligned(t):
|
| 115 |
+
"""Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
|
| 116 |
+
if t is None:
|
| 117 |
+
return None
|
| 118 |
+
return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
|
| 122 |
+
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
|
| 123 |
+
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
|
| 124 |
+
if fully_dynamic:
|
| 125 |
+
return tensor.mark_layout_dynamic()
|
| 126 |
+
if leading_dim == -1:
|
| 127 |
+
leading_dim = t.ndim - 1
|
| 128 |
+
return tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def to_cute_aux_tensor(t, enable_tvm_ffi=True):
|
| 132 |
+
"""Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
|
| 133 |
+
This allows the user to specify alignment and leading dimension for aux tensors used in
|
| 134 |
+
custom score_mod callables.
|
| 135 |
+
"""
|
| 136 |
+
assumed_align: int = getattr(t, "__assumed_align__", None)
|
| 137 |
+
leading_dim: int = getattr(t, "__leading_dim__", None)
|
| 138 |
+
fully_dynamic: bool = leading_dim is None
|
| 139 |
+
|
| 140 |
+
return to_cute_tensor(
|
| 141 |
+
t,
|
| 142 |
+
assumed_align=assumed_align,
|
| 143 |
+
leading_dim=leading_dim,
|
| 144 |
+
fully_dynamic=fully_dynamic,
|
| 145 |
+
enable_tvm_ffi=enable_tvm_ffi,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_aux_tensor_metadata(aux_tensors):
|
| 150 |
+
return tuple(
|
| 151 |
+
(
|
| 152 |
+
getattr(t, "__assumed_align__", 0),
|
| 153 |
+
getattr(t, "__leading_dim__", -1),
|
| 154 |
+
hasattr(t, "__leading_dim__"),
|
| 155 |
+
)
|
| 156 |
+
for t in aux_tensors
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
|
| 161 |
+
"""Return tuple of bools indicating which dims have stride=0 (broadcast).
|
| 162 |
+
|
| 163 |
+
This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
|
| 164 |
+
stride=0 as static, meaning kernels compiled with different broadcast
|
| 165 |
+
patterns are not interchangeable.
|
| 166 |
+
"""
|
| 167 |
+
return tuple(s == 0 for s in tensor.stride())
|
build/torch-cuda/fast_math.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import cutlass
|
| 4 |
+
import cutlass.cute as cute
|
| 5 |
+
from cutlass import Int32
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@cute.jit
|
| 9 |
+
def clz(x: Int32) -> Int32:
|
| 10 |
+
# for i in cutlass.range_constexpr(32):
|
| 11 |
+
# if (1 << (31 - i)) & x:
|
| 12 |
+
# return Int32(i)
|
| 13 |
+
# return Int32(32)
|
| 14 |
+
# Early exit is not supported yet
|
| 15 |
+
res = Int32(32)
|
| 16 |
+
done = False
|
| 17 |
+
for i in cutlass.range(32):
|
| 18 |
+
if ((1 << (31 - i)) & x) and not done:
|
| 19 |
+
res = Int32(i)
|
| 20 |
+
done = True
|
| 21 |
+
return res
|
build/torch-cuda/flash_attn4/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-cuda/flash_bwd.py
ADDED
|
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp
|
| 3 |
+
# from Cutlass C++ to Cute-DSL.
|
| 4 |
+
import math
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from typing import Type, Callable, Optional
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import cuda.bindings.driver as cuda
|
| 10 |
+
|
| 11 |
+
import cutlass
|
| 12 |
+
import cutlass.cute as cute
|
| 13 |
+
from cutlass.cute.nvgpu import cpasync, warp
|
| 14 |
+
from cutlass import Float32, Int32
|
| 15 |
+
import cutlass.utils as utils_basic
|
| 16 |
+
|
| 17 |
+
from .quack import layout_utils
|
| 18 |
+
from . import ampere_helpers as sm80_utils
|
| 19 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 20 |
+
from . import utils
|
| 21 |
+
from .mask import AttentionMask
|
| 22 |
+
from .seqlen_info import SeqlenInfoQK
|
| 23 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 24 |
+
from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FlashAttentionBackwardSm80:
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dtype: Type[cutlass.Numeric],
|
| 31 |
+
head_dim: int,
|
| 32 |
+
head_dim_v: Optional[int] = None,
|
| 33 |
+
qhead_per_kvhead: int = 1,
|
| 34 |
+
m_block_size: int = 64,
|
| 35 |
+
n_block_size: int = 128,
|
| 36 |
+
num_stages_Q: int = 2,
|
| 37 |
+
num_stages_dO: int = 2,
|
| 38 |
+
num_threads: int = 256,
|
| 39 |
+
pack_gqa: bool = False,
|
| 40 |
+
is_causal: bool = False,
|
| 41 |
+
SdP_swapAB: bool = False,
|
| 42 |
+
dKV_swapAB: bool = False,
|
| 43 |
+
dQ_swapAB: bool = False,
|
| 44 |
+
AtomLayoutMSdP: int = 1,
|
| 45 |
+
AtomLayoutNdKV: int = 8,
|
| 46 |
+
AtomLayoutMdQ: int = 1,
|
| 47 |
+
V_in_regs: bool = False,
|
| 48 |
+
):
|
| 49 |
+
"""Initializes the configuration for a flash attention v2 kernel.
|
| 50 |
+
|
| 51 |
+
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
| 52 |
+
should be a multiple of 8.
|
| 53 |
+
|
| 54 |
+
:param head_dim: head dimension
|
| 55 |
+
:type head_dim: int
|
| 56 |
+
:param m_block_size: m block size
|
| 57 |
+
:type m_block_size: int
|
| 58 |
+
:param n_block_size: n block size
|
| 59 |
+
:type n_block_size: int
|
| 60 |
+
:param num_threads: number of threads
|
| 61 |
+
:type num_threads: int
|
| 62 |
+
:param is_causal: is causal
|
| 63 |
+
"""
|
| 64 |
+
self.dtype = dtype
|
| 65 |
+
# padding head_dim to a multiple of 16 as k_block_size
|
| 66 |
+
hdim_multiple_of = 32
|
| 67 |
+
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
| 68 |
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
| 69 |
+
self.same_hdim_kv = head_dim == head_dim_v
|
| 70 |
+
self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
| 71 |
+
# Can save registers (and hence be faster) if we don't have to check hdim predication
|
| 72 |
+
self.check_hdim_oob = head_dim != self.head_dim_padded
|
| 73 |
+
self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
|
| 74 |
+
self.qhead_per_kvhead = qhead_per_kvhead
|
| 75 |
+
self.m_block_size = m_block_size
|
| 76 |
+
self.n_block_size = n_block_size
|
| 77 |
+
self.num_threads = num_threads
|
| 78 |
+
self.pack_gqa = pack_gqa
|
| 79 |
+
self.is_causal = is_causal
|
| 80 |
+
self.num_stages_Q = num_stages_Q
|
| 81 |
+
self.num_stages_dO = num_stages_dO
|
| 82 |
+
self.SdP_swapAB = SdP_swapAB
|
| 83 |
+
self.dKV_swapAB = dKV_swapAB
|
| 84 |
+
self.dQ_swapAB = dQ_swapAB
|
| 85 |
+
self.AtomLayoutMSdP = AtomLayoutMSdP
|
| 86 |
+
self.AtomLayoutNdKV = AtomLayoutNdKV
|
| 87 |
+
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 88 |
+
num_mma_warps = self.num_threads // cute.arch.WARP_SIZE
|
| 89 |
+
self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB
|
| 90 |
+
self.V_in_regs = V_in_regs
|
| 91 |
+
self.share_QV_smem = V_in_regs
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def can_implement(
|
| 95 |
+
dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO,
|
| 96 |
+
num_threads, is_causal,
|
| 97 |
+
V_in_regs=False
|
| 98 |
+
) -> bool:
|
| 99 |
+
"""Check if the kernel can be implemented with the given parameters.
|
| 100 |
+
|
| 101 |
+
:param dtype: data type
|
| 102 |
+
:type dtype: cutlass.Numeric
|
| 103 |
+
:param head_dim: head dimension
|
| 104 |
+
:type head_dim: int
|
| 105 |
+
:param m_block_size: m block size
|
| 106 |
+
:type m_block_size: int
|
| 107 |
+
:param n_block_size: n block size
|
| 108 |
+
:type n_block_size: int
|
| 109 |
+
:param num_threads: number of threads
|
| 110 |
+
:type num_threads: int
|
| 111 |
+
:param is_causal: is causal
|
| 112 |
+
:type is_causal: bool
|
| 113 |
+
|
| 114 |
+
:return: True if the kernel can be implemented, False otherwise
|
| 115 |
+
:rtype: bool
|
| 116 |
+
"""
|
| 117 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 118 |
+
return False
|
| 119 |
+
if head_dim % 8 != 0:
|
| 120 |
+
return False
|
| 121 |
+
if head_dim_v % 8 != 0:
|
| 122 |
+
return False
|
| 123 |
+
if n_block_size % 16 != 0:
|
| 124 |
+
return False
|
| 125 |
+
if num_threads % 32 != 0:
|
| 126 |
+
return False
|
| 127 |
+
# Check if block size setting is out of shared memory capacity
|
| 128 |
+
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
| 129 |
+
smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
|
| 130 |
+
smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
|
| 131 |
+
smem_usage_K = n_block_size * head_dim * 2
|
| 132 |
+
smem_usage_V = n_block_size * head_dim_v * 2
|
| 133 |
+
smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
|
| 134 |
+
smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
|
| 135 |
+
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
|
| 136 |
+
if smem_usage > smem_capacity:
|
| 137 |
+
return False
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
def _check_type(
|
| 141 |
+
self,
|
| 142 |
+
mQ_type: Type[cutlass.Numeric],
|
| 143 |
+
mK_type: Type[cutlass.Numeric],
|
| 144 |
+
mV_type: Type[cutlass.Numeric],
|
| 145 |
+
mdO_type: Type[cutlass.Numeric],
|
| 146 |
+
mLSE_type: Type[cutlass.Numeric],
|
| 147 |
+
mdPsum_type: Type[cutlass.Numeric],
|
| 148 |
+
mdQaccum_type: Type[cutlass.Numeric],
|
| 149 |
+
mdK_type: Type[cutlass.Numeric],
|
| 150 |
+
mdV_type: Type[cutlass.Numeric],
|
| 151 |
+
mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
|
| 152 |
+
mCuSeqlensK_type: Type[cutlass.Numeric] | None,
|
| 153 |
+
mSeqUsedQ_type: Type[cutlass.Numeric] | None,
|
| 154 |
+
mSeqUsedK_type: Type[cutlass.Numeric] | None,
|
| 155 |
+
):
|
| 156 |
+
if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
|
| 157 |
+
raise TypeError("All tensors must have the same data type")
|
| 158 |
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
| 159 |
+
if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)):
|
| 160 |
+
raise TypeError("mdK and mdV tensors must have the same data type as mQ")
|
| 161 |
+
else:
|
| 162 |
+
if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)):
|
| 163 |
+
raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
|
| 164 |
+
if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]):
|
| 165 |
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 166 |
+
if cutlass.const_expr(not mLSE_type in [cutlass.Float32]):
|
| 167 |
+
raise TypeError("LSE tensor must be Float32")
|
| 168 |
+
if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]):
|
| 169 |
+
raise TypeError("dPsum tensor must be Float32")
|
| 170 |
+
if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]):
|
| 171 |
+
raise TypeError("dQaccum tensor must be Float32")
|
| 172 |
+
if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]):
|
| 173 |
+
raise TypeError("cuSeqlensQ tensor must be Int32")
|
| 174 |
+
if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]):
|
| 175 |
+
raise TypeError("cuSeqlensK tensor must be Int32")
|
| 176 |
+
if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]):
|
| 177 |
+
raise TypeError("SeqUsedQ tensor must be Int32")
|
| 178 |
+
if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]):
|
| 179 |
+
raise TypeError("SeqUsedK tensor must be Int32")
|
| 180 |
+
assert mQ_type == self.dtype
|
| 181 |
+
|
| 182 |
+
def _setup_attributes(self):
|
| 183 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 184 |
+
# Shared memory layout: Q/K/V
|
| 185 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 186 |
+
sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded)
|
| 187 |
+
self.sQ_layout = cute.tile_to_shape(
|
| 188 |
+
sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2),
|
| 189 |
+
)
|
| 190 |
+
sK_layout_atom = sQ_layout_atom
|
| 191 |
+
self.sK_layout = cute.tile_to_shape(
|
| 192 |
+
sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),
|
| 193 |
+
)
|
| 194 |
+
sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded)
|
| 195 |
+
self.sV_layout = cute.tile_to_shape(
|
| 196 |
+
sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),
|
| 197 |
+
)
|
| 198 |
+
sdO_layout_atom = sV_layout_atom
|
| 199 |
+
self.sdO_layout = cute.tile_to_shape(
|
| 200 |
+
sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2),
|
| 201 |
+
)
|
| 202 |
+
# TODO: do we set swizzle to be 3 here explicitly?
|
| 203 |
+
sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size)
|
| 204 |
+
self.sPdS_layout = cute.tile_to_shape(
|
| 205 |
+
sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),
|
| 206 |
+
)
|
| 207 |
+
# We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
|
| 208 |
+
# it's still a valid smem address.
|
| 209 |
+
self.sLSE_layout = cute.make_layout(
|
| 210 |
+
(self.m_block_size, self.num_stages_Q),
|
| 211 |
+
stride=(1, cute.round_up(self.m_block_size, 64)),
|
| 212 |
+
)
|
| 213 |
+
sLSEMma_layout = cute.make_layout(
|
| 214 |
+
(self.m_block_size, self.n_block_size, self.num_stages_Q),
|
| 215 |
+
stride=(1, 0, cute.round_up(self.m_block_size, 64)),
|
| 216 |
+
)
|
| 217 |
+
sLSEMma_layout_transposed = cute.make_layout(
|
| 218 |
+
(self.n_block_size, self.m_block_size, self.num_stages_Q),
|
| 219 |
+
stride=(0, 1, cute.round_up(self.m_block_size, 64)),
|
| 220 |
+
)
|
| 221 |
+
self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed
|
| 222 |
+
|
| 223 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 224 |
+
# GMEM Tiled copy:
|
| 225 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 226 |
+
# Thread layouts for copies
|
| 227 |
+
universal_copy_bits = 128
|
| 228 |
+
async_copy_elems = universal_copy_bits // self.dtype.width
|
| 229 |
+
# atom_async_copy: async copy atom for QKV load
|
| 230 |
+
atom_async_copy = cute.make_copy_atom(
|
| 231 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 232 |
+
self.dtype,
|
| 233 |
+
num_bits_per_copy=universal_copy_bits,
|
| 234 |
+
)
|
| 235 |
+
# atom_universal_copy: universal copy atom for O store
|
| 236 |
+
atom_universal_copy = cute.make_copy_atom(
|
| 237 |
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits,
|
| 238 |
+
)
|
| 239 |
+
# tQK_layout: thread layout for QK load
|
| 240 |
+
tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
|
| 241 |
+
assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1"
|
| 242 |
+
tQK_layout = cute.make_ordered_layout(
|
| 243 |
+
(self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0),
|
| 244 |
+
)
|
| 245 |
+
# Do we need to check if we overshot kBlockM when we load Q?
|
| 246 |
+
self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0
|
| 247 |
+
# Do we need to check if we overshot kBlockN when we load K?
|
| 248 |
+
self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0
|
| 249 |
+
tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
|
| 250 |
+
assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1"
|
| 251 |
+
tVdO_layout = cute.make_ordered_layout(
|
| 252 |
+
(self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0),
|
| 253 |
+
)
|
| 254 |
+
# Do we need to check if we overshot kBlockN when we load V?
|
| 255 |
+
self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0
|
| 256 |
+
self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0
|
| 257 |
+
|
| 258 |
+
# Value layouts for copies
|
| 259 |
+
vQKVdO_layout = cute.make_layout((1, async_copy_elems))
|
| 260 |
+
|
| 261 |
+
# gmem_tiled_copy_QK: tiled copy for QK load
|
| 262 |
+
self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout)
|
| 263 |
+
self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout)
|
| 264 |
+
self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout)
|
| 265 |
+
self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout)
|
| 266 |
+
async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
|
| 267 |
+
|
| 268 |
+
# I think we wouldn't require this with smarter padding
|
| 269 |
+
if cutlass.const_expr(not self.varlen_q):
|
| 270 |
+
async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
|
| 271 |
+
atom_async_copy_accum = cute.make_copy_atom(
|
| 272 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 273 |
+
cutlass.Float32,
|
| 274 |
+
num_bits_per_copy=universal_copy_bits,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
async_copy_elems_accum = 1
|
| 278 |
+
atom_async_copy_accum = cute.make_copy_atom(
|
| 279 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 280 |
+
cutlass.Float32,
|
| 281 |
+
num_bits_per_copy=cutlass.Float32.width,
|
| 282 |
+
)
|
| 283 |
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 284 |
+
atom_async_copy_accum,
|
| 285 |
+
cute.make_layout(self.num_threads),
|
| 286 |
+
cute.make_layout(async_copy_elems_accum),
|
| 287 |
+
)
|
| 288 |
+
self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 289 |
+
cute.make_copy_atom(
|
| 290 |
+
cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width
|
| 291 |
+
),
|
| 292 |
+
cute.make_layout(self.num_threads),
|
| 293 |
+
cute.make_layout(1)
|
| 294 |
+
)
|
| 295 |
+
if cutlass.const_expr(self.qhead_per_kvhead > 1):
|
| 296 |
+
self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum
|
| 297 |
+
self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum
|
| 298 |
+
|
| 299 |
+
def _get_tiled_mma(self):
|
| 300 |
+
num_mma_warps = self.num_threads // 32
|
| 301 |
+
AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1)
|
| 302 |
+
tiled_mma_sdp = cute.make_tiled_mma(
|
| 303 |
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
| 304 |
+
AtomLayoutSdP,
|
| 305 |
+
permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16),
|
| 306 |
+
)
|
| 307 |
+
AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1)
|
| 308 |
+
tiled_mma_dkv = cute.make_tiled_mma(
|
| 309 |
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
| 310 |
+
AtomLayoutdKV,
|
| 311 |
+
permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16),
|
| 312 |
+
)
|
| 313 |
+
AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
|
| 314 |
+
tiled_mma_dq = cute.make_tiled_mma(
|
| 315 |
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
| 316 |
+
AtomLayoutdQ,
|
| 317 |
+
permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16),
|
| 318 |
+
)
|
| 319 |
+
return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq
|
| 320 |
+
|
| 321 |
+
def _get_shared_storage_cls(self):
|
| 322 |
+
sQ_struct, sK_struct, sV_struct, sdO_struct = [
|
| 323 |
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
|
| 324 |
+
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout)
|
| 325 |
+
]
|
| 326 |
+
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
| 327 |
+
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
| 328 |
+
sLSE_struct, sdPsum_struct = [
|
| 329 |
+
cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128]
|
| 330 |
+
for layout in (self.sLSE_layout, self.sLSE_layout)
|
| 331 |
+
]
|
| 332 |
+
sP_struct, sdS_struct = [
|
| 333 |
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128]
|
| 334 |
+
for layout in (self.sPdS_layout, self.sPdS_layout)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
@cute.struct
|
| 338 |
+
class SharedStorageSeparateQV:
|
| 339 |
+
sK: sK_struct
|
| 340 |
+
sV: sV_struct
|
| 341 |
+
sQ: sQ_struct
|
| 342 |
+
sdO: sdO_struct
|
| 343 |
+
sLSE: sLSE_struct
|
| 344 |
+
sdPsum: sdPsum_struct
|
| 345 |
+
sP: sP_struct
|
| 346 |
+
sdS: sdS_struct
|
| 347 |
+
# TODO: the case where there's no sP
|
| 348 |
+
|
| 349 |
+
@cute.struct
|
| 350 |
+
class SharedStorageSharedQV:
|
| 351 |
+
sK: sK_struct
|
| 352 |
+
sV: sV_struct
|
| 353 |
+
sQ: sQV_struct
|
| 354 |
+
sdO: sdO_struct
|
| 355 |
+
sLSE: sLSE_struct
|
| 356 |
+
sdPsum: sdPsum_struct
|
| 357 |
+
sP: sP_struct
|
| 358 |
+
sdS: sdS_struct
|
| 359 |
+
|
| 360 |
+
return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV
|
| 361 |
+
|
| 362 |
+
@cute.jit
|
| 363 |
+
def __call__(
|
| 364 |
+
self,
|
| 365 |
+
mQ: cute.Tensor,
|
| 366 |
+
mK: cute.Tensor,
|
| 367 |
+
mV: cute.Tensor,
|
| 368 |
+
mdO: cute.Tensor,
|
| 369 |
+
mLSE: cute.Tensor,
|
| 370 |
+
mdPsum: cute.Tensor,
|
| 371 |
+
mdQaccum: cute.Tensor,
|
| 372 |
+
mdK: cute.Tensor,
|
| 373 |
+
mdV: cute.Tensor,
|
| 374 |
+
softmax_scale: cutlass.Float32,
|
| 375 |
+
stream: cuda.CUstream,
|
| 376 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 377 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 378 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 379 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 380 |
+
softcap: Float32 | float | None = None,
|
| 381 |
+
window_size_left: Int32 | int | None = None,
|
| 382 |
+
window_size_right: Int32 | int | None = None,
|
| 383 |
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
| 384 |
+
):
|
| 385 |
+
assert mdQ_semaphore is None, "semaphore not supported yet"
|
| 386 |
+
# Get the data type and check if it is fp16 or bf16
|
| 387 |
+
self._check_type(*(t.element_type if t is not None else None
|
| 388 |
+
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
|
| 389 |
+
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
|
| 390 |
+
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
|
| 391 |
+
]
|
| 392 |
+
self.varlen_q = (mCuSeqlensQ is not None)
|
| 393 |
+
self._setup_attributes()
|
| 394 |
+
SharedStorage = self._get_shared_storage_cls()
|
| 395 |
+
tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma()
|
| 396 |
+
|
| 397 |
+
num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2]
|
| 398 |
+
|
| 399 |
+
if cutlass.const_expr(mCuSeqlensK is not None):
|
| 400 |
+
TileScheduler = SingleTileVarlenScheduler
|
| 401 |
+
num_batch = mCuSeqlensK.shape[0] - 1
|
| 402 |
+
else:
|
| 403 |
+
TileScheduler = SingleTileScheduler
|
| 404 |
+
num_batch = mK.shape[0]
|
| 405 |
+
|
| 406 |
+
# Uses seqlen k, etc. since main bwd kernel's blocks are over n
|
| 407 |
+
tile_sched_args = TileSchedulerArguments(
|
| 408 |
+
num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
|
| 409 |
+
num_head=num_head,
|
| 410 |
+
num_batch=num_batch,
|
| 411 |
+
num_splits=1,
|
| 412 |
+
seqlen_k=0,
|
| 413 |
+
headdim=mK.shape[2],
|
| 414 |
+
headdim_v=mV.shape[2],
|
| 415 |
+
total_q=mK.shape[0],
|
| 416 |
+
tile_shape_mn=(self.n_block_size, self.m_block_size),
|
| 417 |
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
|
| 418 |
+
mCuSeqlensQ=mCuSeqlensK,
|
| 419 |
+
mSeqUsedQ=mSeqUsedK,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 423 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 424 |
+
|
| 425 |
+
softmax_scale_log2 = softmax_scale * math.log2(math.e)
|
| 426 |
+
self.kernel(
|
| 427 |
+
mQ,
|
| 428 |
+
mK,
|
| 429 |
+
mV,
|
| 430 |
+
mdO,
|
| 431 |
+
mLSE,
|
| 432 |
+
mdPsum,
|
| 433 |
+
mdQaccum,
|
| 434 |
+
mdK,
|
| 435 |
+
mdV,
|
| 436 |
+
mCuSeqlensQ,
|
| 437 |
+
mCuSeqlensK,
|
| 438 |
+
mSeqUsedQ,
|
| 439 |
+
mSeqUsedK,
|
| 440 |
+
softmax_scale,
|
| 441 |
+
softmax_scale_log2,
|
| 442 |
+
self.sQ_layout,
|
| 443 |
+
self.sK_layout,
|
| 444 |
+
self.sV_layout,
|
| 445 |
+
self.sdO_layout,
|
| 446 |
+
self.sPdS_layout,
|
| 447 |
+
self.sLSE_layout,
|
| 448 |
+
self.sLSEMma_layout,
|
| 449 |
+
self.gmem_tiled_copy_QK,
|
| 450 |
+
self.gmem_tiled_copy_VdO,
|
| 451 |
+
self.gmem_tiled_copy_dK,
|
| 452 |
+
self.gmem_tiled_copy_dV,
|
| 453 |
+
self.gmem_tiled_copy_LSE,
|
| 454 |
+
self.gmem_tiled_copy_dQaccum,
|
| 455 |
+
tiled_mma_sdp,
|
| 456 |
+
tiled_mma_dkv,
|
| 457 |
+
tiled_mma_dq,
|
| 458 |
+
SharedStorage,
|
| 459 |
+
tile_sched_params,
|
| 460 |
+
TileScheduler,
|
| 461 |
+
).launch(
|
| 462 |
+
grid=grid_dim,
|
| 463 |
+
block=[self.num_threads, 1, 1],
|
| 464 |
+
smem=SharedStorage.size_in_bytes(),
|
| 465 |
+
stream=stream,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
@cute.kernel
|
| 469 |
+
def kernel(
|
| 470 |
+
self,
|
| 471 |
+
mQ: cute.Tensor,
|
| 472 |
+
mK: cute.Tensor,
|
| 473 |
+
mV: cute.Tensor,
|
| 474 |
+
mdO: cute.Tensor,
|
| 475 |
+
mLSE: cute.Tensor,
|
| 476 |
+
mdPsum: cute.Tensor,
|
| 477 |
+
mdQaccum: cute.Tensor,
|
| 478 |
+
mdK: cute.Tensor,
|
| 479 |
+
mdV: cute.Tensor,
|
| 480 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 481 |
+
mCuSeqlensK: Optional[cute.Tensor],
|
| 482 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 483 |
+
mSeqUsedK: Optional[cute.Tensor],
|
| 484 |
+
softmax_scale: cutlass.Float32,
|
| 485 |
+
softmax_scale_log2: cutlass.Float32,
|
| 486 |
+
sQ_layout: cute.ComposedLayout,
|
| 487 |
+
sK_layout: cute.ComposedLayout,
|
| 488 |
+
sV_layout: cute.ComposedLayout,
|
| 489 |
+
sdO_layout: cute.ComposedLayout,
|
| 490 |
+
sPdS_layout: cute.ComposedLayout,
|
| 491 |
+
sLSE_layout: cute.Layout,
|
| 492 |
+
sLSEMma_layout: cute.Layout,
|
| 493 |
+
gmem_tiled_copy_QK: cute.TiledCopy,
|
| 494 |
+
gmem_tiled_copy_VdO: cute.TiledCopy,
|
| 495 |
+
gmem_tiled_copy_dK: cute.TiledCopy,
|
| 496 |
+
gmem_tiled_copy_dV: cute.TiledCopy,
|
| 497 |
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
| 498 |
+
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
| 499 |
+
tiled_mma_sdp: cute.TiledMma,
|
| 500 |
+
tiled_mma_dkv: cute.TiledMma,
|
| 501 |
+
tiled_mma_dq: cute.TiledMma,
|
| 502 |
+
SharedStorage: cutlass.Constexpr,
|
| 503 |
+
tile_sched_params: ParamsBase,
|
| 504 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 505 |
+
):
|
| 506 |
+
# Thread index, block index
|
| 507 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 508 |
+
|
| 509 |
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
| 510 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 511 |
+
|
| 512 |
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 513 |
+
|
| 514 |
+
if work_tile.is_valid_tile:
|
| 515 |
+
seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
|
| 516 |
+
|
| 517 |
+
m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
|
| 518 |
+
m_block_min = 0
|
| 519 |
+
if cutlass.const_expr(self.is_causal):
|
| 520 |
+
m_block_min = max(
|
| 521 |
+
(n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size,
|
| 522 |
+
m_block_min,
|
| 523 |
+
)
|
| 524 |
+
# TODO: return early if m_block_max == 0
|
| 525 |
+
|
| 526 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 527 |
+
# Get the appropriate tiles for this thread block.
|
| 528 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 529 |
+
blkQ_shape = (self.m_block_size, self.head_dim_padded)
|
| 530 |
+
blkK_shape = (self.n_block_size, self.head_dim_padded)
|
| 531 |
+
blkV_shape = (self.n_block_size, self.head_dim_v_padded)
|
| 532 |
+
blkdO_shape = (self.m_block_size, self.head_dim_v_padded)
|
| 533 |
+
|
| 534 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
| 535 |
+
mQ_cur = mQ[batch_idx, None, head_idx, None]
|
| 536 |
+
mLSE_cur = mLSE[batch_idx, head_idx, None]
|
| 537 |
+
mdO_cur = mdO[batch_idx, None, head_idx, None]
|
| 538 |
+
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
| 539 |
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 540 |
+
else:
|
| 541 |
+
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
| 542 |
+
mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
|
| 543 |
+
mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
|
| 544 |
+
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
| 545 |
+
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
| 546 |
+
mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None])
|
| 547 |
+
head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx
|
| 548 |
+
|
| 549 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
| 550 |
+
mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)]
|
| 551 |
+
else:
|
| 552 |
+
mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)]
|
| 553 |
+
|
| 554 |
+
# (m_block_size, head_dim, m_block)
|
| 555 |
+
gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0))
|
| 556 |
+
# (n_block_size, head_dim)
|
| 557 |
+
gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0))
|
| 558 |
+
# (n_block_size, head_dim_v)
|
| 559 |
+
gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0))
|
| 560 |
+
# (m_block_size, head_dim_v, m_block)
|
| 561 |
+
gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0))
|
| 562 |
+
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,))
|
| 563 |
+
gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,))
|
| 564 |
+
gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,))
|
| 565 |
+
|
| 566 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 567 |
+
# Get shared memory buffer
|
| 568 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 569 |
+
smem = cutlass.utils.SmemAllocator()
|
| 570 |
+
storage = smem.allocate(SharedStorage)
|
| 571 |
+
sQ = storage.sQ.get_tensor(sQ_layout)
|
| 572 |
+
sK = storage.sK.get_tensor(sK_layout)
|
| 573 |
+
if cutlass.const_expr(not self.share_QV_smem):
|
| 574 |
+
sV = storage.sV.get_tensor(sV_layout)
|
| 575 |
+
else:
|
| 576 |
+
sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
|
| 577 |
+
sdO = storage.sdO.get_tensor(sdO_layout)
|
| 578 |
+
sP = storage.sP.get_tensor(sPdS_layout)
|
| 579 |
+
sdS = storage.sdS.get_tensor(sPdS_layout)
|
| 580 |
+
sLSE = storage.sLSE.get_tensor(sLSE_layout)
|
| 581 |
+
sdPsum = storage.sdPsum.get_tensor(sLSE_layout)
|
| 582 |
+
sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout)
|
| 583 |
+
sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout)
|
| 584 |
+
|
| 585 |
+
# Transpose view of tensors for tiled mma
|
| 586 |
+
sQt, sdOt, sKt, sPt, sdSt = [layout_utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)]
|
| 587 |
+
|
| 588 |
+
gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx)
|
| 589 |
+
gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx)
|
| 590 |
+
gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 591 |
+
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
| 592 |
+
# (CPY_Atom, CPY_M, CPY_K, m_block)
|
| 593 |
+
tQgQ = gmem_thr_copy_QK.partition_S(gQ)
|
| 594 |
+
tQsQ = gmem_thr_copy_QK.partition_D(sQ)
|
| 595 |
+
# (CPY_Atom, CPY_N, CPY_K)
|
| 596 |
+
tKgK = gmem_thr_copy_QK.partition_S(gK)
|
| 597 |
+
tKsK = gmem_thr_copy_QK.partition_D(sK)
|
| 598 |
+
# (CPY_Atom, CPY_N, CPY_K)
|
| 599 |
+
tVgV = gmem_thr_copy_VdO.partition_S(gV)
|
| 600 |
+
tVsV = gmem_thr_copy_VdO.partition_D(sV)
|
| 601 |
+
# (CPY_Atom, CPY_M, CPY_K, m_block)
|
| 602 |
+
tdOgdO = gmem_thr_copy_VdO.partition_S(gdO)
|
| 603 |
+
tdOsdO = gmem_thr_copy_VdO.partition_D(sdO)
|
| 604 |
+
tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE)
|
| 605 |
+
tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE)
|
| 606 |
+
tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum)
|
| 607 |
+
tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum)
|
| 608 |
+
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
| 609 |
+
|
| 610 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 611 |
+
# Tile MMA compute thread partitions and allocate accumulators
|
| 612 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 613 |
+
thr_mma_sdp = tiled_mma_sdp.get_slice(tidx)
|
| 614 |
+
thr_mma_dkv = tiled_mma_dkv.get_slice(tidx)
|
| 615 |
+
thr_mma_dq = tiled_mma_dq.get_slice(tidx)
|
| 616 |
+
acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded))
|
| 617 |
+
acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded))
|
| 618 |
+
acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32)
|
| 619 |
+
acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32)
|
| 620 |
+
acc_dK.fill(0.0)
|
| 621 |
+
acc_dV.fill(0.0)
|
| 622 |
+
|
| 623 |
+
tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
|
| 624 |
+
tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB)
|
| 625 |
+
tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
|
| 626 |
+
tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB)
|
| 627 |
+
tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB)
|
| 628 |
+
tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
|
| 629 |
+
tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB)
|
| 630 |
+
tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
|
| 631 |
+
tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB)
|
| 632 |
+
tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
|
| 633 |
+
|
| 634 |
+
LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
|
| 635 |
+
tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
|
| 636 |
+
tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
|
| 637 |
+
|
| 638 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 639 |
+
# Smem copy atom tiling
|
| 640 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 641 |
+
smem_copy_atom = cute.make_copy_atom(
|
| 642 |
+
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,
|
| 643 |
+
)
|
| 644 |
+
smem_copy_atom_transposed = cute.make_copy_atom(
|
| 645 |
+
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype,
|
| 646 |
+
)
|
| 647 |
+
smem_thr_copy_QdO = utils.make_tiled_copy_A(
|
| 648 |
+
smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
|
| 649 |
+
).get_slice(tidx)
|
| 650 |
+
smem_thr_copy_KV = utils.make_tiled_copy_B(
|
| 651 |
+
smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
|
| 652 |
+
).get_slice(tidx)
|
| 653 |
+
# TODO: should this be smem_copy_atom_transposed?
|
| 654 |
+
smem_thr_copy_PdSt = utils.make_tiled_copy_A(
|
| 655 |
+
smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
|
| 656 |
+
).get_slice(tidx)
|
| 657 |
+
smem_thr_copy_QdOt = utils.make_tiled_copy_B(
|
| 658 |
+
smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
|
| 659 |
+
).get_slice(tidx)
|
| 660 |
+
smem_thr_copy_dS = utils.make_tiled_copy_A(
|
| 661 |
+
smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB
|
| 662 |
+
).get_slice(tidx)
|
| 663 |
+
smem_thr_copy_Kt = utils.make_tiled_copy_B(
|
| 664 |
+
smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB
|
| 665 |
+
).get_slice(tidx)
|
| 666 |
+
# TODO: what's the number of bits? What if SdP_swapAB
|
| 667 |
+
r2s_thr_copy_PdS = cute.make_tiled_copy_C(
|
| 668 |
+
cute.make_copy_atom(
|
| 669 |
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
|
| 670 |
+
),
|
| 671 |
+
tiled_mma_sdp,
|
| 672 |
+
).get_slice(tidx)
|
| 673 |
+
|
| 674 |
+
tSsQ = smem_thr_copy_QdO.partition_S(sQ)
|
| 675 |
+
tdPsdO = smem_thr_copy_QdO.partition_S(sdO)
|
| 676 |
+
tSsK = smem_thr_copy_KV.partition_S(sK)
|
| 677 |
+
tdPsV = smem_thr_copy_KV.partition_S(sV)
|
| 678 |
+
tdVsPt = smem_thr_copy_PdSt.partition_S(sPt)
|
| 679 |
+
tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt)
|
| 680 |
+
tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt)
|
| 681 |
+
tdKsQt = smem_thr_copy_QdOt.partition_S(sQt)
|
| 682 |
+
tdQsdS = smem_thr_copy_dS.partition_S(sdS)
|
| 683 |
+
tdQsKt = smem_thr_copy_Kt.partition_S(sKt)
|
| 684 |
+
tPsP = r2s_thr_copy_PdS.partition_D(sP)
|
| 685 |
+
tdSsdS = r2s_thr_copy_PdS.partition_D(sdS)
|
| 686 |
+
|
| 687 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 688 |
+
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
| 689 |
+
# of tile_shape
|
| 690 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 691 |
+
# Construct identity layout for KV
|
| 692 |
+
cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 693 |
+
tQcQ = gmem_thr_copy_QK.partition_S(cQ)
|
| 694 |
+
t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ)
|
| 695 |
+
if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
|
| 696 |
+
tdOcdO = tQcQ
|
| 697 |
+
t0dOcdO = t0QcQ
|
| 698 |
+
else:
|
| 699 |
+
cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
| 700 |
+
tdOcdO = gmem_thr_copy_VdO.partition_S(cdO)
|
| 701 |
+
t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO)
|
| 702 |
+
cLSE = cute.make_identity_tensor((self.m_block_size,))
|
| 703 |
+
tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE)
|
| 704 |
+
|
| 705 |
+
# Allocate predicate tensors for m and n, here we only allocate the tile of k, and
|
| 706 |
+
# use "if" on the mn dimension.
|
| 707 |
+
# This is to reduce register pressure and gets 2-3% performance gain.
|
| 708 |
+
|
| 709 |
+
d_head = mQ.shape[cute.rank(mQ) - 1]
|
| 710 |
+
d_head_v = mdO.shape[cute.rank(mdO) - 1]
|
| 711 |
+
|
| 712 |
+
tQpQ = utils.predicate_k(tQcQ, limit=d_head)
|
| 713 |
+
if cutlass.const_expr(self.same_hdim_kv):
|
| 714 |
+
tdOpdO = tQpQ
|
| 715 |
+
else:
|
| 716 |
+
tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v)
|
| 717 |
+
|
| 718 |
+
# group parameters for compute_one_m_block
|
| 719 |
+
mma_params = SimpleNamespace(
|
| 720 |
+
thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq,
|
| 721 |
+
tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV,
|
| 722 |
+
tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ,
|
| 723 |
+
tdQrdS=tdQrdS, tdQrK=tdQrK,
|
| 724 |
+
acc_dK=acc_dK, acc_dV=acc_dV,
|
| 725 |
+
)
|
| 726 |
+
smem_copy_params = SimpleNamespace(
|
| 727 |
+
smem_thr_copy_QdO=smem_thr_copy_QdO,
|
| 728 |
+
smem_thr_copy_KV=smem_thr_copy_KV,
|
| 729 |
+
smem_thr_copy_PdSt=smem_thr_copy_PdSt,
|
| 730 |
+
smem_thr_copy_QdOt=smem_thr_copy_QdOt,
|
| 731 |
+
smem_thr_copy_dS=smem_thr_copy_dS,
|
| 732 |
+
smem_thr_copy_Kt=smem_thr_copy_Kt,
|
| 733 |
+
r2s_thr_copy_PdS=r2s_thr_copy_PdS,
|
| 734 |
+
tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV,
|
| 735 |
+
tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma,
|
| 736 |
+
tPsP=tPsP, tdSsdS=tdSsdS,
|
| 737 |
+
tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt,
|
| 738 |
+
tdQsdS=tdQsdS, tdQsKt=tdQsKt,
|
| 739 |
+
)
|
| 740 |
+
gmem_copy_params = SimpleNamespace(
|
| 741 |
+
gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum
|
| 742 |
+
)
|
| 743 |
+
load_Q_LSE = partial(
|
| 744 |
+
self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE,
|
| 745 |
+
tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ,
|
| 746 |
+
tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q
|
| 747 |
+
)
|
| 748 |
+
load_dO_dPsum = partial(
|
| 749 |
+
self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE,
|
| 750 |
+
tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO,
|
| 751 |
+
tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q
|
| 752 |
+
)
|
| 753 |
+
compute_one_m_block = partial(
|
| 754 |
+
self.compute_one_m_block, mma_params=mma_params,
|
| 755 |
+
smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,
|
| 756 |
+
load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,
|
| 757 |
+
m_block_max=m_block_max,
|
| 758 |
+
softmax_scale_log2=softmax_scale_log2,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 762 |
+
# Prologue
|
| 763 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 764 |
+
# Start async loads of the last mn-tile, where we take care of the mn residue
|
| 765 |
+
self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k,
|
| 766 |
+
headdim=d_head_v)
|
| 767 |
+
if cutlass.const_expr(self.V_in_regs):
|
| 768 |
+
cute.arch.cp_async_commit_group()
|
| 769 |
+
self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k,
|
| 770 |
+
headdim=d_head)
|
| 771 |
+
cute.arch.cp_async_commit_group()
|
| 772 |
+
|
| 773 |
+
if cutlass.const_expr(self.V_in_regs):
|
| 774 |
+
cute.arch.cp_async_wait_group(1)
|
| 775 |
+
cute.arch.barrier()
|
| 776 |
+
tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV)
|
| 777 |
+
cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view)
|
| 778 |
+
# Sync to avoid loading Q to smem_q, which overlaps with smem_v
|
| 779 |
+
cute.arch.barrier()
|
| 780 |
+
|
| 781 |
+
m_block = m_block_min
|
| 782 |
+
assert self.num_stages_Q >= self.num_stages_dO
|
| 783 |
+
for stage in cutlass.range_constexpr(self.num_stages_Q):
|
| 784 |
+
if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1):
|
| 785 |
+
if stage == 0 or m_block + stage < m_block_max:
|
| 786 |
+
load_Q_LSE(m_block + stage, smem_pipe_write_q=stage)
|
| 787 |
+
cute.arch.cp_async_commit_group()
|
| 788 |
+
if cutlass.const_expr(stage < self.num_stages_dO):
|
| 789 |
+
if stage == 0 or m_block + stage < m_block_max:
|
| 790 |
+
load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage)
|
| 791 |
+
cute.arch.cp_async_commit_group()
|
| 792 |
+
|
| 793 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 794 |
+
# Mainloop
|
| 795 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 796 |
+
# Start processing of the first n-block.
|
| 797 |
+
mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k)
|
| 798 |
+
mask_fn = partial(
|
| 799 |
+
mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
|
| 800 |
+
mask_seqlen=True, mask_causal=self.is_causal
|
| 801 |
+
)
|
| 802 |
+
smem_pipe_read_q = cutlass.Int32(0)
|
| 803 |
+
smem_pipe_read_do = cutlass.Int32(0)
|
| 804 |
+
smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1)
|
| 805 |
+
smem_pipe_write_do = cutlass.Int32(0)
|
| 806 |
+
for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1):
|
| 807 |
+
compute_one_m_block(
|
| 808 |
+
m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do,
|
| 809 |
+
mask_fn=mask_fn,
|
| 810 |
+
)
|
| 811 |
+
smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q)
|
| 812 |
+
smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO)
|
| 813 |
+
smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q)
|
| 814 |
+
smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO)
|
| 815 |
+
|
| 816 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 817 |
+
# Epilogue
|
| 818 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 819 |
+
# If GQA, we scale dK in the postprocessing kernel instead
|
| 820 |
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
| 821 |
+
acc_dK.store(acc_dK.load() * softmax_scale)
|
| 822 |
+
# reuse sK and sV data iterator
|
| 823 |
+
sdK = cute.make_tensor(sK.iterator, sK_layout)
|
| 824 |
+
sdV = cute.make_tensor(sV.iterator, sV_layout)
|
| 825 |
+
self.epilogue(
|
| 826 |
+
acc_dK, acc_dV, mdK, mdV, sdK, sdV,
|
| 827 |
+
gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv,
|
| 828 |
+
tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
@cute.jit
|
| 832 |
+
def compute_one_m_block(
|
| 833 |
+
self,
|
| 834 |
+
m_block: cutlass.Int32,
|
| 835 |
+
smem_pipe_read_q: cutlass.Int32,
|
| 836 |
+
smem_pipe_read_do: cutlass.Int32,
|
| 837 |
+
smem_pipe_write_q: cutlass.Int32,
|
| 838 |
+
smem_pipe_write_do: cutlass.Int32,
|
| 839 |
+
mma_params: SimpleNamespace,
|
| 840 |
+
smem_copy_params: SimpleNamespace,
|
| 841 |
+
gmem_copy_params: SimpleNamespace,
|
| 842 |
+
load_Q_LSE: Callable,
|
| 843 |
+
load_dO_dPsum: Callable,
|
| 844 |
+
m_block_max: cutlass.Int32,
|
| 845 |
+
softmax_scale_log2: cutlass.Float32,
|
| 846 |
+
mask_fn: Optional[Callable] = None,
|
| 847 |
+
):
|
| 848 |
+
def load_Q_next():
|
| 849 |
+
m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1)
|
| 850 |
+
if m_block_next < m_block_max:
|
| 851 |
+
load_Q_LSE(m_block_next, smem_pipe_write_q)
|
| 852 |
+
cute.arch.cp_async_commit_group()
|
| 853 |
+
|
| 854 |
+
def load_dO_next():
|
| 855 |
+
if m_block + self.num_stages_dO < m_block_max:
|
| 856 |
+
load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do)
|
| 857 |
+
cute.arch.cp_async_commit_group()
|
| 858 |
+
|
| 859 |
+
# MMA S
|
| 860 |
+
acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C(
|
| 861 |
+
(self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size)
|
| 862 |
+
)
|
| 863 |
+
acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
|
| 864 |
+
acc_S.fill(0.0)
|
| 865 |
+
cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0)
|
| 866 |
+
cute.arch.barrier()
|
| 867 |
+
sm80_utils.gemm(
|
| 868 |
+
mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK,
|
| 869 |
+
smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
| 870 |
+
smem_copy_params.tSsK,
|
| 871 |
+
smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
|
| 872 |
+
swap_AB=self.SdP_swapAB,
|
| 873 |
+
)
|
| 874 |
+
tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])
|
| 875 |
+
cute.autovec_copy(
|
| 876 |
+
smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE
|
| 877 |
+
)
|
| 878 |
+
if cutlass.const_expr(mask_fn is not None):
|
| 879 |
+
mask_fn(acc_S, m_block=m_block)
|
| 880 |
+
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
|
| 881 |
+
bidx = 0
|
| 882 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
| 883 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
|
| 884 |
+
assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
|
| 885 |
+
for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
|
| 886 |
+
acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True))
|
| 887 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
| 888 |
+
|
| 889 |
+
# MMA dP
|
| 890 |
+
acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
|
| 891 |
+
acc_dP.fill(0.0)
|
| 892 |
+
cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0)
|
| 893 |
+
cute.arch.barrier()
|
| 894 |
+
sm80_utils.gemm(
|
| 895 |
+
mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV,
|
| 896 |
+
smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
| 897 |
+
smem_copy_params.tdPsV,
|
| 898 |
+
smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
|
| 899 |
+
hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None,
|
| 900 |
+
swap_AB=self.SdP_swapAB,
|
| 901 |
+
)
|
| 902 |
+
tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0])
|
| 903 |
+
cute.autovec_copy(
|
| 904 |
+
smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
|
| 905 |
+
)
|
| 906 |
+
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP)
|
| 907 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
|
| 908 |
+
assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
|
| 909 |
+
for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
|
| 910 |
+
acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))
|
| 911 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
|
| 912 |
+
rP = cute.make_fragment_like(acc_S, self.dtype)
|
| 913 |
+
rP.store(acc_S.load().to(self.dtype))
|
| 914 |
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
| 915 |
+
tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N)
|
| 916 |
+
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP)
|
| 917 |
+
rdS = cute.make_fragment_like(acc_dP, self.dtype)
|
| 918 |
+
rdS.store(acc_dP.load().to(self.dtype))
|
| 919 |
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
| 920 |
+
cute.arch.barrier() # Make sure P is written
|
| 921 |
+
# For hdim 64, It's faster to write to smem_dS first before the dV gemm
|
| 922 |
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
| 923 |
+
tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
|
| 924 |
+
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
|
| 925 |
+
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
| 926 |
+
tdVrP = layout_utils.reshape_acc_to_frgA(rP)
|
| 927 |
+
else:
|
| 928 |
+
tdVrP = mma_params.tdVrP
|
| 929 |
+
|
| 930 |
+
# MMA dK
|
| 931 |
+
sm80_utils.gemm(
|
| 932 |
+
mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO,
|
| 933 |
+
smem_copy_params.tdVsPt,
|
| 934 |
+
smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
| 935 |
+
smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
|
| 936 |
+
A_in_regs=self.Mma_dKV_is_RS,
|
| 937 |
+
swap_AB=self.dKV_swapAB,
|
| 938 |
+
)
|
| 939 |
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV)
|
| 940 |
+
cute.arch.barrier() # Make sure dS is written
|
| 941 |
+
|
| 942 |
+
# MMA dQ
|
| 943 |
+
def dQ_mma(hook_fn):
|
| 944 |
+
acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C(
|
| 945 |
+
(self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size)
|
| 946 |
+
)
|
| 947 |
+
acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32)
|
| 948 |
+
acc_dQ.fill(0.0)
|
| 949 |
+
sm80_utils.gemm(
|
| 950 |
+
mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK,
|
| 951 |
+
smem_copy_params.tdQsdS, smem_copy_params.tdQsKt,
|
| 952 |
+
smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt,
|
| 953 |
+
swap_AB=self.dQ_swapAB,
|
| 954 |
+
hook_fn=hook_fn
|
| 955 |
+
)
|
| 956 |
+
# ((1, 1), num_elements)
|
| 957 |
+
acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ)
|
| 958 |
+
tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block]
|
| 959 |
+
assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic)
|
| 960 |
+
for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True):
|
| 961 |
+
utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i))
|
| 962 |
+
# utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1])
|
| 963 |
+
# if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ)
|
| 964 |
+
|
| 965 |
+
# If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration
|
| 966 |
+
if cutlass.const_expr(self.num_stages_Q > 1):
|
| 967 |
+
dQ_mma(load_dO_next)
|
| 968 |
+
|
| 969 |
+
# MMA dK
|
| 970 |
+
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
| 971 |
+
tdVrP = layout_utils.reshape_acc_to_frgA(rdS)
|
| 972 |
+
else:
|
| 973 |
+
tdKrdS = mma_params.tdKrdS
|
| 974 |
+
sm80_utils.gemm(
|
| 975 |
+
mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ,
|
| 976 |
+
smem_copy_params.tdKsdSt,
|
| 977 |
+
smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
| 978 |
+
smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
|
| 979 |
+
A_in_regs=self.Mma_dKV_is_RS,
|
| 980 |
+
swap_AB=self.dKV_swapAB,
|
| 981 |
+
hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None,
|
| 982 |
+
)
|
| 983 |
+
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK)
|
| 984 |
+
if cutlass.const_expr(self.num_stages_Q == 1):
|
| 985 |
+
cute.arch.barrier()
|
| 986 |
+
dQ_mma(load_Q_next)
|
| 987 |
+
|
| 988 |
+
@cute.jit
|
| 989 |
+
def epilogue(
|
| 990 |
+
self,
|
| 991 |
+
acc_dK: cute.Tensor,
|
| 992 |
+
acc_dV: cute.Tensor,
|
| 993 |
+
mdK: cute.Tensor,
|
| 994 |
+
mdV: cute.Tensor,
|
| 995 |
+
sdK: cute.Tensor,
|
| 996 |
+
sdV: cute.Tensor,
|
| 997 |
+
gmem_tiled_copy_dK: cute.TiledCopy,
|
| 998 |
+
gmem_tiled_copy_dV: cute.TiledCopy,
|
| 999 |
+
tiled_mma: cute.TiledMma,
|
| 1000 |
+
tidx: cutlass.Int32,
|
| 1001 |
+
n_block: cutlass.Int32,
|
| 1002 |
+
num_head: cutlass.Int32,
|
| 1003 |
+
batch_size: cutlass.Int32,
|
| 1004 |
+
seqlen: SeqlenInfoQK,
|
| 1005 |
+
d_head: cutlass.Int32,
|
| 1006 |
+
d_head_v: cutlass.Int32
|
| 1007 |
+
):
|
| 1008 |
+
rdV = cute.make_fragment_like(acc_dV, self.dtype)
|
| 1009 |
+
rdV.store(acc_dV.load().to(self.dtype))
|
| 1010 |
+
rdK = cute.make_fragment_like(acc_dK, self.dtype)
|
| 1011 |
+
rdK.store(acc_dK.load().to(self.dtype))
|
| 1012 |
+
gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx)
|
| 1013 |
+
gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx)
|
| 1014 |
+
|
| 1015 |
+
batch_idx = batch_size
|
| 1016 |
+
head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
|
| 1017 |
+
|
| 1018 |
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
| 1019 |
+
# Make sure all threads have finished reading K and V, otherwise we get racy dQ
|
| 1020 |
+
# because smem_q could be changed.
|
| 1021 |
+
cute.arch.barrier()
|
| 1022 |
+
# smem copy atom for dKV
|
| 1023 |
+
smem_copy_atom_dKV = cute.make_copy_atom(
|
| 1024 |
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
|
| 1025 |
+
)
|
| 1026 |
+
smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx)
|
| 1027 |
+
taccdVrdV = smem_thr_copy_dKV.retile(rdV)
|
| 1028 |
+
taccdKrdK = smem_thr_copy_dKV.retile(rdK)
|
| 1029 |
+
taccdVsdV = smem_thr_copy_dKV.partition_D(sdV)
|
| 1030 |
+
taccdKsdK = smem_thr_copy_dKV.partition_D(sdK)
|
| 1031 |
+
# copy acc O from rmem to smem with the smem copy atom
|
| 1032 |
+
cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
|
| 1033 |
+
cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
| 1037 |
+
mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)]
|
| 1038 |
+
else:
|
| 1039 |
+
mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)]
|
| 1040 |
+
|
| 1041 |
+
blkdK_shape = (self.n_block_size, self.head_dim_padded)
|
| 1042 |
+
blkdV_shape = (self.n_block_size, self.head_dim_v_padded)
|
| 1043 |
+
gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0))
|
| 1044 |
+
gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0))
|
| 1045 |
+
tdKsdK = gmem_thr_copy_dK.partition_S(sdK)
|
| 1046 |
+
tdKgdK = gmem_thr_copy_dK.partition_D(gdK)
|
| 1047 |
+
tdVsdV = gmem_thr_copy_dV.partition_S(sdV)
|
| 1048 |
+
tdVgdV = gmem_thr_copy_dV.partition_D(gdV)
|
| 1049 |
+
tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype)
|
| 1050 |
+
tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype)
|
| 1051 |
+
# sync before all smem stores are done.
|
| 1052 |
+
cute.arch.barrier()
|
| 1053 |
+
# load acc dK and dV from smem to rmem for wider vectorization
|
| 1054 |
+
# Need to check OOB when reading from smem if kBlockN isn't evenly tiled
|
| 1055 |
+
# TODO
|
| 1056 |
+
cute.autovec_copy(tdKsdK, tdKrdK)
|
| 1057 |
+
cute.autovec_copy(tdVsdV, tdVrdV)
|
| 1058 |
+
|
| 1059 |
+
cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
|
| 1060 |
+
tdKcdK = gmem_thr_copy_dK.partition_S(cdK)
|
| 1061 |
+
t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK)
|
| 1062 |
+
if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
|
| 1063 |
+
tdVcdV = tdKcdK
|
| 1064 |
+
t0dVcdV = t0dKcdK
|
| 1065 |
+
else:
|
| 1066 |
+
cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
|
| 1067 |
+
tdVcdV = gmem_thr_copy_dV.partition_S(cdV)
|
| 1068 |
+
t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV)
|
| 1069 |
+
tdKpdK = utils.predicate_k(tdKcdK, limit=d_head)
|
| 1070 |
+
if cutlass.const_expr(self.same_hdim_kv):
|
| 1071 |
+
tdVpdV = tdKpdK
|
| 1072 |
+
else:
|
| 1073 |
+
tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v)
|
| 1074 |
+
# copy acc dK and acc_dV from rmem to gmem
|
| 1075 |
+
for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])):
|
| 1076 |
+
if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]:
|
| 1077 |
+
cute.copy(
|
| 1078 |
+
gmem_tiled_copy_dK,
|
| 1079 |
+
tdKrdK[None, rest_m, None],
|
| 1080 |
+
tdKgdK[None, rest_m, None],
|
| 1081 |
+
pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
|
| 1082 |
+
)
|
| 1083 |
+
for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])):
|
| 1084 |
+
if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]:
|
| 1085 |
+
cute.copy(
|
| 1086 |
+
gmem_tiled_copy_dV,
|
| 1087 |
+
tdVrdV[None, rest_m, None],
|
| 1088 |
+
tdVgdV[None, rest_m, None],
|
| 1089 |
+
pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
else: # qhead_per_kvhead > 1, do atomic add
|
| 1093 |
+
# For Sm90, we need to sync to avoid racy writes to smem_q
|
| 1094 |
+
# For Sm80, we don't need to sync since we're not touching smem
|
| 1095 |
+
head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
|
| 1096 |
+
|
| 1097 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
| 1098 |
+
mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)]
|
| 1099 |
+
else:
|
| 1100 |
+
padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size
|
| 1101 |
+
mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None])
|
| 1102 |
+
mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None])
|
| 1103 |
+
|
| 1104 |
+
gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,))
|
| 1105 |
+
gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,))
|
| 1106 |
+
tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV)
|
| 1107 |
+
tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK)
|
| 1108 |
+
acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV)
|
| 1109 |
+
acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK)
|
| 1110 |
+
assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum)
|
| 1111 |
+
assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum)
|
| 1112 |
+
for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True):
|
| 1113 |
+
utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i))
|
| 1114 |
+
for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True):
|
| 1115 |
+
utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i))
|
| 1116 |
+
|
| 1117 |
+
@cute.jit
|
| 1118 |
+
def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr):
|
| 1119 |
+
return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0
|
| 1120 |
+
|
| 1121 |
+
@cute.jit
|
| 1122 |
+
def load_K(
|
| 1123 |
+
self,
|
| 1124 |
+
gmem_thr_copy: cute.TiledCopy,
|
| 1125 |
+
tKgK: cute.Tensor,
|
| 1126 |
+
tKsK: cute.Tensor,
|
| 1127 |
+
block: cutlass.Int32,
|
| 1128 |
+
seqlen: cutlass.Int32,
|
| 1129 |
+
headdim: cutlass.Int32,
|
| 1130 |
+
):
|
| 1131 |
+
cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
|
| 1132 |
+
tKcK = gmem_thr_copy.partition_S(cK)
|
| 1133 |
+
t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK)
|
| 1134 |
+
tKpK = utils.predicate_k(tKcK, limit=headdim)
|
| 1135 |
+
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
| 1136 |
+
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
|
| 1137 |
+
if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size:
|
| 1138 |
+
# Instead of using tKcK, we using t0KcK and subtract the offset from the limit
|
| 1139 |
+
# (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
|
| 1140 |
+
predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0]
|
| 1141 |
+
predicate = cute.make_fragment_like(tKpK[None, 0, None])
|
| 1142 |
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
| 1143 |
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
| 1144 |
+
predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
|
| 1145 |
+
cute.copy(
|
| 1146 |
+
gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate,
|
| 1147 |
+
)
|
| 1148 |
+
# We need to clear the sK smem tiles since we'll use sKt for mma_dq
|
| 1149 |
+
|
| 1150 |
+
@cute.jit
|
| 1151 |
+
def load_V(
|
| 1152 |
+
self,
|
| 1153 |
+
gmem_thr_copy: cute.TiledCopy,
|
| 1154 |
+
tVgV: cute.Tensor,
|
| 1155 |
+
tVsV: cute.Tensor,
|
| 1156 |
+
block: cutlass.Int32,
|
| 1157 |
+
seqlen: cutlass.Int32,
|
| 1158 |
+
headdim: cutlass.Int32,
|
| 1159 |
+
):
|
| 1160 |
+
cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
|
| 1161 |
+
tVcV = gmem_thr_copy.partition_S(cV)
|
| 1162 |
+
t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV)
|
| 1163 |
+
tVpV = utils.predicate_k(tVcV, limit=headdim)
|
| 1164 |
+
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
|
| 1165 |
+
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
|
| 1166 |
+
if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:
|
| 1167 |
+
# Instead of using tVcV, we using t0VcV and subtract the offset from the limit
|
| 1168 |
+
# (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time.
|
| 1169 |
+
predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0]
|
| 1170 |
+
predicate = cute.make_fragment_like(tVpV[None, 0, None])
|
| 1171 |
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
| 1172 |
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
| 1173 |
+
predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
|
| 1174 |
+
cute.copy(
|
| 1175 |
+
gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate,
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
@cute.jit
|
| 1179 |
+
def load_Q_LSE(
|
| 1180 |
+
self,
|
| 1181 |
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
| 1182 |
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
| 1183 |
+
tQgQ: cute.Tensor,
|
| 1184 |
+
tQsQ: cute.Tensor,
|
| 1185 |
+
tQcQ: cute.Tensor,
|
| 1186 |
+
t0QcQ: cute.Tensor,
|
| 1187 |
+
tQpQ: cute.Tensor,
|
| 1188 |
+
tLSEgLSE: cute.Tensor,
|
| 1189 |
+
tLSEsLSE: cute.Tensor,
|
| 1190 |
+
tLSEcLSE: cute.Tensor,
|
| 1191 |
+
block: cutlass.Int32,
|
| 1192 |
+
smem_pipe_write_q: cutlass.Int32,
|
| 1193 |
+
seqlen: cutlass.Int32,
|
| 1194 |
+
):
|
| 1195 |
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
| 1196 |
+
# If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
|
| 1197 |
+
if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size:
|
| 1198 |
+
# Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
|
| 1199 |
+
# (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
|
| 1200 |
+
predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]
|
| 1201 |
+
predicate = cute.make_fragment_like(tQpQ[None, 0, None])
|
| 1202 |
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
| 1203 |
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
| 1204 |
+
predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
|
| 1205 |
+
cute.copy(
|
| 1206 |
+
gmem_tiled_copy_Q,
|
| 1207 |
+
tQgQ[None, m, None, block],
|
| 1208 |
+
tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0],
|
| 1209 |
+
pred=predicate,
|
| 1210 |
+
)
|
| 1211 |
+
# We need to clear the sQ smem tiles since we'll use sQt for mma_dK
|
| 1212 |
+
# We made sure LSE length is padded so we read `kBlockM` elements so that all
|
| 1213 |
+
# elements in sLSE are filled. Without this we might have uninitialized sLSE values.
|
| 1214 |
+
for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])):
|
| 1215 |
+
if tLSEcLSE[0, m][0] < self.m_block_size:
|
| 1216 |
+
cute.copy(
|
| 1217 |
+
gmem_tiled_copy_LSE,
|
| 1218 |
+
tLSEgLSE[None, m, block],
|
| 1219 |
+
tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
@cute.jit
|
| 1223 |
+
def load_dO_dPsum(
|
| 1224 |
+
self,
|
| 1225 |
+
gmem_tiled_copy_dO: cute.TiledCopy,
|
| 1226 |
+
gmem_tiled_copy_dPsum: cute.TiledCopy,
|
| 1227 |
+
tdOgdO: cute.Tensor,
|
| 1228 |
+
tdOsdO: cute.Tensor,
|
| 1229 |
+
tdOcdO: cute.Tensor,
|
| 1230 |
+
t0dOcdO: cute.Tensor,
|
| 1231 |
+
tdOpdO: cute.Tensor,
|
| 1232 |
+
tdPsumgdPsum: cute.Tensor,
|
| 1233 |
+
tdPsumsdPsum: cute.Tensor,
|
| 1234 |
+
tdPsumcdPsum: cute.Tensor,
|
| 1235 |
+
block: cutlass.Int32,
|
| 1236 |
+
smem_pipe_write_q: cutlass.Int32,
|
| 1237 |
+
seqlen: cutlass.Int32,
|
| 1238 |
+
):
|
| 1239 |
+
for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])):
|
| 1240 |
+
# If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
|
| 1241 |
+
if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size:
|
| 1242 |
+
# Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit
|
| 1243 |
+
# (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time.
|
| 1244 |
+
predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0]
|
| 1245 |
+
predicate = cute.make_fragment_like(tdOpdO[None, 0, None])
|
| 1246 |
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
| 1247 |
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
| 1248 |
+
predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
|
| 1249 |
+
cute.copy(
|
| 1250 |
+
gmem_tiled_copy_dO,
|
| 1251 |
+
tdOgdO[None, m, None, block],
|
| 1252 |
+
tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
| 1253 |
+
pred=predicate,
|
| 1254 |
+
)
|
| 1255 |
+
# We need to clear the sQ smem tiles since we'll use sQt for mma_dK
|
| 1256 |
+
# We made sure LSE length is padded so we read `kBlockM` elements so that all
|
| 1257 |
+
# elements in sLSE are filled. Without this we might have uninitialized sLSE values.
|
| 1258 |
+
for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])):
|
| 1259 |
+
if tdPsumcdPsum[0, m][0] < self.m_block_size:
|
| 1260 |
+
cute.copy(
|
| 1261 |
+
gmem_tiled_copy_dPsum,
|
| 1262 |
+
tdPsumgdPsum[None, m, block],
|
| 1263 |
+
tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
| 1264 |
+
)
|
build/torch-cuda/flash_bwd_postprocess.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
|
| 3 |
+
# from Cutlass C++ to Cute-DSL.
|
| 4 |
+
import math
|
| 5 |
+
from typing import Callable, Optional, Type, Literal
|
| 6 |
+
|
| 7 |
+
import cuda.bindings.driver as cuda
|
| 8 |
+
|
| 9 |
+
import cutlass
|
| 10 |
+
import cutlass.cute as cute
|
| 11 |
+
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
| 12 |
+
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
| 13 |
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 14 |
+
from cutlass import Float32, const_expr
|
| 15 |
+
from cutlass.utils import LayoutEnum
|
| 16 |
+
|
| 17 |
+
from .quack import copy_utils
|
| 18 |
+
from .quack import layout_utils
|
| 19 |
+
from .quack import sm90_utils
|
| 20 |
+
|
| 21 |
+
from . import utils
|
| 22 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 23 |
+
from . import ampere_helpers as sm80_utils
|
| 24 |
+
from .seqlen_info import SeqlenInfoQK
|
| 25 |
+
import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
| 26 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 27 |
+
from .tile_scheduler import (
|
| 28 |
+
SingleTileScheduler,
|
| 29 |
+
SingleTileVarlenScheduler,
|
| 30 |
+
TileSchedulerArguments,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class FlashAttentionBackwardPostprocess:
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dtype: Type[cutlass.Numeric],
|
| 38 |
+
head_dim: int,
|
| 39 |
+
arch: Literal[80, 90, 100],
|
| 40 |
+
tile_m: int = 128,
|
| 41 |
+
num_threads: int = 256,
|
| 42 |
+
AtomLayoutMdQ: int = 1,
|
| 43 |
+
dQ_swapAB: bool = False,
|
| 44 |
+
use_2cta_instrs: bool = False,
|
| 45 |
+
cluster_size: int = 1, # for varlen offsets
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
:param head_dim: head dimension
|
| 49 |
+
:type head_dim: int
|
| 50 |
+
:param tile_m: m block size
|
| 51 |
+
:type tile_m: int
|
| 52 |
+
"""
|
| 53 |
+
self.dtype = dtype
|
| 54 |
+
self.tile_m = tile_m
|
| 55 |
+
assert arch // 10 in [8, 9, 10, 11], (
|
| 56 |
+
"Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x) are supported"
|
| 57 |
+
)
|
| 58 |
+
self.arch = arch
|
| 59 |
+
# padding head_dim to a multiple of 32 as k_block_size
|
| 60 |
+
hdim_multiple_of = 32
|
| 61 |
+
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
| 62 |
+
self.check_hdim_oob = head_dim != self.tile_hdim
|
| 63 |
+
self.num_threads = num_threads
|
| 64 |
+
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 65 |
+
self.dQ_swapAB = dQ_swapAB
|
| 66 |
+
self.use_2cta_instrs = use_2cta_instrs and arch == 100 and head_dim != 64
|
| 67 |
+
self.cluster_size = cluster_size
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
|
| 71 |
+
"""Check if the kernel can be implemented with the given parameters.
|
| 72 |
+
|
| 73 |
+
:param dtype: data type
|
| 74 |
+
:type dtype: cutlass.Numeric
|
| 75 |
+
:param head_dim: head dimension
|
| 76 |
+
:type head_dim: int
|
| 77 |
+
:param tile_m: m block size
|
| 78 |
+
:type tile_m: int
|
| 79 |
+
|
| 80 |
+
:return: True if the kernel can be implemented, False otherwise
|
| 81 |
+
:rtype: bool
|
| 82 |
+
"""
|
| 83 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 84 |
+
return False
|
| 85 |
+
if head_dim % 8 != 0:
|
| 86 |
+
return False
|
| 87 |
+
if num_threads % 32 != 0:
|
| 88 |
+
return False
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
def _get_tiled_mma(self):
|
| 92 |
+
if const_expr(self.arch == 80):
|
| 93 |
+
num_mma_warps = self.num_threads // 32
|
| 94 |
+
atom_layout_dQ = (
|
| 95 |
+
(self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
|
| 96 |
+
if const_expr(not self.dQ_swapAB)
|
| 97 |
+
else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
|
| 98 |
+
)
|
| 99 |
+
tiled_mma = cute.make_tiled_mma(
|
| 100 |
+
warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
|
| 101 |
+
atom_layout_dQ,
|
| 102 |
+
permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
|
| 103 |
+
)
|
| 104 |
+
elif const_expr(self.arch == 90):
|
| 105 |
+
num_mma_warp_groups = self.num_threads // 128
|
| 106 |
+
atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ)
|
| 107 |
+
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 108 |
+
tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
|
| 109 |
+
self.dtype,
|
| 110 |
+
self.dtype,
|
| 111 |
+
warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum
|
| 112 |
+
warpgroup.OperandMajorMode.K,
|
| 113 |
+
Float32,
|
| 114 |
+
atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1])
|
| 115 |
+
+ (1,),
|
| 116 |
+
tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
cta_group = tcgen05.CtaGroup.ONE
|
| 120 |
+
tiled_mma = sm100_utils_basic.make_trivial_tiled_mma(
|
| 121 |
+
self.dtype,
|
| 122 |
+
tcgen05.OperandMajorMode.MN, # dS_major_mode
|
| 123 |
+
tcgen05.OperandMajorMode.MN, # Kt_major_mode
|
| 124 |
+
Float32,
|
| 125 |
+
cta_group,
|
| 126 |
+
(self.tile_m, self.tile_hdim),
|
| 127 |
+
)
|
| 128 |
+
if const_expr(self.arch in [80, 90]):
|
| 129 |
+
assert self.num_threads == tiled_mma.size
|
| 130 |
+
return tiled_mma
|
| 131 |
+
|
| 132 |
+
def _setup_attributes(self):
|
| 133 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
# GMEM Tiled copy:
|
| 135 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 136 |
+
# Thread layouts for copies
|
| 137 |
+
universal_copy_bits = 128
|
| 138 |
+
async_copy_elems_accum = universal_copy_bits // Float32.width
|
| 139 |
+
atom_async_copy_accum = cute.make_copy_atom(
|
| 140 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 141 |
+
Float32,
|
| 142 |
+
num_bits_per_copy=universal_copy_bits,
|
| 143 |
+
)
|
| 144 |
+
# We don't do bound checking for the gmem -> smem load so we just assert here.
|
| 145 |
+
assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0
|
| 146 |
+
self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 147 |
+
atom_async_copy_accum,
|
| 148 |
+
cute.make_layout(self.num_threads),
|
| 149 |
+
cute.make_layout(async_copy_elems_accum),
|
| 150 |
+
)
|
| 151 |
+
num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4
|
| 152 |
+
if const_expr(self.arch == 80):
|
| 153 |
+
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 154 |
+
Float32, self.num_threads, num_s2r_copy_elems
|
| 155 |
+
)
|
| 156 |
+
self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
|
| 157 |
+
elif const_expr(self.arch == 90):
|
| 158 |
+
num_threads_per_warp_group = 128
|
| 159 |
+
num_mma_warp_groups = self.num_threads // 128
|
| 160 |
+
self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 161 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 162 |
+
cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout
|
| 163 |
+
cute.make_layout(128 // Float32.width), # val_layout
|
| 164 |
+
)
|
| 165 |
+
self.sdQaccum_layout = cute.make_layout(
|
| 166 |
+
(self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups)
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
self.dQ_reduce_ncol = 32
|
| 170 |
+
dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
|
| 171 |
+
assert self.num_threads == 128 # TODO: currently hard-coded
|
| 172 |
+
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 173 |
+
Float32, self.num_threads, num_s2r_copy_elems
|
| 174 |
+
)
|
| 175 |
+
self.sdQaccum_layout = cute.make_layout(
|
| 176 |
+
(self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
num_copy_elems = 128 // self.dtype.width
|
| 180 |
+
threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems
|
| 181 |
+
self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
|
| 182 |
+
self.dtype, threads_per_row, self.num_threads, num_copy_elems
|
| 183 |
+
)
|
| 184 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 185 |
+
# Shared memory layout: dQ
|
| 186 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 187 |
+
# We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
|
| 188 |
+
# then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
| 189 |
+
# We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
| 190 |
+
mma_shape_n = self.tiled_mma.get_tile_size(1)
|
| 191 |
+
if const_expr(self.arch == 80):
|
| 192 |
+
sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
|
| 193 |
+
self.sdQ_layout = cute.tile_to_shape(
|
| 194 |
+
sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
|
| 195 |
+
)
|
| 196 |
+
elif const_expr(self.arch == 90):
|
| 197 |
+
self.sdQ_layout = sm90_utils.make_smem_layout(
|
| 198 |
+
self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
# TODO: this is hard-coded for hdim 128
|
| 202 |
+
self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
|
| 203 |
+
self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
@cute.jit
|
| 207 |
+
def __call__(
|
| 208 |
+
self,
|
| 209 |
+
mdQaccum: cute.Tensor,
|
| 210 |
+
mdQ: cute.Tensor,
|
| 211 |
+
scale: cutlass.Float32,
|
| 212 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 213 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 214 |
+
stream: cuda.CUstream,
|
| 215 |
+
):
|
| 216 |
+
# Get the data type and check if it is fp16 or bf16
|
| 217 |
+
if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
| 218 |
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 219 |
+
if const_expr(mdQaccum is not None):
|
| 220 |
+
if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
|
| 221 |
+
raise TypeError("dQaccum tensor must be Float32")
|
| 222 |
+
|
| 223 |
+
mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)]
|
| 224 |
+
|
| 225 |
+
self.tiled_mma = self._get_tiled_mma()
|
| 226 |
+
self._setup_attributes()
|
| 227 |
+
|
| 228 |
+
smem_size = max(
|
| 229 |
+
cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout),
|
| 230 |
+
cute.size_in_bytes(self.dtype, self.sdQ_layout),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if const_expr(mCuSeqlensQ is not None):
|
| 234 |
+
TileScheduler = SingleTileVarlenScheduler
|
| 235 |
+
num_head = mdQ.shape[1]
|
| 236 |
+
num_batch = mCuSeqlensQ.shape[0] - 1
|
| 237 |
+
num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
|
| 238 |
+
else:
|
| 239 |
+
TileScheduler = SingleTileScheduler
|
| 240 |
+
num_head = mdQ.shape[2]
|
| 241 |
+
num_batch = mdQ.shape[0]
|
| 242 |
+
num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
|
| 243 |
+
|
| 244 |
+
tile_sched_args = TileSchedulerArguments(
|
| 245 |
+
num_block=num_block,
|
| 246 |
+
num_head=num_head,
|
| 247 |
+
num_batch=num_batch,
|
| 248 |
+
num_splits=1,
|
| 249 |
+
seqlen_k=0,
|
| 250 |
+
headdim=mdQ.shape[2],
|
| 251 |
+
headdim_v=0,
|
| 252 |
+
total_q=mdQ.shape[0],
|
| 253 |
+
tile_shape_mn=(self.tile_m, 1),
|
| 254 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 255 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 259 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 260 |
+
|
| 261 |
+
# grid_dim: (m_block, num_head, batch_size)
|
| 262 |
+
self.kernel(
|
| 263 |
+
mdQaccum,
|
| 264 |
+
mdQ,
|
| 265 |
+
mCuSeqlensQ,
|
| 266 |
+
mSeqUsedQ,
|
| 267 |
+
scale,
|
| 268 |
+
self.tiled_mma,
|
| 269 |
+
self.dQ_swapAB,
|
| 270 |
+
self.sdQaccum_layout,
|
| 271 |
+
self.sdQ_layout,
|
| 272 |
+
self.g2s_tiled_copy_dQaccum,
|
| 273 |
+
self.s2r_tiled_copy_dQaccum,
|
| 274 |
+
self.gmem_tiled_copy_dQ,
|
| 275 |
+
tile_sched_params,
|
| 276 |
+
TileScheduler,
|
| 277 |
+
).launch(
|
| 278 |
+
grid=grid_dim,
|
| 279 |
+
block=[self.num_threads, 1, 1],
|
| 280 |
+
smem=smem_size,
|
| 281 |
+
stream=stream,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
@cute.kernel
|
| 285 |
+
def kernel(
|
| 286 |
+
self,
|
| 287 |
+
mdQaccum: cute.Tensor,
|
| 288 |
+
mdQ: cute.Tensor,
|
| 289 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 290 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 291 |
+
scale: cutlass.Float32,
|
| 292 |
+
tiled_mma: cute.TiledMma,
|
| 293 |
+
dQ_swapAB: cutlass.Constexpr,
|
| 294 |
+
sdQaccum_layout: cute.Layout,
|
| 295 |
+
sdQ_layout: cute.ComposedLayout,
|
| 296 |
+
g2s_tiled_copy_dQaccum: cute.TiledCopy,
|
| 297 |
+
s2r_tiled_copy_dQaccum: cute.TiledCopy,
|
| 298 |
+
gmem_tiled_copy_dQ: cute.TiledCopy,
|
| 299 |
+
tile_sched_params: ParamsBase,
|
| 300 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 301 |
+
):
|
| 302 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 303 |
+
# Get shared memory buffer
|
| 304 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 305 |
+
smem = cutlass.utils.SmemAllocator()
|
| 306 |
+
sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
|
| 307 |
+
sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
|
| 308 |
+
if const_expr(self.arch in [80, 90]):
|
| 309 |
+
sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
|
| 310 |
+
else:
|
| 311 |
+
# extra stage dimension
|
| 312 |
+
sdQ = cute.make_tensor(
|
| 313 |
+
cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype),
|
| 314 |
+
sdQ_layout.outer,
|
| 315 |
+
)[None, None, 0]
|
| 316 |
+
sdQt = layout_utils.transpose_view(sdQ)
|
| 317 |
+
|
| 318 |
+
# Thread index, block index
|
| 319 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 320 |
+
|
| 321 |
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
| 322 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 323 |
+
|
| 324 |
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 325 |
+
|
| 326 |
+
if work_tile.is_valid_tile:
|
| 327 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 328 |
+
# Get the appropriate tiles for this thread block.
|
| 329 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 330 |
+
|
| 331 |
+
seqlen = SeqlenInfoQK.create(
|
| 332 |
+
batch_idx,
|
| 333 |
+
mdQ.shape[1],
|
| 334 |
+
0,
|
| 335 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 336 |
+
mCuSeqlensK=None,
|
| 337 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 338 |
+
mSeqUsedK=None,
|
| 339 |
+
tile_m=self.tile_m * self.cluster_size,
|
| 340 |
+
)
|
| 341 |
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
| 342 |
+
mdQ_cur = mdQ[batch_idx, None, head_idx, None]
|
| 343 |
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 344 |
+
head_dim = mdQ.shape[3]
|
| 345 |
+
else:
|
| 346 |
+
if cutlass.const_expr(self.arch >= 90):
|
| 347 |
+
padded_offset_q = seqlen.padded_offset_q
|
| 348 |
+
else:
|
| 349 |
+
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
|
| 350 |
+
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
|
| 351 |
+
mdQaccum_cur = cute.domain_offset(
|
| 352 |
+
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
|
| 353 |
+
)
|
| 354 |
+
head_dim = mdQ.shape[2]
|
| 355 |
+
|
| 356 |
+
# HACK: Compiler doesn't seem to recognize that padding
|
| 357 |
+
# by padded_offset_q * self.tile_hdim keeps alignment
|
| 358 |
+
# since statically divisible by 4
|
| 359 |
+
|
| 360 |
+
mdQaccum_cur_ptr = cute.make_ptr(
|
| 361 |
+
dtype=mdQaccum_cur.element_type,
|
| 362 |
+
value=mdQaccum_cur.iterator.toint(),
|
| 363 |
+
mem_space=mdQaccum_cur.iterator.memspace,
|
| 364 |
+
assumed_align=mdQaccum.iterator.alignment,
|
| 365 |
+
)
|
| 366 |
+
mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
|
| 367 |
+
|
| 368 |
+
gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
|
| 369 |
+
gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
| 370 |
+
|
| 371 |
+
seqlen_q = seqlen.seqlen_q
|
| 372 |
+
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
|
| 373 |
+
|
| 374 |
+
if const_expr(self.arch == 100 and self.use_2cta_instrs):
|
| 375 |
+
# 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
|
| 376 |
+
num_reduce_threads = self.num_threads
|
| 377 |
+
thr_mma_dsk = tiled_mma.get_slice(tidx)
|
| 378 |
+
dQacc_shape = thr_mma_dsk.partition_shape_C((self.tile_m, self.tile_hdim))
|
| 379 |
+
tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)
|
| 380 |
+
tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)
|
| 381 |
+
|
| 382 |
+
tmem_load_atom = cute.make_copy_atom(
|
| 383 |
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
|
| 384 |
+
)
|
| 385 |
+
tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
|
| 386 |
+
thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)
|
| 387 |
+
|
| 388 |
+
cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
| 389 |
+
tdQcdQ = thr_mma_dsk.partition_C(cdQ)
|
| 390 |
+
tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)
|
| 391 |
+
tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)
|
| 392 |
+
|
| 393 |
+
tiled_copy_accum = s2r_tiled_copy_dQaccum
|
| 394 |
+
g2s_thr_copy = tiled_copy_accum.get_slice(tidx)
|
| 395 |
+
|
| 396 |
+
# S -> R
|
| 397 |
+
tdQrdQ_fp32 = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)
|
| 398 |
+
tdQrdQ_s2r = cute.make_tensor(tdQrdQ_fp32.iterator, tdQrdQ_fp32.shape)
|
| 399 |
+
|
| 400 |
+
smem_copy_atom = sm100_utils_basic.get_smem_store_op(
|
| 401 |
+
LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld
|
| 402 |
+
)
|
| 403 |
+
r2s_tiled_copy = cute.make_tiled_copy(
|
| 404 |
+
smem_copy_atom,
|
| 405 |
+
layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,
|
| 406 |
+
tiler_mn=tiled_tmem_ld.tiler_mn,
|
| 407 |
+
)
|
| 408 |
+
tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))
|
| 409 |
+
tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)
|
| 410 |
+
|
| 411 |
+
num_stages = cute.size(tdQrdQ_fp32, mode=[1])
|
| 412 |
+
stage_stride = self.dQ_reduce_ncol
|
| 413 |
+
row_groups = 2
|
| 414 |
+
assert num_stages % row_groups == 0
|
| 415 |
+
assert num_reduce_threads % row_groups == 0
|
| 416 |
+
stage_groups = num_stages // row_groups
|
| 417 |
+
threads_per_row_group = num_reduce_threads // row_groups
|
| 418 |
+
stage_loads = tuple((row_group, row_group) for row_group in range(row_groups))
|
| 419 |
+
stage_iters = tuple(
|
| 420 |
+
(row_group, row_group * threads_per_row_group)
|
| 421 |
+
for row_group in range(row_groups)
|
| 422 |
+
)
|
| 423 |
+
s2r_lane = tidx % threads_per_row_group
|
| 424 |
+
s2r_buf = tidx // threads_per_row_group
|
| 425 |
+
|
| 426 |
+
gdQaccum_layout_g2s = cute.make_layout(
|
| 427 |
+
shape=(self.tile_m * self.dQ_reduce_ncol, 1), stride=(1, 0)
|
| 428 |
+
)
|
| 429 |
+
sdQaccum_g2s = g2s_thr_copy.partition_D(sdQaccum)
|
| 430 |
+
|
| 431 |
+
# G -> S
|
| 432 |
+
for stage_group in cutlass.range_constexpr(stage_groups):
|
| 433 |
+
for stage_offset, smem_buf in stage_loads:
|
| 434 |
+
stage_idx = stage_group + stage_offset * stage_groups
|
| 435 |
+
gdQaccum_stage = cute.local_tile(
|
| 436 |
+
gdQaccum,
|
| 437 |
+
(self.tile_m * self.dQ_reduce_ncol,),
|
| 438 |
+
(stage_idx,),
|
| 439 |
+
)
|
| 440 |
+
gdQaccum_stage_g2s = cute.make_tensor(
|
| 441 |
+
gdQaccum_stage.iterator,
|
| 442 |
+
gdQaccum_layout_g2s,
|
| 443 |
+
)
|
| 444 |
+
tdQgdQ = g2s_thr_copy.partition_S(gdQaccum_stage_g2s)
|
| 445 |
+
cute.copy(
|
| 446 |
+
g2s_thr_copy,
|
| 447 |
+
tdQgdQ[None, None, 0],
|
| 448 |
+
sdQaccum_g2s[None, None, smem_buf],
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
cute.arch.fence_view_async_shared()
|
| 452 |
+
cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)
|
| 453 |
+
|
| 454 |
+
# S -> R
|
| 455 |
+
for stage_offset, lane_offset in stage_iters:
|
| 456 |
+
stage_idx = stage_group + stage_offset * stage_groups
|
| 457 |
+
s2r_src_tidx = s2r_lane + lane_offset
|
| 458 |
+
s2r_thr_copy = tiled_copy_accum.get_slice(s2r_src_tidx)
|
| 459 |
+
sdQaccum_src = s2r_thr_copy.partition_S(sdQaccum)[None, None, s2r_buf]
|
| 460 |
+
|
| 461 |
+
tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage_idx, None, None]
|
| 462 |
+
tdQrdQ_r2s_cpy = cute.make_tensor(
|
| 463 |
+
tdQrdQ_s2r_cpy.iterator, cute.make_layout(sdQaccum_src.shape)
|
| 464 |
+
)
|
| 465 |
+
cute.copy(s2r_thr_copy, sdQaccum_src, tdQrdQ_r2s_cpy)
|
| 466 |
+
cute.arch.fence_view_async_shared()
|
| 467 |
+
cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)
|
| 468 |
+
|
| 469 |
+
# R -> S
|
| 470 |
+
stage_lo = stage_idx % stage_stride
|
| 471 |
+
stage_hi = stage_idx // stage_stride
|
| 472 |
+
tdQrdQ_r2s_cpy = cute.make_tensor(
|
| 473 |
+
cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),
|
| 474 |
+
tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].shape,
|
| 475 |
+
)
|
| 476 |
+
dQ_vec = tdQrdQ_r2s_cpy.load() * scale
|
| 477 |
+
tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].store(
|
| 478 |
+
dQ_vec.to(self.dtype)
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# R -> S
|
| 482 |
+
cute.copy(
|
| 483 |
+
r2s_tiled_copy,
|
| 484 |
+
tdQrdQ_r2s[None, None, None, 0],
|
| 485 |
+
tdQsdQ_r2s[None, None, None, 0],
|
| 486 |
+
)
|
| 487 |
+
cute.arch.fence_view_async_shared()
|
| 488 |
+
cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)
|
| 489 |
+
else:
|
| 490 |
+
# Step 1: load dQaccum from gmem to smem
|
| 491 |
+
g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx)
|
| 492 |
+
tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum)
|
| 493 |
+
tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat)
|
| 494 |
+
cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s)
|
| 495 |
+
cute.arch.cp_async_commit_group()
|
| 496 |
+
cute.arch.cp_async_wait_group(0)
|
| 497 |
+
cute.arch.barrier()
|
| 498 |
+
|
| 499 |
+
# Step 2: load dQ from smem to rmem
|
| 500 |
+
s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx)
|
| 501 |
+
tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
|
| 502 |
+
tile_shape = (self.tile_m, self.tile_hdim)
|
| 503 |
+
acc = None
|
| 504 |
+
tiled_copy_t2r = None
|
| 505 |
+
if const_expr(self.arch in [80, 90]):
|
| 506 |
+
acc_shape = tiled_mma.partition_shape_C(
|
| 507 |
+
tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
|
| 508 |
+
)
|
| 509 |
+
acc = cute.make_fragment(acc_shape, cutlass.Float32)
|
| 510 |
+
assert cute.size(acc) == cute.size(tdQsdQaccum)
|
| 511 |
+
else:
|
| 512 |
+
thr_mma = tiled_mma.get_slice(0) # 1-CTA
|
| 513 |
+
dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim))
|
| 514 |
+
tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape)
|
| 515 |
+
tdQcdQ = thr_mma.partition_C(
|
| 516 |
+
cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
| 517 |
+
)
|
| 518 |
+
tmem_load_atom = cute.make_copy_atom(
|
| 519 |
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)),
|
| 520 |
+
Float32,
|
| 521 |
+
)
|
| 522 |
+
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
|
| 523 |
+
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
| 524 |
+
tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
|
| 525 |
+
acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
|
| 526 |
+
tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape))
|
| 527 |
+
cute.autovec_copy(tdQsdQaccum, tdQrdQaccum)
|
| 528 |
+
# Convert tdQrdQaccum from fp32 to fp16/bf16
|
| 529 |
+
rdQ = cute.make_fragment_like(acc, self.dtype)
|
| 530 |
+
rdQ.store((acc.load() * scale).to(self.dtype))
|
| 531 |
+
|
| 532 |
+
# Step 3: Copy dQ from register to smem
|
| 533 |
+
cute.arch.barrier() # make sure all threads have finished loading dQaccum
|
| 534 |
+
if const_expr(self.arch in [80, 90]):
|
| 535 |
+
copy_atom_r2s_dQ = utils.get_smem_store_atom(
|
| 536 |
+
self.arch, self.dtype, transpose=self.dQ_swapAB
|
| 537 |
+
)
|
| 538 |
+
tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma)
|
| 539 |
+
else:
|
| 540 |
+
# copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op(
|
| 541 |
+
# LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r,
|
| 542 |
+
# )
|
| 543 |
+
# tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r)
|
| 544 |
+
thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads
|
| 545 |
+
val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width))
|
| 546 |
+
copy_atom_r2s_dQ = cute.make_copy_atom(
|
| 547 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 548 |
+
self.dtype,
|
| 549 |
+
num_bits_per_copy=128,
|
| 550 |
+
)
|
| 551 |
+
tiled_copy_r2s_dQ = cute.make_tiled_copy_tv(
|
| 552 |
+
copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ
|
| 553 |
+
)
|
| 554 |
+
thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
|
| 555 |
+
cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
| 556 |
+
if const_expr(self.arch in [80, 90]):
|
| 557 |
+
taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
|
| 558 |
+
else:
|
| 559 |
+
taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
|
| 560 |
+
taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape)
|
| 561 |
+
taccdQsdQ = thr_copy_r2s_dQ.partition_D(
|
| 562 |
+
sdQ if const_expr(not self.dQ_swapAB) else sdQt
|
| 563 |
+
)
|
| 564 |
+
cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ)
|
| 565 |
+
|
| 566 |
+
# Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
|
| 567 |
+
cute.arch.barrier() # make sure all smem stores are done
|
| 568 |
+
gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx)
|
| 569 |
+
tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ)
|
| 570 |
+
tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ)
|
| 571 |
+
tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype)
|
| 572 |
+
# TODO: check OOB when reading from smem if kBlockM isn't evenly tiled
|
| 573 |
+
cute.autovec_copy(tdQsdQ, tdQrdQ)
|
| 574 |
+
|
| 575 |
+
# Step 5: Copy dQ from register to gmem
|
| 576 |
+
tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ)
|
| 577 |
+
tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim)
|
| 578 |
+
for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True):
|
| 579 |
+
if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m:
|
| 580 |
+
cute.copy(
|
| 581 |
+
gmem_tiled_copy_dQ,
|
| 582 |
+
tdQrdQ[None, rest_m, None],
|
| 583 |
+
tdQgdQ[None, rest_m, None],
|
| 584 |
+
pred=tdQpdQ[None, rest_m, None],
|
| 585 |
+
)
|
build/torch-cuda/flash_bwd_preprocess.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
|
| 3 |
+
# from Cutlass C++ to Cute-DSL.
|
| 4 |
+
import math
|
| 5 |
+
import operator
|
| 6 |
+
from typing import Callable, Type, Optional, Literal
|
| 7 |
+
|
| 8 |
+
import cuda.bindings.driver as cuda
|
| 9 |
+
|
| 10 |
+
import cutlass
|
| 11 |
+
import cutlass.cute as cute
|
| 12 |
+
from cutlass import Float32
|
| 13 |
+
|
| 14 |
+
from .quack import copy_utils
|
| 15 |
+
|
| 16 |
+
from . import utils
|
| 17 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 18 |
+
from .seqlen_info import SeqlenInfoQK
|
| 19 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 20 |
+
from .tile_scheduler import (
|
| 21 |
+
SingleTileScheduler,
|
| 22 |
+
SingleTileVarlenScheduler,
|
| 23 |
+
TileSchedulerArguments,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FlashAttentionBackwardPreprocess:
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dtype: Type[cutlass.Numeric],
|
| 31 |
+
head_dim: int,
|
| 32 |
+
head_dim_v: int,
|
| 33 |
+
arch: Literal[80, 90, 100],
|
| 34 |
+
m_block_size: int = 128,
|
| 35 |
+
num_threads: int = 128,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
| 39 |
+
should be a multiple of 8.
|
| 40 |
+
|
| 41 |
+
:param head_dim: head dimension
|
| 42 |
+
:type head_dim: int
|
| 43 |
+
:param m_block_size: m block size
|
| 44 |
+
:type m_block_size: int
|
| 45 |
+
:param num_threads: number of threads
|
| 46 |
+
:type num_threads: int
|
| 47 |
+
"""
|
| 48 |
+
self.dtype = dtype
|
| 49 |
+
self.m_block_size = m_block_size
|
| 50 |
+
self.arch = arch
|
| 51 |
+
# padding head_dim to a multiple of 32 as k_block_size
|
| 52 |
+
hdim_multiple_of = 32
|
| 53 |
+
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
| 54 |
+
self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
| 55 |
+
self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
|
| 56 |
+
self.num_threads = num_threads
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
|
| 60 |
+
"""Check if the kernel can be implemented with the given parameters.
|
| 61 |
+
|
| 62 |
+
:param dtype: data type
|
| 63 |
+
:type dtype: cutlass.Numeric
|
| 64 |
+
:param head_dim: head dimension
|
| 65 |
+
:type head_dim: int
|
| 66 |
+
:param m_block_size: m block size
|
| 67 |
+
:type m_block_size: int
|
| 68 |
+
:param num_threads: number of threads
|
| 69 |
+
:type num_threads: int
|
| 70 |
+
|
| 71 |
+
:return: True if the kernel can be implemented, False otherwise
|
| 72 |
+
:rtype: bool
|
| 73 |
+
"""
|
| 74 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 75 |
+
return False
|
| 76 |
+
if head_dim % 8 != 0:
|
| 77 |
+
return False
|
| 78 |
+
if num_threads % 32 != 0:
|
| 79 |
+
return False
|
| 80 |
+
if num_threads < m_block_size: # For multiplying lse with log2
|
| 81 |
+
return False
|
| 82 |
+
return True
|
| 83 |
+
|
| 84 |
+
def _setup_attributes(self):
|
| 85 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 86 |
+
# GMEM Tiled copy:
|
| 87 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
# Thread layouts for copies
|
| 89 |
+
# We want kBlockKGmem to be a power of 2 so that when we do the summing,
|
| 90 |
+
# it's just between threads in the same warp
|
| 91 |
+
gmem_k_block_size = (
|
| 92 |
+
128
|
| 93 |
+
if self.head_dim_v_padded % 128 == 0
|
| 94 |
+
else (
|
| 95 |
+
64
|
| 96 |
+
if self.head_dim_v_padded % 64 == 0
|
| 97 |
+
else (32 if self.head_dim_v_padded % 32 == 0 else 16)
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
num_copy_elems = 128 // self.dtype.width
|
| 101 |
+
threads_per_row = gmem_k_block_size // num_copy_elems
|
| 102 |
+
self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(
|
| 103 |
+
self.dtype, threads_per_row, self.num_threads, num_copy_elems
|
| 104 |
+
)
|
| 105 |
+
universal_copy_bits = 128
|
| 106 |
+
num_copy_elems_dQaccum = universal_copy_bits // Float32.width
|
| 107 |
+
assert (
|
| 108 |
+
self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum
|
| 109 |
+
) % self.num_threads == 0
|
| 110 |
+
self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
| 111 |
+
Float32, self.num_threads, num_copy_elems_dQaccum
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
@cute.jit
|
| 115 |
+
def __call__(
|
| 116 |
+
self,
|
| 117 |
+
mO: cute.Tensor,
|
| 118 |
+
mdO: cute.Tensor,
|
| 119 |
+
mdPsum: cute.Tensor,
|
| 120 |
+
mLSE: Optional[cute.Tensor],
|
| 121 |
+
mLSElog2: Optional[cute.Tensor],
|
| 122 |
+
mdQaccum: Optional[cute.Tensor],
|
| 123 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 124 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 125 |
+
stream: cuda.CUstream,
|
| 126 |
+
):
|
| 127 |
+
# Get the data type and check if it is fp16 or bf16
|
| 128 |
+
if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
|
| 129 |
+
raise TypeError("All tensors must have the same data type")
|
| 130 |
+
if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
| 131 |
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 132 |
+
if cutlass.const_expr(mdPsum.element_type not in [Float32]):
|
| 133 |
+
raise TypeError("dPsum tensor must be Float32")
|
| 134 |
+
if cutlass.const_expr(mdQaccum is not None):
|
| 135 |
+
if cutlass.const_expr(mdQaccum.element_type not in [Float32]):
|
| 136 |
+
raise TypeError("dQaccum tensor must be Float32")
|
| 137 |
+
if cutlass.const_expr(mLSE is not None):
|
| 138 |
+
assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
|
| 139 |
+
if cutlass.const_expr(mLSE.element_type not in [Float32]):
|
| 140 |
+
raise TypeError("LSE tensor must be Float32")
|
| 141 |
+
if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
|
| 142 |
+
raise TypeError("LSElog2 tensor must be Float32")
|
| 143 |
+
|
| 144 |
+
mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)]
|
| 145 |
+
|
| 146 |
+
self._setup_attributes()
|
| 147 |
+
|
| 148 |
+
if cutlass.const_expr(mCuSeqlensQ is not None):
|
| 149 |
+
TileScheduler = SingleTileVarlenScheduler
|
| 150 |
+
num_head = mO.shape[1]
|
| 151 |
+
num_batch = mCuSeqlensQ.shape[0] - 1
|
| 152 |
+
else:
|
| 153 |
+
TileScheduler = SingleTileScheduler
|
| 154 |
+
num_head = mO.shape[2]
|
| 155 |
+
num_batch = mO.shape[0]
|
| 156 |
+
|
| 157 |
+
tile_sched_args = TileSchedulerArguments(
|
| 158 |
+
num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
|
| 159 |
+
num_head=num_head,
|
| 160 |
+
num_batch=num_batch,
|
| 161 |
+
num_splits=1,
|
| 162 |
+
seqlen_k=0,
|
| 163 |
+
headdim=0,
|
| 164 |
+
headdim_v=mO.shape[2],
|
| 165 |
+
total_q=mO.shape[0],
|
| 166 |
+
tile_shape_mn=(self.m_block_size, 1),
|
| 167 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 168 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 172 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 173 |
+
|
| 174 |
+
self.kernel(
|
| 175 |
+
mO,
|
| 176 |
+
mdO,
|
| 177 |
+
mdPsum,
|
| 178 |
+
mLSE,
|
| 179 |
+
mLSElog2,
|
| 180 |
+
mdQaccum,
|
| 181 |
+
mCuSeqlensQ,
|
| 182 |
+
mSeqUsedQ,
|
| 183 |
+
self.gmem_tiled_copy_O,
|
| 184 |
+
self.gmem_tiled_copy_dQaccum,
|
| 185 |
+
tile_sched_params,
|
| 186 |
+
TileScheduler,
|
| 187 |
+
).launch(
|
| 188 |
+
grid=grid_dim,
|
| 189 |
+
block=[self.num_threads, 1, 1],
|
| 190 |
+
stream=stream,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
@cute.kernel
|
| 194 |
+
def kernel(
|
| 195 |
+
self,
|
| 196 |
+
mO: cute.Tensor,
|
| 197 |
+
mdO: cute.Tensor,
|
| 198 |
+
mdPsum: cute.Tensor,
|
| 199 |
+
mLSE: Optional[cute.Tensor],
|
| 200 |
+
mLSElog2: Optional[cute.Tensor],
|
| 201 |
+
mdQaccum: Optional[cute.Tensor],
|
| 202 |
+
mCuSeqlensQ: Optional[cute.Tensor],
|
| 203 |
+
mSeqUsedQ: Optional[cute.Tensor],
|
| 204 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 205 |
+
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
| 206 |
+
tile_sched_params: ParamsBase,
|
| 207 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 208 |
+
):
|
| 209 |
+
# Thread index, block index
|
| 210 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 211 |
+
|
| 212 |
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
| 213 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 214 |
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 215 |
+
|
| 216 |
+
if work_tile.is_valid_tile:
|
| 217 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 218 |
+
# Get the appropriate tiles for this thread block.
|
| 219 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 220 |
+
seqlen = SeqlenInfoQK.create(
|
| 221 |
+
batch_idx,
|
| 222 |
+
mO.shape[1],
|
| 223 |
+
0,
|
| 224 |
+
mCuSeqlensQ=mCuSeqlensQ,
|
| 225 |
+
mCuSeqlensK=None,
|
| 226 |
+
mSeqUsedQ=mSeqUsedQ,
|
| 227 |
+
mSeqUsedK=None,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
| 231 |
+
mO_cur = mO[batch_idx, None, head_idx, None]
|
| 232 |
+
mdO_cur = mdO[batch_idx, None, head_idx, None]
|
| 233 |
+
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
| 234 |
+
headdim_v = mO.shape[3]
|
| 235 |
+
else:
|
| 236 |
+
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
|
| 237 |
+
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
| 238 |
+
|
| 239 |
+
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
| 240 |
+
if cutlass.const_expr(self.arch >= 90):
|
| 241 |
+
padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
|
| 242 |
+
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
| 243 |
+
headdim_v = mO.shape[2]
|
| 244 |
+
|
| 245 |
+
blkOdO_shape = (self.m_block_size, self.head_dim_v_padded)
|
| 246 |
+
# (m_block_size, head_dim_v)
|
| 247 |
+
gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
|
| 248 |
+
gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
|
| 249 |
+
|
| 250 |
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 251 |
+
# (CPY_Atom, CPY_M, CPY_K)
|
| 252 |
+
tOgO = gmem_thr_copy_O.partition_S(gO)
|
| 253 |
+
tOgdO = gmem_thr_copy_O.partition_S(gdO)
|
| 254 |
+
|
| 255 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 256 |
+
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
| 257 |
+
# of tile_shape
|
| 258 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 259 |
+
# Construct identity layout for KV
|
| 260 |
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
| 261 |
+
tOcO = gmem_thr_copy_O.partition_S(cO)
|
| 262 |
+
t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
|
| 263 |
+
tOpO = utils.predicate_k(tOcO, limit=headdim_v)
|
| 264 |
+
tOpdO = utils.predicate_k(tOcO, limit=headdim_v)
|
| 265 |
+
|
| 266 |
+
seqlen_q = seqlen.seqlen_q
|
| 267 |
+
seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)
|
| 268 |
+
|
| 269 |
+
if cutlass.const_expr(mLSE is not None):
|
| 270 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
| 271 |
+
mLSE_cur = mLSE[batch_idx, head_idx, None]
|
| 272 |
+
else:
|
| 273 |
+
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
|
| 274 |
+
|
| 275 |
+
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
|
| 276 |
+
lse = Float32.inf
|
| 277 |
+
if tidx < seqlen_q - m_block * self.m_block_size:
|
| 278 |
+
lse = gLSE[tidx]
|
| 279 |
+
|
| 280 |
+
tOrO = cute.make_fragment_like(tOgO)
|
| 281 |
+
tOrdO = cute.make_fragment_like(tOgdO)
|
| 282 |
+
assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
|
| 283 |
+
assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
|
| 284 |
+
assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
|
| 285 |
+
for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
|
| 286 |
+
# Instead of using tOcO, we using t0OcO and subtract the offset from the limit
|
| 287 |
+
# (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
|
| 288 |
+
if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
|
| 289 |
+
cute.copy(
|
| 290 |
+
gmem_thr_copy_O,
|
| 291 |
+
tOgO[None, m, None],
|
| 292 |
+
tOrO[None, m, None],
|
| 293 |
+
pred=tOpO[None, m, None]
|
| 294 |
+
if cutlass.const_expr(self.check_hdim_v_oob)
|
| 295 |
+
else None,
|
| 296 |
+
)
|
| 297 |
+
cute.copy(
|
| 298 |
+
gmem_thr_copy_O,
|
| 299 |
+
tOgdO[None, m, None],
|
| 300 |
+
tOrdO[None, m, None],
|
| 301 |
+
pred=tOpdO[None, m, None]
|
| 302 |
+
if cutlass.const_expr(self.check_hdim_v_oob)
|
| 303 |
+
else None,
|
| 304 |
+
)
|
| 305 |
+
# Sum across the "k" dimension
|
| 306 |
+
dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
|
| 307 |
+
cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
|
| 308 |
+
)
|
| 309 |
+
threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
|
| 310 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0
|
| 311 |
+
dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row)
|
| 312 |
+
dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32)
|
| 313 |
+
dP_sum.store(dpsum)
|
| 314 |
+
|
| 315 |
+
# Write dPsum from rmem -> gmem
|
| 316 |
+
gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,))
|
| 317 |
+
# Only the thread corresponding to column 0 writes out the dPsum to gmem
|
| 318 |
+
if tOcO[0, 0, 0][1] == 0:
|
| 319 |
+
for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
|
| 320 |
+
row = tOcO[0, m, 0][0]
|
| 321 |
+
gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0
|
| 322 |
+
|
| 323 |
+
# Clear dQaccum
|
| 324 |
+
if cutlass.const_expr(mdQaccum is not None):
|
| 325 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
| 326 |
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
| 327 |
+
else:
|
| 328 |
+
mdQaccum_cur = cute.domain_offset(
|
| 329 |
+
(padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# HACK: Compiler doesn't seem to recognize that padding
|
| 333 |
+
# by padded_offset_q * self.head_dim_padded keeps alignment
|
| 334 |
+
# since statically divisible by 4
|
| 335 |
+
|
| 336 |
+
mdQaccum_cur_ptr = cute.make_ptr(
|
| 337 |
+
dtype=mdQaccum_cur.element_type,
|
| 338 |
+
value=mdQaccum_cur.iterator.toint(),
|
| 339 |
+
mem_space=mdQaccum_cur.iterator.memspace,
|
| 340 |
+
assumed_align=mdQaccum.iterator.alignment,
|
| 341 |
+
)
|
| 342 |
+
mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
|
| 343 |
+
|
| 344 |
+
blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
|
| 345 |
+
gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
|
| 346 |
+
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
| 347 |
+
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
| 348 |
+
zero = cute.make_fragment_like(tdQgdQaccum)
|
| 349 |
+
zero.fill(0.0)
|
| 350 |
+
cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
|
| 351 |
+
|
| 352 |
+
if cutlass.const_expr(mLSE is not None):
|
| 353 |
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
| 354 |
+
mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
|
| 355 |
+
else:
|
| 356 |
+
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
|
| 357 |
+
|
| 358 |
+
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
|
| 359 |
+
LOG2_E = math.log2(math.e)
|
| 360 |
+
if tidx < seqlen_q_rounded - m_block * self.m_block_size:
|
| 361 |
+
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
|
build/torch-cuda/flash_bwd_sm100.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/flash_bwd_sm90.py
ADDED
|
@@ -0,0 +1,1591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Callable, Optional, Type
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import cuda.bindings.driver as cuda
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
| 10 |
+
from cutlass.cute.nvgpu import cpasync, warpgroup
|
| 11 |
+
from cutlass.cute import FastDivmodDivisor
|
| 12 |
+
from cutlass import Float32, Int32, Boolean, const_expr
|
| 13 |
+
from cutlass.utils import LayoutEnum
|
| 14 |
+
|
| 15 |
+
from .quack import copy_utils
|
| 16 |
+
from .quack import layout_utils
|
| 17 |
+
from .quack import sm90_utils
|
| 18 |
+
from .quack.sm90_utils import gemm_zero_init, gemm_w_idx
|
| 19 |
+
|
| 20 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 21 |
+
from . import utils
|
| 22 |
+
from .mask import AttentionMask
|
| 23 |
+
from .seqlen_info import SeqlenInfoQK
|
| 24 |
+
from .block_info import BlockInfo
|
| 25 |
+
from . import pipeline
|
| 26 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 27 |
+
from .tile_scheduler import TileSchedulerArguments, SingleTileScheduler
|
| 28 |
+
from .named_barrier import NamedBarrierBwd
|
| 29 |
+
from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
|
| 30 |
+
from .block_sparsity import BlockSparseTensors
|
| 31 |
+
from .block_sparse_utils import (
|
| 32 |
+
get_total_q_block_count_bwd,
|
| 33 |
+
produce_block_sparse_q_loads_bwd_sm90,
|
| 34 |
+
consume_block_sparse_mma_bwd_sm90,
|
| 35 |
+
dQaccum_store_block_sparse_bwd_sm90,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FlashAttentionBackwardSm90:
|
| 40 |
+
arch = 90
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dtype: Type[cutlass.Numeric],
|
| 45 |
+
head_dim: int,
|
| 46 |
+
head_dim_v: Optional[int] = None,
|
| 47 |
+
qhead_per_kvhead: int = 1,
|
| 48 |
+
is_causal: bool = False,
|
| 49 |
+
tile_m: int = 64,
|
| 50 |
+
tile_n: int = 128,
|
| 51 |
+
Q_stage: int = 2,
|
| 52 |
+
dO_stage: int = 2,
|
| 53 |
+
PdS_stage: int = 2,
|
| 54 |
+
SdP_swapAB: bool = False,
|
| 55 |
+
dKV_swapAB: bool = False,
|
| 56 |
+
dQ_swapAB: bool = False,
|
| 57 |
+
AtomLayoutMSdP: int = 1,
|
| 58 |
+
AtomLayoutNdKV: int = 2,
|
| 59 |
+
AtomLayoutMdQ: int = 1,
|
| 60 |
+
num_threads: int = 384,
|
| 61 |
+
V_in_regs: bool = False,
|
| 62 |
+
score_mod: cutlass.Constexpr | None = None,
|
| 63 |
+
score_mod_bwd: cutlass.Constexpr | None = None,
|
| 64 |
+
mask_mod: cutlass.Constexpr | None = None,
|
| 65 |
+
has_aux_tensors: cutlass.Constexpr = False,
|
| 66 |
+
subtile_factor: cutlass.Constexpr[int] = 1,
|
| 67 |
+
):
|
| 68 |
+
self.dtype = dtype
|
| 69 |
+
# padding head_dim to a multiple of 16 as k_block_size
|
| 70 |
+
hdim_multiple_of = 16
|
| 71 |
+
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
| 72 |
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
| 73 |
+
self.same_hdim_kv = head_dim == head_dim_v
|
| 74 |
+
self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
| 75 |
+
# Can save registers (and hence be faster) if we don't have to check hdim predication
|
| 76 |
+
self.check_hdim_oob = head_dim != self.tile_hdim
|
| 77 |
+
self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
|
| 78 |
+
self.qhead_per_kvhead = qhead_per_kvhead
|
| 79 |
+
self.is_causal = is_causal
|
| 80 |
+
self.is_local = False
|
| 81 |
+
self.tile_m = tile_m
|
| 82 |
+
self.tile_n = tile_n
|
| 83 |
+
self.num_threads = num_threads
|
| 84 |
+
self.Q_stage = Q_stage
|
| 85 |
+
self.dO_stage = dO_stage
|
| 86 |
+
self.PdS_stage = PdS_stage
|
| 87 |
+
assert self.dO_stage in [1, self.Q_stage]
|
| 88 |
+
assert self.PdS_stage in [1, self.Q_stage]
|
| 89 |
+
self.SdP_swapAB = SdP_swapAB
|
| 90 |
+
self.dKV_swapAB = dKV_swapAB
|
| 91 |
+
self.dQ_swapAB = dQ_swapAB
|
| 92 |
+
self.AtomLayoutMSdP = AtomLayoutMSdP
|
| 93 |
+
self.AtomLayoutNdKV = AtomLayoutNdKV
|
| 94 |
+
self.AtomLayoutMdQ = AtomLayoutMdQ
|
| 95 |
+
self.num_mma_warp_groups = (self.num_threads // 128) - 1
|
| 96 |
+
self.mma_dkv_is_rs = (
|
| 97 |
+
AtomLayoutMSdP == 1
|
| 98 |
+
and AtomLayoutNdKV == self.num_mma_warp_groups
|
| 99 |
+
and SdP_swapAB
|
| 100 |
+
and not dKV_swapAB
|
| 101 |
+
)
|
| 102 |
+
self.V_in_regs = V_in_regs
|
| 103 |
+
if qhead_per_kvhead > 1:
|
| 104 |
+
assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
|
| 105 |
+
assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups"
|
| 106 |
+
# These are tuned for speed
|
| 107 |
+
# Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
|
| 108 |
+
# them and then shuffle to get the value whenever we need? This can reduce register
|
| 109 |
+
# pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
|
| 110 |
+
# rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
|
| 111 |
+
# TODO: impl these for hdim 64
|
| 112 |
+
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
|
| 113 |
+
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
|
| 114 |
+
|
| 115 |
+
self.buffer_align_bytes = 1024
|
| 116 |
+
|
| 117 |
+
self.score_mod = score_mod
|
| 118 |
+
self.score_mod_bwd = score_mod_bwd
|
| 119 |
+
self.mask_mod = mask_mod
|
| 120 |
+
self.has_aux_tensors = has_aux_tensors
|
| 121 |
+
self.subtile_factor = subtile_factor
|
| 122 |
+
if cutlass.const_expr(has_aux_tensors):
|
| 123 |
+
self.vec_size: cutlass.Constexpr = 1
|
| 124 |
+
else:
|
| 125 |
+
self.vec_size: cutlass.Constexpr = 4
|
| 126 |
+
self.qk_acc_dtype = Float32
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def can_implement(
|
| 130 |
+
dtype,
|
| 131 |
+
head_dim,
|
| 132 |
+
head_dim_v,
|
| 133 |
+
tile_m,
|
| 134 |
+
tile_n,
|
| 135 |
+
Q_stage,
|
| 136 |
+
num_threads,
|
| 137 |
+
V_in_regs=False,
|
| 138 |
+
) -> bool:
|
| 139 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
| 140 |
+
return False
|
| 141 |
+
if head_dim % 8 != 0:
|
| 142 |
+
return False
|
| 143 |
+
if head_dim_v % 8 != 0:
|
| 144 |
+
return False
|
| 145 |
+
if tile_n % 16 != 0:
|
| 146 |
+
return False
|
| 147 |
+
if num_threads % 32 != 0:
|
| 148 |
+
return False
|
| 149 |
+
if (tile_m * 2) % num_threads != 0:
|
| 150 |
+
return False
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
def _check_type(
|
| 154 |
+
self,
|
| 155 |
+
mQ_type: Type[cutlass.Numeric],
|
| 156 |
+
mK_type: Type[cutlass.Numeric],
|
| 157 |
+
mV_type: Type[cutlass.Numeric],
|
| 158 |
+
mdO_type: Type[cutlass.Numeric],
|
| 159 |
+
mLSE_type: Type[cutlass.Numeric],
|
| 160 |
+
mdPsum_type: Type[cutlass.Numeric],
|
| 161 |
+
mdQaccum_type: Type[cutlass.Numeric],
|
| 162 |
+
mdK_type: Type[cutlass.Numeric],
|
| 163 |
+
mdV_type: Type[cutlass.Numeric],
|
| 164 |
+
):
|
| 165 |
+
# Get the data type and check if it is fp16 or bf16
|
| 166 |
+
if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
|
| 167 |
+
raise TypeError("All tensors must have the same data type")
|
| 168 |
+
if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
|
| 169 |
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
| 170 |
+
if const_expr(mLSE_type not in [Float32]):
|
| 171 |
+
raise TypeError("LSE tensor must be Float32")
|
| 172 |
+
if const_expr(mdPsum_type not in [Float32]):
|
| 173 |
+
raise TypeError("dPsum tensor must be Float32")
|
| 174 |
+
if const_expr(mdQaccum_type not in [Float32]):
|
| 175 |
+
raise TypeError("dQaccum tensor must be Float32")
|
| 176 |
+
if const_expr(self.qhead_per_kvhead == 1):
|
| 177 |
+
if const_expr(not (mdK_type == mdV_type == mQ_type)):
|
| 178 |
+
raise TypeError("mdK and mdV tensors must have the same data type as mQ")
|
| 179 |
+
else:
|
| 180 |
+
if const_expr(not (mdK_type == mdV_type == Float32)):
|
| 181 |
+
raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
|
| 182 |
+
assert mQ_type == self.dtype
|
| 183 |
+
|
| 184 |
+
def _setup_attributes(self):
|
| 185 |
+
self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [
|
| 186 |
+
sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage)
|
| 187 |
+
for shape, stage in [
|
| 188 |
+
((self.tile_m, self.tile_hdim), self.Q_stage),
|
| 189 |
+
((self.tile_n, self.tile_hdim), None),
|
| 190 |
+
((self.tile_n, self.tile_hdimv), None),
|
| 191 |
+
((self.tile_m, self.tile_hdimv), self.dO_stage),
|
| 192 |
+
((self.tile_m, self.tile_n), self.PdS_stage),
|
| 193 |
+
]
|
| 194 |
+
]
|
| 195 |
+
self.sdQaccum_layout = cute.make_layout(
|
| 196 |
+
(self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups)
|
| 197 |
+
)
|
| 198 |
+
# dQaccum R->S
|
| 199 |
+
self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
| 200 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 201 |
+
# thr_layout
|
| 202 |
+
cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
|
| 203 |
+
cute.make_layout(128 // Float32.width), # val_layout
|
| 204 |
+
)
|
| 205 |
+
# dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
|
| 206 |
+
# TODO: assert that sVaccum and sKaccum don't overflow smem
|
| 207 |
+
|
| 208 |
+
def _get_tiled_mma(self):
|
| 209 |
+
# S = Q @ K.T, dP = dO @ V.T
|
| 210 |
+
atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP)
|
| 211 |
+
tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
|
| 212 |
+
tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
|
| 213 |
+
self.dtype,
|
| 214 |
+
self.dtype,
|
| 215 |
+
warpgroup.OperandMajorMode.K,
|
| 216 |
+
warpgroup.OperandMajorMode.K,
|
| 217 |
+
Float32,
|
| 218 |
+
atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1])
|
| 219 |
+
+ (1,),
|
| 220 |
+
tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1],
|
| 221 |
+
)
|
| 222 |
+
# dV = P.T @ dO, dK = dS.T @ Q
|
| 223 |
+
atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV)
|
| 224 |
+
tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
|
| 225 |
+
tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
|
| 226 |
+
tiled_mma_dK, tiled_mma_dV = [
|
| 227 |
+
sm90_utils_basic.make_trivial_tiled_mma(
|
| 228 |
+
self.dtype,
|
| 229 |
+
self.dtype,
|
| 230 |
+
warpgroup.OperandMajorMode.MN
|
| 231 |
+
if not self.mma_dkv_is_rs
|
| 232 |
+
else warpgroup.OperandMajorMode.K,
|
| 233 |
+
warpgroup.OperandMajorMode.MN,
|
| 234 |
+
Float32,
|
| 235 |
+
atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1])
|
| 236 |
+
+ (1,),
|
| 237 |
+
tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1],
|
| 238 |
+
a_source=warpgroup.OperandSource.RMEM
|
| 239 |
+
if self.mma_dkv_is_rs
|
| 240 |
+
else warpgroup.OperandSource.SMEM,
|
| 241 |
+
)
|
| 242 |
+
for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
|
| 243 |
+
]
|
| 244 |
+
# dQ = dS @ K
|
| 245 |
+
atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ)
|
| 246 |
+
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
| 247 |
+
tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
|
| 248 |
+
self.dtype,
|
| 249 |
+
self.dtype,
|
| 250 |
+
warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
|
| 251 |
+
warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
|
| 252 |
+
Float32,
|
| 253 |
+
atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,),
|
| 254 |
+
tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
|
| 255 |
+
)
|
| 256 |
+
return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
|
| 257 |
+
|
| 258 |
+
def _get_shared_storage_cls(self):
|
| 259 |
+
sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [
|
| 260 |
+
cute.struct.Align[cute.struct.MemRange[t, cute.cosize(layout)], self.buffer_align_bytes]
|
| 261 |
+
for (layout, t) in [
|
| 262 |
+
(self.sQ_layout, self.dtype),
|
| 263 |
+
(self.sK_layout, self.dtype),
|
| 264 |
+
(self.sV_layout, self.dtype),
|
| 265 |
+
(self.sdO_layout, self.dtype),
|
| 266 |
+
(self.sdQaccum_layout, Float32),
|
| 267 |
+
]
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
cosize_sdS = cute.cosize(self.sPdS_layout)
|
| 271 |
+
cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0
|
| 272 |
+
sLSE_struct = cute.struct.Align[
|
| 273 |
+
cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128
|
| 274 |
+
]
|
| 275 |
+
sdPsum_struct = cute.struct.Align[
|
| 276 |
+
cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
@cute.struct
|
| 280 |
+
class SharedStorageQKV:
|
| 281 |
+
mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2]
|
| 282 |
+
mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2]
|
| 283 |
+
sLSE: sLSE_struct
|
| 284 |
+
sdPsum: sdPsum_struct
|
| 285 |
+
sQ: sQ_struct
|
| 286 |
+
sV: sV_struct
|
| 287 |
+
sK: sK_struct
|
| 288 |
+
sdO: sdO_struct
|
| 289 |
+
sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
|
| 290 |
+
sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024]
|
| 291 |
+
sdQaccum: sdQaccum_struct
|
| 292 |
+
|
| 293 |
+
return SharedStorageQKV
|
| 294 |
+
|
| 295 |
+
@cute.jit
|
| 296 |
+
def __call__(
|
| 297 |
+
self,
|
| 298 |
+
mQ: cute.Tensor,
|
| 299 |
+
mK: cute.Tensor,
|
| 300 |
+
mV: cute.Tensor,
|
| 301 |
+
mdO: cute.Tensor,
|
| 302 |
+
mLSE: cute.Tensor,
|
| 303 |
+
mdPsum: cute.Tensor,
|
| 304 |
+
mdQaccum: cute.Tensor,
|
| 305 |
+
mdK: cute.Tensor,
|
| 306 |
+
mdV: cute.Tensor,
|
| 307 |
+
softmax_scale: Float32,
|
| 308 |
+
stream: cuda.CUstream,
|
| 309 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 310 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 311 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 312 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 313 |
+
softcap: Float32 | float | None = None,
|
| 314 |
+
window_size_left: Int32 | int | None = None,
|
| 315 |
+
window_size_right: Int32 | int | None = None,
|
| 316 |
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
| 317 |
+
mdK_semaphore: Optional[cute.Tensor] = None,
|
| 318 |
+
mdV_semaphore: Optional[cute.Tensor] = None,
|
| 319 |
+
aux_tensors: Optional[list] = None,
|
| 320 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 321 |
+
):
|
| 322 |
+
assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
|
| 323 |
+
"determinism not supported yet for Sm90"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
self._check_type(
|
| 327 |
+
*(
|
| 328 |
+
t.element_type if t is not None else None
|
| 329 |
+
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
|
| 334 |
+
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
|
| 338 |
+
mQ, mK, mV, mdO = [layout_utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)]
|
| 339 |
+
if const_expr(self.qhead_per_kvhead == 1):
|
| 340 |
+
mdK, mdV = [layout_utils.select(t, layout_transpose) for t in (mdK, mdV)]
|
| 341 |
+
else:
|
| 342 |
+
accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b)
|
| 343 |
+
mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
|
| 344 |
+
LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b)
|
| 345 |
+
mLSE, mdPsum, mdQaccum = [
|
| 346 |
+
layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
|
| 350 |
+
|
| 351 |
+
self.num_mma_threads = tiled_mma_SdP.size
|
| 352 |
+
assert self.num_mma_threads + 128 == self.num_threads
|
| 353 |
+
|
| 354 |
+
self.num_threads_per_warp_group = 128
|
| 355 |
+
self.num_producer_threads = 32
|
| 356 |
+
|
| 357 |
+
self.num_mma_regs = 240
|
| 358 |
+
self.num_producer_regs = 24
|
| 359 |
+
# self.num_mma_regs = 232
|
| 360 |
+
# self.num_producer_regs = 40
|
| 361 |
+
|
| 362 |
+
self._setup_attributes()
|
| 363 |
+
SharedStorage = self._get_shared_storage_cls()
|
| 364 |
+
|
| 365 |
+
self.tma_copy_bytes = {
|
| 366 |
+
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
|
| 367 |
+
for name, mX, layout in [
|
| 368 |
+
("Q", mQ, self.sQ_layout),
|
| 369 |
+
("K", mK, self.sK_layout),
|
| 370 |
+
("V", mV, self.sV_layout),
|
| 371 |
+
("dO", mdO, self.sdO_layout),
|
| 372 |
+
]
|
| 373 |
+
}
|
| 374 |
+
self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
|
| 375 |
+
self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
|
| 376 |
+
self.tma_copy_bytes["dQ"] = (
|
| 377 |
+
self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups
|
| 378 |
+
)
|
| 379 |
+
self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
|
| 380 |
+
self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
|
| 381 |
+
|
| 382 |
+
tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
|
| 383 |
+
cpasync.CopyBulkTensorTileG2SOp(),
|
| 384 |
+
mQ,
|
| 385 |
+
cute.select(self.sQ_layout, mode=[0, 1]),
|
| 386 |
+
(self.tile_m, self.tile_hdim),
|
| 387 |
+
)
|
| 388 |
+
tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
|
| 389 |
+
cpasync.CopyBulkTensorTileG2SOp(),
|
| 390 |
+
mK,
|
| 391 |
+
cute.select(self.sK_layout, mode=[0, 1]),
|
| 392 |
+
(self.tile_n, self.tile_hdim),
|
| 393 |
+
)
|
| 394 |
+
tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
|
| 395 |
+
cpasync.CopyBulkTensorTileG2SOp(),
|
| 396 |
+
mV,
|
| 397 |
+
cute.select(self.sV_layout, mode=[0, 1]),
|
| 398 |
+
(self.tile_n, self.tile_hdimv),
|
| 399 |
+
)
|
| 400 |
+
tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom(
|
| 401 |
+
cpasync.CopyBulkTensorTileG2SOp(),
|
| 402 |
+
mdO,
|
| 403 |
+
cute.select(self.sdO_layout, mode=[0, 1]),
|
| 404 |
+
(self.tile_m, self.tile_hdimv),
|
| 405 |
+
)
|
| 406 |
+
if const_expr(self.qhead_per_kvhead == 1):
|
| 407 |
+
tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
|
| 408 |
+
cpasync.CopyBulkTensorTileS2GOp(),
|
| 409 |
+
mdK,
|
| 410 |
+
cute.select(self.sK_layout, mode=[0, 1]),
|
| 411 |
+
(self.tile_n, self.tile_hdim),
|
| 412 |
+
)
|
| 413 |
+
tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
|
| 414 |
+
cpasync.CopyBulkTensorTileS2GOp(),
|
| 415 |
+
mdV,
|
| 416 |
+
cute.select(self.sV_layout, mode=[0, 1]),
|
| 417 |
+
(self.tile_n, self.tile_hdimv),
|
| 418 |
+
)
|
| 419 |
+
else:
|
| 420 |
+
tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
|
| 421 |
+
|
| 422 |
+
TileScheduler = SingleTileScheduler
|
| 423 |
+
tile_sched_args = TileSchedulerArguments(
|
| 424 |
+
cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
|
| 425 |
+
cute.size(mQ.shape[2]),
|
| 426 |
+
cute.size(mQ.shape[3]),
|
| 427 |
+
1, # num_splits
|
| 428 |
+
cute.size(mK.shape[0]),
|
| 429 |
+
mQ.shape[1],
|
| 430 |
+
mV.shape[1],
|
| 431 |
+
total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
| 432 |
+
tile_shape_mn=(self.tile_m, self.tile_n),
|
| 433 |
+
mCuSeqlensQ=None,
|
| 434 |
+
mSeqUsedQ=None,
|
| 435 |
+
qhead_per_kvhead_packgqa=1,
|
| 436 |
+
element_size=self.dtype.width // 8,
|
| 437 |
+
is_persistent=False,
|
| 438 |
+
lpt=False,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
| 442 |
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
| 443 |
+
|
| 444 |
+
LOG2_E = math.log2(math.e)
|
| 445 |
+
if const_expr(self.score_mod is None):
|
| 446 |
+
softmax_scale_log2 = softmax_scale * LOG2_E
|
| 447 |
+
else:
|
| 448 |
+
softmax_scale_log2 = LOG2_E
|
| 449 |
+
|
| 450 |
+
fastdiv_mods = None
|
| 451 |
+
if const_expr(aux_tensors is not None):
|
| 452 |
+
seqlen_q = cute.size(mQ.shape[0])
|
| 453 |
+
seqlen_k = cute.size(mK.shape[0])
|
| 454 |
+
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
| 455 |
+
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
| 456 |
+
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
| 457 |
+
|
| 458 |
+
qhead_per_kvhead_divmod = None
|
| 459 |
+
if const_expr(self.qhead_per_kvhead > 1):
|
| 460 |
+
qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead)
|
| 461 |
+
|
| 462 |
+
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
| 463 |
+
|
| 464 |
+
self.kernel(
|
| 465 |
+
tma_tensor_Q,
|
| 466 |
+
tma_tensor_K,
|
| 467 |
+
tma_tensor_V,
|
| 468 |
+
tma_tensor_dO,
|
| 469 |
+
tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK,
|
| 470 |
+
tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV,
|
| 471 |
+
tma_atom_Q,
|
| 472 |
+
tma_atom_K,
|
| 473 |
+
tma_atom_V,
|
| 474 |
+
tma_atom_dO,
|
| 475 |
+
tma_atom_dK,
|
| 476 |
+
tma_atom_dV,
|
| 477 |
+
mLSE,
|
| 478 |
+
mdPsum,
|
| 479 |
+
mdQaccum,
|
| 480 |
+
self.sQ_layout,
|
| 481 |
+
self.sK_layout,
|
| 482 |
+
self.sV_layout,
|
| 483 |
+
self.sPdS_layout,
|
| 484 |
+
self.sdO_layout,
|
| 485 |
+
self.sdQaccum_layout,
|
| 486 |
+
self.r2s_tiled_copy_dQaccum,
|
| 487 |
+
tiled_mma_SdP,
|
| 488 |
+
tiled_mma_dK,
|
| 489 |
+
tiled_mma_dV,
|
| 490 |
+
tiled_mma_dQ,
|
| 491 |
+
softmax_scale_log2,
|
| 492 |
+
softmax_scale,
|
| 493 |
+
tile_sched_params,
|
| 494 |
+
TileScheduler,
|
| 495 |
+
SharedStorage,
|
| 496 |
+
aux_tensors,
|
| 497 |
+
fastdiv_mods,
|
| 498 |
+
blocksparse_tensors,
|
| 499 |
+
qhead_per_kvhead_divmod,
|
| 500 |
+
).launch(
|
| 501 |
+
grid=grid_dim,
|
| 502 |
+
block=[self.num_threads, 1, 1],
|
| 503 |
+
stream=stream,
|
| 504 |
+
min_blocks_per_mp=1,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
@cute.kernel
|
| 508 |
+
def kernel(
|
| 509 |
+
self,
|
| 510 |
+
mQ: cute.Tensor,
|
| 511 |
+
mK: cute.Tensor,
|
| 512 |
+
mV: cute.Tensor,
|
| 513 |
+
mdO: cute.Tensor,
|
| 514 |
+
mdK: cute.Tensor,
|
| 515 |
+
mdV: cute.Tensor,
|
| 516 |
+
tma_atom_Q: cute.CopyAtom,
|
| 517 |
+
tma_atom_K: cute.CopyAtom,
|
| 518 |
+
tma_atom_V: cute.CopyAtom,
|
| 519 |
+
tma_atom_dO: cute.CopyAtom,
|
| 520 |
+
tma_atom_dK: cute.CopyAtom,
|
| 521 |
+
tma_atom_dV: cute.CopyAtom,
|
| 522 |
+
mLSE: cute.Tensor,
|
| 523 |
+
mdPsum: cute.Tensor,
|
| 524 |
+
mdQaccum: cute.Tensor,
|
| 525 |
+
sQ_layout: cute.ComposedLayout,
|
| 526 |
+
sK_layout: cute.ComposedLayout,
|
| 527 |
+
sV_layout: cute.ComposedLayout,
|
| 528 |
+
sPdS_layout: cute.ComposedLayout,
|
| 529 |
+
sdO_layout: cute.ComposedLayout,
|
| 530 |
+
sdQaccum_layout: cute.Layout,
|
| 531 |
+
r2s_tiled_copy_dQaccum: cute.TiledCopy,
|
| 532 |
+
tiled_mma_SdP: cute.TiledMma,
|
| 533 |
+
tiled_mma_dK: cute.TiledMma,
|
| 534 |
+
tiled_mma_dV: cute.TiledMma,
|
| 535 |
+
tiled_mma_dQ: cute.TiledMma,
|
| 536 |
+
softmax_scale_log2,
|
| 537 |
+
softmax_scale,
|
| 538 |
+
tile_sched_params: ParamsBase,
|
| 539 |
+
TileScheduler: cutlass.Constexpr[Callable],
|
| 540 |
+
SharedStorage: cutlass.Constexpr[Callable],
|
| 541 |
+
aux_tensors: Optional[list] = None,
|
| 542 |
+
fastdiv_mods=(None, None),
|
| 543 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 544 |
+
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 545 |
+
):
|
| 546 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 547 |
+
|
| 548 |
+
# prefetch TMA descriptors
|
| 549 |
+
if warp_idx == 0:
|
| 550 |
+
cpasync.prefetch_descriptor(tma_atom_Q)
|
| 551 |
+
cpasync.prefetch_descriptor(tma_atom_K)
|
| 552 |
+
cpasync.prefetch_descriptor(tma_atom_V)
|
| 553 |
+
cpasync.prefetch_descriptor(tma_atom_dO)
|
| 554 |
+
|
| 555 |
+
smem = cutlass.utils.SmemAllocator()
|
| 556 |
+
storage = smem.allocate(SharedStorage)
|
| 557 |
+
|
| 558 |
+
pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread)
|
| 559 |
+
pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(
|
| 560 |
+
cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
|
| 561 |
+
)
|
| 562 |
+
pipeline_Q = pipeline.PipelineTmaAsync.create(
|
| 563 |
+
barrier_storage=storage.mbar_ptr_Q.data_ptr(),
|
| 564 |
+
num_stages=self.Q_stage,
|
| 565 |
+
producer_group=pipeline_producer_group,
|
| 566 |
+
consumer_group=pipeline_consumer_group,
|
| 567 |
+
tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"],
|
| 568 |
+
defer_sync=True,
|
| 569 |
+
)
|
| 570 |
+
pipeline_dO = pipeline.PipelineTmaAsync.create(
|
| 571 |
+
barrier_storage=storage.mbar_ptr_dO.data_ptr(),
|
| 572 |
+
num_stages=self.dO_stage,
|
| 573 |
+
producer_group=pipeline_producer_group,
|
| 574 |
+
consumer_group=pipeline_consumer_group,
|
| 575 |
+
tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"],
|
| 576 |
+
defer_sync=False,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
|
| 580 |
+
sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner)
|
| 581 |
+
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
| 582 |
+
sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
|
| 583 |
+
sP = None
|
| 584 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 585 |
+
sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
|
| 586 |
+
sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
|
| 587 |
+
sLSE = storage.sLSE.get_tensor(
|
| 588 |
+
cute.make_layout(
|
| 589 |
+
(self.tile_m, self.Q_stage),
|
| 590 |
+
stride=(1, cute.round_up(self.tile_m, 64)),
|
| 591 |
+
)
|
| 592 |
+
)
|
| 593 |
+
sdPsum = storage.sdPsum.get_tensor(
|
| 594 |
+
cute.make_layout(
|
| 595 |
+
(self.tile_m, self.dO_stage),
|
| 596 |
+
stride=(1, cute.round_up(self.tile_m, 64)),
|
| 597 |
+
)
|
| 598 |
+
)
|
| 599 |
+
sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)
|
| 600 |
+
|
| 601 |
+
block_info = BlockInfo(
|
| 602 |
+
self.tile_m,
|
| 603 |
+
self.tile_n,
|
| 604 |
+
self.is_causal,
|
| 605 |
+
self.is_local,
|
| 606 |
+
False, # is_split_kv
|
| 607 |
+
None,
|
| 608 |
+
None,
|
| 609 |
+
qhead_per_kvhead_packgqa=1,
|
| 610 |
+
)
|
| 611 |
+
SeqlenInfoCls = partial(
|
| 612 |
+
SeqlenInfoQK.create,
|
| 613 |
+
seqlen_q_static=mQ.shape[0],
|
| 614 |
+
seqlen_k_static=mK.shape[0],
|
| 615 |
+
mCuSeqlensQ=None,
|
| 616 |
+
mCuSeqlensK=None,
|
| 617 |
+
mSeqUsedQ=None,
|
| 618 |
+
mSeqUsedK=None,
|
| 619 |
+
)
|
| 620 |
+
AttentionMaskCls = partial(
|
| 621 |
+
AttentionMask,
|
| 622 |
+
self.tile_m,
|
| 623 |
+
self.tile_n,
|
| 624 |
+
window_size_left=None,
|
| 625 |
+
window_size_right=None,
|
| 626 |
+
swap_AB=self.SdP_swapAB,
|
| 627 |
+
)
|
| 628 |
+
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
| 629 |
+
|
| 630 |
+
if warp_idx < 4:
|
| 631 |
+
cute.arch.setmaxregister_decrease(self.num_producer_regs)
|
| 632 |
+
if warp_idx == 0:
|
| 633 |
+
self.load(
|
| 634 |
+
mQ,
|
| 635 |
+
mK,
|
| 636 |
+
mV,
|
| 637 |
+
mdO,
|
| 638 |
+
mLSE,
|
| 639 |
+
mdPsum,
|
| 640 |
+
sQ,
|
| 641 |
+
sK,
|
| 642 |
+
sV,
|
| 643 |
+
sdO,
|
| 644 |
+
sLSE,
|
| 645 |
+
sdPsum,
|
| 646 |
+
tma_atom_Q,
|
| 647 |
+
tma_atom_K,
|
| 648 |
+
tma_atom_V,
|
| 649 |
+
tma_atom_dO,
|
| 650 |
+
pipeline_Q,
|
| 651 |
+
pipeline_dO,
|
| 652 |
+
block_info,
|
| 653 |
+
SeqlenInfoCls,
|
| 654 |
+
TileSchedulerCls,
|
| 655 |
+
blocksparse_tensors,
|
| 656 |
+
qhead_per_kvhead_divmod,
|
| 657 |
+
)
|
| 658 |
+
if warp_idx == 1:
|
| 659 |
+
self.dQaccum_store(
|
| 660 |
+
mdQaccum,
|
| 661 |
+
sdQaccum,
|
| 662 |
+
block_info,
|
| 663 |
+
TileSchedulerCls,
|
| 664 |
+
SeqlenInfoCls,
|
| 665 |
+
blocksparse_tensors,
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
| 669 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 670 |
+
tidx = tidx - 128
|
| 671 |
+
self.mma(
|
| 672 |
+
tiled_mma_SdP,
|
| 673 |
+
tiled_mma_dK,
|
| 674 |
+
tiled_mma_dV,
|
| 675 |
+
tiled_mma_dQ,
|
| 676 |
+
mdK,
|
| 677 |
+
mdV,
|
| 678 |
+
mdQaccum,
|
| 679 |
+
sQ,
|
| 680 |
+
sK,
|
| 681 |
+
sV,
|
| 682 |
+
sdO,
|
| 683 |
+
sP,
|
| 684 |
+
sdS,
|
| 685 |
+
sLSE,
|
| 686 |
+
sdPsum,
|
| 687 |
+
sdQaccum,
|
| 688 |
+
pipeline_Q,
|
| 689 |
+
pipeline_dO,
|
| 690 |
+
tidx,
|
| 691 |
+
tma_atom_dK,
|
| 692 |
+
tma_atom_dV,
|
| 693 |
+
r2s_tiled_copy_dQaccum,
|
| 694 |
+
softmax_scale_log2,
|
| 695 |
+
softmax_scale,
|
| 696 |
+
block_info,
|
| 697 |
+
SeqlenInfoCls,
|
| 698 |
+
AttentionMaskCls,
|
| 699 |
+
TileSchedulerCls,
|
| 700 |
+
aux_tensors,
|
| 701 |
+
fastdiv_mods,
|
| 702 |
+
blocksparse_tensors,
|
| 703 |
+
qhead_per_kvhead_divmod,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
@cute.jit
|
| 707 |
+
def load(
|
| 708 |
+
self,
|
| 709 |
+
mQ: cute.Tensor,
|
| 710 |
+
mK: cute.Tensor,
|
| 711 |
+
mV: cute.Tensor,
|
| 712 |
+
mdO: cute.Tensor,
|
| 713 |
+
mLSE: cute.Tensor,
|
| 714 |
+
mdPsum: cute.Tensor,
|
| 715 |
+
sQ: cute.Tensor,
|
| 716 |
+
sK: cute.Tensor,
|
| 717 |
+
sV: cute.Tensor,
|
| 718 |
+
sdO: cute.Tensor,
|
| 719 |
+
sLSE: cute.Tensor,
|
| 720 |
+
sdPsum: cute.Tensor,
|
| 721 |
+
tma_atom_Q: cute.CopyAtom,
|
| 722 |
+
tma_atom_K: cute.CopyAtom,
|
| 723 |
+
tma_atom_V: cute.CopyAtom,
|
| 724 |
+
tma_atom_dO: cute.CopyAtom,
|
| 725 |
+
pipeline_Q: cutlass.pipeline.PipelineAsync,
|
| 726 |
+
pipeline_dO: cutlass.pipeline.PipelineAsync,
|
| 727 |
+
block_info: BlockInfo,
|
| 728 |
+
SeqlenInfoCls: Callable,
|
| 729 |
+
TileSchedulerCls: Callable,
|
| 730 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 731 |
+
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 732 |
+
):
|
| 733 |
+
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
| 734 |
+
|
| 735 |
+
if warp_idx_in_wg == 0:
|
| 736 |
+
producer_state_Q = cutlass.pipeline.make_pipeline_state(
|
| 737 |
+
cutlass.pipeline.PipelineUserType.Producer, self.Q_stage
|
| 738 |
+
)
|
| 739 |
+
producer_state_dO = cutlass.pipeline.make_pipeline_state(
|
| 740 |
+
cutlass.pipeline.PipelineUserType.Producer, self.dO_stage
|
| 741 |
+
)
|
| 742 |
+
tile_scheduler = TileSchedulerCls()
|
| 743 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 744 |
+
while work_tile.is_valid_tile:
|
| 745 |
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 746 |
+
seqlen = SeqlenInfoCls(batch_idx)
|
| 747 |
+
head_idx_kv = (
|
| 748 |
+
head_idx
|
| 749 |
+
if const_expr(self.qhead_per_kvhead == 1)
|
| 750 |
+
else head_idx // qhead_per_kvhead_divmod
|
| 751 |
+
)
|
| 752 |
+
mK_cur = mK[None, None, head_idx_kv, batch_idx]
|
| 753 |
+
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
| 754 |
+
mV_cur = mV[None, None, head_idx_kv, batch_idx]
|
| 755 |
+
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 756 |
+
|
| 757 |
+
mQ_cur = mQ[None, None, head_idx, batch_idx]
|
| 758 |
+
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
|
| 759 |
+
mdO_cur = mdO[None, None, head_idx, batch_idx]
|
| 760 |
+
gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
|
| 761 |
+
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
| 762 |
+
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
|
| 763 |
+
mdPsum_cur = mdPsum[None, head_idx, batch_idx]
|
| 764 |
+
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
|
| 765 |
+
|
| 766 |
+
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
| 767 |
+
tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True
|
| 768 |
+
)
|
| 769 |
+
load_V, _, _ = copy_utils.tma_get_copy_fn(
|
| 770 |
+
tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True
|
| 771 |
+
)
|
| 772 |
+
load_Q, _, _ = copy_utils.tma_get_copy_fn(
|
| 773 |
+
tma_atom_Q, 0, cute.make_layout(1), gQ, sQ
|
| 774 |
+
)
|
| 775 |
+
load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)
|
| 776 |
+
load_dO, _, _ = copy_utils.tma_get_copy_fn(
|
| 777 |
+
tma_atom_dO, 0, cute.make_layout(1), gdO, sdO
|
| 778 |
+
)
|
| 779 |
+
load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)
|
| 780 |
+
load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE)
|
| 781 |
+
load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q)
|
| 782 |
+
load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum)
|
| 783 |
+
load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO)
|
| 784 |
+
|
| 785 |
+
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 786 |
+
|
| 787 |
+
if const_expr(not self.use_block_sparsity):
|
| 788 |
+
total_m_block_cnt = m_block_max - m_block_min
|
| 789 |
+
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
| 790 |
+
else:
|
| 791 |
+
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 792 |
+
blocksparse_tensors,
|
| 793 |
+
batch_idx,
|
| 794 |
+
head_idx,
|
| 795 |
+
n_block,
|
| 796 |
+
subtile_factor=self.subtile_factor,
|
| 797 |
+
m_block_max=m_block_max,
|
| 798 |
+
)
|
| 799 |
+
process_tile = total_m_block_cnt > Int32(0)
|
| 800 |
+
|
| 801 |
+
if process_tile:
|
| 802 |
+
if const_expr(not self.use_block_sparsity):
|
| 803 |
+
first_m_block = m_block_min
|
| 804 |
+
pipeline_Q.producer_acquire(
|
| 805 |
+
producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
|
| 806 |
+
)
|
| 807 |
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
|
| 808 |
+
load_Q(first_m_block, producer_state=producer_state_Q)
|
| 809 |
+
load_LSE(first_m_block, producer_state=producer_state_Q)
|
| 810 |
+
producer_state_dO_cur = (
|
| 811 |
+
producer_state_dO
|
| 812 |
+
if const_expr(self.Q_stage != self.dO_stage)
|
| 813 |
+
else producer_state_Q
|
| 814 |
+
)
|
| 815 |
+
pipeline_dO.producer_acquire(
|
| 816 |
+
producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
|
| 817 |
+
)
|
| 818 |
+
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
|
| 819 |
+
load_dO(first_m_block, producer_state=producer_state_dO_cur)
|
| 820 |
+
load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
|
| 821 |
+
producer_state_Q.advance()
|
| 822 |
+
producer_state_dO.advance()
|
| 823 |
+
|
| 824 |
+
for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
|
| 825 |
+
pipeline_Q.producer_acquire(producer_state_Q)
|
| 826 |
+
load_Q(m_block, producer_state=producer_state_Q)
|
| 827 |
+
load_LSE(m_block, producer_state=producer_state_Q)
|
| 828 |
+
producer_state_dO_cur = (
|
| 829 |
+
producer_state_dO
|
| 830 |
+
if const_expr(self.Q_stage != self.dO_stage)
|
| 831 |
+
else producer_state_Q
|
| 832 |
+
)
|
| 833 |
+
pipeline_dO.producer_acquire(producer_state_dO_cur)
|
| 834 |
+
load_dO(m_block, producer_state=producer_state_dO_cur)
|
| 835 |
+
load_dPsum(m_block, producer_state=producer_state_dO_cur)
|
| 836 |
+
producer_state_Q.advance()
|
| 837 |
+
producer_state_dO.advance()
|
| 838 |
+
else:
|
| 839 |
+
producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90(
|
| 840 |
+
blocksparse_tensors,
|
| 841 |
+
batch_idx,
|
| 842 |
+
head_idx,
|
| 843 |
+
n_block,
|
| 844 |
+
producer_state_Q,
|
| 845 |
+
producer_state_dO,
|
| 846 |
+
pipeline_Q,
|
| 847 |
+
pipeline_dO,
|
| 848 |
+
load_K,
|
| 849 |
+
load_V,
|
| 850 |
+
load_Q,
|
| 851 |
+
load_dO,
|
| 852 |
+
load_LSE,
|
| 853 |
+
load_dPsum,
|
| 854 |
+
self.tma_copy_bytes["K"],
|
| 855 |
+
self.tma_copy_bytes["V"],
|
| 856 |
+
Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage),
|
| 857 |
+
subtile_factor=self.subtile_factor,
|
| 858 |
+
m_block_max=m_block_max,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
tile_scheduler.prefetch_next_work()
|
| 862 |
+
tile_scheduler.advance_to_next_work()
|
| 863 |
+
work_tile = tile_scheduler.get_current_work()
|
| 864 |
+
|
| 865 |
+
@cute.jit
|
| 866 |
+
def apply_score_mod(
|
| 867 |
+
self,
|
| 868 |
+
acc_S: cute.Tensor,
|
| 869 |
+
thr_mma_SdP: cute.core.ThrMma,
|
| 870 |
+
batch_idx,
|
| 871 |
+
head_idx,
|
| 872 |
+
m_block,
|
| 873 |
+
n_block,
|
| 874 |
+
softmax_scale,
|
| 875 |
+
seqlen_info: SeqlenInfoQK,
|
| 876 |
+
aux_tensors=None,
|
| 877 |
+
fastdiv_mods=(None, None),
|
| 878 |
+
):
|
| 879 |
+
# [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing
|
| 880 |
+
cS = cute.make_identity_tensor(
|
| 881 |
+
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
|
| 882 |
+
)
|
| 883 |
+
cS = cute.domain_offset(
|
| 884 |
+
(n_block * self.tile_n, m_block * self.tile_m)
|
| 885 |
+
if self.SdP_swapAB
|
| 886 |
+
else (m_block * self.tile_m, n_block * self.tile_n),
|
| 887 |
+
cS,
|
| 888 |
+
)
|
| 889 |
+
tScS = thr_mma_SdP.partition_C(cS)
|
| 890 |
+
|
| 891 |
+
apply_score_mod_inner(
|
| 892 |
+
acc_S,
|
| 893 |
+
tScS,
|
| 894 |
+
self.score_mod,
|
| 895 |
+
batch_idx,
|
| 896 |
+
head_idx,
|
| 897 |
+
softmax_scale,
|
| 898 |
+
self.vec_size,
|
| 899 |
+
self.qk_acc_dtype,
|
| 900 |
+
aux_tensors,
|
| 901 |
+
fastdiv_mods,
|
| 902 |
+
seqlen_info,
|
| 903 |
+
constant_q_idx=None,
|
| 904 |
+
qhead_per_kvhead=self.qhead_per_kvhead,
|
| 905 |
+
transpose_indices=self.SdP_swapAB,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
@cute.jit
|
| 909 |
+
def apply_score_mod_bwd(
|
| 910 |
+
self,
|
| 911 |
+
grad_tensor: cute.Tensor,
|
| 912 |
+
score_tensor: cute.Tensor,
|
| 913 |
+
thr_mma_SdP: cute.core.ThrMma,
|
| 914 |
+
batch_idx,
|
| 915 |
+
head_idx,
|
| 916 |
+
m_block,
|
| 917 |
+
n_block,
|
| 918 |
+
softmax_scale,
|
| 919 |
+
seqlen_info: SeqlenInfoQK,
|
| 920 |
+
aux_tensors=None,
|
| 921 |
+
fastdiv_mods=(None, None),
|
| 922 |
+
):
|
| 923 |
+
cS = cute.make_identity_tensor(
|
| 924 |
+
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
|
| 925 |
+
)
|
| 926 |
+
cS = cute.domain_offset(
|
| 927 |
+
(n_block * self.tile_n, m_block * self.tile_m)
|
| 928 |
+
if self.SdP_swapAB
|
| 929 |
+
else (m_block * self.tile_m, n_block * self.tile_n),
|
| 930 |
+
cS,
|
| 931 |
+
)
|
| 932 |
+
tScS = thr_mma_SdP.partition_C(cS)
|
| 933 |
+
|
| 934 |
+
apply_score_mod_bwd_inner(
|
| 935 |
+
grad_tensor,
|
| 936 |
+
score_tensor,
|
| 937 |
+
tScS,
|
| 938 |
+
self.score_mod_bwd,
|
| 939 |
+
batch_idx,
|
| 940 |
+
head_idx,
|
| 941 |
+
softmax_scale,
|
| 942 |
+
self.vec_size,
|
| 943 |
+
self.qk_acc_dtype,
|
| 944 |
+
aux_tensors,
|
| 945 |
+
fastdiv_mods,
|
| 946 |
+
seqlen_info,
|
| 947 |
+
constant_q_idx=None,
|
| 948 |
+
qhead_per_kvhead=self.qhead_per_kvhead,
|
| 949 |
+
transpose_indices=self.SdP_swapAB,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
@cute.jit
|
| 953 |
+
def mma(
|
| 954 |
+
self,
|
| 955 |
+
tiled_mma_SdP: cute.TiledMma,
|
| 956 |
+
tiled_mma_dK: cute.TiledMma,
|
| 957 |
+
tiled_mma_dV: cute.TiledMma,
|
| 958 |
+
tiled_mma_dQ: cute.TiledMma,
|
| 959 |
+
mdK: cute.Tensor,
|
| 960 |
+
mdV: cute.Tensor,
|
| 961 |
+
mdQaccum: cute.Tensor,
|
| 962 |
+
sQ: cute.Tensor,
|
| 963 |
+
sK: cute.Tensor,
|
| 964 |
+
sV: cute.Tensor,
|
| 965 |
+
sdO: cute.Tensor,
|
| 966 |
+
sP: Optional[cute.Tensor],
|
| 967 |
+
sdS: cute.Tensor,
|
| 968 |
+
sLSE: cute.Tensor,
|
| 969 |
+
sdPsum: cute.Tensor,
|
| 970 |
+
sdQaccum: cute.Tensor,
|
| 971 |
+
pipeline_Q: cutlass.pipeline.PipelineAsync,
|
| 972 |
+
pipeline_dO: cutlass.pipeline.PipelineAsync,
|
| 973 |
+
tidx: Int32,
|
| 974 |
+
tma_atom_dK: cute.CopyAtom,
|
| 975 |
+
tma_atom_dV: cute.CopyAtom,
|
| 976 |
+
r2s_tiled_copy_dQaccum: cute.TiledCopy,
|
| 977 |
+
softmax_scale_log2: Float32,
|
| 978 |
+
softmax_scale: Float32,
|
| 979 |
+
block_info: BlockInfo,
|
| 980 |
+
SeqlenInfoCls: Callable,
|
| 981 |
+
AttentionMaskCls: Callable,
|
| 982 |
+
TileSchedulerCls: Callable,
|
| 983 |
+
aux_tensors: Optional[list] = None,
|
| 984 |
+
fastdiv_mods=(None, None),
|
| 985 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 986 |
+
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 987 |
+
):
|
| 988 |
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 989 |
+
warp_group_thread_layout = cute.make_layout(
|
| 990 |
+
self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
|
| 991 |
+
)
|
| 992 |
+
thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
|
| 993 |
+
wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 994 |
+
wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 995 |
+
wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 996 |
+
wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx))
|
| 997 |
+
# S = Q @ K.T
|
| 998 |
+
shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
|
| 999 |
+
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
|
| 1000 |
+
wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB
|
| 1001 |
+
)
|
| 1002 |
+
mma_qk_fn = partial(
|
| 1003 |
+
gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB
|
| 1004 |
+
)
|
| 1005 |
+
# dP = dO @ V.T
|
| 1006 |
+
shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv)
|
| 1007 |
+
_, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC(
|
| 1008 |
+
wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB
|
| 1009 |
+
)
|
| 1010 |
+
mma_dov_fn = partial(
|
| 1011 |
+
gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB
|
| 1012 |
+
)
|
| 1013 |
+
# dV += P.T @ dO
|
| 1014 |
+
sPt = layout_utils.transpose_view(sP) if sP is not None else None
|
| 1015 |
+
sdOt = layout_utils.transpose_view(sdO)
|
| 1016 |
+
shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m)
|
| 1017 |
+
acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC(
|
| 1018 |
+
wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB
|
| 1019 |
+
)
|
| 1020 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1021 |
+
mma_pdo_fn = partial(
|
| 1022 |
+
gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB
|
| 1023 |
+
)
|
| 1024 |
+
else:
|
| 1025 |
+
mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)
|
| 1026 |
+
# dK += dS.T @ Q
|
| 1027 |
+
sdSt = layout_utils.transpose_view(sdS)
|
| 1028 |
+
sQt = layout_utils.transpose_view(sQ)
|
| 1029 |
+
shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m)
|
| 1030 |
+
acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC(
|
| 1031 |
+
wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB
|
| 1032 |
+
)
|
| 1033 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1034 |
+
mma_dsq_fn = partial(
|
| 1035 |
+
gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB
|
| 1036 |
+
)
|
| 1037 |
+
else:
|
| 1038 |
+
mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)
|
| 1039 |
+
# dQ = dS @ K
|
| 1040 |
+
sKt = layout_utils.transpose_view(sK)
|
| 1041 |
+
shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
|
| 1042 |
+
_, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
|
| 1043 |
+
wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
|
| 1044 |
+
)
|
| 1045 |
+
mma_dsk_fn = partial(
|
| 1046 |
+
gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
# Smem copy atom tiling
|
| 1050 |
+
copy_P_r2s = None
|
| 1051 |
+
if const_expr(sP is not None):
|
| 1052 |
+
sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
|
| 1053 |
+
copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1054 |
+
tiled_mma_SdP, sP_cpy, tidx, self.arch, transpose=self.SdP_swapAB
|
| 1055 |
+
)
|
| 1056 |
+
sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
|
| 1057 |
+
copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1058 |
+
tiled_mma_SdP, sdS_cpy, tidx, self.arch, transpose=self.SdP_swapAB
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
tLSEsLSE = layout_utils.mma_partition_C_vec(
|
| 1062 |
+
sLSE, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
|
| 1063 |
+
)
|
| 1064 |
+
tLSEsdPsum = layout_utils.mma_partition_C_vec(
|
| 1065 |
+
sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
|
| 1069 |
+
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
|
| 1070 |
+
|
| 1071 |
+
PdS_barrier = cutlass.pipeline.NamedBarrier(
|
| 1072 |
+
barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
|
| 1073 |
+
)
|
| 1074 |
+
score_mod_fn = partial(
|
| 1075 |
+
self.apply_score_mod,
|
| 1076 |
+
thr_mma_SdP=thr_mma_SdP,
|
| 1077 |
+
softmax_scale=softmax_scale,
|
| 1078 |
+
aux_tensors=aux_tensors,
|
| 1079 |
+
fastdiv_mods=fastdiv_mods,
|
| 1080 |
+
)
|
| 1081 |
+
score_mod_bwd_fn = partial(
|
| 1082 |
+
self.apply_score_mod_bwd,
|
| 1083 |
+
thr_mma_SdP=thr_mma_SdP,
|
| 1084 |
+
softmax_scale=softmax_scale,
|
| 1085 |
+
aux_tensors=aux_tensors,
|
| 1086 |
+
fastdiv_mods=fastdiv_mods,
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
mma_one_m_block_all = partial(
|
| 1090 |
+
self.mma_one_m_block,
|
| 1091 |
+
warp_group_idx=warp_group_idx,
|
| 1092 |
+
mma_qk_fn=mma_qk_fn,
|
| 1093 |
+
mma_dov_fn=mma_dov_fn,
|
| 1094 |
+
mma_pdo_fn=mma_pdo_fn,
|
| 1095 |
+
mma_dsq_fn=mma_dsq_fn,
|
| 1096 |
+
mma_dsk_fn=mma_dsk_fn,
|
| 1097 |
+
copy_P_r2s=copy_P_r2s,
|
| 1098 |
+
copy_dS_r2s=copy_dS_r2s,
|
| 1099 |
+
pipeline_Q=pipeline_Q,
|
| 1100 |
+
pipeline_dO=pipeline_dO,
|
| 1101 |
+
tLSEsLSE=tLSEsLSE,
|
| 1102 |
+
tLSEsdPsum=tLSEsdPsum,
|
| 1103 |
+
tdQsdQaccum=tdQsdQaccum,
|
| 1104 |
+
softmax_scale_log2=softmax_scale_log2,
|
| 1105 |
+
PdS_barrier=PdS_barrier,
|
| 1106 |
+
# acc_dV=acc_dV,
|
| 1107 |
+
# acc_dK=acc_dK,
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
consumer_state_Q = cutlass.pipeline.make_pipeline_state(
|
| 1111 |
+
cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage
|
| 1112 |
+
)
|
| 1113 |
+
consumer_state_dO = cutlass.pipeline.make_pipeline_state(
|
| 1114 |
+
cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
|
| 1115 |
+
)
|
| 1116 |
+
tile_scheduler = TileSchedulerCls()
|
| 1117 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1118 |
+
while work_tile.is_valid_tile:
|
| 1119 |
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1120 |
+
seqlen = SeqlenInfoCls(batch_idx)
|
| 1121 |
+
mask = AttentionMaskCls(seqlen)
|
| 1122 |
+
score_mod_fn_cur = partial(
|
| 1123 |
+
score_mod_fn,
|
| 1124 |
+
batch_idx=batch_idx,
|
| 1125 |
+
head_idx=head_idx,
|
| 1126 |
+
n_block=n_block,
|
| 1127 |
+
seqlen_info=seqlen,
|
| 1128 |
+
)
|
| 1129 |
+
score_mod_bwd_fn_cur = partial(
|
| 1130 |
+
score_mod_bwd_fn,
|
| 1131 |
+
batch_idx=batch_idx,
|
| 1132 |
+
head_idx=head_idx,
|
| 1133 |
+
n_block=n_block,
|
| 1134 |
+
seqlen_info=seqlen,
|
| 1135 |
+
)
|
| 1136 |
+
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1137 |
+
|
| 1138 |
+
if const_expr(not self.use_block_sparsity):
|
| 1139 |
+
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
| 1140 |
+
else:
|
| 1141 |
+
total_m_block_cnt = get_total_q_block_count_bwd(
|
| 1142 |
+
blocksparse_tensors,
|
| 1143 |
+
batch_idx,
|
| 1144 |
+
head_idx,
|
| 1145 |
+
n_block,
|
| 1146 |
+
subtile_factor=self.subtile_factor,
|
| 1147 |
+
m_block_max=m_block_max,
|
| 1148 |
+
)
|
| 1149 |
+
process_tile = total_m_block_cnt > Int32(0)
|
| 1150 |
+
|
| 1151 |
+
if process_tile:
|
| 1152 |
+
if const_expr(not self.use_block_sparsity):
|
| 1153 |
+
mask_fn = partial(
|
| 1154 |
+
mask.apply_mask,
|
| 1155 |
+
batch_idx=batch_idx,
|
| 1156 |
+
head_idx=head_idx,
|
| 1157 |
+
n_block=n_block,
|
| 1158 |
+
thr_mma=thr_mma_SdP,
|
| 1159 |
+
mask_seqlen=True,
|
| 1160 |
+
mask_causal=self.is_causal,
|
| 1161 |
+
mask_local=self.is_local,
|
| 1162 |
+
mask_mod=self.mask_mod,
|
| 1163 |
+
aux_tensors=aux_tensors,
|
| 1164 |
+
fastdiv_mods=fastdiv_mods,
|
| 1165 |
+
)
|
| 1166 |
+
dKV_accumulate = False
|
| 1167 |
+
for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):
|
| 1168 |
+
consumer_state_Q, consumer_state_dO = mma_one_m_block_all(
|
| 1169 |
+
m_block,
|
| 1170 |
+
consumer_state_Q,
|
| 1171 |
+
consumer_state_dO,
|
| 1172 |
+
mask_fn=mask_fn,
|
| 1173 |
+
score_mod_fn=score_mod_fn_cur,
|
| 1174 |
+
score_mod_bwd_fn=score_mod_bwd_fn_cur,
|
| 1175 |
+
dKV_accumulate=dKV_accumulate,
|
| 1176 |
+
)
|
| 1177 |
+
dKV_accumulate = True
|
| 1178 |
+
else:
|
| 1179 |
+
consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90(
|
| 1180 |
+
blocksparse_tensors,
|
| 1181 |
+
batch_idx,
|
| 1182 |
+
head_idx,
|
| 1183 |
+
n_block,
|
| 1184 |
+
consumer_state_Q,
|
| 1185 |
+
consumer_state_dO,
|
| 1186 |
+
mma_one_m_block_all,
|
| 1187 |
+
mask,
|
| 1188 |
+
self.mask_mod,
|
| 1189 |
+
is_causal=self.is_causal,
|
| 1190 |
+
is_local=self.is_local,
|
| 1191 |
+
thr_mma_SdP=thr_mma_SdP,
|
| 1192 |
+
score_mod_fn=score_mod_fn_cur,
|
| 1193 |
+
score_mod_bwd_fn=score_mod_bwd_fn_cur,
|
| 1194 |
+
subtile_factor=self.subtile_factor,
|
| 1195 |
+
m_block_max=m_block_max,
|
| 1196 |
+
aux_tensors=aux_tensors,
|
| 1197 |
+
fastdiv_mods=fastdiv_mods,
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
if const_expr(self.qhead_per_kvhead == 1):
|
| 1201 |
+
acc_dK.store(acc_dK.load() * softmax_scale)
|
| 1202 |
+
self.epilogue_dKV(
|
| 1203 |
+
acc_dV,
|
| 1204 |
+
mdV,
|
| 1205 |
+
sV,
|
| 1206 |
+
acc_dK,
|
| 1207 |
+
mdK,
|
| 1208 |
+
sK,
|
| 1209 |
+
seqlen,
|
| 1210 |
+
tma_atom_dK,
|
| 1211 |
+
tma_atom_dV,
|
| 1212 |
+
tiled_mma_dK,
|
| 1213 |
+
tiled_mma_dV,
|
| 1214 |
+
tidx,
|
| 1215 |
+
n_block,
|
| 1216 |
+
head_idx,
|
| 1217 |
+
batch_idx,
|
| 1218 |
+
qhead_per_kvhead_divmod,
|
| 1219 |
+
)
|
| 1220 |
+
else:
|
| 1221 |
+
# Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros.
|
| 1222 |
+
if const_expr(self.use_block_sparsity):
|
| 1223 |
+
acc_dK.fill(0.0)
|
| 1224 |
+
acc_dV.fill(0.0)
|
| 1225 |
+
self.epilogue_dKV(
|
| 1226 |
+
acc_dV,
|
| 1227 |
+
mdV,
|
| 1228 |
+
sV,
|
| 1229 |
+
acc_dK,
|
| 1230 |
+
mdK,
|
| 1231 |
+
sK,
|
| 1232 |
+
seqlen,
|
| 1233 |
+
tma_atom_dK,
|
| 1234 |
+
tma_atom_dV,
|
| 1235 |
+
tiled_mma_dK,
|
| 1236 |
+
tiled_mma_dV,
|
| 1237 |
+
tidx,
|
| 1238 |
+
n_block,
|
| 1239 |
+
head_idx,
|
| 1240 |
+
batch_idx,
|
| 1241 |
+
qhead_per_kvhead_divmod,
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
tile_scheduler.advance_to_next_work()
|
| 1245 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1246 |
+
|
| 1247 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1248 |
+
if warp_idx == 4:
|
| 1249 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 1250 |
+
|
| 1251 |
+
@cute.jit
|
| 1252 |
+
def mma_one_m_block(
|
| 1253 |
+
self,
|
| 1254 |
+
m_block: Int32,
|
| 1255 |
+
consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
| 1256 |
+
consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
| 1257 |
+
warp_group_idx: Int32,
|
| 1258 |
+
mma_qk_fn: Callable,
|
| 1259 |
+
mma_dov_fn: Callable,
|
| 1260 |
+
mma_pdo_fn: Callable,
|
| 1261 |
+
mma_dsq_fn: Callable,
|
| 1262 |
+
mma_dsk_fn: Callable,
|
| 1263 |
+
copy_P_r2s: Optional[Callable],
|
| 1264 |
+
copy_dS_r2s: Callable,
|
| 1265 |
+
pipeline_Q: cutlass.pipeline.PipelineAsync,
|
| 1266 |
+
pipeline_dO: cutlass.pipeline.PipelineAsync,
|
| 1267 |
+
tLSEsLSE: cute.Tensor,
|
| 1268 |
+
tLSEsdPsum: cute.Tensor,
|
| 1269 |
+
tdQsdQaccum: cute.Tensor,
|
| 1270 |
+
softmax_scale_log2: Float32,
|
| 1271 |
+
PdS_barrier: cutlass.pipeline.NamedBarrier,
|
| 1272 |
+
mask_fn: Optional[Callable] = None,
|
| 1273 |
+
score_mod_fn: Optional[Callable] = None,
|
| 1274 |
+
score_mod_bwd_fn: Optional[Callable] = None,
|
| 1275 |
+
dKV_accumulate: Boolean = True,
|
| 1276 |
+
):
|
| 1277 |
+
consumer_state_dO_cur = (
|
| 1278 |
+
consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
|
| 1279 |
+
)
|
| 1280 |
+
smem_idx_Q = consumer_state_Q.index
|
| 1281 |
+
smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
|
| 1282 |
+
smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0
|
| 1283 |
+
# (1) [GEMM 1] S = Q @ K^T
|
| 1284 |
+
pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
|
| 1285 |
+
acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
|
| 1286 |
+
tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
|
| 1287 |
+
# (2) [GEMM 2] dP = dO @ V.T
|
| 1288 |
+
pipeline_dO.consumer_wait(
|
| 1289 |
+
consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur)
|
| 1290 |
+
)
|
| 1291 |
+
acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)
|
| 1292 |
+
|
| 1293 |
+
if const_expr(self.score_mod_bwd is not None):
|
| 1294 |
+
acc_S_pre = cute.make_fragment_like(acc_S)
|
| 1295 |
+
cute.autovec_copy(acc_S, acc_S_pre)
|
| 1296 |
+
|
| 1297 |
+
if const_expr(self.score_mod is not None):
|
| 1298 |
+
score_mod_fn(acc_S, m_block=m_block)
|
| 1299 |
+
|
| 1300 |
+
# (3) [Pointwise 1] P = exp(S - LSE)
|
| 1301 |
+
if cutlass.const_expr(mask_fn is not None):
|
| 1302 |
+
mask_fn(acc_S, m_block=m_block)
|
| 1303 |
+
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
|
| 1304 |
+
for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
|
| 1305 |
+
for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
|
| 1306 |
+
acc_S_mn[r, c] = cute.math.exp2(
|
| 1307 |
+
acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True
|
| 1308 |
+
)
|
| 1309 |
+
tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
|
| 1310 |
+
|
| 1311 |
+
# Convert P from f32 -> f16
|
| 1312 |
+
tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype)
|
| 1313 |
+
# R2S for P
|
| 1314 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1315 |
+
# sync to ensure P has already been used in the previous iteration before overwriting
|
| 1316 |
+
if const_expr(self.PdS_stage == 1):
|
| 1317 |
+
PdS_barrier.arrive_and_wait()
|
| 1318 |
+
copy_P_r2s(tdVrP, dst_idx=smem_idx_PdS)
|
| 1319 |
+
|
| 1320 |
+
# (4) [Pointwise 2] dS = P*(dP-dPsum)
|
| 1321 |
+
warpgroup.wait_group(0)
|
| 1322 |
+
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
|
| 1323 |
+
for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
|
| 1324 |
+
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
|
| 1325 |
+
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
|
| 1326 |
+
|
| 1327 |
+
if const_expr(self.score_mod_bwd is not None):
|
| 1328 |
+
score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
|
| 1329 |
+
|
| 1330 |
+
# Convert dS from f32 -> f16
|
| 1331 |
+
tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
|
| 1332 |
+
|
| 1333 |
+
# If there's double buffering on dS, we don't need to sync here.
|
| 1334 |
+
# Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
|
| 1335 |
+
# But because both WGs have to sync at the end of the loop and double buffering,
|
| 1336 |
+
# this race condition is not possible.
|
| 1337 |
+
# This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and
|
| 1338 |
+
# (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
|
| 1339 |
+
if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
|
| 1340 |
+
cute.arch.fence_view_async_shared()
|
| 1341 |
+
PdS_barrier.arrive_and_wait()
|
| 1342 |
+
|
| 1343 |
+
# R2S for dS
|
| 1344 |
+
copy_dS_r2s(tdKrdS, dst_idx=smem_idx_PdS)
|
| 1345 |
+
|
| 1346 |
+
# (5) [GEMM 3] dV += P.T @ dO
|
| 1347 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1348 |
+
mma_pdo_fn(
|
| 1349 |
+
A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1
|
| 1350 |
+
)
|
| 1351 |
+
else:
|
| 1352 |
+
mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)
|
| 1353 |
+
|
| 1354 |
+
# smem fence to make sure sdS is written before it's read by WGMMA
|
| 1355 |
+
cute.arch.fence_view_async_shared()
|
| 1356 |
+
PdS_barrier.arrive_and_wait()
|
| 1357 |
+
# (6) [GEMM 4] dQ = dS @ K
|
| 1358 |
+
acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
|
| 1359 |
+
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
|
| 1360 |
+
pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
|
| 1361 |
+
|
| 1362 |
+
# (7) [GEMM 5] dK += dS.T @ Q
|
| 1363 |
+
if const_expr(not self.mma_dkv_is_rs):
|
| 1364 |
+
mma_dsq_fn(
|
| 1365 |
+
A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
|
| 1366 |
+
)
|
| 1367 |
+
else:
|
| 1368 |
+
mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
|
| 1369 |
+
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ)
|
| 1370 |
+
|
| 1371 |
+
cute.arch.barrier(
|
| 1372 |
+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1373 |
+
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1374 |
+
)
|
| 1375 |
+
tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
|
| 1376 |
+
cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
|
| 1377 |
+
cute.arch.fence_view_async_shared()
|
| 1378 |
+
cute.arch.barrier_arrive(
|
| 1379 |
+
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1380 |
+
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
warpgroup.wait_group(0)
|
| 1384 |
+
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK)
|
| 1385 |
+
pipeline_Q.consumer_release(consumer_state_Q)
|
| 1386 |
+
# if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block)
|
| 1387 |
+
|
| 1388 |
+
consumer_state_Q.advance()
|
| 1389 |
+
consumer_state_dO.advance()
|
| 1390 |
+
return consumer_state_Q, consumer_state_dO
|
| 1391 |
+
|
| 1392 |
+
@cute.jit
|
| 1393 |
+
def epilogue_dKV(
|
| 1394 |
+
self,
|
| 1395 |
+
acc_dV: cute.Tensor,
|
| 1396 |
+
mdV: cute.Tensor,
|
| 1397 |
+
sV: cute.Tensor,
|
| 1398 |
+
acc_dK: cute.Tensor,
|
| 1399 |
+
mdK: cute.Tensor,
|
| 1400 |
+
sK: cute.Tensor,
|
| 1401 |
+
seqlen: SeqlenInfoQK,
|
| 1402 |
+
tma_atom_dK: cute.CopyAtom,
|
| 1403 |
+
tma_atom_dV: cute.CopyAtom,
|
| 1404 |
+
tiled_mma_dK: cute.TiledMma,
|
| 1405 |
+
tiled_mma_dV: cute.TiledMma,
|
| 1406 |
+
tidx: Int32,
|
| 1407 |
+
n_block: Int32,
|
| 1408 |
+
head_idx: Int32,
|
| 1409 |
+
batch_idx: Int32,
|
| 1410 |
+
qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
|
| 1411 |
+
):
|
| 1412 |
+
epi_barrier = cutlass.pipeline.NamedBarrier(
|
| 1413 |
+
barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads
|
| 1414 |
+
)
|
| 1415 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1416 |
+
|
| 1417 |
+
if const_expr(self.qhead_per_kvhead == 1):
|
| 1418 |
+
mdV_cur = mdV[None, None, head_idx, batch_idx]
|
| 1419 |
+
mdK_cur = mdK[None, None, head_idx, batch_idx]
|
| 1420 |
+
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
| 1421 |
+
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
| 1422 |
+
store_dK, _, _ = copy_utils.tma_get_copy_fn(
|
| 1423 |
+
tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True
|
| 1424 |
+
)
|
| 1425 |
+
store_dV, _, _ = copy_utils.tma_get_copy_fn(
|
| 1426 |
+
tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True
|
| 1427 |
+
)
|
| 1428 |
+
sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
|
| 1429 |
+
sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
|
| 1430 |
+
copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1431 |
+
tiled_mma_dV, sdV, tidx, self.arch, transpose=self.dKV_swapAB
|
| 1432 |
+
)
|
| 1433 |
+
copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
|
| 1434 |
+
tiled_mma_dK, sdK, tidx, self.arch, transpose=self.dKV_swapAB
|
| 1435 |
+
)
|
| 1436 |
+
cute.arch.cp_async_bulk_wait_group(1, read=True)
|
| 1437 |
+
epi_barrier.arrive_and_wait()
|
| 1438 |
+
copy_dV_r2s(acc_dV, dst_idx=None)
|
| 1439 |
+
cute.arch.fence_view_async_shared()
|
| 1440 |
+
epi_barrier.arrive_and_wait()
|
| 1441 |
+
if warp_idx == 4:
|
| 1442 |
+
store_dV()
|
| 1443 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1444 |
+
cute.arch.cp_async_bulk_wait_group(1, read=True)
|
| 1445 |
+
epi_barrier.arrive_and_wait()
|
| 1446 |
+
copy_dK_r2s(acc_dK, dst_idx=None)
|
| 1447 |
+
cute.arch.fence_view_async_shared()
|
| 1448 |
+
epi_barrier.arrive_and_wait()
|
| 1449 |
+
if warp_idx == 4:
|
| 1450 |
+
store_dK()
|
| 1451 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1452 |
+
else:
|
| 1453 |
+
sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_mma_warp_groups
|
| 1454 |
+
sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_mma_warp_groups
|
| 1455 |
+
sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_mma_warp_groups))
|
| 1456 |
+
sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_mma_warp_groups))
|
| 1457 |
+
head_idx_kv = head_idx // qhead_per_kvhead_divmod
|
| 1458 |
+
mdKaccum_cur = mdK[None, head_idx_kv, batch_idx]
|
| 1459 |
+
gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
|
| 1460 |
+
gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
|
| 1461 |
+
mdVaccum_cur = mdV[None, head_idx_kv, batch_idx]
|
| 1462 |
+
gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
|
| 1463 |
+
gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
|
| 1464 |
+
# These two overlap each other
|
| 1465 |
+
sVaccum_ptr = cute.recast_ptr(sV.iterator, dtype=Float32)
|
| 1466 |
+
sdKaccum = cute.make_tensor(sVaccum_ptr, sdKaccum_layout)
|
| 1467 |
+
sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
|
| 1468 |
+
tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
|
| 1469 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
| 1470 |
+
cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
|
| 1471 |
+
cute.make_layout(128 // Float32.width),
|
| 1472 |
+
)
|
| 1473 |
+
thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
|
| 1474 |
+
tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum)
|
| 1475 |
+
tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum)
|
| 1476 |
+
|
| 1477 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 1478 |
+
epi_barrier.arrive_and_wait()
|
| 1479 |
+
tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape)
|
| 1480 |
+
cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum)
|
| 1481 |
+
cute.arch.fence_view_async_shared()
|
| 1482 |
+
epi_barrier.arrive_and_wait()
|
| 1483 |
+
if warp_idx == 4:
|
| 1484 |
+
with cute.arch.elect_one():
|
| 1485 |
+
for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
|
| 1486 |
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1487 |
+
sdKaccum[None, wg_idx].iterator,
|
| 1488 |
+
gdKaccum[None, wg_idx].iterator,
|
| 1489 |
+
self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups,
|
| 1490 |
+
)
|
| 1491 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1492 |
+
|
| 1493 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 1494 |
+
epi_barrier.arrive_and_wait()
|
| 1495 |
+
tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape)
|
| 1496 |
+
cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum)
|
| 1497 |
+
cute.arch.fence_view_async_shared()
|
| 1498 |
+
epi_barrier.arrive_and_wait()
|
| 1499 |
+
if warp_idx == 4:
|
| 1500 |
+
with cute.arch.elect_one():
|
| 1501 |
+
for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
|
| 1502 |
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1503 |
+
sdVaccum[None, wg_idx].iterator,
|
| 1504 |
+
gdVaccum[None, wg_idx].iterator,
|
| 1505 |
+
self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups,
|
| 1506 |
+
)
|
| 1507 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1508 |
+
|
| 1509 |
+
@cute.jit
|
| 1510 |
+
def dQaccum_store(
|
| 1511 |
+
self,
|
| 1512 |
+
mdQaccum: cute.Tensor,
|
| 1513 |
+
sdQaccum: cute.Tensor,
|
| 1514 |
+
block_info: BlockInfo,
|
| 1515 |
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 1516 |
+
SeqlenInfoCls: cutlass.Constexpr[Callable],
|
| 1517 |
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
| 1518 |
+
):
|
| 1519 |
+
tile_scheduler = TileSchedulerCls()
|
| 1520 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 1521 |
+
while work_tile.is_valid_tile:
|
| 1522 |
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
| 1523 |
+
seqlen = SeqlenInfoCls(batch_idx)
|
| 1524 |
+
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
| 1525 |
+
gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
|
| 1526 |
+
# (M * K / WG, WG, _)
|
| 1527 |
+
gdQaccum = cute.flat_divide(
|
| 1528 |
+
gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,)
|
| 1529 |
+
)
|
| 1530 |
+
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
| 1531 |
+
if const_expr(not self.use_block_sparsity):
|
| 1532 |
+
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
| 1533 |
+
loop_count = m_block_max - m_block_min
|
| 1534 |
+
else:
|
| 1535 |
+
total_block_cnt = get_total_q_block_count_bwd(
|
| 1536 |
+
blocksparse_tensors,
|
| 1537 |
+
batch_idx,
|
| 1538 |
+
head_idx,
|
| 1539 |
+
n_block,
|
| 1540 |
+
subtile_factor=self.subtile_factor,
|
| 1541 |
+
m_block_max=m_block_max,
|
| 1542 |
+
)
|
| 1543 |
+
process_tile = total_block_cnt > Int32(0)
|
| 1544 |
+
|
| 1545 |
+
if process_tile:
|
| 1546 |
+
if const_expr(not self.use_block_sparsity):
|
| 1547 |
+
for iter_idx in cutlass.range(loop_count, unroll=1):
|
| 1548 |
+
m_block = m_block_min + iter_idx
|
| 1549 |
+
m_block_safe = m_block
|
| 1550 |
+
|
| 1551 |
+
for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
|
| 1552 |
+
cute.arch.cp_async_bulk_wait_group(
|
| 1553 |
+
self.num_mma_warp_groups - 1 - warp_group_idx, read=True
|
| 1554 |
+
)
|
| 1555 |
+
cute.arch.barrier_arrive(
|
| 1556 |
+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
| 1557 |
+
number_of_threads=self.num_threads_per_warp_group
|
| 1558 |
+
+ cute.arch.WARP_SIZE,
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
|
| 1562 |
+
cute.arch.barrier(
|
| 1563 |
+
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
| 1564 |
+
number_of_threads=self.num_threads_per_warp_group
|
| 1565 |
+
+ cute.arch.WARP_SIZE,
|
| 1566 |
+
)
|
| 1567 |
+
with cute.arch.elect_one():
|
| 1568 |
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
| 1569 |
+
sdQaccum[None, warp_group_idx].iterator,
|
| 1570 |
+
gdQaccum[None, warp_group_idx, m_block_safe].iterator,
|
| 1571 |
+
self.tma_copy_bytes["dQ"],
|
| 1572 |
+
)
|
| 1573 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 1574 |
+
else:
|
| 1575 |
+
dQaccum_store_block_sparse_bwd_sm90(
|
| 1576 |
+
blocksparse_tensors,
|
| 1577 |
+
batch_idx,
|
| 1578 |
+
head_idx,
|
| 1579 |
+
n_block,
|
| 1580 |
+
sdQaccum,
|
| 1581 |
+
gdQaccum,
|
| 1582 |
+
subtile_factor=self.subtile_factor,
|
| 1583 |
+
m_block_max=m_block_max,
|
| 1584 |
+
num_mma_warp_groups=self.num_mma_warp_groups,
|
| 1585 |
+
num_threads_per_warp_group=self.num_threads_per_warp_group,
|
| 1586 |
+
tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
|
| 1587 |
+
)
|
| 1588 |
+
tile_scheduler.advance_to_next_work()
|
| 1589 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1590 |
+
|
| 1591 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
build/torch-cuda/flash_fwd.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/flash_fwd_combine.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h
|
| 3 |
+
# from Cutlass C++ to Cute-DSL.
|
| 4 |
+
import math
|
| 5 |
+
from typing import Type, Optional
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import cuda.bindings.driver as cuda
|
| 9 |
+
|
| 10 |
+
import cutlass
|
| 11 |
+
import cutlass.cute as cute
|
| 12 |
+
from cutlass.cute.nvgpu import cpasync
|
| 13 |
+
from cutlass import Float32, Int32, const_expr
|
| 14 |
+
|
| 15 |
+
from . import utils
|
| 16 |
+
from .cute_dsl_utils import assume_tensor_aligned
|
| 17 |
+
from .seqlen_info import SeqlenInfo
|
| 18 |
+
from cutlass.cute import FastDivmodDivisor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FlashAttentionForwardCombine:
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dtype: Type[cutlass.Numeric],
|
| 25 |
+
dtype_partial: Type[cutlass.Numeric],
|
| 26 |
+
head_dim: int,
|
| 27 |
+
m_block_size: int = 8,
|
| 28 |
+
k_block_size: int = 64,
|
| 29 |
+
log_max_splits: int = 4,
|
| 30 |
+
num_threads: int = 256,
|
| 31 |
+
stages: int = 4,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Forward combine kernel for split attention computation.
|
| 35 |
+
|
| 36 |
+
:param dtype: output data type
|
| 37 |
+
:param dtype_partial: partial accumulation data type
|
| 38 |
+
:param head_dim: head dimension
|
| 39 |
+
:param m_block_size: m block size
|
| 40 |
+
:param k_block_size: k block size
|
| 41 |
+
:param log_max_splits: log2 of maximum splits
|
| 42 |
+
:param num_threads: number of threads
|
| 43 |
+
:param varlen: whether using variable length sequences
|
| 44 |
+
:param stages: number of pipeline stages
|
| 45 |
+
"""
|
| 46 |
+
self.dtype = dtype
|
| 47 |
+
self.dtype_partial = dtype_partial
|
| 48 |
+
self.head_dim = head_dim
|
| 49 |
+
self.m_block_size = m_block_size
|
| 50 |
+
self.k_block_size = k_block_size
|
| 51 |
+
self.max_splits = 1 << log_max_splits
|
| 52 |
+
self.num_threads = num_threads
|
| 53 |
+
self.is_even_k = head_dim % k_block_size == 0
|
| 54 |
+
self.stages = stages
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def can_implement(
|
| 58 |
+
dtype,
|
| 59 |
+
dtype_partial,
|
| 60 |
+
head_dim,
|
| 61 |
+
m_block_size,
|
| 62 |
+
k_block_size,
|
| 63 |
+
log_max_splits,
|
| 64 |
+
num_threads,
|
| 65 |
+
) -> bool:
|
| 66 |
+
"""Check if the kernel can be implemented with the given parameters."""
|
| 67 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
|
| 68 |
+
return False
|
| 69 |
+
if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:
|
| 70 |
+
return False
|
| 71 |
+
if head_dim % 8 != 0:
|
| 72 |
+
return False
|
| 73 |
+
if num_threads % 32 != 0:
|
| 74 |
+
return False
|
| 75 |
+
if m_block_size % 8 != 0:
|
| 76 |
+
return False
|
| 77 |
+
max_splits = 1 << log_max_splits
|
| 78 |
+
if max_splits > 256:
|
| 79 |
+
return False
|
| 80 |
+
if (m_block_size * max_splits) % num_threads != 0:
|
| 81 |
+
return False
|
| 82 |
+
return True
|
| 83 |
+
|
| 84 |
+
def _setup_attributes(self):
|
| 85 |
+
# GMEM copy setup for O partial
|
| 86 |
+
universal_copy_bits = 128
|
| 87 |
+
async_copy_elems = universal_copy_bits // self.dtype_partial.width
|
| 88 |
+
assert self.k_block_size % async_copy_elems == 0
|
| 89 |
+
|
| 90 |
+
k_block_gmem = (
|
| 91 |
+
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
|
| 92 |
+
)
|
| 93 |
+
gmem_threads_per_row = k_block_gmem // async_copy_elems
|
| 94 |
+
assert self.num_threads % gmem_threads_per_row == 0
|
| 95 |
+
|
| 96 |
+
# Async copy atom for O partial load
|
| 97 |
+
atom_async_copy_partial = cute.make_copy_atom(
|
| 98 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 99 |
+
self.dtype_partial,
|
| 100 |
+
num_bits_per_copy=universal_copy_bits,
|
| 101 |
+
)
|
| 102 |
+
tOpartial_layout = cute.make_ordered_layout(
|
| 103 |
+
(self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 104 |
+
order=(1, 0),
|
| 105 |
+
)
|
| 106 |
+
vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load
|
| 107 |
+
self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
|
| 108 |
+
atom_async_copy_partial, tOpartial_layout, vOpartial_layout
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# GMEM copy setup for final O (use universal copy for store)
|
| 112 |
+
atom_universal_copy = cute.make_copy_atom(
|
| 113 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 114 |
+
self.dtype,
|
| 115 |
+
num_bits_per_copy=async_copy_elems * self.dtype.width,
|
| 116 |
+
)
|
| 117 |
+
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
| 118 |
+
atom_universal_copy,
|
| 119 |
+
tOpartial_layout,
|
| 120 |
+
vOpartial_layout, # 4 vals per store
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# LSE copy setup with async copy (alignment = 1)
|
| 124 |
+
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
| 125 |
+
m_block_smem = (
|
| 126 |
+
128
|
| 127 |
+
if self.m_block_size % 128 == 0
|
| 128 |
+
else (
|
| 129 |
+
64
|
| 130 |
+
if self.m_block_size % 64 == 0
|
| 131 |
+
else (
|
| 132 |
+
32
|
| 133 |
+
if self.m_block_size % 32 == 0
|
| 134 |
+
else (16 if self.m_block_size % 16 == 0 else 8)
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
gmem_threads_per_row_lse = m_block_smem
|
| 139 |
+
assert self.num_threads % gmem_threads_per_row_lse == 0
|
| 140 |
+
|
| 141 |
+
# Async copy atom for LSE load
|
| 142 |
+
atom_async_copy_lse = cute.make_copy_atom(
|
| 143 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
| 144 |
+
Float32,
|
| 145 |
+
num_bits_per_copy=lse_copy_bits,
|
| 146 |
+
)
|
| 147 |
+
tLSE_layout = cute.make_ordered_layout(
|
| 148 |
+
(self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
|
| 149 |
+
order=(1, 0),
|
| 150 |
+
)
|
| 151 |
+
vLSE_layout = cute.make_layout(1)
|
| 152 |
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 153 |
+
atom_async_copy_lse, tLSE_layout, vLSE_layout
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 157 |
+
# Shared memory
|
| 158 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 159 |
+
|
| 160 |
+
# Shared memory to register copy for LSE
|
| 161 |
+
self.smem_threads_per_col_lse = self.num_threads // m_block_smem
|
| 162 |
+
assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
|
| 163 |
+
|
| 164 |
+
s2r_layout_atom_lse = cute.make_ordered_layout(
|
| 165 |
+
(self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
|
| 166 |
+
order=(0, 1),
|
| 167 |
+
)
|
| 168 |
+
self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 169 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
|
| 170 |
+
s2r_layout_atom_lse,
|
| 171 |
+
cute.make_layout(1),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# LSE shared memory layout with swizzling to avoid bank conflicts
|
| 175 |
+
# This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
|
| 176 |
+
if const_expr(m_block_smem == 8):
|
| 177 |
+
smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
|
| 178 |
+
elif const_expr(m_block_smem == 16):
|
| 179 |
+
smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
|
| 180 |
+
else:
|
| 181 |
+
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
|
| 182 |
+
smem_layout_atom_lse = cute.make_composed_layout(
|
| 183 |
+
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
| 184 |
+
)
|
| 185 |
+
self.smem_layout_lse = cute.tile_to_shape(
|
| 186 |
+
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# O partial shared memory layout (simple layout for pipeline stages)
|
| 190 |
+
self.smem_layout_o = cute.make_ordered_layout(
|
| 191 |
+
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
@cute.jit
|
| 195 |
+
def __call__(
|
| 196 |
+
self,
|
| 197 |
+
mO_partial: cute.Tensor,
|
| 198 |
+
mLSE_partial: cute.Tensor,
|
| 199 |
+
mO: cute.Tensor,
|
| 200 |
+
mLSE: Optional[cute.Tensor] = None,
|
| 201 |
+
cu_seqlens: Optional[cute.Tensor] = None,
|
| 202 |
+
seqused: Optional[cute.Tensor] = None,
|
| 203 |
+
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
|
| 204 |
+
semaphore_to_reset: Optional[cute.Tensor] = None,
|
| 205 |
+
stream: cuda.CUstream = None,
|
| 206 |
+
):
|
| 207 |
+
# Type checking
|
| 208 |
+
if const_expr(not (mO_partial.element_type == self.dtype_partial)):
|
| 209 |
+
raise TypeError("O partial tensor must match dtype_partial")
|
| 210 |
+
if const_expr(not (mO.element_type == self.dtype)):
|
| 211 |
+
raise TypeError("O tensor must match dtype")
|
| 212 |
+
if const_expr(mLSE_partial.element_type not in [Float32]):
|
| 213 |
+
raise TypeError("LSE partial tensor must be Float32")
|
| 214 |
+
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
|
| 215 |
+
raise TypeError("LSE tensor must be Float32")
|
| 216 |
+
|
| 217 |
+
# Shape validation - input tensors are in user format, need to be converted to kernel format
|
| 218 |
+
if const_expr(len(mO_partial.shape) not in [4, 5]):
|
| 219 |
+
raise ValueError(
|
| 220 |
+
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
|
| 221 |
+
)
|
| 222 |
+
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
|
| 223 |
+
raise ValueError(
|
| 224 |
+
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
|
| 225 |
+
)
|
| 226 |
+
if const_expr(len(mO.shape) not in [3, 4]):
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
|
| 229 |
+
)
|
| 230 |
+
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
|
| 236 |
+
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
|
| 237 |
+
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
|
| 238 |
+
O_partial_layout_transpose = (
|
| 239 |
+
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
|
| 240 |
+
)
|
| 241 |
+
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
|
| 242 |
+
mO_partial = cute.make_tensor(
|
| 243 |
+
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
|
| 244 |
+
)
|
| 245 |
+
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
|
| 246 |
+
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
|
| 247 |
+
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
|
| 248 |
+
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
|
| 249 |
+
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
|
| 250 |
+
mLSE_partial = cute.make_tensor(
|
| 251 |
+
mLSE_partial.iterator,
|
| 252 |
+
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
|
| 253 |
+
)
|
| 254 |
+
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
|
| 255 |
+
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
|
| 256 |
+
mLSE = (
|
| 257 |
+
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
| 258 |
+
if mLSE is not None
|
| 259 |
+
else None
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Determine if we have variable length sequences
|
| 263 |
+
varlen = const_expr(cu_seqlens is not None or seqused is not None)
|
| 264 |
+
|
| 265 |
+
self._setup_attributes()
|
| 266 |
+
|
| 267 |
+
@cute.struct
|
| 268 |
+
class SharedStorage:
|
| 269 |
+
sLSE: cute.struct.Align[
|
| 270 |
+
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 271 |
+
]
|
| 272 |
+
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
|
| 273 |
+
sO: cute.struct.Align[
|
| 274 |
+
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
smem_size = SharedStorage.size_in_bytes()
|
| 278 |
+
|
| 279 |
+
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
|
| 280 |
+
seqlen = mO_partial.shape[0]
|
| 281 |
+
num_head = mO_partial.shape[3]
|
| 282 |
+
batch_size = (
|
| 283 |
+
mO_partial.shape[4]
|
| 284 |
+
if const_expr(cu_seqlens is None)
|
| 285 |
+
else Int32(cu_seqlens.shape[0] - 1)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Create FastDivmodDivisor objects for efficient division
|
| 289 |
+
seqlen_divmod = FastDivmodDivisor(seqlen)
|
| 290 |
+
head_divmod = FastDivmodDivisor(num_head)
|
| 291 |
+
|
| 292 |
+
grid_dim = (
|
| 293 |
+
cute.ceil_div(seqlen * num_head, self.m_block_size),
|
| 294 |
+
cute.ceil_div(self.head_dim, self.k_block_size),
|
| 295 |
+
batch_size,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.kernel(
|
| 299 |
+
mO_partial,
|
| 300 |
+
mLSE_partial,
|
| 301 |
+
mO,
|
| 302 |
+
mLSE,
|
| 303 |
+
cu_seqlens,
|
| 304 |
+
seqused,
|
| 305 |
+
num_splits_dynamic_ptr,
|
| 306 |
+
semaphore_to_reset,
|
| 307 |
+
SharedStorage,
|
| 308 |
+
self.smem_layout_lse,
|
| 309 |
+
self.smem_layout_o,
|
| 310 |
+
self.gmem_tiled_copy_O_partial,
|
| 311 |
+
self.gmem_tiled_copy_O,
|
| 312 |
+
self.gmem_tiled_copy_LSE,
|
| 313 |
+
self.s2r_tiled_copy_LSE,
|
| 314 |
+
seqlen_divmod,
|
| 315 |
+
head_divmod,
|
| 316 |
+
varlen,
|
| 317 |
+
).launch(
|
| 318 |
+
grid=grid_dim,
|
| 319 |
+
block=[self.num_threads, 1, 1],
|
| 320 |
+
smem=smem_size,
|
| 321 |
+
stream=stream,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
@cute.kernel
|
| 325 |
+
def kernel(
|
| 326 |
+
self,
|
| 327 |
+
mO_partial: cute.Tensor,
|
| 328 |
+
mLSE_partial: cute.Tensor,
|
| 329 |
+
mO: cute.Tensor,
|
| 330 |
+
mLSE: Optional[cute.Tensor],
|
| 331 |
+
cu_seqlens: Optional[cute.Tensor],
|
| 332 |
+
seqused: Optional[cute.Tensor],
|
| 333 |
+
num_splits_dynamic_ptr: Optional[cute.Tensor],
|
| 334 |
+
semaphore_to_reset: Optional[cute.Tensor],
|
| 335 |
+
SharedStorage: cutlass.Constexpr,
|
| 336 |
+
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
| 337 |
+
smem_layout_o: cute.Layout,
|
| 338 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 339 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 340 |
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
| 341 |
+
s2r_tiled_copy_LSE: cute.TiledCopy,
|
| 342 |
+
seqlen_divmod: FastDivmodDivisor,
|
| 343 |
+
head_divmod: FastDivmodDivisor,
|
| 344 |
+
varlen: cutlass.Constexpr[bool],
|
| 345 |
+
):
|
| 346 |
+
# Thread and block indices
|
| 347 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 348 |
+
m_block, k_block, batch_idx = cute.arch.block_idx()
|
| 349 |
+
|
| 350 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 351 |
+
# Get shared memory buffer
|
| 352 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 353 |
+
smem = cutlass.utils.SmemAllocator()
|
| 354 |
+
storage = smem.allocate(SharedStorage)
|
| 355 |
+
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
| 356 |
+
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,))
|
| 357 |
+
sO = storage.sO.get_tensor(smem_layout_o)
|
| 358 |
+
|
| 359 |
+
# Handle semaphore reset
|
| 360 |
+
if const_expr(semaphore_to_reset is not None):
|
| 361 |
+
if (
|
| 362 |
+
tidx == 0
|
| 363 |
+
and m_block == cute.arch.grid_dim()[0] - 1
|
| 364 |
+
and k_block == cute.arch.grid_dim()[1] - 1
|
| 365 |
+
and batch_idx == cute.arch.grid_dim()[2] - 1
|
| 366 |
+
):
|
| 367 |
+
semaphore_to_reset[0] = 0
|
| 368 |
+
|
| 369 |
+
# Get number of splits
|
| 370 |
+
num_splits = (
|
| 371 |
+
num_splits_dynamic_ptr[batch_idx]
|
| 372 |
+
if const_expr(num_splits_dynamic_ptr is not None)
|
| 373 |
+
else mLSE_partial.shape[1]
|
| 374 |
+
)
|
| 375 |
+
# Handle variable length sequences using SeqlenInfo
|
| 376 |
+
seqlen_info = SeqlenInfo.create(
|
| 377 |
+
batch_idx=batch_idx,
|
| 378 |
+
seqlen_static=mO_partial.shape[0],
|
| 379 |
+
cu_seqlens=cu_seqlens,
|
| 380 |
+
seqused=seqused,
|
| 381 |
+
)
|
| 382 |
+
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
| 383 |
+
|
| 384 |
+
# Extract number of heads (head index will be determined dynamically)
|
| 385 |
+
num_head = mO_partial.shape[3]
|
| 386 |
+
max_idx = seqlen * num_head
|
| 387 |
+
|
| 388 |
+
# Early exit for single split if dynamic
|
| 389 |
+
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
|
| 390 |
+
const_expr(not varlen) or m_block * self.m_block_size < max_idx
|
| 391 |
+
):
|
| 392 |
+
# ===============================
|
| 393 |
+
# Step 1: Load LSE_partial from gmem to shared memory
|
| 394 |
+
# ===============================
|
| 395 |
+
|
| 396 |
+
if const_expr(cu_seqlens is None):
|
| 397 |
+
mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
|
| 398 |
+
else:
|
| 399 |
+
mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
|
| 400 |
+
mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
|
| 401 |
+
|
| 402 |
+
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 403 |
+
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
| 404 |
+
|
| 405 |
+
# Create identity tensor for coordinate tracking
|
| 406 |
+
cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size))
|
| 407 |
+
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
| 408 |
+
|
| 409 |
+
# Load LSE partial values
|
| 410 |
+
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 411 |
+
mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
|
| 412 |
+
idx = m_block * self.m_block_size + mi
|
| 413 |
+
if idx < max_idx:
|
| 414 |
+
# Calculate actual sequence position and head using FastDivmodDivisor
|
| 415 |
+
if const_expr(not varlen):
|
| 416 |
+
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
| 417 |
+
else:
|
| 418 |
+
head_idx = idx // seqlen
|
| 419 |
+
m_idx = idx - head_idx * seqlen
|
| 420 |
+
mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
|
| 421 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 422 |
+
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
|
| 423 |
+
if si < num_splits:
|
| 424 |
+
cute.copy(
|
| 425 |
+
gmem_thr_copy_LSE,
|
| 426 |
+
mLSE_partial_cur_copy[None, si],
|
| 427 |
+
tLSEsLSE[None, s, m],
|
| 428 |
+
)
|
| 429 |
+
else:
|
| 430 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 431 |
+
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
|
| 432 |
+
cute.arch.cp_async_commit_group()
|
| 433 |
+
|
| 434 |
+
# ===============================
|
| 435 |
+
# Step 2: Load O_partial for pipeline stages
|
| 436 |
+
# ===============================
|
| 437 |
+
|
| 438 |
+
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
| 439 |
+
cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size))
|
| 440 |
+
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
| 441 |
+
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
| 442 |
+
if const_expr(cu_seqlens is None):
|
| 443 |
+
mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
|
| 444 |
+
else:
|
| 445 |
+
mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
|
| 446 |
+
|
| 447 |
+
# Precompute these values to avoid recomputing them in the loop
|
| 448 |
+
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
| 449 |
+
tOmidx = cute.make_fragment(num_rows, cutlass.Int32)
|
| 450 |
+
tOhidx = cute.make_fragment(num_rows, cutlass.Int32)
|
| 451 |
+
tOrOptr = cute.make_fragment(num_rows, cutlass.Int64)
|
| 452 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 453 |
+
mi = tOcO[0, m, 0][0] # m coordinate
|
| 454 |
+
idx = m_block * self.m_block_size + mi
|
| 455 |
+
if const_expr(not varlen):
|
| 456 |
+
tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
|
| 457 |
+
else:
|
| 458 |
+
tOhidx[m] = idx // seqlen
|
| 459 |
+
tOmidx[m] = idx - tOhidx[m] * seqlen
|
| 460 |
+
tOrOptr[m] = utils.elem_pointer(
|
| 461 |
+
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
|
| 462 |
+
).toint()
|
| 463 |
+
if idx >= max_idx:
|
| 464 |
+
tOhidx[m] = -1
|
| 465 |
+
|
| 466 |
+
tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean)
|
| 467 |
+
if const_expr(not self.is_even_k):
|
| 468 |
+
for k in cutlass.range(cute.size(tOpO), unroll_full=True):
|
| 469 |
+
tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
|
| 470 |
+
# if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
|
| 471 |
+
|
| 472 |
+
load_O_partial = partial(
|
| 473 |
+
self.load_O_partial,
|
| 474 |
+
gmem_tiled_copy_O_partial,
|
| 475 |
+
tOrOptr,
|
| 476 |
+
tOsO_partial,
|
| 477 |
+
tOhidx,
|
| 478 |
+
tOpO,
|
| 479 |
+
tOcO,
|
| 480 |
+
mO_partial_cur.layout,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Load first few stages of O_partial
|
| 484 |
+
for stage in cutlass.range(self.stages - 1, unroll_full=True):
|
| 485 |
+
if stage < num_splits:
|
| 486 |
+
load_O_partial(stage, stage)
|
| 487 |
+
cute.arch.cp_async_commit_group()
|
| 488 |
+
|
| 489 |
+
# ===============================
|
| 490 |
+
# Step 3: Load and transpose LSE from smem to registers
|
| 491 |
+
# ===============================
|
| 492 |
+
|
| 493 |
+
# Wait for LSE and initial O partial stages to complete
|
| 494 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 495 |
+
cute.arch.sync_threads()
|
| 496 |
+
# if cute.arch.thread_idx()[0] == 0:
|
| 497 |
+
# # cute.print_tensor(sLSE)
|
| 498 |
+
# for i in range(64):
|
| 499 |
+
# cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
|
| 500 |
+
# cute.arch.sync_threads()
|
| 501 |
+
|
| 502 |
+
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
| 503 |
+
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
| 504 |
+
ts2rrLSE = cute.make_fragment_like(ts2rsLSE)
|
| 505 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
| 506 |
+
|
| 507 |
+
# ===============================
|
| 508 |
+
# Step 4: Compute final LSE along split dimension
|
| 509 |
+
# ===============================
|
| 510 |
+
|
| 511 |
+
lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32)
|
| 512 |
+
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
| 513 |
+
# We compute the max valid split for each row to short-circuit the computation later
|
| 514 |
+
max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32)
|
| 515 |
+
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
| 516 |
+
# Compute max, scales, and final LSE for each row
|
| 517 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 518 |
+
# Find max LSE value across splits
|
| 519 |
+
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
| 520 |
+
lse_max = cute.arch.warp_reduction_max(
|
| 521 |
+
ts2rrLSE[None, None, m]
|
| 522 |
+
.load()
|
| 523 |
+
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
| 524 |
+
threads_in_group=threads_per_col,
|
| 525 |
+
)
|
| 526 |
+
# if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
|
| 527 |
+
# Find max valid split index
|
| 528 |
+
max_valid_idx = -1
|
| 529 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 530 |
+
if ts2rrLSE[0, s, m] != -Float32.inf:
|
| 531 |
+
max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
|
| 532 |
+
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
|
| 533 |
+
max_valid_split[m] = cute.arch.warp_reduction_max(
|
| 534 |
+
max_valid_idx, threads_in_group=threads_per_col
|
| 535 |
+
)
|
| 536 |
+
# Compute exp scales and sum
|
| 537 |
+
lse_max_cur = (
|
| 538 |
+
0.0 if lse_max == -Float32.inf else lse_max
|
| 539 |
+
) # In case all local LSEs are -inf
|
| 540 |
+
LOG2_E = math.log2(math.e)
|
| 541 |
+
lse_sum_cur = 0.0
|
| 542 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 543 |
+
scale = cute.math.exp2(
|
| 544 |
+
ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
|
| 545 |
+
)
|
| 546 |
+
lse_sum_cur += scale
|
| 547 |
+
ts2rrLSE[0, s, m] = scale # Store scale for later use
|
| 548 |
+
lse_sum_cur = cute.arch.warp_reduction_sum(
|
| 549 |
+
lse_sum_cur, threads_in_group=threads_per_col
|
| 550 |
+
)
|
| 551 |
+
lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
|
| 552 |
+
# Normalize scales
|
| 553 |
+
inv_sum = (
|
| 554 |
+
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
|
| 555 |
+
)
|
| 556 |
+
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
|
| 557 |
+
# Store the scales exp(lse - lse_logsum) back to smem
|
| 558 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
|
| 559 |
+
|
| 560 |
+
# Store max valid split to smem
|
| 561 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 562 |
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 563 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 564 |
+
if mi < self.m_block_size:
|
| 565 |
+
sMaxValidSplit[mi] = max_valid_split[m]
|
| 566 |
+
|
| 567 |
+
# ===============================
|
| 568 |
+
# Step 5: Store final LSE to gmem
|
| 569 |
+
# ===============================
|
| 570 |
+
|
| 571 |
+
if const_expr(mLSE is not None):
|
| 572 |
+
if const_expr(cu_seqlens is None):
|
| 573 |
+
mLSE_cur = mLSE[None, None, batch_idx]
|
| 574 |
+
else:
|
| 575 |
+
mLSE_cur = cute.domain_offset((offset, 0), mLSE)
|
| 576 |
+
if k_block == 0: # Only first k_block writes LSE when mLSE is provided
|
| 577 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 578 |
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 579 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 580 |
+
idx = m_block * self.m_block_size + mi
|
| 581 |
+
if idx < max_idx:
|
| 582 |
+
if const_expr(not varlen):
|
| 583 |
+
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
| 584 |
+
else:
|
| 585 |
+
head_idx = idx // seqlen
|
| 586 |
+
m_idx = idx - head_idx * seqlen
|
| 587 |
+
mLSE_cur[m_idx, head_idx] = lse_sum[m]
|
| 588 |
+
|
| 589 |
+
# ===============================
|
| 590 |
+
# Step 6: Read O_partial and accumulate final O
|
| 591 |
+
# ===============================
|
| 592 |
+
|
| 593 |
+
cute.arch.sync_threads()
|
| 594 |
+
|
| 595 |
+
# Get max valid split for this thread
|
| 596 |
+
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
| 597 |
+
for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
|
| 598 |
+
thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
|
| 599 |
+
|
| 600 |
+
tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0])
|
| 601 |
+
tOrO = cute.make_fragment_like(tOrO_partial, Float32)
|
| 602 |
+
tOrO.fill(0.0)
|
| 603 |
+
|
| 604 |
+
stage_load = self.stages - 1
|
| 605 |
+
stage_compute = 0
|
| 606 |
+
|
| 607 |
+
# Main accumulation loop
|
| 608 |
+
for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
|
| 609 |
+
# Get scales for this split
|
| 610 |
+
scale = cute.make_fragment(num_rows, Float32)
|
| 611 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 612 |
+
scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
|
| 613 |
+
|
| 614 |
+
# Load next stage if needed
|
| 615 |
+
split_to_load = s + self.stages - 1
|
| 616 |
+
if split_to_load <= thr_max_valid_split:
|
| 617 |
+
load_O_partial(split_to_load, stage_load)
|
| 618 |
+
cute.arch.cp_async_commit_group()
|
| 619 |
+
stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
|
| 620 |
+
|
| 621 |
+
# Wait for the current stage to be ready
|
| 622 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 623 |
+
# We don't need __syncthreads() because each thread is just reading its own data from smem
|
| 624 |
+
# Copy from smem to registers
|
| 625 |
+
cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
|
| 626 |
+
stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
|
| 627 |
+
|
| 628 |
+
# Accumulate scaled partial results
|
| 629 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 630 |
+
if tOhidx[m] >= 0 and scale[m] > 0.0:
|
| 631 |
+
tOrO[None, m, None].store(
|
| 632 |
+
tOrO[None, m, None].load()
|
| 633 |
+
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# ===============================
|
| 637 |
+
# Step 7: Write final O to gmem
|
| 638 |
+
# ===============================
|
| 639 |
+
|
| 640 |
+
rO = cute.make_fragment_like(tOrO, self.dtype)
|
| 641 |
+
rO.store(tOrO.load().to(self.dtype))
|
| 642 |
+
if const_expr(cu_seqlens is None):
|
| 643 |
+
mO_cur = mO[None, None, None, batch_idx]
|
| 644 |
+
else:
|
| 645 |
+
mO_cur = cute.domain_offset((offset, 0, 0), mO)
|
| 646 |
+
mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
|
| 647 |
+
elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
|
| 648 |
+
# mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,))
|
| 649 |
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 650 |
+
# Write final results
|
| 651 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 652 |
+
if tOhidx[m] >= 0:
|
| 653 |
+
mO_cur_copy = cute.tiled_divide(
|
| 654 |
+
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
|
| 655 |
+
)
|
| 656 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 657 |
+
k_idx = tOcO[0, 0, k][1] // elems_per_store
|
| 658 |
+
if const_expr(self.is_even_k) or tOpO[k]:
|
| 659 |
+
cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
|
| 660 |
+
|
| 661 |
+
@cute.jit
|
| 662 |
+
def load_O_partial(
|
| 663 |
+
self,
|
| 664 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 665 |
+
tOrOptr: cute.Tensor,
|
| 666 |
+
tOsO_partial: cute.Tensor,
|
| 667 |
+
tOhidx: cute.Tensor,
|
| 668 |
+
tOpO: cute.Tensor,
|
| 669 |
+
tOcO: cute.Tensor,
|
| 670 |
+
mO_cur_partial_layout: cute.Layout,
|
| 671 |
+
split: Int32,
|
| 672 |
+
stage: Int32,
|
| 673 |
+
) -> None:
|
| 674 |
+
elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
|
| 675 |
+
tOsO_partial_cur = tOsO_partial[None, None, None, stage]
|
| 676 |
+
for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
|
| 677 |
+
if tOhidx[m] >= 0:
|
| 678 |
+
o_gmem_ptr = cute.make_ptr(
|
| 679 |
+
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
|
| 680 |
+
)
|
| 681 |
+
mO_partial_cur = cute.make_tensor(
|
| 682 |
+
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
|
| 683 |
+
)
|
| 684 |
+
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
| 685 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 686 |
+
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
| 687 |
+
if const_expr(self.is_even_k) or tOpO[k]:
|
| 688 |
+
cute.copy(
|
| 689 |
+
gmem_tiled_copy_O_partial,
|
| 690 |
+
mO_partial_cur_copy[None, k_idx, split],
|
| 691 |
+
tOsO_partial_cur[None, m, k],
|
| 692 |
+
)
|
build/torch-cuda/flash_fwd_sm100.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/interface.py
ADDED
|
@@ -0,0 +1,1855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
|
| 3 |
+
|
| 4 |
+
# Supported features:
|
| 5 |
+
# - BF16 & FP16 dtype
|
| 6 |
+
# - noncausal & causal attention
|
| 7 |
+
# - MHA, GQA, MQA
|
| 8 |
+
# - hdim 64, 96, 128.
|
| 9 |
+
# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape)
|
| 10 |
+
# - varlen
|
| 11 |
+
# - sliding window
|
| 12 |
+
# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow)
|
| 13 |
+
|
| 14 |
+
# Features not supported yet:
|
| 15 |
+
# - split (i.e. FlashDecoding)
|
| 16 |
+
# - tuned block sizes
|
| 17 |
+
# - paged KV
|
| 18 |
+
# - append KV to existing KV cache
|
| 19 |
+
# - FP8
|
| 20 |
+
# - bwd pass optimized for Hopper/Blackwell
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import math
|
| 24 |
+
from functools import lru_cache
|
| 25 |
+
from typing import Optional, Tuple, Callable
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
import cuda.bindings.driver as cuda
|
| 31 |
+
|
| 32 |
+
import cutlass
|
| 33 |
+
import cutlass.cute as cute
|
| 34 |
+
from .cache_utils import get_jit_cache
|
| 35 |
+
from .testing import is_fake_mode
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
| 39 |
+
from . import cute_dsl_ptxas # noqa: F401
|
| 40 |
+
|
| 41 |
+
# Patch to dump ptx and then use system ptxas to compile to cubin
|
| 42 |
+
cute_dsl_ptxas.patch()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
from . import utils
|
| 46 |
+
from .cute_dsl_utils import (
|
| 47 |
+
to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
|
| 48 |
+
)
|
| 49 |
+
from .flash_fwd import FlashAttentionForwardSm90
|
| 50 |
+
from .flash_fwd_sm100 import FlashAttentionForwardSm100
|
| 51 |
+
from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
|
| 52 |
+
from .flash_bwd import FlashAttentionBackwardSm80
|
| 53 |
+
from .flash_bwd_sm90 import FlashAttentionBackwardSm90
|
| 54 |
+
from .flash_bwd_sm100 import FlashAttentionBackwardSm100
|
| 55 |
+
from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
|
| 56 |
+
from .flash_fwd_combine import FlashAttentionForwardCombine
|
| 57 |
+
|
| 58 |
+
from .block_sparsity import (
|
| 59 |
+
BlockSparseTensorsTorch,
|
| 60 |
+
to_cute_block_sparse_tensors,
|
| 61 |
+
normalize_block_sparse_config,
|
| 62 |
+
normalize_block_sparse_config_bwd,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@lru_cache(maxsize=None)
|
| 66 |
+
def _get_device_arch():
|
| 67 |
+
"""Cached device arch check."""
|
| 68 |
+
major, minor = torch.cuda.get_device_capability()
|
| 69 |
+
return major * 10 + minor
|
| 70 |
+
|
| 71 |
+
def maybe_contiguous(x):
|
| 72 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
|
| 76 |
+
assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
|
| 77 |
+
assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
|
| 78 |
+
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
|
| 79 |
+
assert t.is_cuda, f"{name} must be on CUDA"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
torch2cute_dtype_map = {
|
| 83 |
+
torch.float16: cutlass.Float16,
|
| 84 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 85 |
+
torch.float32: cutlass.Float32,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
|
| 90 |
+
# If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
|
| 91 |
+
if num_n_blocks <= 4:
|
| 92 |
+
return 1
|
| 93 |
+
|
| 94 |
+
# NOTE: We should revisit this heuristic after persistence is supported for split KV.
|
| 95 |
+
# Sometimes, it's ideal to over-schedule splits for better efficiency.
|
| 96 |
+
return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _flash_attn_fwd(
|
| 100 |
+
q: torch.Tensor,
|
| 101 |
+
k: torch.Tensor,
|
| 102 |
+
v: torch.Tensor,
|
| 103 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 104 |
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
| 105 |
+
seqused_q: Optional[torch.Tensor] = None,
|
| 106 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 107 |
+
max_seqlen_q: Optional[int] = None,
|
| 108 |
+
max_seqlen_k: Optional[int] = None,
|
| 109 |
+
page_table: Optional[torch.Tensor] = None,
|
| 110 |
+
softmax_scale: Optional[float] = None,
|
| 111 |
+
causal: bool = False,
|
| 112 |
+
softcap: Optional[float] = None,
|
| 113 |
+
window_size_left: Optional[int] = None,
|
| 114 |
+
window_size_right: Optional[int] = None,
|
| 115 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 116 |
+
# m_block_size: int = 128,
|
| 117 |
+
# n_block_size: int = 64,
|
| 118 |
+
# num_threads: int = 128,
|
| 119 |
+
m_block_size: int = 128,
|
| 120 |
+
n_block_size: int = 128,
|
| 121 |
+
num_threads: int = 384,
|
| 122 |
+
num_splits: int = 1,
|
| 123 |
+
pack_gqa: Optional[bool] = None,
|
| 124 |
+
_arch: Optional[int] = None,
|
| 125 |
+
score_mod: Optional[Callable] = None,
|
| 126 |
+
mask_mod: Optional[Callable] = None,
|
| 127 |
+
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
| 128 |
+
return_lse: bool = False,
|
| 129 |
+
out: Optional[torch.Tensor] = None,
|
| 130 |
+
lse: Optional[torch.Tensor] = None,
|
| 131 |
+
aux_tensors: Optional[list[torch.Tensor]] = None,
|
| 132 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 133 |
+
"""Forward pass for FlashAttention.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
...
|
| 137 |
+
score_mod: A callable that takes the attention scores and applies a modification.
|
| 138 |
+
mask_mod: A callable that takes token position information and selectively masks
|
| 139 |
+
block_sparse_tensors: A tuple of tensors used for block sparsity.
|
| 140 |
+
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
|
| 141 |
+
Note: the returned LSE currently does not support taking gradient.
|
| 142 |
+
out: Optional pre-allocated output tensor. If None, will be allocated internally.
|
| 143 |
+
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
|
| 144 |
+
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
|
| 145 |
+
"""
|
| 146 |
+
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
|
| 147 |
+
num_head, head_dim = q.shape[-2:]
|
| 148 |
+
if cu_seqlens_q is None:
|
| 149 |
+
batch_size, seqlen_q = q.shape[:2]
|
| 150 |
+
total_q = batch_size * seqlen_q
|
| 151 |
+
else:
|
| 152 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
| 153 |
+
seqlen_q = None
|
| 154 |
+
total_q = q.shape[0]
|
| 155 |
+
if page_table is not None:
|
| 156 |
+
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
|
| 157 |
+
assert page_table.dtype == torch.int32, "page_table must be int32"
|
| 158 |
+
assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension"
|
| 159 |
+
max_num_pages_per_seq = page_table.shape[1]
|
| 160 |
+
assert page_table.shape == (batch_size, max_num_pages_per_seq)
|
| 161 |
+
num_pages, page_size = k.shape[:2]
|
| 162 |
+
seqlen_k = num_pages * page_size
|
| 163 |
+
else:
|
| 164 |
+
num_pages, page_size = None, None
|
| 165 |
+
seqlen_k = k.shape[-3]
|
| 166 |
+
num_head_kv = k.shape[-2]
|
| 167 |
+
head_dim_v = v.shape[-1]
|
| 168 |
+
if cu_seqlens_k is None:
|
| 169 |
+
if page_table is None:
|
| 170 |
+
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
| 171 |
+
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
| 172 |
+
else:
|
| 173 |
+
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
|
| 174 |
+
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
|
| 175 |
+
else:
|
| 176 |
+
assert k.shape == (seqlen_k, num_head_kv, head_dim)
|
| 177 |
+
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
|
| 178 |
+
assert cu_seqlens_k.shape == (batch_size + 1,), (
|
| 179 |
+
"cu_seqlens_k must have shape (batch_size + 1,)"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if cu_seqlens_q is not None:
|
| 183 |
+
assert cu_seqlens_q.shape == (batch_size + 1,), (
|
| 184 |
+
"cu_seqlens_q must have shape (batch_size + 1,)"
|
| 185 |
+
)
|
| 186 |
+
assert seqused_q is None or seqused_q.shape == (batch_size,), (
|
| 187 |
+
"seqused_q must have shape (batch_size,)"
|
| 188 |
+
)
|
| 189 |
+
assert seqused_k is None or seqused_k.shape == (batch_size,), (
|
| 190 |
+
"seqused_k must have shape (batch_size,)"
|
| 191 |
+
)
|
| 192 |
+
assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
|
| 193 |
+
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
|
| 194 |
+
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
|
| 195 |
+
if t is not None:
|
| 196 |
+
assert t.dtype == torch.int32, (
|
| 197 |
+
"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
|
| 198 |
+
)
|
| 199 |
+
assert t.stride(0) == 1, (
|
| 200 |
+
"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
|
| 201 |
+
)
|
| 202 |
+
if learnable_sink is not None:
|
| 203 |
+
assert learnable_sink.shape == (num_head,)
|
| 204 |
+
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
| 205 |
+
|
| 206 |
+
assert all(
|
| 207 |
+
t is None or t.is_cuda
|
| 208 |
+
for t in (
|
| 209 |
+
q,
|
| 210 |
+
k,
|
| 211 |
+
v,
|
| 212 |
+
cu_seqlens_q,
|
| 213 |
+
cu_seqlens_k,
|
| 214 |
+
seqused_q,
|
| 215 |
+
seqused_k,
|
| 216 |
+
page_table,
|
| 217 |
+
learnable_sink,
|
| 218 |
+
)
|
| 219 |
+
), "inputs must be on CUDA device"
|
| 220 |
+
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
| 221 |
+
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
| 222 |
+
alignment = 16 // q.element_size()
|
| 223 |
+
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
| 224 |
+
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
| 225 |
+
if softmax_scale is None:
|
| 226 |
+
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 227 |
+
if softcap == 0.0:
|
| 228 |
+
softcap = None
|
| 229 |
+
qhead_per_kvhead = num_head // num_head_kv
|
| 230 |
+
if pack_gqa is None:
|
| 231 |
+
pack_gqa = qhead_per_kvhead > 1
|
| 232 |
+
|
| 233 |
+
out_torch_dtype = q.dtype
|
| 234 |
+
device = q.device
|
| 235 |
+
q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
|
| 236 |
+
lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q)
|
| 237 |
+
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
|
| 238 |
+
|
| 239 |
+
if out is None:
|
| 240 |
+
out = torch.empty(
|
| 241 |
+
*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
_validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)
|
| 245 |
+
|
| 246 |
+
if lse is None:
|
| 247 |
+
lse = (
|
| 248 |
+
torch.empty(lse_shape, dtype=torch.float32, device=device)
|
| 249 |
+
if requires_grad or return_lse
|
| 250 |
+
else None
|
| 251 |
+
)
|
| 252 |
+
elif lse is not None:
|
| 253 |
+
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
|
| 254 |
+
|
| 255 |
+
dtype = torch2cute_dtype_map[q.dtype]
|
| 256 |
+
arch = _get_device_arch() if _arch is None else _arch
|
| 257 |
+
|
| 258 |
+
assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
|
| 259 |
+
|
| 260 |
+
use_block_sparsity = block_sparse_tensors is not None
|
| 261 |
+
|
| 262 |
+
if mask_mod is None:
|
| 263 |
+
if causal:
|
| 264 |
+
window_size_right = 0
|
| 265 |
+
if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
|
| 266 |
+
window_size_left = None
|
| 267 |
+
window_size_right = None
|
| 268 |
+
local = window_size_left is not None or window_size_right is not None
|
| 269 |
+
if window_size_left is not None or window_size_right is not None:
|
| 270 |
+
if window_size_left is None and window_size_right == 0:
|
| 271 |
+
causal, local = True, False
|
| 272 |
+
window_size_right = None
|
| 273 |
+
else:
|
| 274 |
+
causal, local = False, True
|
| 275 |
+
else:
|
| 276 |
+
causal, local = False, False
|
| 277 |
+
|
| 278 |
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 279 |
+
|
| 280 |
+
if arch // 10 == 9: # TODO: tune block size according to hdim.
|
| 281 |
+
if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
|
| 282 |
+
n_block_size = 192
|
| 283 |
+
|
| 284 |
+
if arch // 10 in [10, 11]:
|
| 285 |
+
if (
|
| 286 |
+
pack_gqa
|
| 287 |
+
and (128 % qhead_per_kvhead != 0)
|
| 288 |
+
):
|
| 289 |
+
pack_gqa = False
|
| 290 |
+
# TODO: fix GQA + SplitKV + non-varlen
|
| 291 |
+
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
|
| 292 |
+
pack_gqa = False
|
| 293 |
+
|
| 294 |
+
if max_seqlen_q is None:
|
| 295 |
+
max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
|
| 296 |
+
if max_seqlen_k is None:
|
| 297 |
+
max_seqlen_k = seqlen_k
|
| 298 |
+
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
|
| 299 |
+
if arch // 10 == 10:
|
| 300 |
+
q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
|
| 301 |
+
else:
|
| 302 |
+
q_stage = 1
|
| 303 |
+
|
| 304 |
+
if num_splits < 1:
|
| 305 |
+
m_block_size_effective = q_stage * m_block_size
|
| 306 |
+
seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
|
| 307 |
+
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
|
| 308 |
+
num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
|
| 309 |
+
total_mblocks = batch_size * num_head_kv * num_m_blocks
|
| 310 |
+
num_splits = num_splits_heuristic(
|
| 311 |
+
total_mblocks,
|
| 312 |
+
torch.cuda.get_device_properties(device).multi_processor_count,
|
| 313 |
+
num_n_blocks,
|
| 314 |
+
128,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
is_split_kv = num_splits > 1
|
| 318 |
+
if is_split_kv:
|
| 319 |
+
out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
|
| 320 |
+
lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
|
| 321 |
+
|
| 322 |
+
# hash score and mask mods for compile cache
|
| 323 |
+
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
|
| 324 |
+
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
|
| 325 |
+
|
| 326 |
+
if softcap is not None:
|
| 327 |
+
assert score_mod is None, "softcap and score_mod cannot be used together"
|
| 328 |
+
score_mod = utils.create_softcap_scoremod(softcap)
|
| 329 |
+
|
| 330 |
+
is_varlen = (
|
| 331 |
+
cu_seqlens_q is not None
|
| 332 |
+
or cu_seqlens_k is not None
|
| 333 |
+
or seqused_q is not None
|
| 334 |
+
or seqused_k is not None
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if mask_mod is not None:
|
| 338 |
+
if is_varlen:
|
| 339 |
+
raise NotImplementedError(
|
| 340 |
+
"mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if use_block_sparsity:
|
| 344 |
+
if is_varlen:
|
| 345 |
+
raise NotImplementedError(
|
| 346 |
+
"Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
|
| 347 |
+
)
|
| 348 |
+
# NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
|
| 349 |
+
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
|
| 350 |
+
pack_gqa = False
|
| 351 |
+
if is_split_kv:
|
| 352 |
+
raise NotImplementedError(
|
| 353 |
+
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# See get_broadcast_dims for why this is needed in compile key
|
| 357 |
+
block_sparse_broadcast_pattern = None
|
| 358 |
+
normalized_block_sparse_tensors = None
|
| 359 |
+
q_subtile_factor = None
|
| 360 |
+
if block_sparse_tensors is not None:
|
| 361 |
+
if seqlen_q is None:
|
| 362 |
+
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
|
| 363 |
+
(
|
| 364 |
+
normalized_block_sparse_tensors,
|
| 365 |
+
block_sparse_broadcast_pattern,
|
| 366 |
+
q_subtile_factor,
|
| 367 |
+
) = normalize_block_sparse_config(
|
| 368 |
+
block_sparse_tensors,
|
| 369 |
+
batch_size=batch_size,
|
| 370 |
+
num_head=num_head,
|
| 371 |
+
seqlen_q=seqlen_q,
|
| 372 |
+
seqlen_k=seqlen_k,
|
| 373 |
+
block_size=(m_block_size, n_block_size),
|
| 374 |
+
q_stage=q_stage,
|
| 375 |
+
)
|
| 376 |
+
if aux_tensors is not None:
|
| 377 |
+
aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
|
| 378 |
+
else:
|
| 379 |
+
aux_tensor_metadata = None
|
| 380 |
+
|
| 381 |
+
compile_key = (
|
| 382 |
+
dtype,
|
| 383 |
+
head_dim,
|
| 384 |
+
head_dim_v,
|
| 385 |
+
qhead_per_kvhead,
|
| 386 |
+
causal,
|
| 387 |
+
score_mod_hash,
|
| 388 |
+
mask_mod_hash,
|
| 389 |
+
use_block_sparsity,
|
| 390 |
+
block_sparse_broadcast_pattern,
|
| 391 |
+
aux_tensor_metadata,
|
| 392 |
+
lse is None,
|
| 393 |
+
cu_seqlens_q is None,
|
| 394 |
+
cu_seqlens_k is None,
|
| 395 |
+
seqused_q is None,
|
| 396 |
+
seqused_k is None,
|
| 397 |
+
page_table is not None,
|
| 398 |
+
window_size_left is not None,
|
| 399 |
+
window_size_right is not None,
|
| 400 |
+
learnable_sink is not None,
|
| 401 |
+
m_block_size,
|
| 402 |
+
n_block_size,
|
| 403 |
+
q_stage,
|
| 404 |
+
num_threads,
|
| 405 |
+
is_split_kv,
|
| 406 |
+
pack_gqa,
|
| 407 |
+
arch,
|
| 408 |
+
page_size not in [None, 128], # paged KV non-TMA
|
| 409 |
+
q_subtile_factor,
|
| 410 |
+
)
|
| 411 |
+
if compile_key not in _flash_attn_fwd.compile_cache:
|
| 412 |
+
(
|
| 413 |
+
cu_seqlens_q_tensor,
|
| 414 |
+
cu_seqlens_k_tensor,
|
| 415 |
+
seqused_q_tensor,
|
| 416 |
+
seqused_k_tensor,
|
| 417 |
+
learnable_sink_tensor,
|
| 418 |
+
) = [
|
| 419 |
+
to_cute_tensor(t, assumed_align=4, leading_dim=0)
|
| 420 |
+
if t is not None
|
| 421 |
+
else None
|
| 422 |
+
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
|
| 423 |
+
]
|
| 424 |
+
page_table_tensor = (
|
| 425 |
+
to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
|
| 426 |
+
if page_table is not None
|
| 427 |
+
else None
|
| 428 |
+
)
|
| 429 |
+
q_tensor, k_tensor, v_tensor, o_tensor = [
|
| 430 |
+
to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)
|
| 431 |
+
]
|
| 432 |
+
if is_split_kv:
|
| 433 |
+
lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
|
| 434 |
+
elif lse is not None:
|
| 435 |
+
lse_tensor = to_cute_tensor(lse, assumed_align=4)
|
| 436 |
+
else:
|
| 437 |
+
lse_tensor = None
|
| 438 |
+
|
| 439 |
+
sparse_tensors = None
|
| 440 |
+
if normalized_block_sparse_tensors is not None:
|
| 441 |
+
sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)
|
| 442 |
+
|
| 443 |
+
cute_aux_tensors = None
|
| 444 |
+
aux_tensor_metadata = None
|
| 445 |
+
if aux_tensors is not None:
|
| 446 |
+
cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
|
| 447 |
+
|
| 448 |
+
if arch // 10 == 9:
|
| 449 |
+
assert page_table is None, "paged KV not supported on SM 9.0"
|
| 450 |
+
assert not is_split_kv, "SplitKV not supported on SM 9.0"
|
| 451 |
+
# fa_fwd = FlashAttentionForwardSm80(
|
| 452 |
+
fa_fwd = FlashAttentionForwardSm90(
|
| 453 |
+
dtype,
|
| 454 |
+
head_dim,
|
| 455 |
+
head_dim_v,
|
| 456 |
+
qhead_per_kvhead,
|
| 457 |
+
is_causal=causal,
|
| 458 |
+
is_local=local,
|
| 459 |
+
pack_gqa=pack_gqa,
|
| 460 |
+
tile_m=m_block_size,
|
| 461 |
+
tile_n=n_block_size,
|
| 462 |
+
# num_stages=1,
|
| 463 |
+
num_stages=2,
|
| 464 |
+
num_threads=num_threads,
|
| 465 |
+
Q_in_regs=False,
|
| 466 |
+
intra_wg_overlap=True,
|
| 467 |
+
mma_pv_is_rs=True,
|
| 468 |
+
mask_mod=mask_mod,
|
| 469 |
+
score_mod=score_mod,
|
| 470 |
+
has_aux_tensors=aux_tensors is not None,
|
| 471 |
+
q_subtile_factor=q_subtile_factor,
|
| 472 |
+
)
|
| 473 |
+
elif arch // 10 in [10, 11]:
|
| 474 |
+
head_dim_padded = int(math.ceil(head_dim / 16) * 16)
|
| 475 |
+
head_dim_v_padded = int(math.ceil(head_dim / 16) * 16)
|
| 476 |
+
use_2cta_instrs = (
|
| 477 |
+
not causal
|
| 478 |
+
and not local
|
| 479 |
+
and not is_split_kv
|
| 480 |
+
and cu_seqlens_q is None
|
| 481 |
+
and seqused_q is None
|
| 482 |
+
and not use_block_sparsity
|
| 483 |
+
and page_size in [None, 128]
|
| 484 |
+
and head_dim_padded == 128
|
| 485 |
+
and head_dim_v_padded == 128
|
| 486 |
+
)
|
| 487 |
+
fa_fwd = FlashAttentionForwardSm100(
|
| 488 |
+
head_dim,
|
| 489 |
+
head_dim_v,
|
| 490 |
+
qhead_per_kvhead=qhead_per_kvhead,
|
| 491 |
+
is_causal=causal,
|
| 492 |
+
is_local=local,
|
| 493 |
+
is_split_kv=is_split_kv,
|
| 494 |
+
pack_gqa=pack_gqa,
|
| 495 |
+
m_block_size=m_block_size,
|
| 496 |
+
n_block_size=n_block_size,
|
| 497 |
+
q_stage=q_stage,
|
| 498 |
+
is_persistent=not causal
|
| 499 |
+
and not local
|
| 500 |
+
and cu_seqlens_q is None
|
| 501 |
+
and seqused_q is None
|
| 502 |
+
and not is_split_kv,
|
| 503 |
+
score_mod=score_mod,
|
| 504 |
+
mask_mod=mask_mod,
|
| 505 |
+
has_aux_tensors=aux_tensors is not None,
|
| 506 |
+
paged_kv_non_tma=page_size not in [None, 128],
|
| 507 |
+
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
|
| 508 |
+
q_subtile_factor=q_subtile_factor,
|
| 509 |
+
use_2cta_instrs=use_2cta_instrs,
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x"
|
| 514 |
+
)
|
| 515 |
+
# TODO: check @can_implement
|
| 516 |
+
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
| 517 |
+
fa_fwd,
|
| 518 |
+
q_tensor,
|
| 519 |
+
k_tensor,
|
| 520 |
+
v_tensor,
|
| 521 |
+
o_tensor,
|
| 522 |
+
lse_tensor,
|
| 523 |
+
softmax_scale,
|
| 524 |
+
current_stream,
|
| 525 |
+
cu_seqlens_q_tensor,
|
| 526 |
+
cu_seqlens_k_tensor,
|
| 527 |
+
seqused_q_tensor,
|
| 528 |
+
seqused_k_tensor,
|
| 529 |
+
page_table_tensor,
|
| 530 |
+
window_size_left,
|
| 531 |
+
window_size_right,
|
| 532 |
+
learnable_sink_tensor,
|
| 533 |
+
sparse_tensors,
|
| 534 |
+
cute_aux_tensors,
|
| 535 |
+
options="--enable-tvm-ffi",
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# In "fake mode", we will take torch fake tensors as input and the expected behaviors are:
|
| 539 |
+
# - Use those fake metadata to populate compilation cache
|
| 540 |
+
# - Return "fake" output tensors, which could be needed in follow-up fake operations
|
| 541 |
+
# Thus, we skip the actual kernel invocation here.
|
| 542 |
+
if not is_fake_mode():
|
| 543 |
+
_flash_attn_fwd.compile_cache[compile_key](
|
| 544 |
+
q.detach(),
|
| 545 |
+
k.detach(),
|
| 546 |
+
v.detach(),
|
| 547 |
+
out.detach() if not is_split_kv else out_partial,
|
| 548 |
+
lse_partial if is_split_kv else lse,
|
| 549 |
+
softmax_scale,
|
| 550 |
+
current_stream,
|
| 551 |
+
cu_seqlens_q,
|
| 552 |
+
cu_seqlens_k,
|
| 553 |
+
seqused_q,
|
| 554 |
+
seqused_k,
|
| 555 |
+
page_table,
|
| 556 |
+
window_size_left,
|
| 557 |
+
window_size_right,
|
| 558 |
+
learnable_sink,
|
| 559 |
+
normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
|
| 560 |
+
aux_tensors,
|
| 561 |
+
)
|
| 562 |
+
if is_split_kv:
|
| 563 |
+
_flash_attn_fwd_combine(
|
| 564 |
+
out_partial,
|
| 565 |
+
lse_partial.transpose(-1, -2),
|
| 566 |
+
out,
|
| 567 |
+
lse.transpose(-1, -2) if lse is not None else None,
|
| 568 |
+
cu_seqlens_q,
|
| 569 |
+
seqused_q,
|
| 570 |
+
)
|
| 571 |
+
return out, lse
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
_flash_attn_fwd.compile_cache = get_jit_cache("fwd")
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def _flash_attn_bwd(
|
| 578 |
+
q: torch.Tensor,
|
| 579 |
+
k: torch.Tensor,
|
| 580 |
+
v: torch.Tensor,
|
| 581 |
+
out: torch.Tensor,
|
| 582 |
+
dout: torch.Tensor,
|
| 583 |
+
lse: torch.Tensor,
|
| 584 |
+
softmax_scale: Optional[float] = None,
|
| 585 |
+
causal: bool = False,
|
| 586 |
+
softcap: float = 0.0,
|
| 587 |
+
window_size_left: Optional[int] = None,
|
| 588 |
+
window_size_right: Optional[int] = None,
|
| 589 |
+
m_block_size: int = 64,
|
| 590 |
+
n_block_size: int = 128,
|
| 591 |
+
num_threads: int = 256,
|
| 592 |
+
pack_gqa: bool = False,
|
| 593 |
+
num_stages_Q: int = 2,
|
| 594 |
+
num_stages_dO: int = 2,
|
| 595 |
+
SdP_swapAB: bool = False,
|
| 596 |
+
dKV_swapAB: bool = False,
|
| 597 |
+
dQ_swapAB: bool = False,
|
| 598 |
+
AtomLayoutMSdP: int = 2,
|
| 599 |
+
AtomLayoutNdKV: int = 2,
|
| 600 |
+
AtomLayoutMdQ: int = 2,
|
| 601 |
+
V_in_regs: bool = False,
|
| 602 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 603 |
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
| 604 |
+
seqused_q: Optional[torch.Tensor] = None,
|
| 605 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 606 |
+
max_seqlen_q: Optional[int] = None,
|
| 607 |
+
max_seqlen_k: Optional[int] = None,
|
| 608 |
+
deterministic: bool = False,
|
| 609 |
+
dq: Optional[torch.Tensor] = None,
|
| 610 |
+
dk: Optional[torch.Tensor] = None,
|
| 611 |
+
dv: Optional[torch.Tensor] = None,
|
| 612 |
+
score_mod: Optional[Callable] = None,
|
| 613 |
+
score_mod_bwd: Optional[Callable] = None,
|
| 614 |
+
mask_mod: Optional[Callable] = None,
|
| 615 |
+
aux_tensors: Optional[list[torch.Tensor]] = None,
|
| 616 |
+
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
| 617 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 618 |
+
arch = _get_device_arch()
|
| 619 |
+
assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
|
| 620 |
+
|
| 621 |
+
num_head, head_dim = q.shape[-2:]
|
| 622 |
+
|
| 623 |
+
if causal:
|
| 624 |
+
window_size_right = 0
|
| 625 |
+
if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
|
| 626 |
+
window_size_left = None
|
| 627 |
+
window_size_right = None
|
| 628 |
+
local = window_size_left is not None or window_size_right is not None
|
| 629 |
+
if local:
|
| 630 |
+
if window_size_left is None and window_size_right == 0:
|
| 631 |
+
causal, local = True, False
|
| 632 |
+
window_size_right = None
|
| 633 |
+
else:
|
| 634 |
+
causal, local = False, True
|
| 635 |
+
|
| 636 |
+
if arch // 10 == 9:
|
| 637 |
+
m_block_size = 80 if not causal else 64
|
| 638 |
+
n_block_size = 128
|
| 639 |
+
num_stages_Q = 2
|
| 640 |
+
num_stages_dO = 2
|
| 641 |
+
num_stages_PdS = 2
|
| 642 |
+
SdP_swapAB = True
|
| 643 |
+
dKV_swapAB = False
|
| 644 |
+
dQ_swapAB = not causal
|
| 645 |
+
AtomLayoutMSdP = 1
|
| 646 |
+
AtomLayoutNdKV = 2
|
| 647 |
+
AtomLayoutMdQ = 1
|
| 648 |
+
cluster_size = 1
|
| 649 |
+
use_2cta_instrs = False
|
| 650 |
+
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
|
| 651 |
+
is_varlen = (
|
| 652 |
+
cu_seqlens_q is not None
|
| 653 |
+
or cu_seqlens_k is not None
|
| 654 |
+
or seqused_q is not None
|
| 655 |
+
or seqused_k is not None
|
| 656 |
+
)
|
| 657 |
+
assert not is_varlen, "varlen backward is not yet supported on sm90"
|
| 658 |
+
else:
|
| 659 |
+
m_block_size = 128
|
| 660 |
+
n_block_size = 128
|
| 661 |
+
dQ_swapAB = False
|
| 662 |
+
dKV_swapAB = False
|
| 663 |
+
AtomLayoutMdQ = 1
|
| 664 |
+
AtomLayoutNdKV = 1
|
| 665 |
+
disable_2cta = (
|
| 666 |
+
local
|
| 667 |
+
or score_mod is not None
|
| 668 |
+
or score_mod_bwd is not None
|
| 669 |
+
or mask_mod is not None
|
| 670 |
+
)
|
| 671 |
+
cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
|
| 672 |
+
use_2cta_instrs = cluster_size==2
|
| 673 |
+
|
| 674 |
+
q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
|
| 675 |
+
maybe_contiguous(t)
|
| 676 |
+
for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
| 677 |
+
]
|
| 678 |
+
if cu_seqlens_q is None:
|
| 679 |
+
batch_size, seqlen_q = q.shape[:2]
|
| 680 |
+
total_q = batch_size * seqlen_q
|
| 681 |
+
else:
|
| 682 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
| 683 |
+
total_q = q.shape[0]
|
| 684 |
+
seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
|
| 685 |
+
|
| 686 |
+
if cu_seqlens_k is None:
|
| 687 |
+
batch_size, seqlen_k = k.shape[:2]
|
| 688 |
+
total_k = batch_size * seqlen_k
|
| 689 |
+
else:
|
| 690 |
+
batch_size = cu_seqlens_k.shape[0] - 1
|
| 691 |
+
total_k = k.shape[0]
|
| 692 |
+
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
|
| 693 |
+
|
| 694 |
+
num_head_kv = k.shape[-2]
|
| 695 |
+
head_dim_v = v.shape[-1]
|
| 696 |
+
|
| 697 |
+
use_block_sparsity = block_sparse_tensors is not None
|
| 698 |
+
|
| 699 |
+
# SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
|
| 700 |
+
# the base block_m of 128 from forward, and block-sparse size for subtiling.
|
| 701 |
+
if arch // 10 == 9 and use_block_sparsity:
|
| 702 |
+
m_block_size = 64
|
| 703 |
+
# dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
|
| 704 |
+
dQ_swapAB = False
|
| 705 |
+
|
| 706 |
+
# NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
|
| 707 |
+
subtile_factor = 2
|
| 708 |
+
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
| 709 |
+
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
| 710 |
+
num_n_blocks = seqlen_k_rounded // n_block_size
|
| 711 |
+
if cluster_size == 2 and num_n_blocks % cluster_size != 0:
|
| 712 |
+
seqlen_k_rounded = seqlen_k_rounded + n_block_size
|
| 713 |
+
|
| 714 |
+
if cu_seqlens_k is None:
|
| 715 |
+
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
| 716 |
+
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
| 717 |
+
else:
|
| 718 |
+
assert k.shape == (total_k, num_head_kv, head_dim)
|
| 719 |
+
assert v.shape == (total_k, num_head_kv, head_dim_v)
|
| 720 |
+
assert cu_seqlens_k.shape == (batch_size + 1,), (
|
| 721 |
+
"cu_seqlens_k must have shape (batch_size + 1,)"
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
if cu_seqlens_q is not None:
|
| 725 |
+
assert cu_seqlens_q.shape == (batch_size + 1,), (
|
| 726 |
+
"cu_seqlens_q must have shape (batch_size + 1,)"
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
assert out.shape == (total_q, num_head, head_dim_v)
|
| 730 |
+
assert dout.shape == (total_q, num_head, head_dim_v)
|
| 731 |
+
assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)"
|
| 732 |
+
else:
|
| 733 |
+
assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v)
|
| 734 |
+
assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v)
|
| 735 |
+
assert lse.shape == (batch_size, num_head, seqlen_q), (
|
| 736 |
+
"lse must have shape (batch_size, num_head, seqlen_q)"
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
|
| 740 |
+
assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, (
|
| 741 |
+
"inputs must have the same dtype"
|
| 742 |
+
)
|
| 743 |
+
for t in [cu_seqlens_q, cu_seqlens_k]:
|
| 744 |
+
if t is not None:
|
| 745 |
+
assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
|
| 746 |
+
assert lse.dtype == torch.float32, "lse must be float32"
|
| 747 |
+
assert all(
|
| 748 |
+
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
|
| 749 |
+
), "inputs must be on CUDA device"
|
| 750 |
+
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
| 751 |
+
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
| 752 |
+
alignment = 16 // q.element_size()
|
| 753 |
+
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
| 754 |
+
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
| 755 |
+
if softmax_scale is None:
|
| 756 |
+
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 757 |
+
qhead_per_kvhead = num_head // num_head_kv
|
| 758 |
+
if pack_gqa is None:
|
| 759 |
+
pack_gqa = qhead_per_kvhead > 1
|
| 760 |
+
# pack_gqa backward not yet supported in bwd
|
| 761 |
+
pack_gqa = False
|
| 762 |
+
if arch // 10 not in [10, 11]:
|
| 763 |
+
assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
|
| 764 |
+
|
| 765 |
+
if score_mod is not None:
|
| 766 |
+
assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
|
| 767 |
+
assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
|
| 768 |
+
assert cu_seqlens_q is None and cu_seqlens_k is None, (
|
| 769 |
+
"varlen + score_mod not supported in bwd yet"
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
device = q.device
|
| 773 |
+
out_torch_dtype = q.dtype
|
| 774 |
+
|
| 775 |
+
if dq is None:
|
| 776 |
+
dq = torch.empty_like(q)
|
| 777 |
+
else:
|
| 778 |
+
_validate_tensor(dq, "dq", q.shape, out_torch_dtype, device)
|
| 779 |
+
|
| 780 |
+
if dk is None:
|
| 781 |
+
dk = torch.empty_like(k)
|
| 782 |
+
else:
|
| 783 |
+
_validate_tensor(dk, "dk", k.shape, out_torch_dtype, device)
|
| 784 |
+
|
| 785 |
+
if dv is None:
|
| 786 |
+
dv = torch.empty_like(v)
|
| 787 |
+
else:
|
| 788 |
+
_validate_tensor(dv, "dv", v.shape, out_torch_dtype, device)
|
| 789 |
+
|
| 790 |
+
head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
|
| 791 |
+
|
| 792 |
+
if cu_seqlens_q is None:
|
| 793 |
+
dq_accum = torch.empty(
|
| 794 |
+
batch_size,
|
| 795 |
+
num_head,
|
| 796 |
+
seqlen_q_rounded * head_dim_rounded,
|
| 797 |
+
dtype=torch.float32,
|
| 798 |
+
device=device,
|
| 799 |
+
)
|
| 800 |
+
dpsum = torch.empty(
|
| 801 |
+
batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
|
| 802 |
+
)
|
| 803 |
+
lse_log2 = torch.empty(
|
| 804 |
+
batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
|
| 805 |
+
)
|
| 806 |
+
else:
|
| 807 |
+
total_q_rounded_padded = (
|
| 808 |
+
(total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size
|
| 809 |
+
)
|
| 810 |
+
dq_accum = torch.empty(
|
| 811 |
+
num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device
|
| 812 |
+
)
|
| 813 |
+
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 814 |
+
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
| 815 |
+
|
| 816 |
+
dKV_postprocess = qhead_per_kvhead > 1
|
| 817 |
+
if dKV_postprocess:
|
| 818 |
+
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
|
| 819 |
+
if cu_seqlens_k is None:
|
| 820 |
+
dk_accum = torch.zeros(
|
| 821 |
+
batch_size,
|
| 822 |
+
num_head_kv,
|
| 823 |
+
seqlen_k_rounded * head_dim_rounded,
|
| 824 |
+
dtype=torch.float32,
|
| 825 |
+
device=device,
|
| 826 |
+
)
|
| 827 |
+
dv_accum = torch.zeros(
|
| 828 |
+
batch_size,
|
| 829 |
+
num_head_kv,
|
| 830 |
+
seqlen_k_rounded * head_dim_v_rounded,
|
| 831 |
+
dtype=torch.float32,
|
| 832 |
+
device=device,
|
| 833 |
+
)
|
| 834 |
+
else:
|
| 835 |
+
cluster_tile_n = cluster_size * n_block_size
|
| 836 |
+
total_k_rounded_padded = (
|
| 837 |
+
(total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n
|
| 838 |
+
)
|
| 839 |
+
dk_accum = torch.zeros(
|
| 840 |
+
num_head_kv,
|
| 841 |
+
total_k_rounded_padded * head_dim_rounded,
|
| 842 |
+
dtype=torch.float32,
|
| 843 |
+
device=device,
|
| 844 |
+
)
|
| 845 |
+
dv_accum = torch.zeros(
|
| 846 |
+
num_head_kv,
|
| 847 |
+
total_k_rounded_padded * head_dim_v_rounded,
|
| 848 |
+
dtype=torch.float32,
|
| 849 |
+
device=device,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
dtype = torch2cute_dtype_map[q.dtype]
|
| 853 |
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 854 |
+
|
| 855 |
+
if deterministic:
|
| 856 |
+
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device="cuda")
|
| 857 |
+
else:
|
| 858 |
+
dQ_semaphore = None
|
| 859 |
+
|
| 860 |
+
if deterministic and qhead_per_kvhead > 1:
|
| 861 |
+
dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
|
| 862 |
+
dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
|
| 863 |
+
else:
|
| 864 |
+
dK_semaphore = None
|
| 865 |
+
dV_semaphore = None
|
| 866 |
+
|
| 867 |
+
# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
|
| 868 |
+
compile_key_pre = (
|
| 869 |
+
arch,
|
| 870 |
+
dtype,
|
| 871 |
+
head_dim,
|
| 872 |
+
head_dim_v,
|
| 873 |
+
m_block_size,
|
| 874 |
+
num_threads,
|
| 875 |
+
cu_seqlens_q is None,
|
| 876 |
+
seqused_q is None,
|
| 877 |
+
get_broadcast_dims(out),
|
| 878 |
+
get_broadcast_dims(dout),
|
| 879 |
+
)
|
| 880 |
+
if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
|
| 881 |
+
o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
|
| 882 |
+
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
| 883 |
+
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
|
| 884 |
+
]
|
| 885 |
+
lse_tensor = to_cute_tensor(lse, assumed_align=4)
|
| 886 |
+
cu_seqlens_q_tensor, seqused_q_tensor = [
|
| 887 |
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 888 |
+
for t in (cu_seqlens_q, seqused_q)
|
| 889 |
+
]
|
| 890 |
+
fa_bwd_pre = FlashAttentionBackwardPreprocess(
|
| 891 |
+
dtype,
|
| 892 |
+
head_dim,
|
| 893 |
+
head_dim_v,
|
| 894 |
+
arch,
|
| 895 |
+
m_block_size,
|
| 896 |
+
num_threads=num_threads,
|
| 897 |
+
)
|
| 898 |
+
# TODO: check @can_implement
|
| 899 |
+
_flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
|
| 900 |
+
fa_bwd_pre,
|
| 901 |
+
o_tensor,
|
| 902 |
+
do_tensor,
|
| 903 |
+
dpsum_tensor,
|
| 904 |
+
lse_tensor,
|
| 905 |
+
lse_log2_tensor,
|
| 906 |
+
dq_accum_tensor,
|
| 907 |
+
cu_seqlens_q_tensor,
|
| 908 |
+
seqused_q_tensor,
|
| 909 |
+
current_stream,
|
| 910 |
+
options="--enable-tvm-ffi",
|
| 911 |
+
)
|
| 912 |
+
if not is_fake_mode():
|
| 913 |
+
_flash_attn_bwd.compile_cache_pre[compile_key_pre](
|
| 914 |
+
out,
|
| 915 |
+
dout,
|
| 916 |
+
dpsum,
|
| 917 |
+
lse,
|
| 918 |
+
lse_log2,
|
| 919 |
+
dq_accum,
|
| 920 |
+
cu_seqlens_q,
|
| 921 |
+
seqused_q,
|
| 922 |
+
current_stream,
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
# NB num_threads application for 3 kernels
|
| 926 |
+
# There are pre, main, post processing kernels, currenlty num_threads is only actually
|
| 927 |
+
# used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
|
| 928 |
+
# before cache key gen
|
| 929 |
+
num_threads = 384
|
| 930 |
+
|
| 931 |
+
# Backward kernel: compute dk, dv, dq_accum.
|
| 932 |
+
score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
|
| 933 |
+
score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False
|
| 934 |
+
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False
|
| 935 |
+
num_aux_tensors = len(aux_tensors) if aux_tensors else 0
|
| 936 |
+
cute_aux_tensors = None
|
| 937 |
+
if aux_tensors is not None:
|
| 938 |
+
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
|
| 939 |
+
|
| 940 |
+
block_sparse_broadcast_pattern = None
|
| 941 |
+
normalized_block_sparse_tensors = None
|
| 942 |
+
if block_sparse_tensors is not None:
|
| 943 |
+
(
|
| 944 |
+
normalized_block_sparse_tensors,
|
| 945 |
+
block_sparse_broadcast_pattern,
|
| 946 |
+
) = normalize_block_sparse_config_bwd(
|
| 947 |
+
block_sparse_tensors,
|
| 948 |
+
batch_size=batch_size,
|
| 949 |
+
num_head=num_head,
|
| 950 |
+
seqlen_q=seqlen_q,
|
| 951 |
+
seqlen_k=seqlen_k,
|
| 952 |
+
block_size=(m_block_size, n_block_size),
|
| 953 |
+
subtile_factor=subtile_factor,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
if arch // 10 == 9:
|
| 957 |
+
compile_key = (
|
| 958 |
+
arch,
|
| 959 |
+
dtype,
|
| 960 |
+
head_dim,
|
| 961 |
+
head_dim_v,
|
| 962 |
+
qhead_per_kvhead,
|
| 963 |
+
causal,
|
| 964 |
+
softcap != 0.0,
|
| 965 |
+
m_block_size,
|
| 966 |
+
n_block_size,
|
| 967 |
+
num_threads,
|
| 968 |
+
pack_gqa,
|
| 969 |
+
num_stages_Q,
|
| 970 |
+
num_stages_dO,
|
| 971 |
+
SdP_swapAB,
|
| 972 |
+
dKV_swapAB,
|
| 973 |
+
dQ_swapAB,
|
| 974 |
+
AtomLayoutMSdP,
|
| 975 |
+
AtomLayoutNdKV,
|
| 976 |
+
AtomLayoutMdQ,
|
| 977 |
+
V_in_regs,
|
| 978 |
+
cu_seqlens_q is None,
|
| 979 |
+
cu_seqlens_k is None,
|
| 980 |
+
seqused_q is None,
|
| 981 |
+
seqused_k is None,
|
| 982 |
+
score_mod_hash,
|
| 983 |
+
score_mod_bwd_hash,
|
| 984 |
+
mask_mod_hash,
|
| 985 |
+
num_aux_tensors,
|
| 986 |
+
use_block_sparsity,
|
| 987 |
+
block_sparse_broadcast_pattern,
|
| 988 |
+
get_broadcast_dims(q),
|
| 989 |
+
get_broadcast_dims(k),
|
| 990 |
+
get_broadcast_dims(v),
|
| 991 |
+
get_broadcast_dims(dout),
|
| 992 |
+
)
|
| 993 |
+
else:
|
| 994 |
+
compile_key = (
|
| 995 |
+
arch,
|
| 996 |
+
dtype,
|
| 997 |
+
head_dim,
|
| 998 |
+
head_dim_v,
|
| 999 |
+
qhead_per_kvhead,
|
| 1000 |
+
causal,
|
| 1001 |
+
window_size_left is not None,
|
| 1002 |
+
window_size_right is not None,
|
| 1003 |
+
softcap != 0.0,
|
| 1004 |
+
m_block_size,
|
| 1005 |
+
n_block_size,
|
| 1006 |
+
num_threads,
|
| 1007 |
+
pack_gqa,
|
| 1008 |
+
cluster_size,
|
| 1009 |
+
use_2cta_instrs,
|
| 1010 |
+
deterministic,
|
| 1011 |
+
score_mod_hash,
|
| 1012 |
+
score_mod_bwd_hash,
|
| 1013 |
+
mask_mod_hash,
|
| 1014 |
+
num_aux_tensors,
|
| 1015 |
+
use_block_sparsity,
|
| 1016 |
+
block_sparse_broadcast_pattern,
|
| 1017 |
+
cu_seqlens_q is None,
|
| 1018 |
+
cu_seqlens_k is None,
|
| 1019 |
+
seqused_q is None,
|
| 1020 |
+
seqused_k is None,
|
| 1021 |
+
get_broadcast_dims(q),
|
| 1022 |
+
get_broadcast_dims(k),
|
| 1023 |
+
get_broadcast_dims(v),
|
| 1024 |
+
get_broadcast_dims(dout),
|
| 1025 |
+
)
|
| 1026 |
+
if compile_key not in _flash_attn_bwd.compile_cache:
|
| 1027 |
+
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
|
| 1028 |
+
to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
|
| 1029 |
+
]
|
| 1030 |
+
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
| 1031 |
+
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
|
| 1032 |
+
]
|
| 1033 |
+
if dKV_postprocess:
|
| 1034 |
+
dk_accum_tensor, dv_accum_tensor = [
|
| 1035 |
+
to_cute_tensor(t) for t in (dk_accum, dv_accum)
|
| 1036 |
+
]
|
| 1037 |
+
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
|
| 1038 |
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1039 |
+
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
| 1040 |
+
]
|
| 1041 |
+
dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
|
| 1042 |
+
utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
|
| 1043 |
+
if t is not None else None
|
| 1044 |
+
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
|
| 1045 |
+
]
|
| 1046 |
+
fa_bwd_sm80 = FlashAttentionBackwardSm80(
|
| 1047 |
+
dtype,
|
| 1048 |
+
head_dim,
|
| 1049 |
+
head_dim_v,
|
| 1050 |
+
qhead_per_kvhead,
|
| 1051 |
+
m_block_size,
|
| 1052 |
+
n_block_size,
|
| 1053 |
+
num_stages_Q,
|
| 1054 |
+
num_stages_dO,
|
| 1055 |
+
num_threads,
|
| 1056 |
+
pack_gqa,
|
| 1057 |
+
causal,
|
| 1058 |
+
SdP_swapAB,
|
| 1059 |
+
dKV_swapAB,
|
| 1060 |
+
dQ_swapAB,
|
| 1061 |
+
AtomLayoutMSdP,
|
| 1062 |
+
AtomLayoutNdKV,
|
| 1063 |
+
AtomLayoutMdQ,
|
| 1064 |
+
V_in_regs=V_in_regs,
|
| 1065 |
+
)
|
| 1066 |
+
if arch // 10 == 9:
|
| 1067 |
+
fa_bwd_obj = FlashAttentionBackwardSm90(
|
| 1068 |
+
dtype,
|
| 1069 |
+
head_dim,
|
| 1070 |
+
head_dim_v,
|
| 1071 |
+
qhead_per_kvhead,
|
| 1072 |
+
causal,
|
| 1073 |
+
m_block_size,
|
| 1074 |
+
n_block_size,
|
| 1075 |
+
num_stages_Q,
|
| 1076 |
+
num_stages_dO,
|
| 1077 |
+
num_stages_PdS,
|
| 1078 |
+
SdP_swapAB,
|
| 1079 |
+
dKV_swapAB,
|
| 1080 |
+
dQ_swapAB,
|
| 1081 |
+
AtomLayoutMSdP,
|
| 1082 |
+
AtomLayoutNdKV,
|
| 1083 |
+
AtomLayoutMdQ,
|
| 1084 |
+
num_threads,
|
| 1085 |
+
V_in_regs=V_in_regs,
|
| 1086 |
+
score_mod=score_mod,
|
| 1087 |
+
score_mod_bwd=score_mod_bwd,
|
| 1088 |
+
mask_mod=mask_mod,
|
| 1089 |
+
has_aux_tensors=aux_tensors is not None,
|
| 1090 |
+
subtile_factor=subtile_factor,
|
| 1091 |
+
)
|
| 1092 |
+
else:
|
| 1093 |
+
fa_bwd_obj = FlashAttentionBackwardSm100(
|
| 1094 |
+
head_dim,
|
| 1095 |
+
head_dim_v,
|
| 1096 |
+
is_causal=causal,
|
| 1097 |
+
is_local=local,
|
| 1098 |
+
qhead_per_kvhead=qhead_per_kvhead,
|
| 1099 |
+
tile_m=m_block_size,
|
| 1100 |
+
tile_n=n_block_size,
|
| 1101 |
+
cluster_size=cluster_size,
|
| 1102 |
+
use_2cta_instrs=use_2cta_instrs,
|
| 1103 |
+
deterministic=deterministic,
|
| 1104 |
+
score_mod=score_mod,
|
| 1105 |
+
score_mod_bwd=score_mod_bwd,
|
| 1106 |
+
mask_mod=mask_mod,
|
| 1107 |
+
has_aux_tensors=aux_tensors is not None,
|
| 1108 |
+
subtile_factor=subtile_factor,
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
# Block sparse tensors for backward use Q-direction indexing (transposed from forward).
|
| 1112 |
+
sparse_tensors_compile = None
|
| 1113 |
+
if normalized_block_sparse_tensors is not None:
|
| 1114 |
+
sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)
|
| 1115 |
+
|
| 1116 |
+
# TODO: check @can_implement
|
| 1117 |
+
_flash_attn_bwd.compile_cache[compile_key] = cute.compile(
|
| 1118 |
+
fa_bwd_obj,
|
| 1119 |
+
q_tensor,
|
| 1120 |
+
k_tensor,
|
| 1121 |
+
v_tensor,
|
| 1122 |
+
do_tensor,
|
| 1123 |
+
lse_log2_tensor,
|
| 1124 |
+
dpsum_tensor,
|
| 1125 |
+
dq_accum_tensor,
|
| 1126 |
+
dk_tensor if not dKV_postprocess else dk_accum_tensor,
|
| 1127 |
+
dv_tensor if not dKV_postprocess else dv_accum_tensor,
|
| 1128 |
+
softmax_scale,
|
| 1129 |
+
current_stream,
|
| 1130 |
+
cu_seqlens_q_tensor,
|
| 1131 |
+
cu_seqlens_k_tensor,
|
| 1132 |
+
seqused_q_tensor,
|
| 1133 |
+
seqused_k_tensor,
|
| 1134 |
+
None, # softcap - not yet supported in backward
|
| 1135 |
+
window_size_left,
|
| 1136 |
+
window_size_right,
|
| 1137 |
+
dQ_semaphore_tensor,
|
| 1138 |
+
dK_semaphore_tensor,
|
| 1139 |
+
dV_semaphore_tensor,
|
| 1140 |
+
cute_aux_tensors,
|
| 1141 |
+
sparse_tensors_compile,
|
| 1142 |
+
options="--enable-tvm-ffi",
|
| 1143 |
+
)
|
| 1144 |
+
if not is_fake_mode():
|
| 1145 |
+
_flash_attn_bwd.compile_cache[compile_key](
|
| 1146 |
+
q.detach(),
|
| 1147 |
+
k.detach(),
|
| 1148 |
+
v.detach(),
|
| 1149 |
+
dout,
|
| 1150 |
+
lse_log2,
|
| 1151 |
+
dpsum,
|
| 1152 |
+
dq_accum,
|
| 1153 |
+
dk if not dKV_postprocess else dk_accum,
|
| 1154 |
+
dv if not dKV_postprocess else dv_accum,
|
| 1155 |
+
softmax_scale,
|
| 1156 |
+
current_stream,
|
| 1157 |
+
cu_seqlens_q,
|
| 1158 |
+
cu_seqlens_k,
|
| 1159 |
+
seqused_q,
|
| 1160 |
+
seqused_k,
|
| 1161 |
+
None, # softcap - not yet supported in backward
|
| 1162 |
+
window_size_left,
|
| 1163 |
+
window_size_right,
|
| 1164 |
+
dQ_semaphore,
|
| 1165 |
+
dK_semaphore,
|
| 1166 |
+
dV_semaphore,
|
| 1167 |
+
aux_tensors,
|
| 1168 |
+
normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
num_threads = 256 if arch // 10 == 9 else 128
|
| 1172 |
+
# Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
|
| 1173 |
+
compile_key_post = (
|
| 1174 |
+
arch,
|
| 1175 |
+
dtype,
|
| 1176 |
+
head_dim,
|
| 1177 |
+
m_block_size,
|
| 1178 |
+
num_threads,
|
| 1179 |
+
AtomLayoutMdQ,
|
| 1180 |
+
dQ_swapAB,
|
| 1181 |
+
cu_seqlens_q is None,
|
| 1182 |
+
seqused_q is None,
|
| 1183 |
+
use_2cta_instrs,
|
| 1184 |
+
1, # no cluster for tile_m
|
| 1185 |
+
get_broadcast_dims(dq_accum),
|
| 1186 |
+
get_broadcast_dims(dq),
|
| 1187 |
+
)
|
| 1188 |
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
| 1189 |
+
dq_accum_tensor = to_cute_tensor(dq_accum)
|
| 1190 |
+
dq_tensor = to_cute_tensor(dq)
|
| 1191 |
+
cu_seqlens_q_tensor, seqused_q_tensor = [
|
| 1192 |
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1193 |
+
for t in (cu_seqlens_q, seqused_q)
|
| 1194 |
+
]
|
| 1195 |
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1196 |
+
dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB,
|
| 1197 |
+
use_2cta_instrs=use_2cta_instrs,
|
| 1198 |
+
)
|
| 1199 |
+
# TODO: check @can_implement
|
| 1200 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1201 |
+
fa_bwd_post,
|
| 1202 |
+
dq_accum_tensor,
|
| 1203 |
+
dq_tensor,
|
| 1204 |
+
softmax_scale,
|
| 1205 |
+
cu_seqlens_q_tensor,
|
| 1206 |
+
seqused_q_tensor,
|
| 1207 |
+
current_stream,
|
| 1208 |
+
options="--enable-tvm-ffi",
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
if not is_fake_mode():
|
| 1212 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1213 |
+
dq_accum,
|
| 1214 |
+
dq,
|
| 1215 |
+
softmax_scale,
|
| 1216 |
+
cu_seqlens_q,
|
| 1217 |
+
seqused_q,
|
| 1218 |
+
current_stream,
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
if dKV_postprocess:
|
| 1222 |
+
# Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
|
| 1223 |
+
compile_key_post = (
|
| 1224 |
+
arch,
|
| 1225 |
+
dtype,
|
| 1226 |
+
head_dim,
|
| 1227 |
+
n_block_size,
|
| 1228 |
+
num_threads,
|
| 1229 |
+
AtomLayoutNdKV,
|
| 1230 |
+
dKV_swapAB,
|
| 1231 |
+
cu_seqlens_k is None,
|
| 1232 |
+
seqused_k is None,
|
| 1233 |
+
False, # even for 2cta, is split along hdim, so always False
|
| 1234 |
+
cluster_size, # cluster is for tile_n
|
| 1235 |
+
get_broadcast_dims(dk_accum),
|
| 1236 |
+
get_broadcast_dims(dk),
|
| 1237 |
+
)
|
| 1238 |
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
| 1239 |
+
dk_accum_tensor = to_cute_tensor(dk_accum)
|
| 1240 |
+
dk_tensor = to_cute_tensor(dk)
|
| 1241 |
+
cu_seqlens_k_tensor, seqused_k_tensor = [
|
| 1242 |
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1243 |
+
for t in (cu_seqlens_k, seqused_k)
|
| 1244 |
+
]
|
| 1245 |
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1246 |
+
dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
|
| 1247 |
+
cluster_size=cluster_size,
|
| 1248 |
+
)
|
| 1249 |
+
# TODO: check @can_implement
|
| 1250 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1251 |
+
fa_bwd_post,
|
| 1252 |
+
dk_accum_tensor,
|
| 1253 |
+
dk_tensor,
|
| 1254 |
+
softmax_scale,
|
| 1255 |
+
cu_seqlens_k_tensor,
|
| 1256 |
+
seqused_k_tensor,
|
| 1257 |
+
current_stream,
|
| 1258 |
+
options="--enable-tvm-ffi",
|
| 1259 |
+
)
|
| 1260 |
+
if not is_fake_mode():
|
| 1261 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1262 |
+
dk_accum,
|
| 1263 |
+
dk,
|
| 1264 |
+
softmax_scale,
|
| 1265 |
+
cu_seqlens_k,
|
| 1266 |
+
seqused_k,
|
| 1267 |
+
current_stream,
|
| 1268 |
+
)
|
| 1269 |
+
compile_key_post = (
|
| 1270 |
+
arch,
|
| 1271 |
+
dtype,
|
| 1272 |
+
head_dim_v,
|
| 1273 |
+
n_block_size,
|
| 1274 |
+
num_threads,
|
| 1275 |
+
AtomLayoutNdKV,
|
| 1276 |
+
dKV_swapAB,
|
| 1277 |
+
cu_seqlens_k is None,
|
| 1278 |
+
seqused_k is None,
|
| 1279 |
+
False,
|
| 1280 |
+
cluster_size,
|
| 1281 |
+
get_broadcast_dims(dv_accum),
|
| 1282 |
+
get_broadcast_dims(dv),
|
| 1283 |
+
)
|
| 1284 |
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
| 1285 |
+
dv_accum_tensor = to_cute_tensor(dv_accum)
|
| 1286 |
+
dv_tensor = to_cute_tensor(dv)
|
| 1287 |
+
cu_seqlens_k_tensor, seqused_k_tensor = [
|
| 1288 |
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
| 1289 |
+
for t in (cu_seqlens_k, seqused_k)
|
| 1290 |
+
]
|
| 1291 |
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
| 1292 |
+
dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
|
| 1293 |
+
cluster_size=cluster_size,
|
| 1294 |
+
)
|
| 1295 |
+
# TODO: check @can_implement
|
| 1296 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
| 1297 |
+
fa_bwd_post,
|
| 1298 |
+
dv_accum_tensor,
|
| 1299 |
+
dv_tensor,
|
| 1300 |
+
cutlass.Float32(1.0),
|
| 1301 |
+
cu_seqlens_k_tensor,
|
| 1302 |
+
seqused_k_tensor,
|
| 1303 |
+
current_stream,
|
| 1304 |
+
options="--enable-tvm-ffi",
|
| 1305 |
+
)
|
| 1306 |
+
if not is_fake_mode():
|
| 1307 |
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
| 1308 |
+
dv_accum,
|
| 1309 |
+
dv,
|
| 1310 |
+
1.0,
|
| 1311 |
+
cu_seqlens_k,
|
| 1312 |
+
seqused_k,
|
| 1313 |
+
current_stream,
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
return dq, dk, dv
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
_flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre")
|
| 1320 |
+
_flash_attn_bwd.compile_cache = get_jit_cache("bwd")
|
| 1321 |
+
_flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post")
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
class FlashAttnFunc(torch.autograd.Function):
|
| 1325 |
+
@staticmethod
|
| 1326 |
+
def forward(
|
| 1327 |
+
ctx,
|
| 1328 |
+
q: torch.Tensor,
|
| 1329 |
+
k: torch.Tensor,
|
| 1330 |
+
v: torch.Tensor,
|
| 1331 |
+
softmax_scale: Optional[float] = None,
|
| 1332 |
+
causal: bool = False,
|
| 1333 |
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
| 1334 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 1335 |
+
softcap: float = 0.0,
|
| 1336 |
+
num_splits: int = 1,
|
| 1337 |
+
pack_gqa: Optional[bool] = None,
|
| 1338 |
+
deterministic: bool = False,
|
| 1339 |
+
mask_mod: Optional[Callable] = None,
|
| 1340 |
+
full_block_cnt: Optional[torch.Tensor] = None,
|
| 1341 |
+
full_block_idx: Optional[torch.Tensor] = None,
|
| 1342 |
+
mask_block_cnt: Optional[torch.Tensor] = None,
|
| 1343 |
+
mask_block_idx: Optional[torch.Tensor] = None,
|
| 1344 |
+
block_size: Optional[Tuple[int, int]] = None,
|
| 1345 |
+
return_lse: bool = False,
|
| 1346 |
+
):
|
| 1347 |
+
# Only create block sparse tensors if at least one block sparse parameter is provided
|
| 1348 |
+
block_sparse_tensors = None
|
| 1349 |
+
if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]):
|
| 1350 |
+
block_sparse_tensors = BlockSparseTensorsTorch(
|
| 1351 |
+
full_block_cnt=full_block_cnt,
|
| 1352 |
+
full_block_idx=full_block_idx,
|
| 1353 |
+
mask_block_cnt=mask_block_cnt,
|
| 1354 |
+
mask_block_idx=mask_block_idx,
|
| 1355 |
+
block_size=block_size,
|
| 1356 |
+
)
|
| 1357 |
+
out, lse = _flash_attn_fwd(
|
| 1358 |
+
q,
|
| 1359 |
+
k,
|
| 1360 |
+
v,
|
| 1361 |
+
softmax_scale=softmax_scale,
|
| 1362 |
+
causal=causal,
|
| 1363 |
+
window_size_left=window_size[0],
|
| 1364 |
+
window_size_right=window_size[1],
|
| 1365 |
+
learnable_sink=learnable_sink,
|
| 1366 |
+
softcap=softcap,
|
| 1367 |
+
num_splits=num_splits,
|
| 1368 |
+
pack_gqa=pack_gqa,
|
| 1369 |
+
mask_mod=mask_mod,
|
| 1370 |
+
block_sparse_tensors=block_sparse_tensors,
|
| 1371 |
+
return_lse=return_lse,
|
| 1372 |
+
)
|
| 1373 |
+
ctx.save_for_backward(q, k, v, out, lse)
|
| 1374 |
+
ctx.softmax_scale = softmax_scale
|
| 1375 |
+
ctx.causal = causal
|
| 1376 |
+
ctx.window_size = window_size
|
| 1377 |
+
ctx.softcap = softcap
|
| 1378 |
+
ctx.deterministic = deterministic
|
| 1379 |
+
# LSE gradient is not supported yet
|
| 1380 |
+
if lse is not None:
|
| 1381 |
+
ctx.mark_non_differentiable(lse)
|
| 1382 |
+
return out, lse
|
| 1383 |
+
|
| 1384 |
+
@staticmethod
|
| 1385 |
+
def backward(ctx, dout, *args):
|
| 1386 |
+
q, k, v, out, lse = ctx.saved_tensors
|
| 1387 |
+
dq, dk, dv = _flash_attn_bwd(
|
| 1388 |
+
q,
|
| 1389 |
+
k,
|
| 1390 |
+
v,
|
| 1391 |
+
out,
|
| 1392 |
+
dout,
|
| 1393 |
+
lse,
|
| 1394 |
+
ctx.softmax_scale,
|
| 1395 |
+
ctx.causal,
|
| 1396 |
+
ctx.softcap,
|
| 1397 |
+
window_size_left=ctx.window_size[0],
|
| 1398 |
+
window_size_right=ctx.window_size[1],
|
| 1399 |
+
deterministic=ctx.deterministic,
|
| 1400 |
+
)
|
| 1401 |
+
return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
| 1405 |
+
@staticmethod
|
| 1406 |
+
def forward(
|
| 1407 |
+
ctx,
|
| 1408 |
+
q: torch.Tensor,
|
| 1409 |
+
k: torch.Tensor,
|
| 1410 |
+
v: torch.Tensor,
|
| 1411 |
+
cu_seqlens_q: Optional[torch.Tensor],
|
| 1412 |
+
cu_seqlens_k: Optional[torch.Tensor],
|
| 1413 |
+
seqused_q: Optional[torch.Tensor] = None,
|
| 1414 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 1415 |
+
max_seqlen_q: Optional[int] = None,
|
| 1416 |
+
max_seqlen_k: Optional[int] = None,
|
| 1417 |
+
page_table: Optional[torch.Tensor] = None,
|
| 1418 |
+
softmax_scale: Optional[float] = None,
|
| 1419 |
+
causal: bool = False,
|
| 1420 |
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
| 1421 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 1422 |
+
softcap: float = 0.0,
|
| 1423 |
+
num_splits: int = 1,
|
| 1424 |
+
pack_gqa: Optional[bool] = None,
|
| 1425 |
+
deterministic: bool = False,
|
| 1426 |
+
score_mod: Optional[Callable] = None,
|
| 1427 |
+
aux_tensors: Optional[list] = None,
|
| 1428 |
+
return_lse: bool = False,
|
| 1429 |
+
):
|
| 1430 |
+
out, lse = _flash_attn_fwd(
|
| 1431 |
+
q,
|
| 1432 |
+
k,
|
| 1433 |
+
v,
|
| 1434 |
+
cu_seqlens_q,
|
| 1435 |
+
cu_seqlens_k,
|
| 1436 |
+
seqused_q,
|
| 1437 |
+
seqused_k,
|
| 1438 |
+
max_seqlen_q=max_seqlen_q,
|
| 1439 |
+
max_seqlen_k=max_seqlen_k,
|
| 1440 |
+
page_table=page_table,
|
| 1441 |
+
softmax_scale=softmax_scale,
|
| 1442 |
+
causal=causal,
|
| 1443 |
+
window_size_left=window_size[0],
|
| 1444 |
+
window_size_right=window_size[1],
|
| 1445 |
+
learnable_sink=learnable_sink,
|
| 1446 |
+
softcap=softcap,
|
| 1447 |
+
num_splits=num_splits,
|
| 1448 |
+
pack_gqa=pack_gqa,
|
| 1449 |
+
score_mod=score_mod,
|
| 1450 |
+
aux_tensors=aux_tensors,
|
| 1451 |
+
return_lse=return_lse,
|
| 1452 |
+
)
|
| 1453 |
+
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
| 1454 |
+
ctx.softmax_scale = softmax_scale
|
| 1455 |
+
ctx.causal = causal
|
| 1456 |
+
ctx.window_size = window_size
|
| 1457 |
+
ctx.softcap = softcap
|
| 1458 |
+
ctx.deterministic = deterministic
|
| 1459 |
+
ctx.max_seqlen_q = max_seqlen_q
|
| 1460 |
+
ctx.max_seqlen_k = max_seqlen_k
|
| 1461 |
+
# LSE gradient is not supported yet
|
| 1462 |
+
if lse is not None:
|
| 1463 |
+
ctx.mark_non_differentiable(lse)
|
| 1464 |
+
return out, lse
|
| 1465 |
+
|
| 1466 |
+
@staticmethod
|
| 1467 |
+
def backward(ctx, dout, *args):
|
| 1468 |
+
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
| 1469 |
+
assert ctx.softcap == 0.0
|
| 1470 |
+
dq, dk, dv = _flash_attn_bwd(
|
| 1471 |
+
q,
|
| 1472 |
+
k,
|
| 1473 |
+
v,
|
| 1474 |
+
out,
|
| 1475 |
+
dout,
|
| 1476 |
+
lse,
|
| 1477 |
+
ctx.softmax_scale,
|
| 1478 |
+
ctx.causal,
|
| 1479 |
+
ctx.softcap,
|
| 1480 |
+
window_size_left=ctx.window_size[0],
|
| 1481 |
+
window_size_right=ctx.window_size[1],
|
| 1482 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 1483 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 1484 |
+
seqused_q=seqused_q,
|
| 1485 |
+
seqused_k=seqused_k,
|
| 1486 |
+
max_seqlen_q=ctx.max_seqlen_q,
|
| 1487 |
+
max_seqlen_k=ctx.max_seqlen_k,
|
| 1488 |
+
deterministic=ctx.deterministic,
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
return dq, dk, dv, *((None,) * 20)
|
| 1492 |
+
|
| 1493 |
+
|
| 1494 |
+
def flash_attn_func(
|
| 1495 |
+
q: torch.Tensor,
|
| 1496 |
+
k: torch.Tensor,
|
| 1497 |
+
v: torch.Tensor,
|
| 1498 |
+
softmax_scale: Optional[float] = None,
|
| 1499 |
+
causal: bool = False,
|
| 1500 |
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
| 1501 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 1502 |
+
softcap: float = 0.0,
|
| 1503 |
+
num_splits: int = 1,
|
| 1504 |
+
pack_gqa: Optional[bool] = None,
|
| 1505 |
+
deterministic: bool = False,
|
| 1506 |
+
mask_mod: Optional[Callable] = None,
|
| 1507 |
+
full_block_cnt: Optional[torch.Tensor] = None,
|
| 1508 |
+
full_block_idx: Optional[torch.Tensor] = None,
|
| 1509 |
+
mask_block_cnt: Optional[torch.Tensor] = None,
|
| 1510 |
+
mask_block_idx: Optional[torch.Tensor] = None,
|
| 1511 |
+
block_size: Optional[Tuple[int, int]] = None,
|
| 1512 |
+
return_lse: bool = False,
|
| 1513 |
+
):
|
| 1514 |
+
return FlashAttnFunc.apply(
|
| 1515 |
+
q,
|
| 1516 |
+
k,
|
| 1517 |
+
v,
|
| 1518 |
+
softmax_scale,
|
| 1519 |
+
causal,
|
| 1520 |
+
window_size,
|
| 1521 |
+
learnable_sink,
|
| 1522 |
+
softcap,
|
| 1523 |
+
num_splits,
|
| 1524 |
+
pack_gqa,
|
| 1525 |
+
deterministic,
|
| 1526 |
+
mask_mod,
|
| 1527 |
+
full_block_cnt,
|
| 1528 |
+
full_block_idx,
|
| 1529 |
+
mask_block_cnt,
|
| 1530 |
+
mask_block_idx,
|
| 1531 |
+
block_size,
|
| 1532 |
+
return_lse,
|
| 1533 |
+
)
|
| 1534 |
+
|
| 1535 |
+
|
| 1536 |
+
def flash_attn_varlen_func(
|
| 1537 |
+
q: torch.Tensor,
|
| 1538 |
+
k: torch.Tensor,
|
| 1539 |
+
v: torch.Tensor,
|
| 1540 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 1541 |
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
| 1542 |
+
max_seqlen_q: Optional[int] = None,
|
| 1543 |
+
max_seqlen_k: Optional[int] = None,
|
| 1544 |
+
seqused_q: Optional[torch.Tensor] = None,
|
| 1545 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 1546 |
+
page_table: Optional[torch.Tensor] = None,
|
| 1547 |
+
softmax_scale: Optional[float] = None,
|
| 1548 |
+
causal: bool = False,
|
| 1549 |
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
| 1550 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 1551 |
+
softcap: float = 0.0,
|
| 1552 |
+
num_splits: int = 1,
|
| 1553 |
+
pack_gqa: Optional[bool] = None,
|
| 1554 |
+
deterministic: bool = False,
|
| 1555 |
+
score_mod: Optional[Callable] = None,
|
| 1556 |
+
aux_tensors: Optional[list] = None,
|
| 1557 |
+
return_lse: bool = False,
|
| 1558 |
+
):
|
| 1559 |
+
return FlashAttnVarlenFunc.apply(
|
| 1560 |
+
q,
|
| 1561 |
+
k,
|
| 1562 |
+
v,
|
| 1563 |
+
cu_seqlens_q,
|
| 1564 |
+
cu_seqlens_k,
|
| 1565 |
+
seqused_q,
|
| 1566 |
+
seqused_k,
|
| 1567 |
+
max_seqlen_q,
|
| 1568 |
+
max_seqlen_k,
|
| 1569 |
+
page_table,
|
| 1570 |
+
softmax_scale,
|
| 1571 |
+
causal,
|
| 1572 |
+
window_size,
|
| 1573 |
+
learnable_sink,
|
| 1574 |
+
softcap,
|
| 1575 |
+
num_splits,
|
| 1576 |
+
pack_gqa,
|
| 1577 |
+
deterministic,
|
| 1578 |
+
score_mod,
|
| 1579 |
+
aux_tensors,
|
| 1580 |
+
return_lse,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
def _flash_attn_fwd_combine(
|
| 1585 |
+
out_partial: torch.Tensor,
|
| 1586 |
+
lse_partial: torch.Tensor,
|
| 1587 |
+
out: torch.Tensor,
|
| 1588 |
+
lse: Optional[torch.Tensor] = None,
|
| 1589 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1590 |
+
seqused: Optional[torch.Tensor] = None,
|
| 1591 |
+
num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
|
| 1592 |
+
semaphore_to_reset: Optional[torch.Tensor] = None,
|
| 1593 |
+
) -> None:
|
| 1594 |
+
"""Forward combine kernel for split attention computation.
|
| 1595 |
+
|
| 1596 |
+
Combines partial outputs and log-sum-exp values from multiple splits
|
| 1597 |
+
of attention computation into final outputs.
|
| 1598 |
+
|
| 1599 |
+
Args:
|
| 1600 |
+
out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or
|
| 1601 |
+
(num_splits, total_q, nheads, headdim) if there's cu_seqlens
|
| 1602 |
+
lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or
|
| 1603 |
+
(num_splits, total_q, nheads) if there's cu_seqlens
|
| 1604 |
+
out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens
|
| 1605 |
+
lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.
|
| 1606 |
+
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
| 1607 |
+
seqused: Used sequence lengths for each batch
|
| 1608 |
+
num_splits_dynamic_ptr: Dynamic number of splits per batch
|
| 1609 |
+
semaphore_to_reset: Semaphore for synchronization
|
| 1610 |
+
k_block_size: Block size for head dimension
|
| 1611 |
+
|
| 1612 |
+
Returns:
|
| 1613 |
+
None
|
| 1614 |
+
"""
|
| 1615 |
+
# Input validation
|
| 1616 |
+
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
| 1617 |
+
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
| 1618 |
+
assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
|
| 1619 |
+
"out_partial must be fp16, bf16, or fp32"
|
| 1620 |
+
)
|
| 1621 |
+
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
|
| 1622 |
+
assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
|
| 1623 |
+
assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
|
| 1624 |
+
assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
|
| 1625 |
+
assert lse_partial.shape == out_partial.shape[:-1]
|
| 1626 |
+
|
| 1627 |
+
# Determine if this is variable length based on dimensions
|
| 1628 |
+
is_varlen = out_partial.dim() == 4
|
| 1629 |
+
|
| 1630 |
+
# Validate output tensor shapes and types
|
| 1631 |
+
assert out.shape == out_partial.shape[1:], "out shape mismatch"
|
| 1632 |
+
if lse is not None:
|
| 1633 |
+
assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
|
| 1634 |
+
assert lse.dtype == torch.float32, "lse must be fp32"
|
| 1635 |
+
|
| 1636 |
+
# Validate optional tensors
|
| 1637 |
+
for t, name in [
|
| 1638 |
+
(cu_seqlens, "cu_seqlens"),
|
| 1639 |
+
(seqused, "seqused"),
|
| 1640 |
+
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
|
| 1641 |
+
]:
|
| 1642 |
+
if t is not None:
|
| 1643 |
+
assert t.dtype == torch.int32, f"{name} must be int32"
|
| 1644 |
+
assert t.is_cuda, f"{name} must be on CUDA device"
|
| 1645 |
+
assert t.is_contiguous(), f"{name} must be contiguous"
|
| 1646 |
+
|
| 1647 |
+
head_dim = out_partial.shape[-1]
|
| 1648 |
+
num_splits = out_partial.shape[0]
|
| 1649 |
+
assert num_splits <= 256
|
| 1650 |
+
# If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
|
| 1651 |
+
# so that kBlockM is smaller and we have more parallelism.
|
| 1652 |
+
k_block_size = 64 if head_dim <= 64 else 128
|
| 1653 |
+
# We want kBlockM to be as small as possible to maximize parallelism.
|
| 1654 |
+
# E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
| 1655 |
+
m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
|
| 1656 |
+
log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
|
| 1657 |
+
if m_block_size == 8:
|
| 1658 |
+
# If kBlockM == 8 then the minimum number of splits is 32.
|
| 1659 |
+
# TODO: we can deal w this by using 128 threads instead
|
| 1660 |
+
log_max_splits = max(log_max_splits, 5)
|
| 1661 |
+
|
| 1662 |
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 1663 |
+
|
| 1664 |
+
# Create combine kernel configuration
|
| 1665 |
+
dtype = torch2cute_dtype_map[out.dtype]
|
| 1666 |
+
dtype_partial = torch2cute_dtype_map[out_partial.dtype]
|
| 1667 |
+
|
| 1668 |
+
compile_key = (
|
| 1669 |
+
dtype,
|
| 1670 |
+
dtype_partial,
|
| 1671 |
+
head_dim,
|
| 1672 |
+
m_block_size,
|
| 1673 |
+
k_block_size,
|
| 1674 |
+
log_max_splits,
|
| 1675 |
+
cu_seqlens is not None,
|
| 1676 |
+
seqused is not None,
|
| 1677 |
+
lse is not None,
|
| 1678 |
+
)
|
| 1679 |
+
|
| 1680 |
+
if compile_key not in _flash_attn_fwd_combine.compile_cache:
|
| 1681 |
+
out_partial_tensor = to_cute_tensor(
|
| 1682 |
+
out_partial, leading_dim=4 if not is_varlen else 3
|
| 1683 |
+
)
|
| 1684 |
+
lse_partial_tensor = to_cute_tensor(
|
| 1685 |
+
lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
|
| 1686 |
+
)
|
| 1687 |
+
out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
|
| 1688 |
+
lse_tensor = (
|
| 1689 |
+
to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
|
| 1690 |
+
if lse is not None
|
| 1691 |
+
else None
|
| 1692 |
+
)
|
| 1693 |
+
|
| 1694 |
+
optional_tensors = [
|
| 1695 |
+
to_cute_tensor(t, assumed_align=4, leading_dim=0)
|
| 1696 |
+
if t is not None
|
| 1697 |
+
else None
|
| 1698 |
+
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
|
| 1699 |
+
]
|
| 1700 |
+
cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
|
| 1701 |
+
optional_tensors
|
| 1702 |
+
)
|
| 1703 |
+
fa_combine = FlashAttentionForwardCombine(
|
| 1704 |
+
dtype=dtype,
|
| 1705 |
+
dtype_partial=dtype_partial,
|
| 1706 |
+
head_dim=head_dim,
|
| 1707 |
+
m_block_size=m_block_size,
|
| 1708 |
+
k_block_size=k_block_size,
|
| 1709 |
+
log_max_splits=log_max_splits,
|
| 1710 |
+
)
|
| 1711 |
+
|
| 1712 |
+
# Check if implementation is supported
|
| 1713 |
+
if not fa_combine.can_implement(
|
| 1714 |
+
dtype,
|
| 1715 |
+
dtype_partial,
|
| 1716 |
+
head_dim,
|
| 1717 |
+
m_block_size,
|
| 1718 |
+
k_block_size,
|
| 1719 |
+
log_max_splits,
|
| 1720 |
+
num_threads=256,
|
| 1721 |
+
):
|
| 1722 |
+
raise RuntimeError(
|
| 1723 |
+
"FlashAttention combine kernel cannot be implemented with given parameters"
|
| 1724 |
+
)
|
| 1725 |
+
|
| 1726 |
+
_flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
|
| 1727 |
+
fa_combine,
|
| 1728 |
+
out_partial_tensor,
|
| 1729 |
+
lse_partial_tensor,
|
| 1730 |
+
out_tensor,
|
| 1731 |
+
lse_tensor,
|
| 1732 |
+
cu_seqlens_tensor,
|
| 1733 |
+
seqused_tensor,
|
| 1734 |
+
num_splits_dynamic_tensor,
|
| 1735 |
+
semaphore_tensor,
|
| 1736 |
+
current_stream,
|
| 1737 |
+
options="--enable-tvm-ffi",
|
| 1738 |
+
)
|
| 1739 |
+
if not is_fake_mode():
|
| 1740 |
+
_flash_attn_fwd_combine.compile_cache[compile_key](
|
| 1741 |
+
out_partial,
|
| 1742 |
+
lse_partial,
|
| 1743 |
+
out,
|
| 1744 |
+
lse,
|
| 1745 |
+
cu_seqlens,
|
| 1746 |
+
seqused,
|
| 1747 |
+
num_splits_dynamic_ptr,
|
| 1748 |
+
semaphore_to_reset,
|
| 1749 |
+
current_stream,
|
| 1750 |
+
)
|
| 1751 |
+
|
| 1752 |
+
|
| 1753 |
+
_flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine")
|
| 1754 |
+
|
| 1755 |
+
|
| 1756 |
+
def flash_attn_combine(
|
| 1757 |
+
out_partial: torch.Tensor,
|
| 1758 |
+
lse_partial: torch.Tensor,
|
| 1759 |
+
out: Optional[torch.Tensor] = None,
|
| 1760 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 1761 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 1762 |
+
seqused: Optional[torch.Tensor] = None,
|
| 1763 |
+
return_lse: bool = True,
|
| 1764 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1765 |
+
"""Flash Attention combine function for split attention computation.
|
| 1766 |
+
|
| 1767 |
+
Combines partial outputs and log-sum-exp values from multiple splits
|
| 1768 |
+
of attention computation into final outputs. This is the main user-facing
|
| 1769 |
+
interface for the combine kernel.
|
| 1770 |
+
|
| 1771 |
+
Args:
|
| 1772 |
+
out_partial: Partial outputs tensor with shape:
|
| 1773 |
+
- (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input
|
| 1774 |
+
- (num_splits, total_q, num_heads, head_size) for variable length input
|
| 1775 |
+
lse_partial: Partial LSE tensor with shape:
|
| 1776 |
+
- (num_splits, batch_size, seqlen, num_heads) for regular batched input
|
| 1777 |
+
- (num_splits, total_q, num_heads) for variable length input
|
| 1778 |
+
out: Optional output tensor. If None, will be created automatically.
|
| 1779 |
+
out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
|
| 1780 |
+
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
| 1781 |
+
seqused: Used sequence lengths for each batch
|
| 1782 |
+
return_lse: Whether to return the combined LSE tensor. Default is True.
|
| 1783 |
+
|
| 1784 |
+
Returns:
|
| 1785 |
+
Tuple of (out, lse) where:
|
| 1786 |
+
- out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size)
|
| 1787 |
+
or (total_q, num_heads, head_size) for varlen
|
| 1788 |
+
- lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads)
|
| 1789 |
+
or (total_q, num_heads) for varlen. None if return_lse=False
|
| 1790 |
+
|
| 1791 |
+
Note:
|
| 1792 |
+
This function expects the input tensors to be in the format produced by
|
| 1793 |
+
split attention computation, where the first dimension is num_splits.
|
| 1794 |
+
The permuting from user format to kernel format is now done inside the kernel.
|
| 1795 |
+
"""
|
| 1796 |
+
# Input validation
|
| 1797 |
+
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
| 1798 |
+
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
| 1799 |
+
assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
|
| 1800 |
+
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
|
| 1801 |
+
|
| 1802 |
+
# Determine if this is variable length based on dimensions
|
| 1803 |
+
is_varlen = out_partial.dim() == 4
|
| 1804 |
+
|
| 1805 |
+
if is_varlen:
|
| 1806 |
+
# Variable length: (num_splits, total_q, num_heads, head_size)
|
| 1807 |
+
num_splits, total_q, num_heads, head_size = out_partial.shape
|
| 1808 |
+
assert lse_partial.shape == (num_splits, total_q, num_heads), (
|
| 1809 |
+
"lse_partial shape mismatch for varlen"
|
| 1810 |
+
)
|
| 1811 |
+
batch_size = 1 # Treat as single batch for varlen
|
| 1812 |
+
seqlen = total_q
|
| 1813 |
+
else:
|
| 1814 |
+
# Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
|
| 1815 |
+
num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
|
| 1816 |
+
assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
|
| 1817 |
+
"lse_partial shape mismatch"
|
| 1818 |
+
)
|
| 1819 |
+
|
| 1820 |
+
# Determine output dtype
|
| 1821 |
+
if out_dtype is None:
|
| 1822 |
+
out_dtype = out_partial.dtype
|
| 1823 |
+
|
| 1824 |
+
# Create output if not provided
|
| 1825 |
+
device = out_partial.device
|
| 1826 |
+
if out is None:
|
| 1827 |
+
if is_varlen:
|
| 1828 |
+
out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device)
|
| 1829 |
+
else:
|
| 1830 |
+
out = torch.empty(
|
| 1831 |
+
batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
|
| 1832 |
+
)
|
| 1833 |
+
|
| 1834 |
+
# Create lse output only if requested
|
| 1835 |
+
if return_lse:
|
| 1836 |
+
if is_varlen:
|
| 1837 |
+
lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(
|
| 1838 |
+
0, 1
|
| 1839 |
+
)
|
| 1840 |
+
else:
|
| 1841 |
+
lse = torch.empty(
|
| 1842 |
+
batch_size, num_heads, seqlen, dtype=torch.float32, device=device
|
| 1843 |
+
).transpose(1, 2)
|
| 1844 |
+
else:
|
| 1845 |
+
lse = None
|
| 1846 |
+
|
| 1847 |
+
_flash_attn_fwd_combine(
|
| 1848 |
+
out_partial,
|
| 1849 |
+
lse_partial,
|
| 1850 |
+
out,
|
| 1851 |
+
lse,
|
| 1852 |
+
cu_seqlens,
|
| 1853 |
+
seqused,
|
| 1854 |
+
)
|
| 1855 |
+
return out, lse
|
build/torch-cuda/mask.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Callable
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Float32, Int32, const_expr
|
| 9 |
+
|
| 10 |
+
from .quack import layout_utils
|
| 11 |
+
from . import utils
|
| 12 |
+
from .seqlen_info import SeqlenInfoQK
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@cute.jit
|
| 16 |
+
def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
|
| 17 |
+
# Bit manipulation, compiles down to the R2P instruction
|
| 18 |
+
# For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using.
|
| 19 |
+
# For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ...,
|
| 20 |
+
# we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
|
| 21 |
+
if const_expr(arch == 90):
|
| 22 |
+
col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2)
|
| 23 |
+
else:
|
| 24 |
+
col_limit_transformed = col_limit
|
| 25 |
+
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
|
| 26 |
+
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
| 27 |
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 28 |
+
# Don't need to clamp to 32 since the shr.u32 instruction does that already
|
| 29 |
+
col_limit_right_s = max(col_limit_transformed - s * 24, 0)
|
| 30 |
+
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
| 31 |
+
mask = (1 << col_limit_right_s) - 1
|
| 32 |
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 33 |
+
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
| 34 |
+
in_bound = cutlass.Boolean(mask & (1 << i))
|
| 35 |
+
c = s * 24 + i
|
| 36 |
+
if const_expr(rank1):
|
| 37 |
+
X[c] = X[c] if in_bound else -Float32.inf
|
| 38 |
+
# This is the equivalent of:
|
| 39 |
+
# X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
|
| 40 |
+
else:
|
| 41 |
+
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
|
| 42 |
+
X[r, c] = X[r, c] if in_bound else -Float32.inf
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@cute.jit
|
| 46 |
+
def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None:
|
| 47 |
+
# Bit manipulation, compiles down to the R2P instruction
|
| 48 |
+
# For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127
|
| 49 |
+
# or 0, 1, ..., 15, 32, ..., 47, 64, ...
|
| 50 |
+
# We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
|
| 51 |
+
# Here we hardcode for the case of 2 warp groups.
|
| 52 |
+
num_wg = 2
|
| 53 |
+
row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min(
|
| 54 |
+
row_limit_top % (num_rep * num_wg), num_rep
|
| 55 |
+
)
|
| 56 |
+
ncol = cute.size(X.shape)
|
| 57 |
+
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
| 58 |
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 59 |
+
row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
|
| 60 |
+
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
| 61 |
+
mask = (1 << row_limit_top_s) - 1
|
| 62 |
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 63 |
+
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
| 64 |
+
out_bound = cutlass.Boolean(mask & (1 << i))
|
| 65 |
+
c = s * 24 + i
|
| 66 |
+
X[c] = -Float32.inf if out_bound else X[c]
|
| 67 |
+
# tidx = cute.arch.thread_idx()[0] % 256
|
| 68 |
+
# if tidx == 128:
|
| 69 |
+
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@cute.jit
|
| 73 |
+
def mask_r2p_dual_bound(
|
| 74 |
+
X: cute.Tensor,
|
| 75 |
+
col_limit_left: Int32, # Inclusive lower bound
|
| 76 |
+
col_limit_right: Int32, # Exclusive upper bound
|
| 77 |
+
) -> None:
|
| 78 |
+
"""
|
| 79 |
+
Dual-bound masking using two bitmasks for SM100, following mask_r2p.
|
| 80 |
+
Masks elements where: NOT (col_limit_left <= col < col_limit_right)
|
| 81 |
+
|
| 82 |
+
Uses bit manipulation to create a range mask:
|
| 83 |
+
mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
|
| 84 |
+
mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
|
| 85 |
+
mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
|
| 86 |
+
"""
|
| 87 |
+
ncol = const_expr(cute.size(X.shape))
|
| 88 |
+
|
| 89 |
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
| 90 |
+
right_s = max(col_limit_right - s * 24, 0)
|
| 91 |
+
left_s = max(col_limit_left - s * 24, 0)
|
| 92 |
+
|
| 93 |
+
# otherwise cute dsl complains about python int too large to convert into c long
|
| 94 |
+
right_s = min(right_s, 24)
|
| 95 |
+
left_s = min(left_s, 24)
|
| 96 |
+
|
| 97 |
+
# bits (right-1)..left are 1
|
| 98 |
+
mask_right = (1 << right_s) - 1
|
| 99 |
+
mask_left = (1 << left_s) - 1
|
| 100 |
+
mask_range = mask_right & ~mask_left
|
| 101 |
+
|
| 102 |
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
| 103 |
+
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
| 104 |
+
in_bound = cutlass.Boolean(mask_range & (1 << i))
|
| 105 |
+
c = s * 24 + i
|
| 106 |
+
X[c] = X[c] if in_bound else -Float32.inf
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass(frozen=True)
|
| 110 |
+
class AttentionMask:
|
| 111 |
+
tile_m: cutlass.Constexpr[int]
|
| 112 |
+
tile_n: cutlass.Constexpr[int]
|
| 113 |
+
seqlen_info: SeqlenInfoQK
|
| 114 |
+
window_size_left: Optional[Int32] = None
|
| 115 |
+
window_size_right: Optional[Int32] = None
|
| 116 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA
|
| 117 |
+
swap_AB: cutlass.Constexpr[bool] = False
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def seqlen_q(self) -> Int32:
|
| 121 |
+
return self.seqlen_info.seqlen_q
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def seqlen_k(self) -> Int32:
|
| 125 |
+
return self.seqlen_info.seqlen_k
|
| 126 |
+
|
| 127 |
+
@cute.jit
|
| 128 |
+
def apply_mask(
|
| 129 |
+
self,
|
| 130 |
+
acc_S: cute.Tensor,
|
| 131 |
+
batch_idx: cutlass.Int32,
|
| 132 |
+
head_idx: cutlass.Int32,
|
| 133 |
+
m_block: cutlass.Int32,
|
| 134 |
+
n_block: cutlass.Int32,
|
| 135 |
+
thr_mma: cute.TiledMma,
|
| 136 |
+
mask_seqlen: cutlass.Constexpr[bool],
|
| 137 |
+
mask_causal: cutlass.Constexpr[bool],
|
| 138 |
+
mask_local: cutlass.Constexpr[bool] = False,
|
| 139 |
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
| 140 |
+
aux_tensors: Optional[list] = None,
|
| 141 |
+
fastdiv_mods=(None, None),
|
| 142 |
+
) -> None:
|
| 143 |
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
| 144 |
+
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)
|
| 145 |
+
acc_shape = (self.tile_m, self.tile_n)
|
| 146 |
+
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
|
| 147 |
+
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)
|
| 148 |
+
# We use t0ScS as these indices are known at compile time. We then must subtract the
|
| 149 |
+
# column limit by the thread column offset.
|
| 150 |
+
t0ScS_mn = layout_utils.reshape_acc_to_mn(
|
| 151 |
+
thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
|
| 152 |
+
)
|
| 153 |
+
ROW = 0 if const_expr(not self.swap_AB) else 1
|
| 154 |
+
COL = 1 if const_expr(not self.swap_AB) else 0
|
| 155 |
+
thr_col_offset = tScS_mn[0][COL]
|
| 156 |
+
# To handle edge cases of completely masked out rows where n_block_max = 0,
|
| 157 |
+
# we treat negative n_blocks as 0th n_block
|
| 158 |
+
# TODO: find more transparent solution
|
| 159 |
+
if n_block < 0:
|
| 160 |
+
n_block = 0
|
| 161 |
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
| 162 |
+
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
| 163 |
+
if const_expr(mask_seqlen):
|
| 164 |
+
# The compiler now choses not to use R2P
|
| 165 |
+
r2p = const_expr(False and not self.swap_AB)
|
| 166 |
+
if const_expr(not r2p):
|
| 167 |
+
# traverse column index.
|
| 168 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 169 |
+
oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
|
| 170 |
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 171 |
+
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
|
| 172 |
+
else:
|
| 173 |
+
mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)
|
| 174 |
+
|
| 175 |
+
elif const_expr(
|
| 176 |
+
not mask_causal and not mask_local and mask_mod is not None
|
| 177 |
+
): # FlexAttention mask mod
|
| 178 |
+
nrow = const_expr(cute.size(tScS_mn.shape[0]))
|
| 179 |
+
ncol = const_expr(cute.size(tScS_mn.shape[1]))
|
| 180 |
+
has_fastdiv = const_expr(
|
| 181 |
+
fastdiv_mods is not None
|
| 182 |
+
and fastdiv_mods[0] is not None
|
| 183 |
+
and fastdiv_mods[1] is not None
|
| 184 |
+
)
|
| 185 |
+
wrap_aux_indices = const_expr(
|
| 186 |
+
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
for r in cutlass.range_constexpr(nrow):
|
| 190 |
+
# Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
|
| 191 |
+
local_row = tScS_mn[r, 0][ROW]
|
| 192 |
+
global_row_idx = local_row + m_block * self.tile_m
|
| 193 |
+
row_for_mod = global_row_idx
|
| 194 |
+
head_idx_for_mod = head_idx
|
| 195 |
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 196 |
+
head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
|
| 197 |
+
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
|
| 198 |
+
row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
|
| 199 |
+
row_for_seqlen = row_for_mod
|
| 200 |
+
if const_expr(wrap_aux_indices):
|
| 201 |
+
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
|
| 202 |
+
|
| 203 |
+
for col in cutlass.range_constexpr(ncol):
|
| 204 |
+
col_idx_local = t0ScS_mn[0, col][COL]
|
| 205 |
+
# Convert to absolute column index
|
| 206 |
+
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
|
| 207 |
+
col_for_mod = global_col_idx
|
| 208 |
+
if const_expr(wrap_aux_indices):
|
| 209 |
+
_, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])
|
| 210 |
+
|
| 211 |
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
| 212 |
+
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
|
| 213 |
+
q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
|
| 214 |
+
kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
|
| 215 |
+
mask_value = mask_mod(
|
| 216 |
+
batch_idx_ssa,
|
| 217 |
+
head_idx_ssa,
|
| 218 |
+
q_idx_ssa,
|
| 219 |
+
kv_idx_ssa,
|
| 220 |
+
self.seqlen_info,
|
| 221 |
+
aux_tensors,
|
| 222 |
+
)
|
| 223 |
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
| 224 |
+
if const_expr(mask_seqlen):
|
| 225 |
+
out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
|
| 226 |
+
global_col_idx >= self.seqlen_k
|
| 227 |
+
)
|
| 228 |
+
if out_of_bounds:
|
| 229 |
+
acc_S_mn[r, col] = -cutlass.Float32.inf
|
| 230 |
+
else:
|
| 231 |
+
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
|
| 232 |
+
else:
|
| 233 |
+
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
|
| 234 |
+
|
| 235 |
+
else: # Causal or local
|
| 236 |
+
if const_expr(not self.swap_AB):
|
| 237 |
+
# If PackGQA, we split the work of compute divmod among threads in the same row
|
| 238 |
+
threads_per_row = thr_mma.tv_layout_C.shape[0][0]
|
| 239 |
+
mma_m_idx = None
|
| 240 |
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 241 |
+
assert not self.swap_AB, "swap_AB with PackGQA not supported yet"
|
| 242 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, (
|
| 243 |
+
"threads_per_row must divide WARP_SIZE"
|
| 244 |
+
)
|
| 245 |
+
assert cute.size(acc_S_mn.shape[0]) <= threads_per_row
|
| 246 |
+
tidx = thr_mma.thr_idx
|
| 247 |
+
mma_m_idx = (
|
| 248 |
+
m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]
|
| 249 |
+
) // self.qhead_per_kvhead_packgqa
|
| 250 |
+
causal_row_offset = (
|
| 251 |
+
1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset
|
| 252 |
+
)
|
| 253 |
+
if const_expr(mask_causal):
|
| 254 |
+
r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100
|
| 255 |
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 256 |
+
# get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
|
| 257 |
+
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
| 258 |
+
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
| 259 |
+
else:
|
| 260 |
+
row_idx = utils.shuffle_sync(
|
| 261 |
+
mma_m_idx, r % threads_per_row, width=threads_per_row
|
| 262 |
+
)
|
| 263 |
+
col_limit_right = row_idx + causal_row_offset
|
| 264 |
+
if const_expr(mask_seqlen):
|
| 265 |
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 266 |
+
if const_expr(not r2p):
|
| 267 |
+
# traverse column index.
|
| 268 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 269 |
+
acc_S_mn[r, c] = (
|
| 270 |
+
-Float32.inf
|
| 271 |
+
if t0ScS_mn[0, c][1] >= col_limit_right
|
| 272 |
+
else acc_S_mn[r, c]
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True)
|
| 276 |
+
else: # Local
|
| 277 |
+
local_row_offset_right = (
|
| 278 |
+
causal_row_offset + self.window_size_right
|
| 279 |
+
if const_expr(self.window_size_right is not None)
|
| 280 |
+
else None
|
| 281 |
+
)
|
| 282 |
+
local_row_offset_left = (
|
| 283 |
+
causal_row_offset - 1 - self.window_size_left
|
| 284 |
+
if const_expr(self.window_size_left is not None)
|
| 285 |
+
else None
|
| 286 |
+
)
|
| 287 |
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 288 |
+
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
| 289 |
+
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
| 290 |
+
else:
|
| 291 |
+
row_idx = utils.shuffle_sync(
|
| 292 |
+
mma_m_idx, r % threads_per_row, width=threads_per_row
|
| 293 |
+
)
|
| 294 |
+
if const_expr(self.window_size_right is not None):
|
| 295 |
+
col_limit_right = row_idx + local_row_offset_right
|
| 296 |
+
else:
|
| 297 |
+
col_limit_right = self.tile_n
|
| 298 |
+
if const_expr(mask_seqlen):
|
| 299 |
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 300 |
+
col_limit_left = (
|
| 301 |
+
row_idx + local_row_offset_left
|
| 302 |
+
if const_expr(self.window_size_left is not None)
|
| 303 |
+
else 0
|
| 304 |
+
)
|
| 305 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
| 306 |
+
# traverse column index.
|
| 307 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 308 |
+
col_idx = t0ScS_mn[0, c][1]
|
| 309 |
+
# only consider the column index, so the row index sets to 0.
|
| 310 |
+
if col_idx >= col_limit_right or col_idx < col_limit_left:
|
| 311 |
+
acc_S_mn[r, c] = -Float32.inf
|
| 312 |
+
else: # swap_AB
|
| 313 |
+
assert self.qhead_per_kvhead_packgqa == 1
|
| 314 |
+
thr_row_offset = tScS_mn[0][ROW]
|
| 315 |
+
causal_row_offset = (
|
| 316 |
+
seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
|
| 317 |
+
)
|
| 318 |
+
if const_expr(mask_causal):
|
| 319 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 320 |
+
col0 = t0ScS_mn[0, c][COL]
|
| 321 |
+
# If col0 is beyond the column limit, we want to mask out the entire
|
| 322 |
+
# column, by setting row limit to be self.tile_m.
|
| 323 |
+
row_limit_top = (
|
| 324 |
+
self.tile_m
|
| 325 |
+
if col0 >= seqlenk_col_limit and mask_seqlen
|
| 326 |
+
else col0 - causal_row_offset
|
| 327 |
+
)
|
| 328 |
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 329 |
+
acc_S_mn[r, c] = (
|
| 330 |
+
-Float32.inf
|
| 331 |
+
if t0ScS_mn[r, 0][ROW] < row_limit_top
|
| 332 |
+
else acc_S_mn[r, c]
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
| 336 |
+
col0 = t0ScS_mn[0, c][COL]
|
| 337 |
+
# If col0 is beyond the column limit, we want to mask out the entire
|
| 338 |
+
# column, by setting row limit to be self.tile_m.
|
| 339 |
+
row_limit_top = (
|
| 340 |
+
self.tile_m
|
| 341 |
+
if col0 >= seqlenk_col_limit
|
| 342 |
+
else col0 - causal_row_offset - self.window_size_right
|
| 343 |
+
)
|
| 344 |
+
# TODO: do we need col_limit_sink?
|
| 345 |
+
row_limit_bot = col0 - causal_row_offset + self.window_size_left
|
| 346 |
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
| 347 |
+
row_idx = t0ScS_mn[r, 0][ROW]
|
| 348 |
+
acc_S_mn[r, c] = (
|
| 349 |
+
-Float32.inf
|
| 350 |
+
if row_idx < row_limit_top or row_idx > row_limit_bot
|
| 351 |
+
else acc_S_mn[r, c]
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
@cute.jit
|
| 355 |
+
def apply_mask_sm100(
|
| 356 |
+
self,
|
| 357 |
+
acc_S: cute.Tensor,
|
| 358 |
+
m_block: Int32,
|
| 359 |
+
n_block: Int32,
|
| 360 |
+
thr_mma: cute.TiledMma,
|
| 361 |
+
thr_tmem_load: cute.TiledCopy,
|
| 362 |
+
mask_seqlen: cutlass.Constexpr[bool],
|
| 363 |
+
mask_causal: cutlass.Constexpr[bool],
|
| 364 |
+
mask_local: cutlass.Constexpr[bool] = False,
|
| 365 |
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
| 366 |
+
batch_idx: Int32 = None,
|
| 367 |
+
head_idx: Int32 = None,
|
| 368 |
+
aux_tensors: Optional[list] = None,
|
| 369 |
+
fastdiv_mods=(None, None),
|
| 370 |
+
head_divmod=None,
|
| 371 |
+
check_q_boundary: bool = False,
|
| 372 |
+
) -> None:
|
| 373 |
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
| 374 |
+
acc_shape = (self.tile_m, self.tile_n)
|
| 375 |
+
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
|
| 376 |
+
tScS = thr_mma.partition_C(cS)
|
| 377 |
+
tScS = tScS[(None, None), 0, 0]
|
| 378 |
+
tScS_t2r = thr_tmem_load.partition_D(tScS)
|
| 379 |
+
# To handle edge cases of completely masked out rows where n_block_max = 0,
|
| 380 |
+
# we treat negative n_blocks as 0th n_block
|
| 381 |
+
# TODO: find more transparent solution
|
| 382 |
+
if n_block < 0:
|
| 383 |
+
n_block = 0
|
| 384 |
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n
|
| 385 |
+
r2p = True
|
| 386 |
+
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
| 387 |
+
if const_expr(mask_seqlen):
|
| 388 |
+
if const_expr(not r2p):
|
| 389 |
+
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
| 390 |
+
# if tScS_t2r[i][1] >= seqlenk_col_limit:
|
| 391 |
+
# acc_S[i] = -Float32.inf
|
| 392 |
+
# For some reason the 2 lines above generate really bad SASS
|
| 393 |
+
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
|
| 394 |
+
else:
|
| 395 |
+
mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
|
| 396 |
+
|
| 397 |
+
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
| 398 |
+
# Block sparse case w/ mask_mod
|
| 399 |
+
has_fastdiv = const_expr(
|
| 400 |
+
fastdiv_mods is not None
|
| 401 |
+
and fastdiv_mods[0] is not None
|
| 402 |
+
and fastdiv_mods[1] is not None
|
| 403 |
+
)
|
| 404 |
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
| 405 |
+
|
| 406 |
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
| 407 |
+
for i in cutlass.range_constexpr(ncol):
|
| 408 |
+
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
|
| 409 |
+
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
|
| 410 |
+
global_row = row_coord + m_block * self.tile_m
|
| 411 |
+
global_col = col_coord + n_block * self.tile_n
|
| 412 |
+
|
| 413 |
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 414 |
+
assert head_divmod is not None
|
| 415 |
+
mask_row, head_offset = divmod(global_row, head_divmod)
|
| 416 |
+
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
|
| 417 |
+
else:
|
| 418 |
+
head_idx_for_mod = head_idx
|
| 419 |
+
mask_row = global_row
|
| 420 |
+
|
| 421 |
+
mask_row_for_mod = mask_row
|
| 422 |
+
if const_expr(has_fastdiv and aux_tensors is not None):
|
| 423 |
+
if check_q_boundary:
|
| 424 |
+
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
|
| 425 |
+
global_col_for_mod = global_col
|
| 426 |
+
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
|
| 427 |
+
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
|
| 428 |
+
|
| 429 |
+
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
|
| 430 |
+
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
|
| 431 |
+
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
|
| 432 |
+
mask_value = mask_mod(
|
| 433 |
+
batch_idx_ssa,
|
| 434 |
+
head_idx_ssa,
|
| 435 |
+
mask_row_ssa,
|
| 436 |
+
kv_idx_ssa,
|
| 437 |
+
self.seqlen_info,
|
| 438 |
+
aux_tensors,
|
| 439 |
+
)
|
| 440 |
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
| 441 |
+
acc_S[i] = acc_S[i] if cond else -Float32.inf
|
| 442 |
+
if const_expr(mask_seqlen):
|
| 443 |
+
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
|
| 444 |
+
if check_q_boundary:
|
| 445 |
+
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
|
| 446 |
+
|
| 447 |
+
else: # Causal or local
|
| 448 |
+
causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
|
| 449 |
+
row_idx = tScS_t2r[0][0] + m_block * self.tile_m
|
| 450 |
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
| 451 |
+
row_idx = row_idx // self.qhead_per_kvhead_packgqa
|
| 452 |
+
if const_expr(mask_causal):
|
| 453 |
+
col_limit_right = row_idx + causal_row_offset
|
| 454 |
+
if const_expr(mask_seqlen):
|
| 455 |
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 456 |
+
# if cute.arch.thread_idx()[0] % 32 == 0:
|
| 457 |
+
# cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)
|
| 458 |
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
| 459 |
+
if const_expr(not r2p):
|
| 460 |
+
for i in cutlass.range(ncol, unroll_full=True):
|
| 461 |
+
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
|
| 462 |
+
else:
|
| 463 |
+
mask_r2p(acc_S, col_limit_right, arch=100, rank1=True)
|
| 464 |
+
else:
|
| 465 |
+
local_row_offset_right = (
|
| 466 |
+
causal_row_offset + self.window_size_right
|
| 467 |
+
if const_expr(self.window_size_right is not None)
|
| 468 |
+
else None
|
| 469 |
+
)
|
| 470 |
+
local_row_offset_left = (
|
| 471 |
+
causal_row_offset - 1 - self.window_size_left
|
| 472 |
+
if const_expr(self.window_size_left is not None)
|
| 473 |
+
else None
|
| 474 |
+
)
|
| 475 |
+
if const_expr(self.window_size_right is not None):
|
| 476 |
+
col_limit_right = row_idx + local_row_offset_right
|
| 477 |
+
else:
|
| 478 |
+
col_limit_right = self.tile_n
|
| 479 |
+
if const_expr(mask_seqlen):
|
| 480 |
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
| 481 |
+
col_limit_left = (
|
| 482 |
+
row_idx + local_row_offset_left
|
| 483 |
+
if const_expr(self.window_size_left is not None)
|
| 484 |
+
else 0
|
| 485 |
+
)
|
| 486 |
+
if const_expr(not r2p):
|
| 487 |
+
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
| 488 |
+
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
| 489 |
+
col_idx = tScS_t2r[i][1]
|
| 490 |
+
acc_S[i] = (
|
| 491 |
+
-Float32.inf
|
| 492 |
+
if col_idx >= col_limit_right or col_idx < col_limit_left
|
| 493 |
+
else acc_S[i]
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
# XOR-based R2P dual bound masking
|
| 497 |
+
mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)
|
| 498 |
+
|
| 499 |
+
@cute.jit
|
| 500 |
+
def apply_mask_sm100_transposed(
|
| 501 |
+
self,
|
| 502 |
+
acc_S: cute.Tensor,
|
| 503 |
+
tScS_t2r: cute.Tensor,
|
| 504 |
+
t0ScS_t2r: cute.Tensor,
|
| 505 |
+
m_block: cutlass.Int32,
|
| 506 |
+
n_block: cutlass.Int32,
|
| 507 |
+
mask_seqlen: cutlass.Constexpr,
|
| 508 |
+
mask_causal: cutlass.Constexpr,
|
| 509 |
+
mask_local: cutlass.Constexpr,
|
| 510 |
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
| 511 |
+
batch_idx: Int32 = None,
|
| 512 |
+
head_idx: Int32 = None,
|
| 513 |
+
aux_tensors: Optional[list] = None,
|
| 514 |
+
fastdiv_mods=(None, None),
|
| 515 |
+
is_full_block: bool = False,
|
| 516 |
+
check_m_boundary: bool = True,
|
| 517 |
+
) -> None:
|
| 518 |
+
"""
|
| 519 |
+
Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
|
| 520 |
+
|
| 521 |
+
Coordinate conventio:
|
| 522 |
+
- ROW corresponds to Q (m_block)
|
| 523 |
+
- COL corresponds to KV (n_block)
|
| 524 |
+
|
| 525 |
+
is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.
|
| 526 |
+
check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).
|
| 527 |
+
When iterating m_blocks in forward order, only the last m_block may be partial.
|
| 528 |
+
"""
|
| 529 |
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
| 530 |
+
ROW = 0 if const_expr(not self.swap_AB) else 1
|
| 531 |
+
COL = 1 if const_expr(not self.swap_AB) else 0
|
| 532 |
+
# assert t0ScS_t2r[0][COL] == 0, "col0 == 0" # tmp comment for 2-cta bwd
|
| 533 |
+
thr_col_offset = tScS_t2r[0][COL]
|
| 534 |
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
| 535 |
+
|
| 536 |
+
if const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
| 537 |
+
# Block sparse case with mask_mod (backward)
|
| 538 |
+
#
|
| 539 |
+
# Coordinate convention: ROW → Q (m_block), COL → KV (n_block).
|
| 540 |
+
# These already account for swap_AB.
|
| 541 |
+
#
|
| 542 |
+
# FULL blocks: mask_mod returns True for all elements, so skip it.
|
| 543 |
+
# Still need seqlen bounds check (elements may be OOB on last m_block).
|
| 544 |
+
# PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.
|
| 545 |
+
if is_full_block:
|
| 546 |
+
if const_expr(mask_seqlen):
|
| 547 |
+
if seqlenk_col_limit <= 0:
|
| 548 |
+
# Entire tile is OOB for K
|
| 549 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 550 |
+
acc_S[i] = -cutlass.Float32.inf
|
| 551 |
+
elif check_m_boundary:
|
| 552 |
+
# Last m_block: check Q and K boundaries
|
| 553 |
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
| 554 |
+
for i in cutlass.range_constexpr(ncol):
|
| 555 |
+
row_coord = tScS_t2r[i][ROW]
|
| 556 |
+
col_coord = tScS_t2r[i][COL]
|
| 557 |
+
global_q = row_coord + m_block * self.tile_m
|
| 558 |
+
global_kv = col_coord + n_block * self.tile_n
|
| 559 |
+
q_out_of_bounds = global_q >= self.seqlen_q
|
| 560 |
+
kv_out_of_bounds = global_kv >= self.seqlen_k
|
| 561 |
+
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
|
| 562 |
+
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
|
| 563 |
+
else:
|
| 564 |
+
# Partial block
|
| 565 |
+
has_fastdiv = const_expr(
|
| 566 |
+
fastdiv_mods is not None
|
| 567 |
+
and fastdiv_mods[0] is not None
|
| 568 |
+
and fastdiv_mods[1] is not None
|
| 569 |
+
)
|
| 570 |
+
wrap_aux_indices = const_expr(
|
| 571 |
+
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
|
| 572 |
+
)
|
| 573 |
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
| 574 |
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
|
| 575 |
+
|
| 576 |
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
| 577 |
+
for i in cutlass.range_constexpr(ncol):
|
| 578 |
+
row_coord = tScS_t2r[i][ROW]
|
| 579 |
+
col_coord = tScS_t2r[i][COL]
|
| 580 |
+
global_q = row_coord + m_block * self.tile_m
|
| 581 |
+
global_kv = col_coord + n_block * self.tile_n
|
| 582 |
+
|
| 583 |
+
q_idx_for_mod = global_q
|
| 584 |
+
kv_idx_for_mod = global_kv
|
| 585 |
+
if const_expr(wrap_aux_indices):
|
| 586 |
+
_, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])
|
| 587 |
+
_, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])
|
| 588 |
+
|
| 589 |
+
q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)
|
| 590 |
+
kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)
|
| 591 |
+
|
| 592 |
+
mask_value = mask_mod(
|
| 593 |
+
batch_idx_ssa,
|
| 594 |
+
head_idx_ssa,
|
| 595 |
+
q_idx_ssa,
|
| 596 |
+
kv_idx_ssa,
|
| 597 |
+
self.seqlen_info,
|
| 598 |
+
aux_tensors,
|
| 599 |
+
)
|
| 600 |
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
| 601 |
+
acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf
|
| 602 |
+
|
| 603 |
+
if const_expr(mask_seqlen):
|
| 604 |
+
# check_m_boundary=False skips q check for non-boundary m_blocks
|
| 605 |
+
q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)
|
| 606 |
+
kv_out_of_bounds = global_kv >= self.seqlen_k
|
| 607 |
+
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
|
| 608 |
+
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
|
| 609 |
+
|
| 610 |
+
elif const_expr(not mask_causal and not mask_local):
|
| 611 |
+
if const_expr(mask_seqlen):
|
| 612 |
+
if seqlenk_col_limit <= 0:
|
| 613 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 614 |
+
acc_S[i] = -cutlass.Float32.inf
|
| 615 |
+
else: # Causal or local
|
| 616 |
+
thr_row_offset = tScS_t2r[0][ROW]
|
| 617 |
+
seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
|
| 618 |
+
causal_offset = seqlenq_row_limit - seqlenk_col_limit
|
| 619 |
+
if const_expr(mask_causal):
|
| 620 |
+
# tidx = cute.arch.thread_idx()[0] % 256
|
| 621 |
+
# if tidx < 32:
|
| 622 |
+
# cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
|
| 623 |
+
row_limit_top = causal_offset
|
| 624 |
+
if const_expr(mask_seqlen):
|
| 625 |
+
# If col is beyond the column limit, we want to mask out the entire
|
| 626 |
+
# column, by setting row limit to be self.tile_m.
|
| 627 |
+
if seqlenk_col_limit <= 0:
|
| 628 |
+
row_limit_top = self.tile_m
|
| 629 |
+
r2p = True
|
| 630 |
+
if const_expr(not r2p):
|
| 631 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 632 |
+
acc_S[i] = (
|
| 633 |
+
-cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]
|
| 634 |
+
)
|
| 635 |
+
else:
|
| 636 |
+
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
|
| 637 |
+
mask_r2p_transposed(acc_S, row_limit_top, num_rep)
|
| 638 |
+
else:
|
| 639 |
+
if const_expr(self.window_size_right is not None):
|
| 640 |
+
row_limit_top = causal_offset - self.window_size_right
|
| 641 |
+
else:
|
| 642 |
+
row_limit_top = 0
|
| 643 |
+
if const_expr(self.window_size_left is not None):
|
| 644 |
+
row_limit_bot = causal_offset + self.window_size_left
|
| 645 |
+
if const_expr(mask_seqlen):
|
| 646 |
+
if seqlenk_col_limit <= 0:
|
| 647 |
+
row_limit_top = self.tile_m
|
| 648 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 649 |
+
row_idx = t0ScS_t2r[i][ROW]
|
| 650 |
+
local_mask = row_idx < row_limit_top
|
| 651 |
+
if const_expr(self.window_size_left is not None):
|
| 652 |
+
local_mask |= row_idx > row_limit_bot
|
| 653 |
+
acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
|
build/torch-cuda/metadata.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 0,
|
| 3 |
+
"python-depends": [
|
| 4 |
+
"einops",
|
| 5 |
+
"tvm-ffi",
|
| 6 |
+
"nvidia-cutlass-dsl"
|
| 7 |
+
]
|
| 8 |
+
}
|
build/torch-cuda/mma_sm100_desc.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
# Ported Cutlass code from C++ to Python:
|
| 3 |
+
# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp
|
| 4 |
+
# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp
|
| 5 |
+
|
| 6 |
+
from enum import IntEnum
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.cute as cute
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Enumerations that match the HW encodings (values MUST stay identical)
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Major(IntEnum): # matrix “layout” in the ISA docs
|
| 17 |
+
K = 0
|
| 18 |
+
MN = 1
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ScaleIn(IntEnum): # negate flags
|
| 22 |
+
One = 0
|
| 23 |
+
Neg = 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Saturate(IntEnum):
|
| 27 |
+
False_ = 0
|
| 28 |
+
True_ = 1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CFormat(IntEnum): # 2-bit field (bits 4-5)
|
| 32 |
+
F16 = 0
|
| 33 |
+
F32 = 1
|
| 34 |
+
S32 = 2
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class F16F32Format(IntEnum): # 3-bit field (A/B element type)
|
| 38 |
+
F16 = 0
|
| 39 |
+
BF16 = 1
|
| 40 |
+
TF32 = 2
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class S8Format(IntEnum):
|
| 44 |
+
UINT8 = 0
|
| 45 |
+
INT8 = 1
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MXF8F6F4Format(IntEnum):
|
| 49 |
+
E4M3 = 0
|
| 50 |
+
E5M2 = 1
|
| 51 |
+
E2M3 = 3
|
| 52 |
+
E3M2 = 4
|
| 53 |
+
E2M1 = 5
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class MaxShift(IntEnum):
|
| 57 |
+
NoShift = 0
|
| 58 |
+
MaxShift8 = 1
|
| 59 |
+
MaxShift16 = 2
|
| 60 |
+
MaxShift32 = 3
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# CUTLASS-type → encoding helpers
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def to_UMMA_format(cutlass_type) -> int:
|
| 69 |
+
"""
|
| 70 |
+
Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
|
| 71 |
+
"""
|
| 72 |
+
if cutlass_type is cutlass.Int8:
|
| 73 |
+
return S8Format.INT8
|
| 74 |
+
# Unsigned 8-bit (if available in your CUTLASS build)
|
| 75 |
+
if cutlass_type is cutlass.Uint8:
|
| 76 |
+
return S8Format.UINT8
|
| 77 |
+
# FP-16 / BF-16
|
| 78 |
+
if cutlass_type is cutlass.Float16:
|
| 79 |
+
return F16F32Format.F16
|
| 80 |
+
if cutlass_type is cutlass.BFloat16:
|
| 81 |
+
return F16F32Format.BF16
|
| 82 |
+
# TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
|
| 83 |
+
if cutlass_type is cutlass.TFloat32:
|
| 84 |
+
return F16F32Format.TF32
|
| 85 |
+
# Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them
|
| 86 |
+
if cutlass_type is cutlass.FloatE4M3FN:
|
| 87 |
+
return MXF8F6F4Format.E4M3
|
| 88 |
+
if cutlass_type is cutlass.FloatE5M2:
|
| 89 |
+
return MXF8F6F4Format.E5M2
|
| 90 |
+
raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def to_C_format(cutlass_type) -> int:
|
| 94 |
+
"""
|
| 95 |
+
Map a CUTLASS scalar class to the 2-bit accumulator encoding.
|
| 96 |
+
"""
|
| 97 |
+
if cutlass_type is cutlass.Float16:
|
| 98 |
+
return CFormat.F16
|
| 99 |
+
if cutlass_type is cutlass.Float32:
|
| 100 |
+
return CFormat.F32
|
| 101 |
+
if cutlass_type is cutlass.Int32:
|
| 102 |
+
return CFormat.S32
|
| 103 |
+
raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# The constructor – accepts only CUTLASS scalar classes
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def make_instr_desc(
|
| 112 |
+
a_type, # CUTLASS scalar class, e.g. cutlass.Int8
|
| 113 |
+
b_type,
|
| 114 |
+
c_type,
|
| 115 |
+
M: int, # 64, 128 or 256
|
| 116 |
+
N: int, # 8 … 256 (multiple of 8)
|
| 117 |
+
a_major: Major,
|
| 118 |
+
b_major: Major,
|
| 119 |
+
a_neg: ScaleIn = ScaleIn.One,
|
| 120 |
+
b_neg: ScaleIn = ScaleIn.One,
|
| 121 |
+
c_sat: Saturate = Saturate.False_,
|
| 122 |
+
is_sparse: bool = False,
|
| 123 |
+
max_shift: MaxShift = MaxShift.NoShift,
|
| 124 |
+
) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Build the 32-bit instruction descriptor for Blackwell MMA.
|
| 127 |
+
All matrix/accumulator **types must be CUTLASS scalar classes** –
|
| 128 |
+
passing integers is forbidden.
|
| 129 |
+
"""
|
| 130 |
+
# --- encode element formats -------------------------------------------------
|
| 131 |
+
a_fmt = int(to_UMMA_format(a_type))
|
| 132 |
+
b_fmt = int(to_UMMA_format(b_type))
|
| 133 |
+
c_fmt = int(to_C_format(c_type))
|
| 134 |
+
|
| 135 |
+
# --- range checks on M/N -----------------------------------------------------
|
| 136 |
+
if M not in (64, 128, 256):
|
| 137 |
+
raise ValueError("M must be 64, 128 or 256")
|
| 138 |
+
if N < 8 or N > 256 or (N & 7):
|
| 139 |
+
raise ValueError("N must be a multiple of 8 in the range 8…256")
|
| 140 |
+
|
| 141 |
+
m_dim = M >> 4 # 5-bit field
|
| 142 |
+
n_dim = N >> 3 # 6-bit field
|
| 143 |
+
|
| 144 |
+
# fmt: off
|
| 145 |
+
# --- pack the bit-fields -----------------------------------------------------
|
| 146 |
+
desc = 0
|
| 147 |
+
desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
|
| 148 |
+
desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
|
| 149 |
+
desc |= (int(c_sat) & 0x1) << 3 # saturate
|
| 150 |
+
desc |= (c_fmt & 0x3) << 4 # c_format
|
| 151 |
+
desc |= (a_fmt & 0x7) << 7 # a_format
|
| 152 |
+
desc |= (b_fmt & 0x7) << 10 # b_format
|
| 153 |
+
desc |= (int(a_neg) & 0x1) << 13 # a_negate
|
| 154 |
+
desc |= (int(b_neg) & 0x1) << 14 # b_negate
|
| 155 |
+
desc |= (int(a_major) & 0x1) << 15 # a_major
|
| 156 |
+
desc |= (int(b_major) & 0x1) << 16 # b_major
|
| 157 |
+
desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
|
| 158 |
+
desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
|
| 159 |
+
desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
|
| 160 |
+
# fmt: on
|
| 161 |
+
|
| 162 |
+
return desc & 0xFFFF_FFFF # ensure 32-bit result
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
|
| 166 |
+
return make_instr_desc(
|
| 167 |
+
op.a_dtype,
|
| 168 |
+
op.b_dtype,
|
| 169 |
+
op.acc_dtype,
|
| 170 |
+
op.shape_mnk[0],
|
| 171 |
+
op.shape_mnk[1],
|
| 172 |
+
Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
| 173 |
+
Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
|
| 178 |
+
SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs)
|
| 179 |
+
SWIZZLE_128B_BASE32B = 1
|
| 180 |
+
SWIZZLE_128B = 2
|
| 181 |
+
SWIZZLE_64B = 4
|
| 182 |
+
SWIZZLE_32B = 6
|
| 183 |
+
# values 3,5,7 are reserved / illegal for UMMA
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# Helpers – figure out the SWIZZLE_* family from the tensor layout
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
|
| 192 |
+
B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift
|
| 193 |
+
|
| 194 |
+
if M == 4: # Swizzle<*,4,3>
|
| 195 |
+
if S != 3:
|
| 196 |
+
raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
|
| 197 |
+
return {
|
| 198 |
+
0: LayoutType.SWIZZLE_NONE,
|
| 199 |
+
1: LayoutType.SWIZZLE_32B,
|
| 200 |
+
2: LayoutType.SWIZZLE_64B,
|
| 201 |
+
3: LayoutType.SWIZZLE_128B,
|
| 202 |
+
}[B] # KeyError ⇒ invalid B→ raise
|
| 203 |
+
if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
|
| 204 |
+
if (B, S) != (2, 2):
|
| 205 |
+
raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
|
| 206 |
+
return LayoutType.SWIZZLE_128B_BASE32B
|
| 207 |
+
|
| 208 |
+
# Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
|
| 209 |
+
raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
|
| 213 |
+
"""
|
| 214 |
+
Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit
|
| 215 |
+
smem-descriptor, without the smem start address.
|
| 216 |
+
layout must correspond to layout of an uint128 tensor.
|
| 217 |
+
"""
|
| 218 |
+
# ------------------------------------------------------------------ meta
|
| 219 |
+
layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
|
| 220 |
+
|
| 221 |
+
VERSION = 1 # bits 46–47
|
| 222 |
+
LBO_MODE = 0 # bit 52
|
| 223 |
+
BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
|
| 224 |
+
|
| 225 |
+
# ---------------------------------------------------------- strides (units: uint128_t = 16 B)
|
| 226 |
+
swizzle_atom_mn_size = {
|
| 227 |
+
LayoutType.SWIZZLE_NONE: 1,
|
| 228 |
+
LayoutType.SWIZZLE_32B: 2,
|
| 229 |
+
LayoutType.SWIZZLE_64B: 4,
|
| 230 |
+
LayoutType.SWIZZLE_128B: 8,
|
| 231 |
+
LayoutType.SWIZZLE_128B_BASE32B: 8,
|
| 232 |
+
}[layout_type]
|
| 233 |
+
|
| 234 |
+
if major is Major.MN:
|
| 235 |
+
swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
|
| 236 |
+
canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
|
| 237 |
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
| 238 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
|
| 239 |
+
stride_00 = canonical_layout.stride[0][0]
|
| 240 |
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
|
| 241 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
| 242 |
+
stride_10 = canonical_layout.stride[1][0]
|
| 243 |
+
if stride_10 != swizzle_atom_mn_size:
|
| 244 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
| 245 |
+
stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
|
| 246 |
+
if layout_type is LayoutType.SWIZZLE_NONE:
|
| 247 |
+
stride_byte_offset, leading_byte_offset = stride_01, stride_11
|
| 248 |
+
else:
|
| 249 |
+
stride_byte_offset, leading_byte_offset = stride_11, stride_01
|
| 250 |
+
else:
|
| 251 |
+
if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
|
| 252 |
+
raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
|
| 253 |
+
if not cute.size(layout.shape[0]) % 8 == 0:
|
| 254 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
|
| 255 |
+
canonical_layout = cute.logical_divide(layout, (8, 2))
|
| 256 |
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
| 257 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
|
| 258 |
+
stride_00 = canonical_layout.stride[0][0]
|
| 259 |
+
if stride_00 != swizzle_atom_mn_size:
|
| 260 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
| 261 |
+
stride_10 = canonical_layout.stride[1][0]
|
| 262 |
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
|
| 263 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
| 264 |
+
stride_01 = canonical_layout.stride[0][1]
|
| 265 |
+
stride_byte_offset, leading_byte_offset = stride_01, stride_10
|
| 266 |
+
|
| 267 |
+
# ------------------------------------------------------------------ pack
|
| 268 |
+
desc = 0
|
| 269 |
+
# leading_byte_offset_ [16:30)
|
| 270 |
+
desc |= (leading_byte_offset & 0x3FFF) << 16
|
| 271 |
+
# stride_byte_offset_ [32:46)
|
| 272 |
+
desc |= (stride_byte_offset & 0x3FFF) << 32
|
| 273 |
+
# version_ [46:48)
|
| 274 |
+
desc |= (VERSION & 0x3) << 46
|
| 275 |
+
# base_offset_ [49:52)
|
| 276 |
+
desc |= (BASE_OFFSET & 0x7) << 49
|
| 277 |
+
# lbo_mode_ [52:53)
|
| 278 |
+
desc |= (LBO_MODE & 0x1) << 52
|
| 279 |
+
# layout_type_ [61:64)
|
| 280 |
+
desc |= (int(layout_type) & 0x7) << 61
|
| 281 |
+
|
| 282 |
+
return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
|
| 286 |
+
# 14 bits, remove 4 LSB (bits 0-13 in desc)
|
| 287 |
+
return (start_addr.toint() & 0x3FFFF) >> 4
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:
|
| 291 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 292 |
+
return make_smem_desc_base(
|
| 293 |
+
cute.recast_layout(128, sA.element_type.width, sA.layout[0]),
|
| 294 |
+
sA_swizzle,
|
| 295 |
+
major,
|
| 296 |
+
)
|
build/torch-cuda/named_barrier.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import enum
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NamedBarrierFwd(enum.IntEnum):
|
| 7 |
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
| 8 |
+
WarpSchedulerWG1 = enum.auto()
|
| 9 |
+
WarpSchedulerWG2 = enum.auto()
|
| 10 |
+
WarpSchedulerWG3 = enum.auto()
|
| 11 |
+
PFull = enum.auto()
|
| 12 |
+
PEmpty = enum.auto()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NamedBarrierBwd(enum.IntEnum):
|
| 16 |
+
Epilogue = enum.auto()
|
| 17 |
+
WarpSchedulerWG1 = enum.auto()
|
| 18 |
+
WarpSchedulerWG2 = enum.auto()
|
| 19 |
+
WarpSchedulerWG3 = enum.auto()
|
| 20 |
+
PdS = enum.auto()
|
| 21 |
+
dQFullWG0 = enum.auto()
|
| 22 |
+
dQFullWG1 = enum.auto()
|
| 23 |
+
dQEmptyWG0 = enum.auto()
|
| 24 |
+
dQEmptyWG1 = enum.auto()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class NamedBarrierBwdSm100(enum.IntEnum):
|
| 28 |
+
EpilogueWG1 = enum.auto()
|
| 29 |
+
EpilogueWG2 = enum.auto()
|
| 30 |
+
Compute = enum.auto()
|
| 31 |
+
dQaccReduce = enum.auto()
|
| 32 |
+
TmemPtr = enum.auto()
|
build/torch-cuda/pack_gqa.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
from .quack import layout_utils
|
| 8 |
+
from . import utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PackGQA:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
m_block_size: cutlass.Constexpr[int],
|
| 15 |
+
head_dim_padded: cutlass.Constexpr[int],
|
| 16 |
+
check_hdim_oob: cutlass.Constexpr[bool],
|
| 17 |
+
qhead_per_kvhead: cutlass.Constexpr[bool],
|
| 18 |
+
):
|
| 19 |
+
self.m_block_size = m_block_size
|
| 20 |
+
self.head_dim_padded = head_dim_padded
|
| 21 |
+
self.check_hdim_oob = check_hdim_oob
|
| 22 |
+
self.qhead_per_kvhead = qhead_per_kvhead
|
| 23 |
+
|
| 24 |
+
@cute.jit
|
| 25 |
+
def compute_ptr(
|
| 26 |
+
self,
|
| 27 |
+
tensor: cute.Tensor,
|
| 28 |
+
cRows: cute.Tensor,
|
| 29 |
+
tidx: cutlass.Int32,
|
| 30 |
+
block: cutlass.Int32,
|
| 31 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 32 |
+
num_threads: cutlass.Constexpr[int],
|
| 33 |
+
):
|
| 34 |
+
num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
|
| 35 |
+
tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64)
|
| 36 |
+
for i in cutlass.range_constexpr(num_ptr_per_thread):
|
| 37 |
+
row = i * num_threads + cRows[tidx % threads_per_row][0]
|
| 38 |
+
idx = block * self.m_block_size + row
|
| 39 |
+
m_idx = idx // self.qhead_per_kvhead
|
| 40 |
+
h_idx = idx - m_idx * self.qhead_per_kvhead
|
| 41 |
+
tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
|
| 42 |
+
return tPrPtr
|
| 43 |
+
|
| 44 |
+
@cute.jit
|
| 45 |
+
def load_Q(
|
| 46 |
+
self,
|
| 47 |
+
mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
| 48 |
+
sQ: cute.Tensor, # (m_block_size, head_dim_padded)
|
| 49 |
+
gmem_tiled_copy: cute.TiledCopy,
|
| 50 |
+
tidx: cutlass.Int32,
|
| 51 |
+
block: cutlass.Int32,
|
| 52 |
+
seqlen: cutlass.Int32,
|
| 53 |
+
):
|
| 54 |
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
| 55 |
+
cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 56 |
+
tQsQ = gmem_thr_copy.partition_D(sQ)
|
| 57 |
+
tQcQ = gmem_thr_copy.partition_S(cQ)
|
| 58 |
+
t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
|
| 59 |
+
tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
|
| 60 |
+
tQcQ_row = tQcQ[0, None, 0]
|
| 61 |
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
| 62 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 63 |
+
num_threads = gmem_tiled_copy.size
|
| 64 |
+
tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
|
| 65 |
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
| 66 |
+
q_ptr_i64 = utils.shuffle_sync(
|
| 67 |
+
tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
| 68 |
+
)
|
| 69 |
+
q_gmem_ptr = cute.make_ptr(
|
| 70 |
+
mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
| 71 |
+
)
|
| 72 |
+
if (
|
| 73 |
+
t0QcQ[0, m, 0][0]
|
| 74 |
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
|
| 75 |
+
):
|
| 76 |
+
mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
|
| 77 |
+
elems_per_load = cute.size(tQsQ.shape[0][0])
|
| 78 |
+
mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
|
| 79 |
+
for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
|
| 80 |
+
ki = tQcQ[0, 0, k][1] // elems_per_load
|
| 81 |
+
cute.copy(
|
| 82 |
+
gmem_thr_copy,
|
| 83 |
+
mQ_cur_copy[None, ki],
|
| 84 |
+
tQsQ[None, m, k],
|
| 85 |
+
pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
| 86 |
+
)
|
| 87 |
+
# We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
| 88 |
+
|
| 89 |
+
@cute.jit
|
| 90 |
+
def store_LSE(
|
| 91 |
+
self,
|
| 92 |
+
mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
|
| 93 |
+
tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
|
| 94 |
+
tiled_mma: cute.TiledMma,
|
| 95 |
+
tidx: cutlass.Int32,
|
| 96 |
+
block: cutlass.Int32,
|
| 97 |
+
seqlen: cutlass.Int32,
|
| 98 |
+
):
|
| 99 |
+
thr_mma = tiled_mma.get_slice(tidx)
|
| 100 |
+
caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 101 |
+
taccOcO = thr_mma.partition_C(caccO)
|
| 102 |
+
taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
|
| 103 |
+
assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
|
| 104 |
+
threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
|
| 105 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 106 |
+
assert cute.size(tLSErLSE) <= threads_per_row
|
| 107 |
+
num_threads = tiled_mma.size
|
| 108 |
+
tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
|
| 109 |
+
for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
|
| 110 |
+
lse_ptr_i64 = utils.shuffle_sync(
|
| 111 |
+
tPrLSEPtr[m // threads_per_row],
|
| 112 |
+
m % threads_per_row,
|
| 113 |
+
width=threads_per_row,
|
| 114 |
+
)
|
| 115 |
+
lse_gmem_ptr = cute.make_ptr(
|
| 116 |
+
mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
|
| 117 |
+
)
|
| 118 |
+
row = block * self.m_block_size + taccOcO_row[m][0]
|
| 119 |
+
# Only the thread corresponding to column 0 writes out the lse to gmem
|
| 120 |
+
if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
|
| 121 |
+
mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
|
| 122 |
+
mLSE_copy[0] = tLSErLSE[m]
|
| 123 |
+
|
| 124 |
+
@cute.jit
|
| 125 |
+
def store_O(
|
| 126 |
+
self,
|
| 127 |
+
mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
| 128 |
+
tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
|
| 129 |
+
gmem_tiled_copy: cute.TiledCopy,
|
| 130 |
+
tidx: cutlass.Int32,
|
| 131 |
+
block: cutlass.Int32,
|
| 132 |
+
seqlen: cutlass.Int32,
|
| 133 |
+
):
|
| 134 |
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
| 135 |
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 136 |
+
tOcO = gmem_thr_copy.partition_S(cO)
|
| 137 |
+
t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
|
| 138 |
+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
|
| 139 |
+
tOcO_row = tOcO[0, None, 0]
|
| 140 |
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
| 141 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 142 |
+
num_threads = gmem_tiled_copy.size
|
| 143 |
+
tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
|
| 144 |
+
for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
| 145 |
+
o_ptr_i64 = utils.shuffle_sync(
|
| 146 |
+
tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
| 147 |
+
)
|
| 148 |
+
o_gmem_ptr = cute.make_ptr(
|
| 149 |
+
mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
| 150 |
+
)
|
| 151 |
+
if (
|
| 152 |
+
t0OcO[0, m, 0][0]
|
| 153 |
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
|
| 154 |
+
):
|
| 155 |
+
mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
|
| 156 |
+
elems_per_load = cute.size(tOrO.shape[0][0])
|
| 157 |
+
mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
|
| 158 |
+
for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
|
| 159 |
+
ki = tOcO[0, 0, k][1] // elems_per_load
|
| 160 |
+
cute.copy(
|
| 161 |
+
gmem_thr_copy,
|
| 162 |
+
tOrO[None, m, k],
|
| 163 |
+
mO_cur_copy[None, ki],
|
| 164 |
+
pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
| 165 |
+
)
|
build/torch-cuda/paged_kv.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass.cute.nvgpu import cpasync
|
| 7 |
+
from cutlass import Int32, const_expr
|
| 8 |
+
|
| 9 |
+
from . import utils
|
| 10 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 11 |
+
from cutlass.cute import FastDivmodDivisor
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class PagedKVManager(ParamsBase):
|
| 18 |
+
mPageTable: cute.Tensor
|
| 19 |
+
mK_paged: cute.Tensor
|
| 20 |
+
mV_paged: cute.Tensor
|
| 21 |
+
thread_idx: Int32
|
| 22 |
+
|
| 23 |
+
page_size_divmod: FastDivmodDivisor
|
| 24 |
+
seqlen_k: Int32
|
| 25 |
+
leftpad_k: Int32
|
| 26 |
+
n_block_size: Int32
|
| 27 |
+
num_threads: cutlass.Constexpr[Int32]
|
| 28 |
+
head_dim_padded: cutlass.Constexpr[Int32]
|
| 29 |
+
head_dim_v_padded: cutlass.Constexpr[Int32]
|
| 30 |
+
|
| 31 |
+
gmem_threads_per_row: cutlass.Constexpr[Int32]
|
| 32 |
+
page_entry_per_thread: Int32
|
| 33 |
+
async_copy_elems: Int32
|
| 34 |
+
|
| 35 |
+
gmem_tiled_copy_KV: cute.TiledCopy
|
| 36 |
+
gmem_thr_copy_KV: cute.TiledCopy
|
| 37 |
+
tPrPage: cute.Tensor
|
| 38 |
+
tPrPageOffset: cute.Tensor
|
| 39 |
+
tKpK: cute.Tensor
|
| 40 |
+
tVpV: cute.Tensor
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def create(
|
| 44 |
+
mPageTable: cute.Tensor,
|
| 45 |
+
mK_paged: cute.Tensor,
|
| 46 |
+
mV_paged: cute.Tensor,
|
| 47 |
+
page_size_divmod: FastDivmodDivisor,
|
| 48 |
+
bidb: Int32,
|
| 49 |
+
bidh: Int32,
|
| 50 |
+
thread_idx: Int32,
|
| 51 |
+
seqlen_k: Int32,
|
| 52 |
+
leftpad_k: Int32,
|
| 53 |
+
n_block_size: cutlass.Constexpr[Int32],
|
| 54 |
+
head_dim_padded: cutlass.Constexpr[Int32],
|
| 55 |
+
head_dim_v_padded: cutlass.Constexpr[Int32],
|
| 56 |
+
num_threads: cutlass.Constexpr[Int32],
|
| 57 |
+
dtype: Type[cutlass.Numeric],
|
| 58 |
+
):
|
| 59 |
+
universal_copy_bits = 128
|
| 60 |
+
async_copy_elems = universal_copy_bits // dtype.width
|
| 61 |
+
dtype_bytes = dtype.width // 8
|
| 62 |
+
gmem_k_block_size = math.gcd(
|
| 63 |
+
head_dim_padded,
|
| 64 |
+
head_dim_v_padded,
|
| 65 |
+
128 // dtype_bytes,
|
| 66 |
+
)
|
| 67 |
+
assert gmem_k_block_size % async_copy_elems == 0
|
| 68 |
+
gmem_threads_per_row = gmem_k_block_size // async_copy_elems
|
| 69 |
+
assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0
|
| 70 |
+
atom_async_copy = cute.make_copy_atom(
|
| 71 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 72 |
+
dtype,
|
| 73 |
+
num_bits_per_copy=universal_copy_bits,
|
| 74 |
+
)
|
| 75 |
+
thr_layout = cute.make_ordered_layout(
|
| 76 |
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 77 |
+
order=(1, 0),
|
| 78 |
+
)
|
| 79 |
+
val_layout = cute.make_layout((1, async_copy_elems))
|
| 80 |
+
gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
|
| 81 |
+
gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
|
| 82 |
+
page_entry_per_thread = n_block_size // num_threads
|
| 83 |
+
|
| 84 |
+
tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
| 85 |
+
tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
| 86 |
+
|
| 87 |
+
mPageTable = mPageTable[bidb, None]
|
| 88 |
+
mK_paged = mK_paged[None, None, bidh, None]
|
| 89 |
+
mV_paged = mV_paged[None, None, bidh, None]
|
| 90 |
+
|
| 91 |
+
cK = cute.make_identity_tensor((n_block_size, head_dim_padded))
|
| 92 |
+
tKcK = gmem_thr_copy_KV.partition_S(cK)
|
| 93 |
+
tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1])
|
| 94 |
+
|
| 95 |
+
if const_expr(head_dim_padded == head_dim_v_padded):
|
| 96 |
+
tVpV = tKpK
|
| 97 |
+
else:
|
| 98 |
+
cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
|
| 99 |
+
tVcV = gmem_thr_copy_KV.partition_S(cV)
|
| 100 |
+
tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0])
|
| 101 |
+
|
| 102 |
+
return PagedKVManager(
|
| 103 |
+
mPageTable,
|
| 104 |
+
mK_paged,
|
| 105 |
+
mV_paged,
|
| 106 |
+
thread_idx,
|
| 107 |
+
page_size_divmod,
|
| 108 |
+
seqlen_k,
|
| 109 |
+
leftpad_k,
|
| 110 |
+
n_block_size,
|
| 111 |
+
num_threads,
|
| 112 |
+
head_dim_padded,
|
| 113 |
+
head_dim_v_padded,
|
| 114 |
+
gmem_threads_per_row,
|
| 115 |
+
page_entry_per_thread,
|
| 116 |
+
async_copy_elems,
|
| 117 |
+
gmem_tiled_copy_KV,
|
| 118 |
+
gmem_thr_copy_KV,
|
| 119 |
+
tPrPage,
|
| 120 |
+
tPrPageOffset,
|
| 121 |
+
tKpK,
|
| 122 |
+
tVpV,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
@cute.jit
|
| 126 |
+
def load_page_table(self, n_block: Int32):
|
| 127 |
+
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
| 128 |
+
row = (
|
| 129 |
+
i * self.num_threads
|
| 130 |
+
+ (self.thread_idx % self.gmem_threads_per_row)
|
| 131 |
+
* (self.num_threads // self.gmem_threads_per_row)
|
| 132 |
+
+ (self.thread_idx // self.gmem_threads_per_row)
|
| 133 |
+
)
|
| 134 |
+
row_idx = n_block * self.n_block_size + row
|
| 135 |
+
|
| 136 |
+
page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
|
| 137 |
+
|
| 138 |
+
is_valid = (
|
| 139 |
+
(i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size
|
| 140 |
+
) and row_idx < self.seqlen_k
|
| 141 |
+
page = self.mPageTable[page_idx] if is_valid else 0
|
| 142 |
+
|
| 143 |
+
self.tPrPage[i] = page
|
| 144 |
+
self.tPrPageOffset[i] = page_offset
|
| 145 |
+
|
| 146 |
+
@cute.jit
|
| 147 |
+
def compute_X_ptr(self, K_or_V: str):
|
| 148 |
+
tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
|
| 149 |
+
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
| 150 |
+
page = self.tPrPage[i]
|
| 151 |
+
page_offset = self.tPrPageOffset[i]
|
| 152 |
+
if const_expr(K_or_V == "K"):
|
| 153 |
+
tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint()
|
| 154 |
+
else:
|
| 155 |
+
tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint()
|
| 156 |
+
return tPrXPtr
|
| 157 |
+
|
| 158 |
+
@cute.jit
|
| 159 |
+
def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
|
| 160 |
+
assert K_or_V in ("K", "V")
|
| 161 |
+
|
| 162 |
+
tPrXPtr = self.compute_X_ptr(K_or_V)
|
| 163 |
+
|
| 164 |
+
# Finesse sX layout to be (M, N).
|
| 165 |
+
sX_pi = cute.make_tensor(
|
| 166 |
+
sX.iterator,
|
| 167 |
+
cute.make_layout(
|
| 168 |
+
(sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
|
| 169 |
+
stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
|
| 170 |
+
),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if const_expr(K_or_V == "V"):
|
| 174 |
+
# Need to transpose V
|
| 175 |
+
sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
|
| 176 |
+
|
| 177 |
+
head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
|
| 178 |
+
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
|
| 179 |
+
tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
|
| 180 |
+
tXcX = self.gmem_thr_copy_KV.partition_S(cX)
|
| 181 |
+
tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX)
|
| 182 |
+
|
| 183 |
+
seqlenk_row_limit = (
|
| 184 |
+
self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0
|
| 185 |
+
)
|
| 186 |
+
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
|
| 187 |
+
row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit
|
| 188 |
+
should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean)
|
| 189 |
+
should_load.fill(row_valid)
|
| 190 |
+
|
| 191 |
+
x_ptr_i64 = utils.shuffle_sync(
|
| 192 |
+
tPrXPtr[m // self.gmem_threads_per_row],
|
| 193 |
+
m % self.gmem_threads_per_row,
|
| 194 |
+
width=self.gmem_threads_per_row,
|
| 195 |
+
)
|
| 196 |
+
x_gmem_ptr = cute.make_ptr(
|
| 197 |
+
self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
| 198 |
+
)
|
| 199 |
+
mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,)))
|
| 200 |
+
mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
|
| 201 |
+
|
| 202 |
+
for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
|
| 203 |
+
ki = tXcX[0, 0, k][1] // self.async_copy_elems
|
| 204 |
+
mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki]
|
| 205 |
+
tXsX_k = tXsX[None, m, k]
|
| 206 |
+
mX_paged_cur_copy_ki = cute.make_tensor(
|
| 207 |
+
mX_paged_cur_copy_ki.iterator, tXsX_k.layout
|
| 208 |
+
)
|
| 209 |
+
cute.copy(
|
| 210 |
+
self.gmem_tiled_copy_KV,
|
| 211 |
+
mX_paged_cur_copy_ki,
|
| 212 |
+
tXsX_k,
|
| 213 |
+
pred=should_load,
|
| 214 |
+
)
|
build/torch-cuda/pipeline.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
# import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Boolean, Int32, const_expr
|
| 9 |
+
from cutlass.cutlass_dsl import if_generate, dsl_user_op
|
| 10 |
+
from cutlass.pipeline import PipelineState
|
| 11 |
+
from cutlass.pipeline import PipelineUserType
|
| 12 |
+
from cutlass.pipeline import NamedBarrier as NamedBarrierOg
|
| 13 |
+
from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
|
| 14 |
+
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
| 15 |
+
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
| 16 |
+
from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
|
| 17 |
+
from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PipelineStateSimple:
|
| 21 |
+
"""
|
| 22 |
+
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
| 23 |
+
Use a single Int32 to store both the index and phase bit, then we use divmod to get the
|
| 24 |
+
index and phase. If stages is a power of 2, divmod turns into bit twiddling.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, stages: int, phase_index: Int32):
|
| 28 |
+
# assert stages < 2**16
|
| 29 |
+
# self._log_stages = int(math.log2(stages))
|
| 30 |
+
# assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
|
| 31 |
+
self._stages = stages
|
| 32 |
+
self._phase_index = phase_index
|
| 33 |
+
|
| 34 |
+
def clone(self) -> "PipelineStateSimple":
|
| 35 |
+
return PipelineStateSimple(self.stages, self._phase_index)
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def stages(self) -> int:
|
| 39 |
+
# return 1 << self._log_stages
|
| 40 |
+
return self._stages
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def index(self) -> Int32:
|
| 44 |
+
# return self._phase_index & 0xFFFF
|
| 45 |
+
# return self._phase_index & ((1 << self._log_stages) - 1)
|
| 46 |
+
if const_expr(self._stages == 1):
|
| 47 |
+
return Int32(0)
|
| 48 |
+
else:
|
| 49 |
+
return self._phase_index % self._stages
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def phase(self) -> Int32:
|
| 53 |
+
# return self._phase_index >> 16
|
| 54 |
+
# PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
|
| 55 |
+
# take modulo 2. But in practice just passing the phase in without modulo works fine.
|
| 56 |
+
# return (self._phase_index >> self._log_stages) % 2
|
| 57 |
+
# return self._phase_index >> self._log_stages
|
| 58 |
+
if const_expr(self._stages == 1):
|
| 59 |
+
return self._phase_index
|
| 60 |
+
else:
|
| 61 |
+
return self._phase_index // self._stages
|
| 62 |
+
|
| 63 |
+
def advance(self):
|
| 64 |
+
if const_expr(self._stages == 1):
|
| 65 |
+
self._phase_index ^= 1
|
| 66 |
+
else:
|
| 67 |
+
self._phase_index += 1
|
| 68 |
+
|
| 69 |
+
# def then_body(phase_index):
|
| 70 |
+
# # XOR the phase bit and set the index to 0
|
| 71 |
+
# return (phase_index & 0xFFFF0000) ^ (1 << 16)
|
| 72 |
+
|
| 73 |
+
# def else_body(phase_index):
|
| 74 |
+
# return phase_index
|
| 75 |
+
|
| 76 |
+
# self._phase_index = if_generate(
|
| 77 |
+
# (self._phase_index & 0xFFFF) == self.stages,
|
| 78 |
+
# then_body,
|
| 79 |
+
# else_body,
|
| 80 |
+
# [self._phase_index],
|
| 81 |
+
# [Int32],
|
| 82 |
+
# )
|
| 83 |
+
|
| 84 |
+
def __extract_mlir_values__(self):
|
| 85 |
+
phase_index = self._phase_index
|
| 86 |
+
return [phase_index.ir_value()]
|
| 87 |
+
|
| 88 |
+
def __new_from_mlir_values__(self, values):
|
| 89 |
+
return PipelineStateSimple(self.stages, Int32(values[0]))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
| 93 |
+
"""
|
| 94 |
+
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 95 |
+
"""
|
| 96 |
+
if type is PipelineUserType.Producer:
|
| 97 |
+
# return PipelineStateSimple(stages, Int32(1 << 16))
|
| 98 |
+
return PipelineStateSimple(stages, Int32(stages))
|
| 99 |
+
elif type is PipelineUserType.Consumer:
|
| 100 |
+
return PipelineStateSimple(stages, Int32(0))
|
| 101 |
+
else:
|
| 102 |
+
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass(frozen=True)
|
| 106 |
+
class NamedBarrier(NamedBarrierOg):
|
| 107 |
+
@staticmethod
|
| 108 |
+
def create(*args, **kwargs):
|
| 109 |
+
obj = NamedBarrierOg.create(*args, **kwargs)
|
| 110 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 111 |
+
object.__setattr__(obj, "__class__", NamedBarrier)
|
| 112 |
+
return obj
|
| 113 |
+
|
| 114 |
+
@dsl_user_op
|
| 115 |
+
def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 116 |
+
"""
|
| 117 |
+
The aligned flavor of arrive is used when all threads in the CTA will execute the
|
| 118 |
+
same instruction. See PTX documentation.
|
| 119 |
+
"""
|
| 120 |
+
cute.arch.barrier_arrive(
|
| 121 |
+
barrier_id=self.barrier_id + index,
|
| 122 |
+
number_of_threads=self.num_threads,
|
| 123 |
+
loc=loc,
|
| 124 |
+
ip=ip,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
@dsl_user_op
|
| 128 |
+
def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 129 |
+
cute.arch.barrier(
|
| 130 |
+
barrier_id=self.barrier_id + index,
|
| 131 |
+
number_of_threads=self.num_threads,
|
| 132 |
+
loc=loc,
|
| 133 |
+
ip=ip,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass(frozen=True)
|
| 138 |
+
class PipelineAsync(PipelineAsyncOg):
|
| 139 |
+
@staticmethod
|
| 140 |
+
def create(*args, **kwargs):
|
| 141 |
+
obj = PipelineAsyncOg.create(*args, **kwargs)
|
| 142 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 143 |
+
# obj.__class__ = PipelineAsync
|
| 144 |
+
object.__setattr__(obj, "__class__", PipelineAsync)
|
| 145 |
+
return obj
|
| 146 |
+
|
| 147 |
+
@dsl_user_op
|
| 148 |
+
def producer_acquire_w_index_phase(
|
| 149 |
+
self,
|
| 150 |
+
index: Int32,
|
| 151 |
+
phase: Int32,
|
| 152 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 153 |
+
*,
|
| 154 |
+
loc=None,
|
| 155 |
+
ip=None,
|
| 156 |
+
):
|
| 157 |
+
if_generate(
|
| 158 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 159 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 160 |
+
loc=loc,
|
| 161 |
+
ip=ip,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
@dsl_user_op
|
| 165 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 166 |
+
self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
|
| 167 |
+
|
| 168 |
+
@dsl_user_op
|
| 169 |
+
def consumer_wait_w_index_phase(
|
| 170 |
+
self,
|
| 171 |
+
index: Int32,
|
| 172 |
+
phase: Int32,
|
| 173 |
+
try_wait_token: Optional[Boolean] = None,
|
| 174 |
+
*,
|
| 175 |
+
loc=None,
|
| 176 |
+
ip=None,
|
| 177 |
+
):
|
| 178 |
+
if_generate(
|
| 179 |
+
try_wait_token is None or try_wait_token == 0,
|
| 180 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 181 |
+
loc=loc,
|
| 182 |
+
ip=ip,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
@dsl_user_op
|
| 186 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 187 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@dataclass(frozen=True)
|
| 191 |
+
class PipelineTmaAsync(PipelineTmaAsyncOg):
|
| 192 |
+
"""
|
| 193 |
+
Override producer_acquire to take in extra_tx_count parameter.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def create(*args, **kwargs):
|
| 198 |
+
obj = PipelineTmaAsyncOg.create(*args, **kwargs)
|
| 199 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 200 |
+
object.__setattr__(obj, "__class__", PipelineTmaAsync)
|
| 201 |
+
return obj
|
| 202 |
+
|
| 203 |
+
@dsl_user_op
|
| 204 |
+
def producer_acquire(
|
| 205 |
+
self,
|
| 206 |
+
state: PipelineState,
|
| 207 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 208 |
+
extra_tx_count: int = 0,
|
| 209 |
+
*,
|
| 210 |
+
loc=None,
|
| 211 |
+
ip=None,
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 215 |
+
"""
|
| 216 |
+
if_generate(
|
| 217 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 218 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 219 |
+
loc=loc,
|
| 220 |
+
ip=ip,
|
| 221 |
+
)
|
| 222 |
+
if const_expr(extra_tx_count == 0):
|
| 223 |
+
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
|
| 224 |
+
else:
|
| 225 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 226 |
+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@dataclass(frozen=True)
|
| 230 |
+
class PipelineTmaUmma(PipelineTmaUmmaOg):
|
| 231 |
+
"""
|
| 232 |
+
Override producer_acquire to take in extra_tx_count parameter.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def create(*args, **kwargs):
|
| 237 |
+
obj = PipelineTmaUmmaOg.create(*args, **kwargs)
|
| 238 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 239 |
+
# obj.__class__ = PipelineTmaUmma
|
| 240 |
+
object.__setattr__(obj, "__class__", PipelineTmaUmma)
|
| 241 |
+
return obj
|
| 242 |
+
|
| 243 |
+
@dsl_user_op
|
| 244 |
+
def producer_acquire(
|
| 245 |
+
self,
|
| 246 |
+
state: PipelineState,
|
| 247 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 248 |
+
extra_tx_count: int = 0,
|
| 249 |
+
*,
|
| 250 |
+
loc=None,
|
| 251 |
+
ip=None,
|
| 252 |
+
):
|
| 253 |
+
"""
|
| 254 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 255 |
+
"""
|
| 256 |
+
if_generate(
|
| 257 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 258 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 259 |
+
loc=loc,
|
| 260 |
+
ip=ip,
|
| 261 |
+
)
|
| 262 |
+
if const_expr(extra_tx_count == 0):
|
| 263 |
+
if_generate(
|
| 264 |
+
self.is_leader_cta,
|
| 265 |
+
lambda: self.sync_object_full.arrive(
|
| 266 |
+
state.index, self.producer_mask, loc=loc, ip=ip
|
| 267 |
+
),
|
| 268 |
+
loc=loc,
|
| 269 |
+
ip=ip,
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 273 |
+
if_generate(
|
| 274 |
+
self.is_leader_cta,
|
| 275 |
+
lambda: self.sync_object_full.arrive_and_expect_tx(
|
| 276 |
+
state.index, tx_count, loc=loc, ip=ip
|
| 277 |
+
),
|
| 278 |
+
loc=loc,
|
| 279 |
+
ip=ip,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
@dsl_user_op
|
| 283 |
+
def producer_acquire_w_index_phase(
|
| 284 |
+
self,
|
| 285 |
+
index: Int32,
|
| 286 |
+
phase: Int32,
|
| 287 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 288 |
+
*,
|
| 289 |
+
loc=None,
|
| 290 |
+
ip=None,
|
| 291 |
+
):
|
| 292 |
+
"""
|
| 293 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 294 |
+
"""
|
| 295 |
+
if_generate(
|
| 296 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 297 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 298 |
+
loc=loc,
|
| 299 |
+
ip=ip,
|
| 300 |
+
)
|
| 301 |
+
if_generate(
|
| 302 |
+
self.is_leader_cta,
|
| 303 |
+
lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
|
| 304 |
+
loc=loc,
|
| 305 |
+
ip=ip,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
@dsl_user_op
|
| 309 |
+
def consumer_wait_w_index_phase(
|
| 310 |
+
self,
|
| 311 |
+
index: Int32,
|
| 312 |
+
phase: Int32,
|
| 313 |
+
try_wait_token: Optional[Boolean] = None,
|
| 314 |
+
*,
|
| 315 |
+
loc=None,
|
| 316 |
+
ip=None,
|
| 317 |
+
):
|
| 318 |
+
if_generate(
|
| 319 |
+
try_wait_token is None or try_wait_token == 0,
|
| 320 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 321 |
+
loc=loc,
|
| 322 |
+
ip=ip,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
@dsl_user_op
|
| 326 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 327 |
+
"""
|
| 328 |
+
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 329 |
+
"""
|
| 330 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@dataclass(frozen=True)
|
| 334 |
+
class PipelineUmmaAsync(PipelineUmmaAsyncOg):
|
| 335 |
+
@staticmethod
|
| 336 |
+
def create(*args, **kwargs):
|
| 337 |
+
obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
|
| 338 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 339 |
+
object.__setattr__(obj, "__class__", PipelineUmmaAsync)
|
| 340 |
+
return obj
|
| 341 |
+
|
| 342 |
+
@dsl_user_op
|
| 343 |
+
def producer_acquire_w_index_phase(
|
| 344 |
+
self,
|
| 345 |
+
index: Int32,
|
| 346 |
+
phase: Int32,
|
| 347 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 348 |
+
*,
|
| 349 |
+
loc=None,
|
| 350 |
+
ip=None,
|
| 351 |
+
):
|
| 352 |
+
if_generate(
|
| 353 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 354 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 355 |
+
loc=loc,
|
| 356 |
+
ip=ip,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
@dsl_user_op
|
| 360 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 361 |
+
"""
|
| 362 |
+
UMMA producer commit buffer full, cta_group needs to be provided.
|
| 363 |
+
"""
|
| 364 |
+
self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
|
| 365 |
+
|
| 366 |
+
@dsl_user_op
|
| 367 |
+
def consumer_wait_w_index_phase(
|
| 368 |
+
self,
|
| 369 |
+
index: Int32,
|
| 370 |
+
phase: Int32,
|
| 371 |
+
try_wait_token: Optional[Boolean] = None,
|
| 372 |
+
*,
|
| 373 |
+
loc=None,
|
| 374 |
+
ip=None,
|
| 375 |
+
):
|
| 376 |
+
if_generate(
|
| 377 |
+
try_wait_token is None or try_wait_token == 0,
|
| 378 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 379 |
+
loc=loc,
|
| 380 |
+
ip=ip,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
@dsl_user_op
|
| 384 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 385 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@dataclass(frozen=True)
|
| 389 |
+
class PipelineAsyncUmma(PipelineAsyncUmmaOg):
|
| 390 |
+
@staticmethod
|
| 391 |
+
def create(*args, **kwargs):
|
| 392 |
+
obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
|
| 393 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 394 |
+
object.__setattr__(obj, "__class__", PipelineAsyncUmma)
|
| 395 |
+
return obj
|
| 396 |
+
|
| 397 |
+
@dsl_user_op
|
| 398 |
+
def producer_acquire_w_index_phase(
|
| 399 |
+
self,
|
| 400 |
+
index: Int32,
|
| 401 |
+
phase: Int32,
|
| 402 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 403 |
+
*,
|
| 404 |
+
loc=None,
|
| 405 |
+
ip=None,
|
| 406 |
+
):
|
| 407 |
+
if_generate(
|
| 408 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 409 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 410 |
+
loc=loc,
|
| 411 |
+
ip=ip,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
@dsl_user_op
|
| 415 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 416 |
+
self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
|
| 417 |
+
|
| 418 |
+
@dsl_user_op
|
| 419 |
+
def consumer_wait_w_index_phase(
|
| 420 |
+
self,
|
| 421 |
+
index: Int32,
|
| 422 |
+
phase: Int32,
|
| 423 |
+
try_wait_token: Optional[Boolean] = None,
|
| 424 |
+
*,
|
| 425 |
+
loc=None,
|
| 426 |
+
ip=None,
|
| 427 |
+
):
|
| 428 |
+
if_generate(
|
| 429 |
+
try_wait_token is None or try_wait_token == 0,
|
| 430 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 431 |
+
loc=loc,
|
| 432 |
+
ip=ip,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
@dsl_user_op
|
| 436 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 437 |
+
"""
|
| 438 |
+
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 439 |
+
"""
|
| 440 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
build/torch-cuda/quack/__init__.py
ADDED
|
File without changes
|
build/torch-cuda/quack/activation.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Float32, Boolean, const_expr
|
| 9 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 10 |
+
from cutlass._mlir.dialects import llvm, nvvm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
sub_packed_f32x2 = partial(
|
| 17 |
+
cute.arch.calc_packed_f32x2_op,
|
| 18 |
+
src_c=None,
|
| 19 |
+
calc_func=nvvm.sub_packed_f32x2,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 25 |
+
return Float32(
|
| 26 |
+
llvm.inline_asm(
|
| 27 |
+
T.f32(),
|
| 28 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 29 |
+
"tanh.approx.f32 $0, $1;",
|
| 30 |
+
"=f,f",
|
| 31 |
+
has_side_effects=False,
|
| 32 |
+
is_align_stack=False,
|
| 33 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dsl_user_op
|
| 39 |
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 40 |
+
if const_expr(not isinstance(x, tuple)):
|
| 41 |
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
| 42 |
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
| 43 |
+
else:
|
| 44 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 45 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 46 |
+
return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dsl_user_op
|
| 50 |
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 51 |
+
# return dout * out * (1.0 - out)
|
| 52 |
+
return dout * (out - out * out)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dsl_user_op
|
| 56 |
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 57 |
+
if const_expr(not isinstance(x, tuple)):
|
| 58 |
+
return cute.arch.fmax(x, Float32(0.0))
|
| 59 |
+
else:
|
| 60 |
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dsl_user_op
|
| 64 |
+
@cute.jit
|
| 65 |
+
def drelu(
|
| 66 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 67 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 68 |
+
if const_expr(not isinstance(x, tuple)):
|
| 69 |
+
x_pos = Boolean(x > 0)
|
| 70 |
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
| 71 |
+
else:
|
| 72 |
+
x0_pos = Boolean(x[0] > 0)
|
| 73 |
+
x1_pos = Boolean(x[1] > 0)
|
| 74 |
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
| 75 |
+
return dx, relu(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dsl_user_op
|
| 79 |
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 80 |
+
if const_expr(not isinstance(x, tuple)):
|
| 81 |
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
| 82 |
+
else:
|
| 83 |
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
| 84 |
+
return cute.arch.mul_packed_f32x2(relu_x, x)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dsl_user_op
|
| 88 |
+
@cute.jit
|
| 89 |
+
def drelu_sq(
|
| 90 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 91 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 92 |
+
"""
|
| 93 |
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
| 94 |
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
| 95 |
+
Returns: (dx, relu_sq_out) where:
|
| 96 |
+
- dx = dout * 2 * x if x > 0, else 0
|
| 97 |
+
- relu_sq_out = max(x, 0) * x
|
| 98 |
+
"""
|
| 99 |
+
if const_expr(not isinstance(x, tuple)):
|
| 100 |
+
relu_x = relu(x)
|
| 101 |
+
relu_sq_out = relu_x * x
|
| 102 |
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
| 103 |
+
dx = 2.0 * (dout * relu_x)
|
| 104 |
+
return dx, relu_sq_out
|
| 105 |
+
else:
|
| 106 |
+
relu_x = relu(x)
|
| 107 |
+
relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
|
| 108 |
+
dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
|
| 109 |
+
return dx, relu_sq_out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dsl_user_op
|
| 113 |
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 114 |
+
"""
|
| 115 |
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
| 116 |
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
| 117 |
+
"""
|
| 118 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
| 119 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
| 120 |
+
if const_expr(not isinstance(x, tuple)):
|
| 121 |
+
return 0.5 * (
|
| 122 |
+
x
|
| 123 |
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
| 124 |
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
| 125 |
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 129 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 130 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 131 |
+
)
|
| 132 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 133 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 134 |
+
x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
|
| 135 |
+
return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dsl_user_op
|
| 139 |
+
def dgelu_tanh_approx(
|
| 140 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 141 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 142 |
+
"""
|
| 143 |
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
| 144 |
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
| 145 |
+
Returns: (dx, gelu_out)
|
| 146 |
+
|
| 147 |
+
Derivative uses the chain rule:
|
| 148 |
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 149 |
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
| 150 |
+
and sech^2(z) = 1 - tanh^2(z)
|
| 151 |
+
"""
|
| 152 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
| 153 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
| 154 |
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
| 155 |
+
|
| 156 |
+
if const_expr(not isinstance(x, tuple)):
|
| 157 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 158 |
+
x_sq = x * x
|
| 159 |
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
| 160 |
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
| 161 |
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
| 162 |
+
gelu_out = x * half_tanh_z_plus_one
|
| 163 |
+
|
| 164 |
+
# Compute gradient
|
| 165 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 166 |
+
sech2_z = 1 - tanh_z * tanh_z
|
| 167 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 168 |
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
| 169 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 170 |
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
| 171 |
+
|
| 172 |
+
dx = dout * dgelu
|
| 173 |
+
return dx, gelu_out
|
| 174 |
+
else:
|
| 175 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 176 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 177 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 178 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 179 |
+
)
|
| 180 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 181 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 182 |
+
half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
| 183 |
+
gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
| 184 |
+
|
| 185 |
+
# Compute gradient
|
| 186 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 187 |
+
sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
| 188 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 189 |
+
dz_dx = cute.arch.fma_packed_f32x2(
|
| 190 |
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 191 |
+
)
|
| 192 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 193 |
+
sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
|
| 194 |
+
x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
|
| 195 |
+
dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
| 196 |
+
|
| 197 |
+
dx = cute.arch.mul_packed_f32x2(dout, dgelu)
|
| 198 |
+
return dx, gelu_out
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dsl_user_op
|
| 202 |
+
@cute.jit
|
| 203 |
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 204 |
+
if const_expr(not isinstance(x, tuple)):
|
| 205 |
+
use_linear = Boolean(x > 20.0)
|
| 206 |
+
return (
|
| 207 |
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
| 208 |
+
if not use_linear
|
| 209 |
+
else x
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
log2_e = math.log2(math.e)
|
| 213 |
+
x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
|
| 214 |
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
| 215 |
+
x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
|
| 216 |
+
log_x_exp_p1 = (
|
| 217 |
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
| 218 |
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
| 219 |
+
)
|
| 220 |
+
ln2 = math.log(2.0)
|
| 221 |
+
softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
| 222 |
+
use_linear_0 = Boolean(x[0] > 20.0)
|
| 223 |
+
use_linear_1 = Boolean(x[1] > 20.0)
|
| 224 |
+
return (
|
| 225 |
+
softplus_x[0] if not use_linear_0 else x[0],
|
| 226 |
+
softplus_x[1] if not use_linear_1 else x[1],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dsl_user_op
|
| 231 |
+
@cute.jit
|
| 232 |
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 233 |
+
use_linear = Boolean(out > 20.0)
|
| 234 |
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
| 235 |
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
| 236 |
+
return dx if not use_linear else dout
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dsl_user_op
|
| 240 |
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
| 241 |
+
"""
|
| 242 |
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
| 243 |
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
| 244 |
+
"""
|
| 245 |
+
if const_expr(not isinstance(x, tuple)):
|
| 246 |
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
| 247 |
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
| 248 |
+
return x_half * tanh(x_half) + x_half
|
| 249 |
+
else:
|
| 250 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
| 251 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 252 |
+
return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@dsl_user_op
|
| 256 |
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 257 |
+
if const_expr(not isinstance(x, tuple)):
|
| 258 |
+
return silu(x) * y
|
| 259 |
+
else:
|
| 260 |
+
return cute.arch.mul_packed_f32x2(silu(x), y)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dsl_user_op
|
| 264 |
+
def dswiglu(
|
| 265 |
+
x: F32_or_F32x2,
|
| 266 |
+
y: F32_or_F32x2,
|
| 267 |
+
dout: F32_or_F32x2,
|
| 268 |
+
*,
|
| 269 |
+
already_halved: bool = False,
|
| 270 |
+
loc=None,
|
| 271 |
+
ip=None,
|
| 272 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 273 |
+
"""
|
| 274 |
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 275 |
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
| 276 |
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
| 277 |
+
|
| 278 |
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
| 279 |
+
|
| 280 |
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
| 281 |
+
to use FFMA instead of FADD and FMUL).
|
| 282 |
+
"""
|
| 283 |
+
if const_expr(not isinstance(x, tuple)):
|
| 284 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
| 285 |
+
# FMUL, MUFU.TANH, then FFMA
|
| 286 |
+
if const_expr(not already_halved):
|
| 287 |
+
sigmoid_x = sigmoid(x)
|
| 288 |
+
silu_x = x * sigmoid_x # FMUL
|
| 289 |
+
else:
|
| 290 |
+
tanh_x = tanh(x) # MUFU.TANH
|
| 291 |
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
| 292 |
+
silu_x = x * tanh_x + x # FFMA
|
| 293 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 294 |
+
# d_silu(x) * dout
|
| 295 |
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
| 296 |
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
| 297 |
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
| 298 |
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
| 299 |
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 300 |
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
| 301 |
+
dx = d_silu_x_dout * y # FMUL
|
| 302 |
+
dy = silu_x_dout
|
| 303 |
+
swiglu_out = silu_x * y # FMUL
|
| 304 |
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
| 305 |
+
return dx, dy, swiglu_out
|
| 306 |
+
else:
|
| 307 |
+
# Compute sigmoid(x) and silu(x)
|
| 308 |
+
if const_expr(not already_halved):
|
| 309 |
+
sigmoid_x = sigmoid(x)
|
| 310 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
|
| 311 |
+
else:
|
| 312 |
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
| 313 |
+
sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
| 314 |
+
silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
|
| 315 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 316 |
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 317 |
+
sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
|
| 318 |
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
| 319 |
+
)
|
| 320 |
+
d_silu_x_dout = cute.arch.fma_packed_f32x2(
|
| 321 |
+
sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
|
| 322 |
+
)
|
| 323 |
+
dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
|
| 324 |
+
dy = silu_x_dout
|
| 325 |
+
swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
|
| 326 |
+
return dx, dy, swiglu_out
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@dsl_user_op
|
| 330 |
+
def swiglu_oai(
|
| 331 |
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 332 |
+
) -> F32_or_F32x2:
|
| 333 |
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
| 334 |
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
| 335 |
+
x * sigmoid(alpha * x) * (y + 1)
|
| 336 |
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
| 337 |
+
"""
|
| 338 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 339 |
+
if const_expr(not isinstance(x, tuple)):
|
| 340 |
+
x_half = 0.5 * x
|
| 341 |
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
| 342 |
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
| 343 |
+
return silu_x * y + silu_x
|
| 344 |
+
else:
|
| 345 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 346 |
+
alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
|
| 347 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 348 |
+
silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
| 349 |
+
return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@dsl_user_op
|
| 353 |
+
def dswiglu_oai(
|
| 354 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 355 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 356 |
+
"""
|
| 357 |
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
| 358 |
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
| 359 |
+
Returns: (dx, dy, swiglu_oai_out)
|
| 360 |
+
|
| 361 |
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
| 362 |
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
| 363 |
+
"""
|
| 364 |
+
if const_expr(not isinstance(x, tuple)):
|
| 365 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 366 |
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
| 367 |
+
# MUFU.TANH, then FFMA
|
| 368 |
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
| 369 |
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
| 370 |
+
silu_x = x * sigmoid_alpha_x # FMUL
|
| 371 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 372 |
+
# FFMA, FFMA, FMUL
|
| 373 |
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 374 |
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
| 375 |
+
dy = silu_x_dout
|
| 376 |
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
| 377 |
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
| 378 |
+
return dx, dy, swiglu_out
|
| 379 |
+
else:
|
| 380 |
+
# Compute sigmoid(alpha * x)
|
| 381 |
+
alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
| 382 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 383 |
+
sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 384 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
|
| 385 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 386 |
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 387 |
+
silu_x_minus_product = cute.arch.fma_packed_f32x2(
|
| 388 |
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
| 389 |
+
)
|
| 390 |
+
sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
|
| 391 |
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
| 392 |
+
)
|
| 393 |
+
d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
| 394 |
+
dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
| 395 |
+
dy = silu_x_dout
|
| 396 |
+
swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 397 |
+
return dx, dy, swiglu_out
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@dsl_user_op
|
| 401 |
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 402 |
+
"""GLU: Gated Linear Unit
|
| 403 |
+
glu(x, y) = sigmoid(x) * y
|
| 404 |
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 405 |
+
"""
|
| 406 |
+
if const_expr(not isinstance(x, tuple)):
|
| 407 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 408 |
+
return sigmoid_x * y # FMUL
|
| 409 |
+
else:
|
| 410 |
+
sigmoid_x = sigmoid(x)
|
| 411 |
+
return cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@dsl_user_op
|
| 415 |
+
def dglu(
|
| 416 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 417 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 418 |
+
"""
|
| 419 |
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 420 |
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
| 421 |
+
Returns: (dx, dy, glu_out) where:
|
| 422 |
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
| 423 |
+
- dy = dout * sigmoid(x)
|
| 424 |
+
- glu_out = sigmoid(x) * y
|
| 425 |
+
"""
|
| 426 |
+
if const_expr(not isinstance(x, tuple)):
|
| 427 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 428 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 429 |
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
| 430 |
+
glu_out = sigmoid_x * y # FMUL
|
| 431 |
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
| 432 |
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
| 433 |
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
| 434 |
+
# = (y - glu_out) * sigmoid_x_dout
|
| 435 |
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
| 436 |
+
dy = sigmoid_x_dout
|
| 437 |
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
| 438 |
+
return dx, dy, glu_out
|
| 439 |
+
else:
|
| 440 |
+
sigmoid_x = sigmoid(x)
|
| 441 |
+
sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
|
| 442 |
+
glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 443 |
+
# dx = (y - glu_out) * sigmoid_x_dout
|
| 444 |
+
y_minus_glu_out = sub_packed_f32x2(y, glu_out)
|
| 445 |
+
dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
| 446 |
+
dy = sigmoid_x_dout
|
| 447 |
+
return dx, dy, glu_out
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@dsl_user_op
|
| 451 |
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 452 |
+
"""ReGLU: ReLU Gated Linear Unit
|
| 453 |
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
| 454 |
+
"""
|
| 455 |
+
if const_expr(not isinstance(x, tuple)):
|
| 456 |
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
| 457 |
+
else:
|
| 458 |
+
relu_x = relu(x)
|
| 459 |
+
return cute.arch.mul_packed_f32x2(relu_x, y)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@dsl_user_op
|
| 463 |
+
@cute.jit
|
| 464 |
+
def dreglu(
|
| 465 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 466 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 467 |
+
"""
|
| 468 |
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 469 |
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
| 470 |
+
Returns: (dx, dy, reglu_out) where:
|
| 471 |
+
- dx = dout * y if x > 0, else 0
|
| 472 |
+
- dy = dout * relu(x)
|
| 473 |
+
- reglu_out = relu(x) * y
|
| 474 |
+
"""
|
| 475 |
+
if const_expr(not isinstance(x, tuple)):
|
| 476 |
+
x_pos = Boolean(x > 0)
|
| 477 |
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
| 478 |
+
dx = (dout * y) if x_pos else Float32(0.0)
|
| 479 |
+
dy = dout * relu_x
|
| 480 |
+
reglu_out = relu_x * y
|
| 481 |
+
return dx, dy, reglu_out
|
| 482 |
+
else:
|
| 483 |
+
x0_pos = Boolean(x[0] > 0)
|
| 484 |
+
x1_pos = Boolean(x[1] > 0)
|
| 485 |
+
relu_x = relu(x)
|
| 486 |
+
dout_y = cute.arch.mul_packed_f32x2(dout, y)
|
| 487 |
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
| 488 |
+
dy = cute.arch.mul_packed_f32x2(dout, relu_x)
|
| 489 |
+
reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
|
| 490 |
+
return dx, dy, reglu_out
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@dsl_user_op
|
| 494 |
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 495 |
+
"""GeGLU: GELU Gated Linear Unit
|
| 496 |
+
geglu(x, y) = gelu(x) * y
|
| 497 |
+
Uses the tanh approximation of GELU
|
| 498 |
+
"""
|
| 499 |
+
if const_expr(not isinstance(x, tuple)):
|
| 500 |
+
return gelu_tanh_approx(x) * y
|
| 501 |
+
else:
|
| 502 |
+
return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
@dsl_user_op
|
| 506 |
+
def dgeglu(
|
| 507 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 508 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 509 |
+
"""
|
| 510 |
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 511 |
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
| 512 |
+
Returns: (dx, dy, geglu_out) where:
|
| 513 |
+
- dx = dout * y * d_gelu(x)
|
| 514 |
+
- dy = dout * gelu(x)
|
| 515 |
+
- geglu_out = gelu(x) * y
|
| 516 |
+
"""
|
| 517 |
+
if const_expr(not isinstance(x, tuple)):
|
| 518 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 519 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 520 |
+
# Compute gradients for geglu
|
| 521 |
+
dx = dgelu_x_dout * y
|
| 522 |
+
dy = gelu_x * dout
|
| 523 |
+
geglu_out = gelu_x * y
|
| 524 |
+
return dx, dy, geglu_out
|
| 525 |
+
else:
|
| 526 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 527 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 528 |
+
# Compute gradients for geglu
|
| 529 |
+
dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
|
| 530 |
+
dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
|
| 531 |
+
geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
|
| 532 |
+
return dx, dy, geglu_out
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# ============================================================================
|
| 536 |
+
# Activation name -> function maps
|
| 537 |
+
# ============================================================================
|
| 538 |
+
|
| 539 |
+
act_fn_map = {
|
| 540 |
+
None: None,
|
| 541 |
+
"silu": silu,
|
| 542 |
+
"relu": relu,
|
| 543 |
+
"relu_sq": relu_sq,
|
| 544 |
+
"gelu_tanh_approx": gelu_tanh_approx,
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
dact_fn_map = {
|
| 548 |
+
None: None,
|
| 549 |
+
"relu": drelu,
|
| 550 |
+
"relu_sq": drelu_sq,
|
| 551 |
+
"gelu_tanh_approx": dgelu_tanh_approx,
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
gate_fn_map = {
|
| 555 |
+
"swiglu": swiglu,
|
| 556 |
+
"swiglu_oai": swiglu_oai,
|
| 557 |
+
"reglu": reglu,
|
| 558 |
+
"geglu": geglu,
|
| 559 |
+
"glu": glu,
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
dgate_fn_map = {
|
| 563 |
+
"swiglu": dswiglu,
|
| 564 |
+
"swiglu_oai": dswiglu_oai,
|
| 565 |
+
"reglu": dreglu,
|
| 566 |
+
"geglu": dgeglu,
|
| 567 |
+
"glu": dglu,
|
| 568 |
+
}
|
build/torch-cuda/quack/compile_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
| 9 |
+
if leading_dim < 0:
|
| 10 |
+
leading_dim = len(shape) + leading_dim
|
| 11 |
+
if dtype is None:
|
| 12 |
+
return None
|
| 13 |
+
stride = tuple(
|
| 14 |
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
| 15 |
+
for i in range(len(shape))
|
| 16 |
+
)
|
| 17 |
+
return cute.runtime.make_fake_tensor(
|
| 18 |
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
| 19 |
+
)
|
build/torch-cuda/quack/copy_utils.py
ADDED
|
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Type, Tuple, Callable, Sequence
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
|
| 9 |
+
from cutlass import Int32, Int16, Boolean, const_expr
|
| 10 |
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 11 |
+
from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
|
| 12 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 13 |
+
import cutlass.pipeline
|
| 14 |
+
from cutlass._mlir.dialects import llvm
|
| 15 |
+
from cutlass._mlir import ir
|
| 16 |
+
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
Sm100MmaPeerBitMask = 0xFEFFFFFF
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dsl_user_op
|
| 23 |
+
def cvt_copy(
|
| 24 |
+
tiled_copy: cute.TiledCopy,
|
| 25 |
+
src: cute.Tensor,
|
| 26 |
+
dst: cute.Tensor,
|
| 27 |
+
*,
|
| 28 |
+
pred: Optional[cute.Tensor] = None,
|
| 29 |
+
retile: bool = False,
|
| 30 |
+
loc=None,
|
| 31 |
+
ip=None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
) -> None:
|
| 34 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 35 |
+
if const_expr(src.element_type != dst.element_type):
|
| 36 |
+
src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
|
| 37 |
+
src_cvt.store(src.load().to(dst.element_type))
|
| 38 |
+
src = src_cvt
|
| 39 |
+
if const_expr(retile):
|
| 40 |
+
src = tiled_copy.retile(src)
|
| 41 |
+
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dsl_user_op
|
| 45 |
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 46 |
+
dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
|
| 47 |
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 48 |
+
return dst
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dsl_user_op
|
| 52 |
+
def load_s2r_retile(
|
| 53 |
+
tiled_copy: cute.TiledCopy,
|
| 54 |
+
src: cute.Tensor,
|
| 55 |
+
dst_shape: cute.Tensor | cute.Shape,
|
| 56 |
+
*,
|
| 57 |
+
loc=None,
|
| 58 |
+
ip=None,
|
| 59 |
+
) -> cute.Tensor:
|
| 60 |
+
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
| 61 |
+
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
| 62 |
+
dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
|
| 63 |
+
else:
|
| 64 |
+
dst = dst_shape
|
| 65 |
+
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
| 66 |
+
return dst
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dsl_user_op
|
| 70 |
+
def load_t2r(
|
| 71 |
+
thr_copy: cute.ThrCopy, shape: cute.Shape, src: cute.Tensor, *, loc=None, ip=None
|
| 72 |
+
) -> cute.Tensor:
|
| 73 |
+
cDst = cute.make_identity_tensor(shape)
|
| 74 |
+
dst = cute.make_rmem_tensor(thr_copy.partition_D(cDst).shape, src.element_type, loc=loc, ip=ip)
|
| 75 |
+
cute.copy(thr_copy, src, dst, loc=loc, ip=ip)
|
| 76 |
+
return dst
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dsl_user_op
|
| 80 |
+
def get_copy_atom(
|
| 81 |
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
| 82 |
+
) -> cute.CopyAtom:
|
| 83 |
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
| 84 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 85 |
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dsl_user_op
|
| 89 |
+
def copy(
|
| 90 |
+
src: cute.Tensor,
|
| 91 |
+
dst: cute.Tensor,
|
| 92 |
+
*,
|
| 93 |
+
pred: Optional[cute.Tensor] = None,
|
| 94 |
+
is_async: bool = False,
|
| 95 |
+
loc=None,
|
| 96 |
+
ip=None,
|
| 97 |
+
**kwargs,
|
| 98 |
+
) -> None:
|
| 99 |
+
num_copy_elems = src.shape[0][0]
|
| 100 |
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
| 101 |
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def tiled_copy_1d(
|
| 105 |
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
| 106 |
+
) -> cute.TiledCopy:
|
| 107 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 108 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 109 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 110 |
+
thr_layout = cute.make_layout(num_threads)
|
| 111 |
+
val_layout = cute.make_layout(num_copy_elems)
|
| 112 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def tiled_copy_2d(
|
| 116 |
+
dtype: Type[cutlass.Numeric],
|
| 117 |
+
threads_per_row: int,
|
| 118 |
+
num_threads: int,
|
| 119 |
+
num_copy_elems: int = 1,
|
| 120 |
+
is_async: bool = False,
|
| 121 |
+
) -> cute.TiledCopy:
|
| 122 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 123 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 124 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 125 |
+
assert num_threads % threads_per_row == 0
|
| 126 |
+
thr_layout = cute.make_ordered_layout(
|
| 127 |
+
(num_threads // threads_per_row, threads_per_row),
|
| 128 |
+
order=(1, 0),
|
| 129 |
+
)
|
| 130 |
+
val_layout = cute.make_layout((1, num_copy_elems))
|
| 131 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@cute.jit
|
| 135 |
+
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
| 136 |
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 137 |
+
tApA = cute.make_rmem_tensor(
|
| 138 |
+
cute.make_layout(
|
| 139 |
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 140 |
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
| 141 |
+
),
|
| 142 |
+
Boolean,
|
| 143 |
+
)
|
| 144 |
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
| 145 |
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
| 146 |
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
| 147 |
+
return tApA
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# def tiled_copy_2d(
|
| 151 |
+
# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
| 152 |
+
# ) -> cute.TiledCopy:
|
| 153 |
+
# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
| 154 |
+
# copy_elems = num_copy_bits // dtype.width
|
| 155 |
+
# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 156 |
+
# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 157 |
+
# gmem_threads_per_row = major_mode_size // copy_elems
|
| 158 |
+
# assert num_threads % gmem_threads_per_row == 0
|
| 159 |
+
# thr_layout = cute.make_ordered_layout(
|
| 160 |
+
# (num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 161 |
+
# order=(1, 0),
|
| 162 |
+
# )
|
| 163 |
+
# val_layout = cute.make_layout((1, copy_elems))
|
| 164 |
+
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Ragged tensor trick for TMA: encodes variable-length sequences into a higher-rank
|
| 168 |
+
# tensor so that TMA's out-of-bounds checking handles sequence boundaries.
|
| 169 |
+
#
|
| 170 |
+
# Given a tensor T with a ragged dimension (variable-length across batches), we create
|
| 171 |
+
# a higher-rank tensor where the ragged dim is replaced with a fixed size `big_int`, and
|
| 172 |
+
# extra dim(s) are appended. When indexing into a specific sequence at (offset, length),
|
| 173 |
+
# `offset_ragged_tensor` computes coordinates such that:
|
| 174 |
+
# ragged_coord = big_int - length (OOB check clamps reads past the sequence end)
|
| 175 |
+
# extra_coord(s) = f(offset, length) (selects the correct memory region)
|
| 176 |
+
#
|
| 177 |
+
# ptr_shift=True: 1-extra-dim approach (adds 1 dim, supports up to 4D input):
|
| 178 |
+
# Shape: (*before, big_int, *after, max_int)
|
| 179 |
+
# Stride: (*original_strides, stride_r) where stride_r = T.stride[ragged_dim]
|
| 180 |
+
# Pointer shifted backward by big_int * stride_r elements.
|
| 181 |
+
# Address for coords (big_int - length) in ragged dim, (offset + length) in extra dim:
|
| 182 |
+
# addr = (base - big_int * s_r) + (big_int - length) * s_r + (offset + length) * s_r
|
| 183 |
+
# = base + offset * s_r [correct]
|
| 184 |
+
# Works for epilogue TMA store. Does NOT work for TMA load with large big_int
|
| 185 |
+
# — the shifted pointer must land in physically mapped GPU memory.
|
| 186 |
+
#
|
| 187 |
+
# ptr_shift=False: 2-extra-dim approach (adds 2 dims, supports up to 3D input):
|
| 188 |
+
# Shape: (*before, big_int, *after, max_int, max_int)
|
| 189 |
+
# Stride: (*before_strides, stride_r, *after_strides, 2^34 - stride_r, stride_r)
|
| 190 |
+
# No pointer shift. Uses 64-bit address wraparound to cancel the ragged offset.
|
| 191 |
+
# Let W = 2^34 - stride_r. Address for coords (big_int - length) in ragged dim,
|
| 192 |
+
# big_int in extra dim 0, (offset + length) in extra dim 1:
|
| 193 |
+
# addr = base + (big_int - length) * s_r + big_int * W + (offset + length) * s_r
|
| 194 |
+
# = base + big_int * (s_r + W) - length * s_r + (offset + length) * s_r
|
| 195 |
+
# = base + big_int * 2^34 + offset * s_r
|
| 196 |
+
# Since big_int = 2^30: big_int * 2^34 = 2^64 ≡ 0 (mod 2^64), so:
|
| 197 |
+
# addr = base + offset * s_r [correct]
|
| 198 |
+
# Works for all TMA paths since the base pointer is never shifted.
|
| 199 |
+
#
|
| 200 |
+
# Ragged tensor was adapted from the implementation from Triton, but here we have an option that
|
| 201 |
+
# only needs 1 extra dimension instead of 2.
|
| 202 |
+
# https://github.com/triton-lang/triton/blob/main/python/triton/tools/ragged_tma.py
|
| 203 |
+
BIG_INT = 2**30
|
| 204 |
+
MAX_INT = 2**31 - 1
|
| 205 |
+
BIG_INT_INV = 2**64 // BIG_INT
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@dsl_user_op
|
| 209 |
+
def create_ragged_tensor_for_tma(
|
| 210 |
+
T: cute.Tensor,
|
| 211 |
+
ragged_dim: int = 0,
|
| 212 |
+
ptr_shift: bool = False,
|
| 213 |
+
*,
|
| 214 |
+
loc=None,
|
| 215 |
+
ip=None,
|
| 216 |
+
) -> cute.Tensor:
|
| 217 |
+
rank = cute.rank(T)
|
| 218 |
+
if ragged_dim < 0:
|
| 219 |
+
ragged_dim += rank
|
| 220 |
+
if ptr_shift:
|
| 221 |
+
assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions"
|
| 222 |
+
new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,)
|
| 223 |
+
new_stride = T.stride + (T.stride[ragged_dim],)
|
| 224 |
+
ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1)
|
| 225 |
+
new_ptr = cute.domain_offset(ptr_offset, T).iterator
|
| 226 |
+
return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride))
|
| 227 |
+
else:
|
| 228 |
+
assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions"
|
| 229 |
+
stride_r = T.stride[ragged_dim]
|
| 230 |
+
new_shape = (
|
| 231 |
+
T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT)
|
| 232 |
+
)
|
| 233 |
+
new_stride = (
|
| 234 |
+
T.stride[:ragged_dim]
|
| 235 |
+
+ (stride_r,)
|
| 236 |
+
+ T.stride[ragged_dim + 1 :]
|
| 237 |
+
+ (BIG_INT_INV - stride_r, stride_r)
|
| 238 |
+
)
|
| 239 |
+
return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@dsl_user_op
|
| 243 |
+
def offset_ragged_tensor(
|
| 244 |
+
T: cute.Tensor,
|
| 245 |
+
offset: Int32,
|
| 246 |
+
length: Int32,
|
| 247 |
+
ragged_dim: int = 0,
|
| 248 |
+
ptr_shift: bool = False,
|
| 249 |
+
*,
|
| 250 |
+
loc=None,
|
| 251 |
+
ip=None,
|
| 252 |
+
) -> cute.Tensor:
|
| 253 |
+
rank = cute.rank(T)
|
| 254 |
+
if ragged_dim < 0:
|
| 255 |
+
ragged_dim += rank
|
| 256 |
+
big_int = cute.size(T, mode=[ragged_dim])
|
| 257 |
+
offset_val = big_int - length
|
| 258 |
+
if ptr_shift:
|
| 259 |
+
# 1-extra-dim: rank = original_rank + 1
|
| 260 |
+
assert rank >= ragged_dim + 2
|
| 261 |
+
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2)
|
| 262 |
+
index_tuple = (None,) * (rank - 1) + (offset + length,)
|
| 263 |
+
else:
|
| 264 |
+
# 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims
|
| 265 |
+
assert rank >= ragged_dim + 3
|
| 266 |
+
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3)
|
| 267 |
+
index_tuple = (None,) * (rank - 2) + (big_int, offset + length)
|
| 268 |
+
return cute.domain_offset(offset_tuple, T[index_tuple])
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
| 272 |
+
bit_msk = (1 << b) - 1
|
| 273 |
+
yyy_msk = bit_msk << (m + s)
|
| 274 |
+
return ptr_int ^ ((ptr_int & yyy_msk) >> s)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def swizzle_ptr(ptr: cute.Pointer):
|
| 278 |
+
swz = ptr.type.swizzle_type
|
| 279 |
+
ptr_int = swizzle_int(ptr.toint(), swz.num_bits, swz.num_base, swz.num_shift)
|
| 280 |
+
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
| 284 |
+
outer = tensor.layout
|
| 285 |
+
width = tensor.element_type.width
|
| 286 |
+
swizzle_type = tensor.iterator.type.swizzle_type
|
| 287 |
+
inner = cute.make_swizzle(swizzle_type.num_bits, swizzle_type.num_base, swizzle_type.num_shift)
|
| 288 |
+
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
| 289 |
+
# for 16 bits and <3, 2, 3> for 32 bits)
|
| 290 |
+
new_layout = cute.recast_layout(
|
| 291 |
+
width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
|
| 292 |
+
)
|
| 293 |
+
# recast_ptr to remove the pointer swizzle
|
| 294 |
+
return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def partition_D_position_independent(
|
| 298 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 299 |
+
) -> cute.Tensor:
|
| 300 |
+
return cute.make_tensor(
|
| 301 |
+
swizzle_ptr(thr_copy.partition_D(tensor).iterator),
|
| 302 |
+
thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def partition_S_position_independent(
|
| 307 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 308 |
+
) -> cute.Tensor:
|
| 309 |
+
return cute.make_tensor(
|
| 310 |
+
swizzle_ptr(thr_copy.partition_S(tensor).iterator),
|
| 311 |
+
thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@dsl_user_op
|
| 316 |
+
def sm90_get_smem_load_op(
|
| 317 |
+
layout_c: cutlass.utils.LayoutEnum,
|
| 318 |
+
elem_ty_c: Type[cutlass.Numeric],
|
| 319 |
+
*,
|
| 320 |
+
loc=None,
|
| 321 |
+
ip=None,
|
| 322 |
+
) -> cute.CopyAtom:
|
| 323 |
+
"""
|
| 324 |
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
| 325 |
+
|
| 326 |
+
Parameters:
|
| 327 |
+
-----------
|
| 328 |
+
layout_c : LayoutEnum
|
| 329 |
+
The layout enum of the output tensor D.
|
| 330 |
+
|
| 331 |
+
elem_ty_c : Type[Numeric]
|
| 332 |
+
The element type for output tensor D.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
--------
|
| 336 |
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
| 340 |
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
| 341 |
+
is_m_major = layout_c.is_m_major_c()
|
| 342 |
+
if elem_ty_c.width == 16:
|
| 343 |
+
return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
|
| 344 |
+
else:
|
| 345 |
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_smem_store_atom(
|
| 349 |
+
arch: cutlass.Constexpr[int],
|
| 350 |
+
element_type: Type[cute.Numeric],
|
| 351 |
+
transpose: bool = False,
|
| 352 |
+
major_mode_size: Optional[int] = None,
|
| 353 |
+
) -> cute.CopyAtom:
|
| 354 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 355 |
+
return cute.make_copy_atom(
|
| 356 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 357 |
+
element_type,
|
| 358 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
num_matrices = (
|
| 362 |
+
4
|
| 363 |
+
if major_mode_size is None or major_mode_size % 16 == 0
|
| 364 |
+
else (2 if major_mode_size % 8 == 0 else 1)
|
| 365 |
+
)
|
| 366 |
+
return cute.make_copy_atom(
|
| 367 |
+
warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
|
| 368 |
+
element_type,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_smem_load_atom(
|
| 373 |
+
arch: cutlass.Constexpr[int],
|
| 374 |
+
element_type: Type[cute.Numeric],
|
| 375 |
+
transpose: bool = False,
|
| 376 |
+
major_mode_size: Optional[int] = None,
|
| 377 |
+
) -> cute.CopyAtom:
|
| 378 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 379 |
+
return cute.make_copy_atom(
|
| 380 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 381 |
+
element_type,
|
| 382 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 383 |
+
)
|
| 384 |
+
else:
|
| 385 |
+
num_matrices = (
|
| 386 |
+
4
|
| 387 |
+
if major_mode_size is None or major_mode_size % 16 == 0
|
| 388 |
+
else (2 if major_mode_size % 8 == 0 else 1)
|
| 389 |
+
)
|
| 390 |
+
return cute.make_copy_atom(
|
| 391 |
+
warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
|
| 392 |
+
element_type,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def get_smem_store_C(
|
| 397 |
+
tiled_mma: cute.TiledMma,
|
| 398 |
+
sC: cute.Tensor,
|
| 399 |
+
tidx: Int32,
|
| 400 |
+
arch: int,
|
| 401 |
+
transpose: bool = False,
|
| 402 |
+
position_independent=False,
|
| 403 |
+
major_mode_size: Optional[int] = None,
|
| 404 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 405 |
+
dtype = sC.element_type
|
| 406 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose, major_mode_size=major_mode_size)
|
| 407 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 408 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 409 |
+
if const_expr(not position_independent):
|
| 410 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 411 |
+
else:
|
| 412 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 413 |
+
|
| 414 |
+
def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
|
| 415 |
+
dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
|
| 416 |
+
cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
|
| 417 |
+
|
| 418 |
+
return copy_fn, thr_copy, tRS_sC
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def get_smem_load_C(
|
| 422 |
+
tiled_mma: cute.TiledMma,
|
| 423 |
+
sC: cute.Tensor,
|
| 424 |
+
tidx: Int32,
|
| 425 |
+
arch: int,
|
| 426 |
+
transpose: bool = False,
|
| 427 |
+
position_independent=False,
|
| 428 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 429 |
+
dtype = sC.element_type
|
| 430 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 431 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 432 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 433 |
+
if const_expr(not position_independent):
|
| 434 |
+
tSR_sC = thr_copy.partition_S(sC)
|
| 435 |
+
else:
|
| 436 |
+
tSR_sC = partition_S_position_independent(thr_copy, sC)
|
| 437 |
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
| 438 |
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 439 |
+
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
| 440 |
+
|
| 441 |
+
def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
|
| 442 |
+
src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
|
| 443 |
+
return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
|
| 444 |
+
|
| 445 |
+
return copy_fn, thr_copy, tSR_sC
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def epilog_smem_copy_atom(
|
| 449 |
+
tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
|
| 450 |
+
) -> cute.TiledCopy:
|
| 451 |
+
copy_atom_C = cute.make_copy_atom(
|
| 452 |
+
warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
|
| 453 |
+
cutlass.Float16, # this is just to get the right source layout
|
| 454 |
+
)
|
| 455 |
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
| 456 |
+
return tiled_copy_C_atom
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def get_smem_store_epi(
|
| 460 |
+
tiled_mma: cute.TiledMma,
|
| 461 |
+
epi_tile: cute.Shape,
|
| 462 |
+
sC: Optional[cute.Tensor],
|
| 463 |
+
tidx: Int32,
|
| 464 |
+
arch: int,
|
| 465 |
+
transpose: bool = False,
|
| 466 |
+
position_independent=False,
|
| 467 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
| 468 |
+
dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
|
| 469 |
+
tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
|
| 470 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 471 |
+
tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
|
| 472 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 473 |
+
tRS_sC = None
|
| 474 |
+
if const_expr(sC is not None):
|
| 475 |
+
if const_expr(not position_independent):
|
| 476 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 477 |
+
else:
|
| 478 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 479 |
+
sC_shape = sC.shape[:2] if sC is not None else epi_tile
|
| 480 |
+
# (R2S, R2S_M, R2S_N, PIPE_C)
|
| 481 |
+
tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
|
| 482 |
+
tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
|
| 483 |
+
|
| 484 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 485 |
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
|
| 486 |
+
|
| 487 |
+
return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def get_smem_store_A(
|
| 491 |
+
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
| 492 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 493 |
+
dtype = sA.element_type
|
| 494 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 495 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 496 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 497 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 498 |
+
if const_expr(not position_independent):
|
| 499 |
+
tRS_sA = thr_copy.partition_D(sA)
|
| 500 |
+
else:
|
| 501 |
+
tRS_sA = partition_D_position_independent(thr_copy, sA)
|
| 502 |
+
|
| 503 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 504 |
+
cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
|
| 505 |
+
|
| 506 |
+
return copy_fn, thr_copy, tRS_sA
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def get_smem_load_A(
|
| 510 |
+
tiled_mma: cute.TiledMma,
|
| 511 |
+
sA: cute.Tensor,
|
| 512 |
+
tidx: Int32,
|
| 513 |
+
arch: int,
|
| 514 |
+
with_dst_tensor: bool = False,
|
| 515 |
+
position_independent=False,
|
| 516 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 517 |
+
dtype = sA.element_type
|
| 518 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 519 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 520 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 521 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 522 |
+
if const_expr(not position_independent):
|
| 523 |
+
tSR_sA = thr_copy.partition_S(sA)
|
| 524 |
+
else:
|
| 525 |
+
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
| 526 |
+
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
| 527 |
+
|
| 528 |
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
| 529 |
+
return load_s2r_retile(
|
| 530 |
+
tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
|
| 534 |
+
return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
|
| 535 |
+
|
| 536 |
+
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@dsl_user_op
|
| 540 |
+
def cpasync_reduce_bulk_add_f32(
|
| 541 |
+
smem_ptr: cute.Pointer,
|
| 542 |
+
gmem_ptr: cute.Pointer,
|
| 543 |
+
store_bytes: int | Int32,
|
| 544 |
+
*,
|
| 545 |
+
loc=None,
|
| 546 |
+
ip=None,
|
| 547 |
+
):
|
| 548 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 549 |
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
| 550 |
+
llvm.inline_asm(
|
| 551 |
+
None,
|
| 552 |
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
| 553 |
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
| 554 |
+
"l,r,r",
|
| 555 |
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
| 556 |
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
| 557 |
+
# "l,r,r,l",
|
| 558 |
+
has_side_effects=True,
|
| 559 |
+
is_align_stack=False,
|
| 560 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
@dsl_user_op
|
| 565 |
+
def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
|
| 566 |
+
"""
|
| 567 |
+
Get the address of the TMA descriptor embedded in a TMA Copy Atom.
|
| 568 |
+
|
| 569 |
+
Extracts the constant memory address of the TMA descriptor for use with
|
| 570 |
+
custom PTX instructions.
|
| 571 |
+
|
| 572 |
+
:param tma_atom: TMA Copy Atom from make_tiled_tma_atom
|
| 573 |
+
:return: Pointer to TMA descriptor in constant memory
|
| 574 |
+
|
| 575 |
+
Example:
|
| 576 |
+
>>> desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 577 |
+
"""
|
| 578 |
+
exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
|
| 579 |
+
tma_desc_ptr_type = ir.Type.parse(
|
| 580 |
+
"!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
|
| 581 |
+
)
|
| 582 |
+
return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
@dsl_user_op
|
| 586 |
+
def tma_gather4_load(
|
| 587 |
+
tma_desc_ptr: cute.Pointer,
|
| 588 |
+
dst_smem_ptr: cute.Pointer,
|
| 589 |
+
mbarrier_ptr: cute.Pointer,
|
| 590 |
+
col_idx: Int32,
|
| 591 |
+
row_indices: Sequence[Int32],
|
| 592 |
+
*,
|
| 593 |
+
num_cta: int = 1,
|
| 594 |
+
multicast_mask=None,
|
| 595 |
+
loc=None,
|
| 596 |
+
ip=None,
|
| 597 |
+
) -> None:
|
| 598 |
+
"""
|
| 599 |
+
Perform TMA gather4 load from global memory to shared memory.
|
| 600 |
+
|
| 601 |
+
Issues PTX instruction:
|
| 602 |
+
cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 603 |
+
[dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
|
| 604 |
+
|
| 605 |
+
This loads 4 rows (specified by row_indices) from a 2D tensor at the given
|
| 606 |
+
column index into shared memory, using the TMA descriptor.
|
| 607 |
+
|
| 608 |
+
:param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
|
| 609 |
+
:type tma_desc_ptr: Pointer
|
| 610 |
+
:param dst_smem_ptr: Destination address in shared memory
|
| 611 |
+
:type dst_smem_ptr: Pointer
|
| 612 |
+
:param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
|
| 613 |
+
:type mbarrier_ptr: Pointer
|
| 614 |
+
:param col_idx: Column index
|
| 615 |
+
:type col_idx: Int32
|
| 616 |
+
:param row_indices: Sequence of exactly 4 row indices
|
| 617 |
+
:type row_indices: Sequence[Int32]
|
| 618 |
+
:param num_cta: Number of CTAs participating (default: 1)
|
| 619 |
+
:type num_cta: int
|
| 620 |
+
:param multicast_mask: Optional multicast mask
|
| 621 |
+
:type multicast_mask: Int16
|
| 622 |
+
|
| 623 |
+
Requirements:
|
| 624 |
+
- row_indices must contain exactly 4 elements
|
| 625 |
+
- Compute capability >= SM_100 (Blackwell)
|
| 626 |
+
- TMA descriptor must be properly initialized for 2D tensor
|
| 627 |
+
|
| 628 |
+
Example:
|
| 629 |
+
>>> from cutlass.cute.nvgpu import cpasync
|
| 630 |
+
>>> from cutlass.cute import core
|
| 631 |
+
>>>
|
| 632 |
+
>>> # Create TMA descriptor
|
| 633 |
+
>>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
|
| 634 |
+
>>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 635 |
+
>>>
|
| 636 |
+
>>> # Compute indices (typically from kernel logic)
|
| 637 |
+
>>> col_idx = core.get(...) or 5 # Int32 value
|
| 638 |
+
>>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
|
| 639 |
+
>>>
|
| 640 |
+
>>> # Gather 4 rows at computed column
|
| 641 |
+
>>> tma_gather4_load(
|
| 642 |
+
... tma_desc_ptr=tma_desc_ptr,
|
| 643 |
+
... dst_smem_ptr=smem_ptr,
|
| 644 |
+
... mbarrier_ptr=barrier_ptr,
|
| 645 |
+
... col_idx=col_idx,
|
| 646 |
+
... row_indices=row_indices
|
| 647 |
+
... )
|
| 648 |
+
"""
|
| 649 |
+
if len(row_indices) != 4:
|
| 650 |
+
raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
|
| 651 |
+
col_val = Int32(col_idx).ir_value()
|
| 652 |
+
row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
|
| 653 |
+
# Convert pointers to integer addresses
|
| 654 |
+
desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 655 |
+
dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 656 |
+
mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
|
| 657 |
+
if num_cta > 1:
|
| 658 |
+
# Executed by both CTAs. Set peer bit to 0 so that the
|
| 659 |
+
# transaction bytes will update CTA0's barrier.
|
| 660 |
+
mbar_addr = mbar_addr & Sm100MmaPeerBitMask
|
| 661 |
+
mbar_addr = mbar_addr.ir_value()
|
| 662 |
+
# Handle multicast_mask - may already be ir.Value or Python int
|
| 663 |
+
multicast_mask_val = None
|
| 664 |
+
if multicast_mask is not None:
|
| 665 |
+
multicast_mask_val = Int16(multicast_mask).ir_value()
|
| 666 |
+
assert multicast_mask_val is None, "multicast is not supported yet"
|
| 667 |
+
# Emit inline PTX for TMA gather4
|
| 668 |
+
# PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 669 |
+
# [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
|
| 670 |
+
ptx = (
|
| 671 |
+
f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
|
| 672 |
+
"[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
llvm.inline_asm(
|
| 676 |
+
None,
|
| 677 |
+
[
|
| 678 |
+
dst_addr,
|
| 679 |
+
desc_addr,
|
| 680 |
+
col_val,
|
| 681 |
+
row_vals[0],
|
| 682 |
+
row_vals[1],
|
| 683 |
+
row_vals[2],
|
| 684 |
+
row_vals[3],
|
| 685 |
+
mbar_addr,
|
| 686 |
+
],
|
| 687 |
+
ptx,
|
| 688 |
+
"r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
|
| 689 |
+
has_side_effects=True,
|
| 690 |
+
is_align_stack=False,
|
| 691 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 692 |
+
loc=loc,
|
| 693 |
+
ip=ip,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def cpasync_bulk_get_copy_fn(
|
| 698 |
+
src_tensor: cute.Tensor,
|
| 699 |
+
dst_tensor: cute.Tensor,
|
| 700 |
+
single_stage: bool = False,
|
| 701 |
+
**kwargs,
|
| 702 |
+
) -> Callable:
|
| 703 |
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
| 704 |
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
| 705 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 706 |
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
| 707 |
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
| 708 |
+
|
| 709 |
+
def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 710 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 711 |
+
with cute.arch.elect_one():
|
| 712 |
+
cute.copy(
|
| 713 |
+
atom,
|
| 714 |
+
src[None, src_idx],
|
| 715 |
+
dst[None, dst_idx],
|
| 716 |
+
mbar_ptr=tma_bar_ptr,
|
| 717 |
+
**new_kwargs,
|
| 718 |
+
**kwargs,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 722 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 723 |
+
with cute.arch.elect_one():
|
| 724 |
+
cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
|
| 725 |
+
|
| 726 |
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
@dsl_user_op
|
| 730 |
+
def tma_get_copy_fn(
|
| 731 |
+
atom: cute.CopyAtom,
|
| 732 |
+
cta_coord: cute.Coord,
|
| 733 |
+
cta_layout: cute.Layout,
|
| 734 |
+
src_tensor: cute.Tensor,
|
| 735 |
+
dst_tensor: cute.Tensor,
|
| 736 |
+
filter_zeros: bool = False,
|
| 737 |
+
single_stage: bool = False,
|
| 738 |
+
*,
|
| 739 |
+
loc=None,
|
| 740 |
+
ip=None,
|
| 741 |
+
**kwargs,
|
| 742 |
+
) -> Callable:
|
| 743 |
+
src_is_smem = const_expr(
|
| 744 |
+
isinstance(src_tensor.iterator, cute.Pointer)
|
| 745 |
+
and src_tensor.memspace == cute.AddressSpace.smem
|
| 746 |
+
)
|
| 747 |
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
| 748 |
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
| 749 |
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
| 750 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 751 |
+
s, g = cpasync.tma_partition(
|
| 752 |
+
atom,
|
| 753 |
+
cta_coord,
|
| 754 |
+
cta_layout,
|
| 755 |
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 756 |
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 757 |
+
loc=loc,
|
| 758 |
+
ip=ip,
|
| 759 |
+
)
|
| 760 |
+
if const_expr(filter_zeros):
|
| 761 |
+
s = cute.filter_zeros(s)
|
| 762 |
+
g = cute.filter_zeros(g)
|
| 763 |
+
src, dst = (s, g) if src_is_smem else (g, s)
|
| 764 |
+
|
| 765 |
+
@dsl_user_op
|
| 766 |
+
def copy_tma(src_idx, dst_idx, *, loc=None, ip=None, **new_kwargs):
|
| 767 |
+
cute.copy(
|
| 768 |
+
atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs, loc=loc, ip=ip
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
@dsl_user_op
|
| 772 |
+
def copy_tma_single_stage(*, loc=None, ip=None, **new_kwargs):
|
| 773 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs, loc=loc, ip=ip)
|
| 774 |
+
|
| 775 |
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
| 779 |
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
| 780 |
+
copy(
|
| 781 |
+
src_idx=src_idx,
|
| 782 |
+
dst_idx=producer_state.index,
|
| 783 |
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
| 784 |
+
**new_kwargs,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
return copy_fn
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
@cute.jit
|
| 791 |
+
def gather_m_get_copy_fn(
|
| 792 |
+
thr_copy_A: cute.ThrCopy,
|
| 793 |
+
mA: cute.Tensor, # (whatever, K)
|
| 794 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 795 |
+
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
| 796 |
+
limit_m: Int32,
|
| 797 |
+
limit_k: Int32,
|
| 798 |
+
) -> Callable:
|
| 799 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 800 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 801 |
+
# k-major
|
| 802 |
+
assert tAsA.shape[2] == 1
|
| 803 |
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 804 |
+
|
| 805 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 806 |
+
if const_expr(not is_even_m_smem):
|
| 807 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 808 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 809 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 810 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 811 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 812 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 813 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 814 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 815 |
+
limit_m = limit_m - tAcA[0][0]
|
| 816 |
+
limit_k = limit_k - tAcA[0][1]
|
| 817 |
+
# Read and cache indices for A
|
| 818 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 819 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 820 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 821 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 822 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 823 |
+
m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
|
| 824 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 825 |
+
row_idx = tAcA[0, m, 0][0]
|
| 826 |
+
if tApA_m[m]:
|
| 827 |
+
m_idx[m] = gsAIdx[row_idx]
|
| 828 |
+
else:
|
| 829 |
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 830 |
+
|
| 831 |
+
mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
|
| 832 |
+
|
| 833 |
+
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 834 |
+
tApA_k = None
|
| 835 |
+
if const_expr(pred):
|
| 836 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 837 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 838 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 839 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 840 |
+
mA_cur = mA_k[None, (None, src_idx)]
|
| 841 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 842 |
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
| 843 |
+
# ((elems_per_load), thread_per_row)
|
| 844 |
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
| 845 |
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
| 846 |
+
mA_row = cute.tiled_divide(
|
| 847 |
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
| 848 |
+
)[None, None, 0]
|
| 849 |
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
| 850 |
+
# There's only 1 load per row
|
| 851 |
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
| 852 |
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
| 853 |
+
cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
|
| 854 |
+
|
| 855 |
+
return copy_fn
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
@cute.jit
|
| 859 |
+
def gather_k_get_copy_fn(
|
| 860 |
+
thr_copy_A: cute.ThrCopy,
|
| 861 |
+
mA: cute.Tensor, # (tile_M, whatever)
|
| 862 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 863 |
+
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
| 864 |
+
limit_m: Int32,
|
| 865 |
+
limit_k: Int32,
|
| 866 |
+
) -> Callable:
|
| 867 |
+
gAIdx, sAIdx = None, None
|
| 868 |
+
if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
|
| 869 |
+
gAIdx = gsAIdx
|
| 870 |
+
else:
|
| 871 |
+
assert gsAIdx.memspace == cute.AddressSpace.smem
|
| 872 |
+
sAIdx = gsAIdx
|
| 873 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 874 |
+
# (atom_v, CPY_M, 1, STAGE)
|
| 875 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 876 |
+
# m-major
|
| 877 |
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
| 878 |
+
|
| 879 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 880 |
+
if const_expr(not is_even_m_smem):
|
| 881 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 882 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 883 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 884 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 885 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 886 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 887 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 888 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 889 |
+
limit_m = limit_m - tAcA[0][0]
|
| 890 |
+
limit_k = limit_k - tAcA[0][1]
|
| 891 |
+
# Read and cache indices for A
|
| 892 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 893 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 894 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 895 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 896 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 897 |
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
| 898 |
+
# This is very convoluted but idk a better way
|
| 899 |
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
| 900 |
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
| 901 |
+
tidx = thr_copy_A.thr_idx
|
| 902 |
+
tAmA = cute.logical_divide(
|
| 903 |
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
| 904 |
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
| 905 |
+
|
| 906 |
+
def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 907 |
+
# Prefetch mAIdx early, even before smem is free
|
| 908 |
+
tApA_k = None
|
| 909 |
+
if const_expr(pred):
|
| 910 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 911 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 912 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 913 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 914 |
+
gAIdx_cur = gAIdx[None, src_idx]
|
| 915 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 916 |
+
for k in cutlass.range(cols_per_thread):
|
| 917 |
+
col_idx = tAcA[0, 0, k][1]
|
| 918 |
+
if const_expr(not pred):
|
| 919 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 920 |
+
else:
|
| 921 |
+
if tApA_k[k]:
|
| 922 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 923 |
+
else:
|
| 924 |
+
k_idx[k] = -1
|
| 925 |
+
return k_idx, tApA_k
|
| 926 |
+
|
| 927 |
+
def prefetch_from_smem_fn(
|
| 928 |
+
a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
|
| 929 |
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 930 |
+
tApA_k = None
|
| 931 |
+
if const_expr(pred):
|
| 932 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 933 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 934 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 935 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 936 |
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 937 |
+
sAIdx_cur = sAIdx[None, dst_idx]
|
| 938 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 939 |
+
for k in cutlass.range(cols_per_thread):
|
| 940 |
+
col_idx = tAcA[0, 0, k][1]
|
| 941 |
+
k_idx[k] = sAIdx_cur[col_idx]
|
| 942 |
+
cute.arch.sync_warp()
|
| 943 |
+
with cute.arch.elect_one():
|
| 944 |
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
| 945 |
+
return k_idx, tApA_k
|
| 946 |
+
|
| 947 |
+
def copy_fn(
|
| 948 |
+
src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
|
| 949 |
+
):
|
| 950 |
+
k_idx, tApA_k = k_idx_tApA_k
|
| 951 |
+
tApA_k_pred = None
|
| 952 |
+
if const_expr(pred):
|
| 953 |
+
tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
| 954 |
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
| 955 |
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
| 956 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 957 |
+
if tApA_m[m]:
|
| 958 |
+
cute.copy(
|
| 959 |
+
thr_copy_A,
|
| 960 |
+
tAmA[None, m, k_idx[k]],
|
| 961 |
+
tAsA[(None, m, k), dst_idx],
|
| 962 |
+
pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
| 966 |
+
gAIdx is not None
|
| 967 |
+
) else prefetch_from_smem_fn
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
@cute.jit
|
| 971 |
+
def gather_m_get_tma_copy_fn(
|
| 972 |
+
tma_atom: cute.CopyAtom,
|
| 973 |
+
mA: cute.Tensor, # (whatever, K)
|
| 974 |
+
sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
|
| 975 |
+
sAIdx: cute.Tensor, # (tile_M),
|
| 976 |
+
warp_idx: Int32,
|
| 977 |
+
num_warps: int,
|
| 978 |
+
num_cta: int = 1,
|
| 979 |
+
) -> Callable:
|
| 980 |
+
tile_M = cute.size(sAIdx, mode=[0])
|
| 981 |
+
tile_K = cute.size(sA[None, None, 0]) // tile_M
|
| 982 |
+
assert tile_M % 4 == 0
|
| 983 |
+
# cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
|
| 984 |
+
cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
|
| 985 |
+
|
| 986 |
+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
|
| 987 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
|
| 988 |
+
cute.make_layout(num_warps), # thr_layout
|
| 989 |
+
cute.make_layout(4), # val_layout
|
| 990 |
+
)
|
| 991 |
+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
|
| 992 |
+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
|
| 993 |
+
# ((4, 1), 8, (64, 1), STAGE)
|
| 994 |
+
tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
|
| 995 |
+
tSR_rAIdx = load_s2r(tSR_sAIdx)
|
| 996 |
+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
|
| 997 |
+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 998 |
+
|
| 999 |
+
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
|
| 1000 |
+
col_idx = tile_K * src_idx
|
| 1001 |
+
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1002 |
+
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
|
| 1003 |
+
smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
|
| 1004 |
+
with cute.arch.elect_one():
|
| 1005 |
+
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
|
| 1006 |
+
|
| 1007 |
+
return copy_fn
|
build/torch-cuda/quack/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Tuple, get_origin
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from dataclasses import dataclass, fields
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from triton.tools.disasm import extract
|
| 11 |
+
except ImportError:
|
| 12 |
+
extract = None
|
| 13 |
+
|
| 14 |
+
import cutlass
|
| 15 |
+
import cutlass.cute as cute
|
| 16 |
+
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 17 |
+
from cutlass.base_dsl.typing import JitArgument
|
| 18 |
+
from cutlass.base_dsl.tvm_ffi_builder import spec
|
| 19 |
+
from cutlass.cutlass_dsl import NumericMeta
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
| 26 |
+
cute_compile_og = cute.compile
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Patch TVM-FFI converter to handle Constexpr type annotations as compile-time constants.
|
| 30 |
+
# Fields annotated with cutlass.Constexpr[T] are emitted as ConstNone (not runtime args).
|
| 31 |
+
# At call time, pass None for these fields; the compile-time value is baked in.
|
| 32 |
+
import cutlass.cute._tvm_ffi_args_spec_converter as _converter_module # noqa
|
| 33 |
+
|
| 34 |
+
_original_convert_single_arg = _converter_module._convert_single_arg
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _patched_convert_single_arg(arg, arg_name, arg_type, ctx):
|
| 38 |
+
if arg_type is not None and get_origin(arg_type) is cutlass.Constexpr:
|
| 39 |
+
return spec.ConstNone(arg_name)
|
| 40 |
+
# If arg is a NamedTuple but arg_type doesn't have _fields (e.g. annotated as tuple),
|
| 41 |
+
# redirect so the converter uses the NamedTuple's own type hints.
|
| 42 |
+
if (
|
| 43 |
+
isinstance(arg, tuple)
|
| 44 |
+
and hasattr(type(arg), "_fields")
|
| 45 |
+
and (arg_type is None or not hasattr(arg_type, "_fields"))
|
| 46 |
+
):
|
| 47 |
+
return _original_convert_single_arg(arg, arg_name, type(arg), ctx)
|
| 48 |
+
return _original_convert_single_arg(arg, arg_name, arg_type, ctx)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_converter_module._convert_single_arg = _patched_convert_single_arg
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
torch2cute_dtype_map = {
|
| 55 |
+
torch.float16: Float16,
|
| 56 |
+
torch.bfloat16: BFloat16,
|
| 57 |
+
torch.float32: Float32,
|
| 58 |
+
torch.int32: Int32,
|
| 59 |
+
torch.int64: Int64,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@lru_cache
|
| 64 |
+
def get_max_active_clusters(cluster_size):
|
| 65 |
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@lru_cache
|
| 69 |
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 70 |
+
return torch.cuda.get_device_capability(device)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _partition_fields(obj):
|
| 74 |
+
"""Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
|
| 75 |
+
all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)}
|
| 76 |
+
constexpr = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 77 |
+
non_constexpr = {n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)}
|
| 78 |
+
return constexpr, non_constexpr
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _new_from_mlir_values(self, values):
|
| 82 |
+
constexpr_fields, non_constexpr_fields = _partition_fields(self)
|
| 83 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 84 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 85 |
+
values = values[n_items:]
|
| 86 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _namedtuple_new_from_mlir_values(self, values):
|
| 90 |
+
"""Generic __new_from_mlir_values__ for NamedTuples.
|
| 91 |
+
|
| 92 |
+
Applied to NamedTuple classes via the ``@mlir_namedtuple`` decorator.
|
| 93 |
+
|
| 94 |
+
Fields that are None or Constexpr (StaticTypes) are preserved from ``self`` (the compile-time
|
| 95 |
+
template). Only non-static fields consume MLIR values. Multi-value fields (e.g. cute.Tensor)
|
| 96 |
+
consume the correct number of values via ``cutlass.new_from_mlir_values``.
|
| 97 |
+
|
| 98 |
+
Constexpr fields (annotated ``cutlass.Constexpr[T]``) are baked into the compiled kernel via
|
| 99 |
+
a converter patch (see above). At call time, pass None for these fields.
|
| 100 |
+
"""
|
| 101 |
+
from cutlass.base_dsl.typing import get_mlir_types
|
| 102 |
+
|
| 103 |
+
values = list(values)
|
| 104 |
+
new_fields = []
|
| 105 |
+
for field_val in self:
|
| 106 |
+
if field_val is None or isinstance(field_val, StaticTypes):
|
| 107 |
+
new_fields.append(field_val)
|
| 108 |
+
else:
|
| 109 |
+
n_items = len(get_mlir_types(field_val))
|
| 110 |
+
new_fields.append(cutlass.new_from_mlir_values(field_val, values[:n_items]))
|
| 111 |
+
values = values[n_items:]
|
| 112 |
+
return self.__class__(*new_fields)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def mlir_namedtuple(cls):
|
| 116 |
+
"""Decorator that adds MLIR value reconstruction to a NamedTuple class.
|
| 117 |
+
|
| 118 |
+
Usage::
|
| 119 |
+
|
| 120 |
+
@mlir_namedtuple
|
| 121 |
+
class MyArgs(NamedTuple):
|
| 122 |
+
tensor_arg: cute.Tensor
|
| 123 |
+
const_arg: cutlass.Constexpr[int] = 0
|
| 124 |
+
"""
|
| 125 |
+
cls.__new_from_mlir_values__ = _namedtuple_new_from_mlir_values
|
| 126 |
+
return cls
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ParamsBase:
|
| 131 |
+
def __extract_mlir_values__(self):
|
| 132 |
+
_, non_constexpr_fields = _partition_fields(self)
|
| 133 |
+
values, self._values_pos = [], []
|
| 134 |
+
for obj in non_constexpr_fields.values():
|
| 135 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 136 |
+
values += obj_values
|
| 137 |
+
self._values_pos.append(len(obj_values))
|
| 138 |
+
return values
|
| 139 |
+
|
| 140 |
+
__new_from_mlir_values__ = _new_from_mlir_values
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass
|
| 144 |
+
class ArgumentsBase(JitArgument):
|
| 145 |
+
def __c_pointers__(self):
|
| 146 |
+
_, non_constexpr_fields = _partition_fields(self)
|
| 147 |
+
c_ptrs = []
|
| 148 |
+
for obj in non_constexpr_fields.values():
|
| 149 |
+
if hasattr(obj, "__c_pointers__"):
|
| 150 |
+
c_ptrs.extend(obj.__c_pointers__())
|
| 151 |
+
return c_ptrs
|
| 152 |
+
|
| 153 |
+
def __get_mlir_types__(self):
|
| 154 |
+
_, non_constexpr_fields = _partition_fields(self)
|
| 155 |
+
types, self._values_pos = [], []
|
| 156 |
+
for obj in non_constexpr_fields.values():
|
| 157 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 158 |
+
obj_types = obj.__get_mlir_types__()
|
| 159 |
+
types.extend(obj_types)
|
| 160 |
+
self._values_pos.append(len(obj_types))
|
| 161 |
+
else:
|
| 162 |
+
self._values_pos.append(0)
|
| 163 |
+
return types
|
| 164 |
+
|
| 165 |
+
__new_from_mlir_values__ = _new_from_mlir_values
|
build/torch-cuda/quack/layout_utils.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
from cutlass import Int32, const_expr
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
| 11 |
+
"""Transpose the first two dimensions of a tensor on smem."""
|
| 12 |
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
| 13 |
+
order = (1, 0, *range(2, cute.rank(a)))
|
| 14 |
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
| 18 |
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
| 22 |
+
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
| 23 |
+
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
| 24 |
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@cute.jit
|
| 28 |
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
| 29 |
+
assert t.element_type.width == 16
|
| 30 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
| 31 |
+
t_u32 = cute.recast_tensor(t, Int32)
|
| 32 |
+
|
| 33 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 34 |
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
| 35 |
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
| 36 |
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
| 37 |
+
# upper_map = [0, 3, 1, 2]
|
| 38 |
+
# lower_map = [1, 2, 0, 3]
|
| 39 |
+
# upper_idx = upper_map[quad_idx]
|
| 40 |
+
# indexing isn't supported so we have to do arithmetic
|
| 41 |
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
| 42 |
+
lower_idx = upper_idx ^ 1
|
| 43 |
+
|
| 44 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 45 |
+
width = 4
|
| 46 |
+
mask = cute.arch.WARP_SIZE - width
|
| 47 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 48 |
+
mask_and_clamp = mask << 8 | clamp
|
| 49 |
+
|
| 50 |
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
| 51 |
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
| 52 |
+
upper0 = upper if lane_03 else lower
|
| 53 |
+
lower0 = lower if lane_03 else upper
|
| 54 |
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
| 55 |
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
| 56 |
+
t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
|
| 57 |
+
t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@cute.jit
|
| 61 |
+
def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
|
| 62 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 63 |
+
T0 | T1 | T2 | T3
|
| 64 |
+
a b | c d | e f | g h
|
| 65 |
+
to
|
| 66 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 67 |
+
a | b | c | d | e | f | g | h
|
| 68 |
+
This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
assert t.element_type.width == 32
|
| 72 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 73 |
+
|
| 74 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 75 |
+
# left_map = [0, 2, 1, 3]
|
| 76 |
+
# right_map = [2, 0, 3, 1]
|
| 77 |
+
# indexing isn't supported so we have to do arithmetic
|
| 78 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 79 |
+
right_idx = left_idx ^ 0b10
|
| 80 |
+
|
| 81 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 82 |
+
width = 4
|
| 83 |
+
mask = cute.arch.WARP_SIZE - width
|
| 84 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 85 |
+
mask_and_clamp = mask << 8 | clamp
|
| 86 |
+
|
| 87 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 88 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 89 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 90 |
+
# a b | c d | e f | g h -> a b | c d | f e | h g
|
| 91 |
+
left0 = left if quad_idx < 2 else right
|
| 92 |
+
right0 = right if quad_idx < 2 else left
|
| 93 |
+
# a b | c d | f e | h g -> a b | f d | c e | h g
|
| 94 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 95 |
+
# a b | f d | c e | h g -> a e | f b | c g | h d
|
| 96 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 97 |
+
# a e | f b | c g | h d -> a e | b f | c g | d h
|
| 98 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
|
| 99 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
|
| 100 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@cute.jit
|
| 104 |
+
def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
|
| 105 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 106 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 107 |
+
a | b | c | d | e | f | g | h
|
| 108 |
+
to
|
| 109 |
+
T0 | T1 | T2 | T3
|
| 110 |
+
a b | c d | e f | g h
|
| 111 |
+
This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
assert t.element_type.width == 32
|
| 115 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 116 |
+
|
| 117 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 118 |
+
# left_map = [0, 2, 1, 3]
|
| 119 |
+
# right_map = [1, 3, 0, 2]
|
| 120 |
+
# indexing isn't supported so we have to do arithmetic
|
| 121 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 122 |
+
right_idx = left_idx ^ 0b01
|
| 123 |
+
|
| 124 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 125 |
+
width = 4
|
| 126 |
+
mask = cute.arch.WARP_SIZE - width
|
| 127 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 128 |
+
mask_and_clamp = mask << 8 | clamp
|
| 129 |
+
|
| 130 |
+
# This is just the inverse of permute_Cregs_b32_for_stsm
|
| 131 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 132 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 133 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 134 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 135 |
+
# a e | b f | c g | d h -> a e | f b | c g | h d
|
| 136 |
+
left0 = left if quad_idx % 2 == 0 else right
|
| 137 |
+
right0 = right if quad_idx % 2 == 0 else left
|
| 138 |
+
# a e | f b | c g | h d -> a b | f d | c e | h g
|
| 139 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 140 |
+
# a b | f d | c e | h g -> a b | c d | f e | h g
|
| 141 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 142 |
+
# a b | c d | f e | h g -> a b | c d | e f | g h
|
| 143 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
|
| 144 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@cute.jit
|
| 148 |
+
def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
| 149 |
+
return cute.make_layout(
|
| 150 |
+
tuple(l.shape for l in layouts),
|
| 151 |
+
stride=tuple(l.stride for l in layouts),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
|
| 156 |
+
"""
|
| 157 |
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
| 158 |
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
| 159 |
+
"""
|
| 160 |
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 161 |
+
shape = (
|
| 162 |
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
| 163 |
+
(
|
| 164 |
+
acc_layout_col_major.shape[0][0],
|
| 165 |
+
*acc_layout_col_major.shape[0][2:],
|
| 166 |
+
acc_layout_col_major.shape[2],
|
| 167 |
+
), # MMA_N
|
| 168 |
+
*acc_layout_col_major.shape[3:],
|
| 169 |
+
)
|
| 170 |
+
stride = (
|
| 171 |
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
| 172 |
+
(
|
| 173 |
+
acc_layout_col_major.stride[0][0],
|
| 174 |
+
*acc_layout_col_major.stride[0][2:],
|
| 175 |
+
acc_layout_col_major.stride[2],
|
| 176 |
+
), # MMA_N
|
| 177 |
+
*acc_layout_col_major.stride[3:],
|
| 178 |
+
)
|
| 179 |
+
if const_expr(transpose):
|
| 180 |
+
shape = (shape[1], shape[0], *shape[2:])
|
| 181 |
+
stride = (stride[1], stride[0], *stride[2:])
|
| 182 |
+
acc_layout_mn = cute.make_layout(shape, stride=stride)
|
| 183 |
+
return cute.composition(acc_layout, acc_layout_mn)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 187 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 191 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@cute.jit
|
| 195 |
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
| 196 |
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
| 197 |
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
| 198 |
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
| 199 |
+
# If N / 8 is odd, we'll convert to ((2, 2, 1), MMA_M, N / 8, MMA_N).
|
| 200 |
+
# TODO: Sm90 FP8
|
| 201 |
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
| 202 |
+
div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1
|
| 203 |
+
l = cute.logical_divide(
|
| 204 |
+
acc_layout, ((None, None, div), None, None)
|
| 205 |
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
| 206 |
+
rA_mma_view = cute.make_layout(
|
| 207 |
+
(
|
| 208 |
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
| 209 |
+
l.shape[1],
|
| 210 |
+
(l.shape[0][2][1], l.shape[2]),
|
| 211 |
+
),
|
| 212 |
+
stride=(
|
| 213 |
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
| 214 |
+
l.stride[1],
|
| 215 |
+
(l.stride[0][2][1], l.stride[2]),
|
| 216 |
+
),
|
| 217 |
+
)
|
| 218 |
+
else: # Sm80
|
| 219 |
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
| 220 |
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
| 221 |
+
rA_mma_view = cute.make_layout(
|
| 222 |
+
(
|
| 223 |
+
(l.shape[0], l.shape[2][0]),
|
| 224 |
+
l.shape[1],
|
| 225 |
+
l.shape[2][1],
|
| 226 |
+
),
|
| 227 |
+
stride=(
|
| 228 |
+
(l.stride[0], l.stride[2][0]),
|
| 229 |
+
l.stride[1],
|
| 230 |
+
l.stride[2][1],
|
| 231 |
+
),
|
| 232 |
+
)
|
| 233 |
+
return rA_mma_view
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
|
| 237 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def convert_layout_zero_stride(
|
| 241 |
+
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
| 242 |
+
) -> cute.Layout:
|
| 243 |
+
layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
|
| 244 |
+
# Group the modes with non-zero stride in the ref_layout together,
|
| 245 |
+
# and the modes with zero stride together
|
| 246 |
+
layout_flat = cute.flatten(layout)
|
| 247 |
+
ref_layout_flat = cute.flatten(ref_layout)
|
| 248 |
+
nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
|
| 249 |
+
zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
|
| 250 |
+
# There's an edge case when all modes are zero stride
|
| 251 |
+
new_shape = (
|
| 252 |
+
tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
|
| 253 |
+
tuple(layout_flat[i].shape for i in zero_modes),
|
| 254 |
+
)
|
| 255 |
+
new_stride = (
|
| 256 |
+
tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
|
| 257 |
+
tuple(layout_flat[i].stride for i in zero_modes),
|
| 258 |
+
)
|
| 259 |
+
out_layout = cute.make_layout(new_shape, stride=new_stride)
|
| 260 |
+
if const_expr(isinstance(input, cute.Tensor)):
|
| 261 |
+
return cute.make_tensor(input.iterator, out_layout)
|
| 262 |
+
else:
|
| 263 |
+
return out_layout
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def mma_partition_C_vec(
|
| 267 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 268 |
+
) -> cute.Tensor:
|
| 269 |
+
assert cute.rank(sVec) == 2
|
| 270 |
+
assert sVec.stride[0] == 1
|
| 271 |
+
stage = sVec.shape[1]
|
| 272 |
+
shape = (
|
| 273 |
+
(sVec.shape[0], expand_shape, stage)
|
| 274 |
+
if const_expr(is_colvec)
|
| 275 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 276 |
+
)
|
| 277 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 278 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 279 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
|
| 280 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def mma_partition_A_vec(
|
| 284 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 285 |
+
) -> cute.Tensor:
|
| 286 |
+
assert cute.rank(sVec) == 2
|
| 287 |
+
assert sVec.stride[0] == 1
|
| 288 |
+
stage = sVec.shape[1]
|
| 289 |
+
shape = (
|
| 290 |
+
(sVec.shape[0], expand_shape, stage)
|
| 291 |
+
if const_expr(is_colvec)
|
| 292 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 293 |
+
)
|
| 294 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 295 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 296 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 297 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
build/torch-cuda/quack/sm90_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Type, Union, Optional
|
| 4 |
+
|
| 5 |
+
import cutlass
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 8 |
+
from cutlass.cute.nvgpu import warpgroup
|
| 9 |
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
| 10 |
+
from cutlass import Float32, Int32, Boolean, const_expr
|
| 11 |
+
from cutlass.utils import LayoutEnum
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dsl_user_op
|
| 15 |
+
def make_smem_layout(
|
| 16 |
+
dtype: Type[Numeric],
|
| 17 |
+
layout: LayoutEnum,
|
| 18 |
+
tile: cute.Tile,
|
| 19 |
+
stage: Optional[int] = None,
|
| 20 |
+
major_mode_size: Optional[int] = None,
|
| 21 |
+
*,
|
| 22 |
+
loc=None,
|
| 23 |
+
ip=None,
|
| 24 |
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
| 25 |
+
shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
|
| 26 |
+
if const_expr(major_mode_size is None):
|
| 27 |
+
major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
|
| 28 |
+
smem_layout_atom = warpgroup.make_smem_layout_atom(
|
| 29 |
+
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
| 30 |
+
dtype,
|
| 31 |
+
)
|
| 32 |
+
order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
|
| 33 |
+
smem_layout_staged = cute.tile_to_shape(
|
| 34 |
+
smem_layout_atom,
|
| 35 |
+
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
| 36 |
+
order=order if const_expr(stage is not None) else order[:2],
|
| 37 |
+
)
|
| 38 |
+
return smem_layout_staged
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# For compatibility with blackwell_helpers.py
|
| 42 |
+
make_smem_layout_epi = make_smem_layout
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dsl_user_op
|
| 46 |
+
def partition_for_epilogue(
|
| 47 |
+
cT: cute.Tensor,
|
| 48 |
+
epi_tile: cute.Tile,
|
| 49 |
+
tiled_copy: cute.TiledCopy,
|
| 50 |
+
tidx: Int32,
|
| 51 |
+
reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
|
| 52 |
+
*,
|
| 53 |
+
loc=None,
|
| 54 |
+
ip=None,
|
| 55 |
+
) -> cute.Tensor:
|
| 56 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 57 |
+
cT_epi = cute.flat_divide(cT, epi_tile)
|
| 58 |
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
| 59 |
+
if const_expr(reference_src):
|
| 60 |
+
return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
|
| 61 |
+
else:
|
| 62 |
+
return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@cute.jit
|
| 66 |
+
def gemm(
|
| 67 |
+
tiled_mma: cute.TiledMma,
|
| 68 |
+
acc: cute.Tensor,
|
| 69 |
+
tCrA: cute.Tensor,
|
| 70 |
+
tCrB: cute.Tensor,
|
| 71 |
+
zero_init: cutlass.Constexpr[bool] = False,
|
| 72 |
+
wg_wait: cutlass.Constexpr[int] = 0,
|
| 73 |
+
# A_in_regs: cutlass.Constexpr[bool] = False,
|
| 74 |
+
swap_AB: cutlass.Constexpr[bool] = False,
|
| 75 |
+
) -> None:
|
| 76 |
+
if const_expr(swap_AB):
|
| 77 |
+
gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
|
| 78 |
+
else:
|
| 79 |
+
warpgroup.fence()
|
| 80 |
+
# We make a new mma_atom since we'll be modifying its attribute (accumulate).
|
| 81 |
+
# Otherwise the compiler complains "operand #0 does not dominate this use"
|
| 82 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 83 |
+
mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
|
| 84 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 85 |
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 86 |
+
mma_atom.set(warpgroup.Field.ACCUMULATE, True)
|
| 87 |
+
warpgroup.commit_group()
|
| 88 |
+
if const_expr(wg_wait >= 0):
|
| 89 |
+
warpgroup.wait_group(wg_wait)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def gemm_zero_init(
|
| 93 |
+
tiled_mma: cute.TiledMma,
|
| 94 |
+
shape: cute.Shape,
|
| 95 |
+
tCrA: cute.Tensor,
|
| 96 |
+
tCrB: cute.Tensor,
|
| 97 |
+
A_idx: Optional[Int32] = None,
|
| 98 |
+
B_idx: Optional[Int32] = None,
|
| 99 |
+
wg_wait: int = -1,
|
| 100 |
+
swap_AB: bool = False,
|
| 101 |
+
) -> cute.Tensor:
|
| 102 |
+
if const_expr(swap_AB):
|
| 103 |
+
return gemm_zero_init(
|
| 104 |
+
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
|
| 108 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 109 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 110 |
+
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
| 111 |
+
return acc
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def gemm_w_idx(
|
| 115 |
+
tiled_mma: cute.TiledMma,
|
| 116 |
+
acc: cute.Tensor,
|
| 117 |
+
tCrA: cute.Tensor,
|
| 118 |
+
tCrB: cute.Tensor,
|
| 119 |
+
zero_init: Boolean,
|
| 120 |
+
A_idx: Optional[Int32] = None,
|
| 121 |
+
B_idx: Optional[Int32] = None,
|
| 122 |
+
wg_wait: int = -1,
|
| 123 |
+
swap_AB: bool = False,
|
| 124 |
+
) -> None:
|
| 125 |
+
if const_expr(swap_AB):
|
| 126 |
+
gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
|
| 127 |
+
else:
|
| 128 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 129 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 130 |
+
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def partition_fragment_ABC(
|
| 134 |
+
thr_mma: cute.ThrMma,
|
| 135 |
+
shape_mnk: cute.Shape,
|
| 136 |
+
sA: Optional[cute.Tensor],
|
| 137 |
+
sB: Optional[cute.Tensor],
|
| 138 |
+
swap_AB: bool = False,
|
| 139 |
+
):
|
| 140 |
+
is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
|
| 141 |
+
if const_expr(not swap_AB):
|
| 142 |
+
acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
|
| 143 |
+
if const_expr(not is_rs):
|
| 144 |
+
assert sA is not None
|
| 145 |
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
|
| 146 |
+
else:
|
| 147 |
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
|
| 148 |
+
assert sB is not None
|
| 149 |
+
tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
|
| 150 |
+
else:
|
| 151 |
+
acc = cute.make_rmem_tensor(
|
| 152 |
+
thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
|
| 153 |
+
)
|
| 154 |
+
if const_expr(not is_rs):
|
| 155 |
+
assert sB is not None
|
| 156 |
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
|
| 157 |
+
else: # B in rmem
|
| 158 |
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
|
| 159 |
+
assert sA is not None
|
| 160 |
+
tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
|
| 161 |
+
return acc, tCrA, tCrB
|
build/torch-cuda/seqlen_info.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Int32, const_expr
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
This consolidates all the info related to sequence length. This is so that we can do all
|
| 10 |
+
the gmem reads once at the beginning of each tile, rather than having to repeat these reads
|
| 11 |
+
to compute various things like n_block_min, n_block_max, etc.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class SeqlenInfo:
|
| 17 |
+
offset: cutlass.Int32
|
| 18 |
+
seqlen: cutlass.Int32
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def create(
|
| 22 |
+
batch_idx: cutlass.Int32,
|
| 23 |
+
seqlen_static: cutlass.Int32,
|
| 24 |
+
cu_seqlens: Optional[cute.Tensor] = None,
|
| 25 |
+
seqused: Optional[cute.Tensor] = None,
|
| 26 |
+
):
|
| 27 |
+
offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
|
| 28 |
+
if const_expr(seqused is not None):
|
| 29 |
+
seqlen = seqused[batch_idx]
|
| 30 |
+
elif const_expr(cu_seqlens is not None):
|
| 31 |
+
seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
|
| 32 |
+
else:
|
| 33 |
+
seqlen = seqlen_static
|
| 34 |
+
return SeqlenInfo(offset, seqlen)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass(frozen=True)
|
| 38 |
+
class SeqlenInfoQK:
|
| 39 |
+
offset_q: cutlass.Int32
|
| 40 |
+
offset_k: cutlass.Int32
|
| 41 |
+
padded_offset_q: cutlass.Int32
|
| 42 |
+
padded_offset_k: cutlass.Int32
|
| 43 |
+
seqlen_q: cutlass.Int32
|
| 44 |
+
seqlen_k: cutlass.Int32
|
| 45 |
+
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
| 46 |
+
has_cu_seqlens_k: cutlass.Constexpr[bool]
|
| 47 |
+
has_seqused_q: cutlass.Constexpr[bool]
|
| 48 |
+
has_seqused_k: cutlass.Constexpr[bool]
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def create(
|
| 52 |
+
batch_idx: cutlass.Int32,
|
| 53 |
+
seqlen_q_static: cutlass.Int32,
|
| 54 |
+
seqlen_k_static: cutlass.Int32,
|
| 55 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 56 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 57 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 58 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 59 |
+
tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
|
| 60 |
+
tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
|
| 61 |
+
):
|
| 62 |
+
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
| 63 |
+
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
| 64 |
+
padded_offset_q = (
|
| 65 |
+
0
|
| 66 |
+
if const_expr(mCuSeqlensQ is None)
|
| 67 |
+
else (offset_q + batch_idx * tile_m) // tile_m * tile_m
|
| 68 |
+
)
|
| 69 |
+
padded_offset_k = (
|
| 70 |
+
0
|
| 71 |
+
if const_expr(mCuSeqlensK is None)
|
| 72 |
+
else (offset_k + batch_idx * tile_n) // tile_n * tile_n
|
| 73 |
+
)
|
| 74 |
+
if const_expr(mSeqUsedQ is not None):
|
| 75 |
+
seqlen_q = mSeqUsedQ[batch_idx]
|
| 76 |
+
else:
|
| 77 |
+
seqlen_q = (
|
| 78 |
+
seqlen_q_static
|
| 79 |
+
if const_expr(mCuSeqlensQ is None)
|
| 80 |
+
else mCuSeqlensQ[batch_idx + 1] - offset_q
|
| 81 |
+
)
|
| 82 |
+
if const_expr(mSeqUsedK is not None):
|
| 83 |
+
seqlen_k = mSeqUsedK[batch_idx]
|
| 84 |
+
else:
|
| 85 |
+
seqlen_k = (
|
| 86 |
+
seqlen_k_static
|
| 87 |
+
if const_expr(mCuSeqlensK is None)
|
| 88 |
+
else mCuSeqlensK[batch_idx + 1] - offset_k
|
| 89 |
+
)
|
| 90 |
+
has_cu_seqlens_q: int = mCuSeqlensQ is not None
|
| 91 |
+
has_cu_seqlens_k: int = mCuSeqlensK is not None
|
| 92 |
+
has_seqused_q: int = mSeqUsedQ is not None
|
| 93 |
+
has_seqused_k: int = mSeqUsedK is not None
|
| 94 |
+
return SeqlenInfoQK(
|
| 95 |
+
offset_q,
|
| 96 |
+
offset_k,
|
| 97 |
+
padded_offset_q,
|
| 98 |
+
padded_offset_k,
|
| 99 |
+
seqlen_q,
|
| 100 |
+
seqlen_k,
|
| 101 |
+
has_cu_seqlens_q,
|
| 102 |
+
has_cu_seqlens_k,
|
| 103 |
+
has_seqused_q,
|
| 104 |
+
has_seqused_k,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def offset_batch_Q(
|
| 108 |
+
self,
|
| 109 |
+
mQ: cute.Tensor,
|
| 110 |
+
batch_idx: Int32,
|
| 111 |
+
dim: int,
|
| 112 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 113 |
+
) -> cute.Tensor:
|
| 114 |
+
"""Seqlen must be the first dimension of mQ"""
|
| 115 |
+
if const_expr(not self.has_cu_seqlens_q):
|
| 116 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
| 117 |
+
return mQ[idx]
|
| 118 |
+
else:
|
| 119 |
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
| 120 |
+
offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
|
| 121 |
+
idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
|
| 122 |
+
return cute.domain_offset(idx, mQ)
|
| 123 |
+
|
| 124 |
+
def offset_batch_K(
|
| 125 |
+
self,
|
| 126 |
+
mK: cute.Tensor,
|
| 127 |
+
batch_idx: Int32,
|
| 128 |
+
dim: int,
|
| 129 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 130 |
+
) -> cute.Tensor:
|
| 131 |
+
"""Seqlen must be the first dimension of mK"""
|
| 132 |
+
if const_expr(not self.has_cu_seqlens_k):
|
| 133 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
| 134 |
+
return mK[idx]
|
| 135 |
+
else:
|
| 136 |
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
| 137 |
+
idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
|
| 138 |
+
return cute.domain_offset(idx, mK)
|
build/torch-cuda/softmax.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import operator
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.cute as cute
|
| 10 |
+
from cutlass import Float32
|
| 11 |
+
|
| 12 |
+
from .quack import layout_utils
|
| 13 |
+
from . import utils
|
| 14 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 15 |
+
from .seqlen_info import SeqlenInfoQK
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class Softmax(ParamsBase):
|
| 20 |
+
scale_log2: Float32
|
| 21 |
+
num_rows: cutlass.Constexpr[int]
|
| 22 |
+
row_max: cute.Tensor
|
| 23 |
+
row_sum: cute.Tensor
|
| 24 |
+
arch: cutlass.Constexpr[int] = 80
|
| 25 |
+
softmax_scale: Float32 | None = None
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def create(
|
| 29 |
+
scale_log2: Float32,
|
| 30 |
+
num_rows: cutlass.Constexpr[int],
|
| 31 |
+
arch: cutlass.Constexpr[int] = 80,
|
| 32 |
+
softmax_scale: Float32 | None = None,
|
| 33 |
+
):
|
| 34 |
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
| 35 |
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
| 36 |
+
return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
|
| 37 |
+
|
| 38 |
+
def reset(self) -> None:
|
| 39 |
+
self.row_max.fill(-Float32.inf)
|
| 40 |
+
self.row_sum.fill(0.0)
|
| 41 |
+
|
| 42 |
+
def _compute_row_max(
|
| 43 |
+
self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
|
| 44 |
+
) -> Float32:
|
| 45 |
+
return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
|
| 46 |
+
|
| 47 |
+
def _compute_row_sum(
|
| 48 |
+
self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
|
| 49 |
+
) -> Float32:
|
| 50 |
+
return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
|
| 51 |
+
|
| 52 |
+
@cute.jit
|
| 53 |
+
def online_softmax(
|
| 54 |
+
self,
|
| 55 |
+
acc_S: cute.Tensor,
|
| 56 |
+
is_first: cutlass.Constexpr[bool] = False,
|
| 57 |
+
check_inf: cutlass.Constexpr[bool] = True,
|
| 58 |
+
) -> cute.Tensor:
|
| 59 |
+
"""Apply online softmax and return the row_scale to rescale O.
|
| 60 |
+
|
| 61 |
+
:param acc_S: acc_S tensor
|
| 62 |
+
:type acc_S: cute.Tensor
|
| 63 |
+
:param is_first: is first n_block
|
| 64 |
+
:type is_first: cutlass.Constexpr
|
| 65 |
+
"""
|
| 66 |
+
# Change acc_S to M,N layout view.
|
| 67 |
+
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
|
| 68 |
+
row_scale = cute.make_fragment_like(self.row_max, Float32)
|
| 69 |
+
|
| 70 |
+
row_max = self.row_max
|
| 71 |
+
row_sum = self.row_sum
|
| 72 |
+
scale_log2 = self.scale_log2
|
| 73 |
+
arch = self.arch
|
| 74 |
+
|
| 75 |
+
# Each iteration processes one row of acc_S
|
| 76 |
+
for r in cutlass.range(cute.size(row_max), unroll_full=True):
|
| 77 |
+
acc_S_row = acc_S_mn[r, None].load() # (n_block_size)
|
| 78 |
+
|
| 79 |
+
row_max_cur = utils.fmax_reduce(
|
| 80 |
+
acc_S_row,
|
| 81 |
+
init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
|
| 82 |
+
arch=arch,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4)
|
| 86 |
+
# Update row_max before changing row_max_cur to safe value for -inf
|
| 87 |
+
row_max_prev = row_max[r]
|
| 88 |
+
row_max[r] = row_max_cur
|
| 89 |
+
|
| 90 |
+
if cutlass.const_expr(check_inf):
|
| 91 |
+
row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
|
| 92 |
+
|
| 93 |
+
if cutlass.const_expr(is_first):
|
| 94 |
+
row_max_cur_scaled = row_max_cur * scale_log2
|
| 95 |
+
acc_S_row_exp = cute.math.exp2(
|
| 96 |
+
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
|
| 97 |
+
)
|
| 98 |
+
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
|
| 99 |
+
row_scale[r] = 1.0
|
| 100 |
+
else:
|
| 101 |
+
row_max_cur_scaled = row_max_cur * scale_log2
|
| 102 |
+
acc_S_row_exp = cute.math.exp2(
|
| 103 |
+
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
|
| 104 |
+
)
|
| 105 |
+
# row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled)
|
| 106 |
+
row_scale[r] = cute.math.exp2(
|
| 107 |
+
(row_max_prev - row_max_cur) * scale_log2, fastmath=True
|
| 108 |
+
)
|
| 109 |
+
acc_S_row_sum = utils.fadd_reduce(
|
| 110 |
+
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
row_sum[r] = acc_S_row_sum
|
| 114 |
+
acc_S_mn[r, None].store(acc_S_row_exp)
|
| 115 |
+
|
| 116 |
+
return row_scale
|
| 117 |
+
|
| 118 |
+
@cute.jit
|
| 119 |
+
def finalize(
|
| 120 |
+
self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
|
| 121 |
+
) -> cute.Tensor:
|
| 122 |
+
"""Finalize the online softmax by computing the scale and logsumexp."""
|
| 123 |
+
if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
|
| 124 |
+
assert cute.size(sink_val) == cute.size(self.row_sum)
|
| 125 |
+
row_sum = self.row_sum
|
| 126 |
+
row_max = self.row_max
|
| 127 |
+
scale_log2 = self.scale_log2
|
| 128 |
+
|
| 129 |
+
# quad reduction for row_sum as we didn't do it during each iteration of online softmax
|
| 130 |
+
row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
|
| 131 |
+
row_scale = cute.make_fragment_like(row_max, Float32)
|
| 132 |
+
|
| 133 |
+
for r in cutlass.range(cute.size(row_sum), unroll_full=True):
|
| 134 |
+
if cutlass.const_expr(sink_val is not None):
|
| 135 |
+
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
|
| 136 |
+
LOG2_E = math.log2(math.e)
|
| 137 |
+
row_sum[r] += cute.math.exp2(
|
| 138 |
+
sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
|
| 142 |
+
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
| 143 |
+
row_scale[r] = (
|
| 144 |
+
cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
|
| 145 |
+
) * final_scale
|
| 146 |
+
row_sum_cur = row_sum[r]
|
| 147 |
+
LN2 = math.log(2.0)
|
| 148 |
+
row_sum[r] = (
|
| 149 |
+
(row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
|
| 150 |
+
if not acc_O_mn_row_is_zero_or_nan
|
| 151 |
+
else -Float32.inf
|
| 152 |
+
)
|
| 153 |
+
return row_scale
|
| 154 |
+
|
| 155 |
+
@cute.jit
|
| 156 |
+
def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
|
| 157 |
+
"""Scale each row of acc_O by the given scale tensor.
|
| 158 |
+
:param acc_O: input tensor
|
| 159 |
+
:type acc_O: cute.Tensor
|
| 160 |
+
:param row_scale: row_scale tensor
|
| 161 |
+
:type row_scale: cute.Tensor
|
| 162 |
+
"""
|
| 163 |
+
acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
|
| 164 |
+
assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
|
| 165 |
+
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
|
| 166 |
+
acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@dataclass
|
| 170 |
+
class SoftmaxSm100(Softmax):
|
| 171 |
+
rescale_threshold: cutlass.Constexpr[float] = 0.0
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def create(
|
| 175 |
+
scale_log2: Float32,
|
| 176 |
+
rescale_threshold: cutlass.Constexpr[float] = 0.0,
|
| 177 |
+
softmax_scale: Float32 | None = None,
|
| 178 |
+
):
|
| 179 |
+
num_rows = 1
|
| 180 |
+
arch = 100
|
| 181 |
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
| 182 |
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
| 183 |
+
return SoftmaxSm100(
|
| 184 |
+
scale_log2,
|
| 185 |
+
num_rows,
|
| 186 |
+
row_max,
|
| 187 |
+
row_sum,
|
| 188 |
+
arch,
|
| 189 |
+
softmax_scale,
|
| 190 |
+
rescale_threshold=rescale_threshold,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
@cute.jit
|
| 194 |
+
def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
|
| 195 |
+
if cutlass.const_expr(is_first):
|
| 196 |
+
row_max_new = self._compute_row_max(acc_S_row)
|
| 197 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 198 |
+
acc_scale = 0.0
|
| 199 |
+
else:
|
| 200 |
+
row_max_old = self.row_max[0]
|
| 201 |
+
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
|
| 202 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 203 |
+
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
|
| 204 |
+
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
| 205 |
+
if cutlass.const_expr(self.rescale_threshold > 0.0):
|
| 206 |
+
if acc_scale_ >= -self.rescale_threshold:
|
| 207 |
+
row_max_new = row_max_old
|
| 208 |
+
row_max_safe = row_max_old
|
| 209 |
+
acc_scale = 1.0
|
| 210 |
+
self.row_max[0] = row_max_new
|
| 211 |
+
return row_max_safe, acc_scale
|
| 212 |
+
|
| 213 |
+
def update_row_sum(
|
| 214 |
+
self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
|
| 215 |
+
) -> None:
|
| 216 |
+
init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
|
| 217 |
+
# self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale)
|
| 218 |
+
self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
|
| 219 |
+
# tmp = self._compute_row_sum(acc_S_row_exp)
|
| 220 |
+
# self.row_sum[0] = self.row_sum[0] * row_scale + tmp
|
| 221 |
+
|
| 222 |
+
@cute.jit
|
| 223 |
+
def scale_subtract_rowmax(
|
| 224 |
+
self,
|
| 225 |
+
acc_S_row: cute.Tensor,
|
| 226 |
+
row_max: Float32,
|
| 227 |
+
):
|
| 228 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 229 |
+
row_max_scaled = row_max * self.scale_log2
|
| 230 |
+
for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
|
| 231 |
+
acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
|
| 232 |
+
(acc_S_row[i], acc_S_row[i + 1]),
|
| 233 |
+
(self.scale_log2, self.scale_log2),
|
| 234 |
+
(-row_max_scaled, -row_max_scaled),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
@cute.jit
|
| 238 |
+
def apply_exp2_convert(
|
| 239 |
+
self,
|
| 240 |
+
acc_S_row: cute.Tensor,
|
| 241 |
+
acc_S_row_converted: cute.Tensor,
|
| 242 |
+
ex2_emu_freq: cutlass.Constexpr[int] = 0,
|
| 243 |
+
ex2_emu_res: cutlass.Constexpr[int] = 4,
|
| 244 |
+
ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
|
| 245 |
+
):
|
| 246 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 247 |
+
frg_tile = 32
|
| 248 |
+
assert frg_tile % 2 == 0
|
| 249 |
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
| 250 |
+
assert cute.size(acc_S_row) % frg_tile == 0
|
| 251 |
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
| 252 |
+
acc_S_row_converted_frg = cute.logical_divide(
|
| 253 |
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
| 254 |
+
)
|
| 255 |
+
for j in cutlass.range_constexpr(frg_cnt):
|
| 256 |
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
| 257 |
+
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 258 |
+
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 259 |
+
if cutlass.const_expr(ex2_emu_freq == 0):
|
| 260 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 261 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 262 |
+
else:
|
| 263 |
+
if cutlass.const_expr(
|
| 264 |
+
k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
|
| 265 |
+
or j >= frg_cnt - 1
|
| 266 |
+
or j < ex2_emu_start_frg
|
| 267 |
+
):
|
| 268 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 269 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(
|
| 270 |
+
acc_S_row_frg[k + 1, j], fastmath=True
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
|
| 274 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
|
| 275 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
|
| 276 |
+
)
|
| 277 |
+
acc_S_row_converted_frg[None, j].store(
|
| 278 |
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
@cute.jit
|
| 282 |
+
def scale_apply_exp2_convert(
|
| 283 |
+
self,
|
| 284 |
+
acc_S_row: cute.Tensor,
|
| 285 |
+
row_max: Float32,
|
| 286 |
+
acc_S_row_converted: cute.Tensor,
|
| 287 |
+
):
|
| 288 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 289 |
+
minus_row_max_scaled = -row_max * self.scale_log2
|
| 290 |
+
for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
|
| 291 |
+
acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
|
| 292 |
+
(acc_S_row[i], acc_S_row[i + 1]),
|
| 293 |
+
(self.scale_log2, self.scale_log2),
|
| 294 |
+
(minus_row_max_scaled, minus_row_max_scaled),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
|
| 298 |
+
# acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
|
| 299 |
+
# (acc_S_row[i], acc_S_row[i + 1]),
|
| 300 |
+
# (self.scale_log2, self.scale_log2),
|
| 301 |
+
# (minus_row_max_scaled, minus_row_max_scaled),
|
| 302 |
+
# )
|
| 303 |
+
# acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True)
|
| 304 |
+
# acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True)
|
| 305 |
+
|
| 306 |
+
frg_tile = 32
|
| 307 |
+
assert frg_tile % 2 == 0
|
| 308 |
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
| 309 |
+
assert cute.size(acc_S_row) % frg_tile == 0
|
| 310 |
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
| 311 |
+
acc_S_row_converted_frg = cute.logical_divide(
|
| 312 |
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
| 313 |
+
)
|
| 314 |
+
for j in cutlass.range_constexpr(frg_cnt):
|
| 315 |
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
| 316 |
+
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
|
| 317 |
+
# cute.arch.fma_packed_f32x2(
|
| 318 |
+
# (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
|
| 319 |
+
# (self.scale_log2, self.scale_log2),
|
| 320 |
+
# (minus_row_max_scaled, minus_row_max_scaled),
|
| 321 |
+
# )
|
| 322 |
+
# )
|
| 323 |
+
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 324 |
+
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 325 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 326 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 327 |
+
acc_S_row_converted_frg[None, j].store(
|
| 328 |
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@cute.jit
|
| 333 |
+
def floor_if_packed(
|
| 334 |
+
q_idx,
|
| 335 |
+
qhead_per_kvhead: cutlass.Constexpr[int],
|
| 336 |
+
) -> cute.Tensor:
|
| 337 |
+
"""Convert q_idx to packed format for Pack-GQA."""
|
| 338 |
+
if cutlass.const_expr(qhead_per_kvhead == 1):
|
| 339 |
+
return q_idx
|
| 340 |
+
return q_idx // qhead_per_kvhead
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@cute.jit
|
| 344 |
+
def apply_score_mod_inner(
|
| 345 |
+
score_tensor,
|
| 346 |
+
index_tensor,
|
| 347 |
+
score_mod: cutlass.Constexpr,
|
| 348 |
+
batch_idx,
|
| 349 |
+
head_idx,
|
| 350 |
+
softmax_scale,
|
| 351 |
+
vec_size: cutlass.Constexpr,
|
| 352 |
+
qk_acc_dtype: cutlass.Constexpr,
|
| 353 |
+
aux_tensors,
|
| 354 |
+
fastdiv_mods,
|
| 355 |
+
seqlen_info: SeqlenInfoQK,
|
| 356 |
+
constant_q_idx: cutlass.Constexpr,
|
| 357 |
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 358 |
+
transpose_indices: cutlass.Constexpr[bool] = False,
|
| 359 |
+
):
|
| 360 |
+
"""Shared implementation for applying score modification.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100)
|
| 364 |
+
index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100)
|
| 365 |
+
score_mod: The score modification function to apply
|
| 366 |
+
batch_idx: Batch index
|
| 367 |
+
head_idx: Head index
|
| 368 |
+
softmax_scale: Scale to apply
|
| 369 |
+
vec_size: Vector size for processing elements
|
| 370 |
+
qk_acc_dtype: Data type for accumulator
|
| 371 |
+
aux_tensors: Optional aux_tensors for FlexAttention
|
| 372 |
+
fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
|
| 373 |
+
seqlen_info: Sequence length info
|
| 374 |
+
constant_q_idx: If provided, use this constant for all q_idx values
|
| 375 |
+
If None, compute q_idx per-element
|
| 376 |
+
qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this
|
| 377 |
+
when greater than 1 so score mods see logical heads.
|
| 378 |
+
transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed)
|
| 379 |
+
"""
|
| 380 |
+
# Index positions in the index_tensor tuple
|
| 381 |
+
# Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
|
| 382 |
+
# Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
|
| 383 |
+
if cutlass.const_expr(transpose_indices):
|
| 384 |
+
q_idx_pos = cutlass.const_expr(1)
|
| 385 |
+
kv_idx_pos = cutlass.const_expr(0)
|
| 386 |
+
else:
|
| 387 |
+
q_idx_pos = cutlass.const_expr(0)
|
| 388 |
+
kv_idx_pos = cutlass.const_expr(1)
|
| 389 |
+
|
| 390 |
+
n_vals = cutlass.const_expr(cute.size(score_tensor.shape))
|
| 391 |
+
score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)
|
| 392 |
+
kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
| 393 |
+
|
| 394 |
+
# SSA values for batch (constant across all elements)
|
| 395 |
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
|
| 396 |
+
|
| 397 |
+
# Handle q_idx based on whether it's constant
|
| 398 |
+
q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
| 399 |
+
|
| 400 |
+
# For Pack-GQA with non-constant q_idx, we need per-element head indices
|
| 401 |
+
# since a thread my process multiple query head indices
|
| 402 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 403 |
+
head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
| 404 |
+
|
| 405 |
+
for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
|
| 406 |
+
for j in cutlass.range(vec_size, unroll_full=True):
|
| 407 |
+
score_vec[j] = score_tensor[i + j] * softmax_scale
|
| 408 |
+
|
| 409 |
+
# Extract head offset from packed q_idx for Pack-GQA
|
| 410 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 411 |
+
q_idx_packed = index_tensor[i + j][q_idx_pos]
|
| 412 |
+
# Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
|
| 413 |
+
q_idx_logical = q_idx_packed // qhead_per_kvhead
|
| 414 |
+
head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
|
| 415 |
+
head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
|
| 416 |
+
|
| 417 |
+
# If we will do loads we mod, in order to not read OOB
|
| 418 |
+
if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
|
| 419 |
+
if cutlass.const_expr(constant_q_idx is None):
|
| 420 |
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
| 421 |
+
q_idx_floored = floor_if_packed(
|
| 422 |
+
index_tensor[i + j][q_idx_pos], qhead_per_kvhead
|
| 423 |
+
)
|
| 424 |
+
_, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
|
| 425 |
+
q_idx_vec[j] = q_idx_wrapped
|
| 426 |
+
else:
|
| 427 |
+
_, seqlen_k_divmod = fastdiv_mods
|
| 428 |
+
|
| 429 |
+
_, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
|
| 430 |
+
kv_idx_vec[j] = kv_idx_wrapped
|
| 431 |
+
else:
|
| 432 |
+
# No bounds checking - direct indexing
|
| 433 |
+
if constant_q_idx is None:
|
| 434 |
+
q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
|
| 435 |
+
kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
|
| 436 |
+
|
| 437 |
+
# Convert to SSA for score_mod call
|
| 438 |
+
score_ssa = score_vec.load()
|
| 439 |
+
kv_idx_ssa = kv_idx_vec.load()
|
| 440 |
+
if cutlass.const_expr(constant_q_idx is None):
|
| 441 |
+
q_idx_ssa = q_idx_vec.load()
|
| 442 |
+
else:
|
| 443 |
+
# NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical
|
| 444 |
+
q_idx_const = constant_q_idx
|
| 445 |
+
q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,))
|
| 446 |
+
|
| 447 |
+
# Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise
|
| 448 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 449 |
+
head_idx_ssa = head_idx_vec.load()
|
| 450 |
+
else:
|
| 451 |
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
|
| 452 |
+
|
| 453 |
+
aux_args = []
|
| 454 |
+
if cutlass.const_expr(aux_tensors is not None):
|
| 455 |
+
aux_args = aux_tensors
|
| 456 |
+
|
| 457 |
+
post_mod_scores = score_mod(
|
| 458 |
+
score_ssa,
|
| 459 |
+
batch_idx_ssa,
|
| 460 |
+
head_idx_ssa,
|
| 461 |
+
q_idx=q_idx_ssa,
|
| 462 |
+
kv_idx=kv_idx_ssa,
|
| 463 |
+
seqlen_info=seqlen_info,
|
| 464 |
+
aux_tensors=aux_args,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Write back modified scores
|
| 468 |
+
score_vec.store(post_mod_scores)
|
| 469 |
+
for j in cutlass.range(vec_size, unroll_full=True):
|
| 470 |
+
score_tensor[i + j] = score_vec[j]
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@cute.jit
|
| 474 |
+
def apply_score_mod_bwd_inner(
|
| 475 |
+
grad_tensor,
|
| 476 |
+
score_tensor,
|
| 477 |
+
index_tensor,
|
| 478 |
+
score_mod_bwd: cutlass.Constexpr,
|
| 479 |
+
batch_idx,
|
| 480 |
+
head_idx,
|
| 481 |
+
softmax_scale,
|
| 482 |
+
vec_size: cutlass.Constexpr,
|
| 483 |
+
qk_acc_dtype: cutlass.Constexpr,
|
| 484 |
+
aux_tensors,
|
| 485 |
+
fastdiv_mods,
|
| 486 |
+
seqlen_info,
|
| 487 |
+
constant_q_idx: cutlass.Constexpr,
|
| 488 |
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
| 489 |
+
transpose_indices: cutlass.Constexpr[bool] = False,
|
| 490 |
+
):
|
| 491 |
+
"""Apply backward score modification (joint graph).
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores)
|
| 495 |
+
score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally
|
| 496 |
+
index_tensor: Index positions (same as forward)
|
| 497 |
+
score_mod_bwd: The backward score modification function (joint graph)
|
| 498 |
+
batch_idx: Batch index
|
| 499 |
+
head_idx: Head index
|
| 500 |
+
softmax_scale: Scale to apply to score_tensor
|
| 501 |
+
vec_size: Vector size for processing elements
|
| 502 |
+
qk_acc_dtype: Data type for accumulator
|
| 503 |
+
aux_tensors: Optional aux_tensors for FlexAttention
|
| 504 |
+
fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
|
| 505 |
+
seqlen_info: Sequence length info
|
| 506 |
+
constant_q_idx: If provided, use this constant for all q_idx values
|
| 507 |
+
qhead_per_kvhead: Pack-GQA replication factor
|
| 508 |
+
transpose_indices: If True, swap q_idx/kv_idx in index_tensor
|
| 509 |
+
"""
|
| 510 |
+
# Index positions in the index_tensor tuple
|
| 511 |
+
# Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
|
| 512 |
+
# Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
|
| 513 |
+
if cutlass.const_expr(transpose_indices):
|
| 514 |
+
q_idx_pos = cutlass.const_expr(1)
|
| 515 |
+
kv_idx_pos = cutlass.const_expr(0)
|
| 516 |
+
else:
|
| 517 |
+
q_idx_pos = cutlass.const_expr(0)
|
| 518 |
+
kv_idx_pos = cutlass.const_expr(1)
|
| 519 |
+
n_vals = cutlass.const_expr(cute.size(grad_tensor.shape))
|
| 520 |
+
grad_vec = cute.make_fragment(vec_size, qk_acc_dtype)
|
| 521 |
+
score_vec = cute.make_fragment(vec_size, qk_acc_dtype)
|
| 522 |
+
kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
| 523 |
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
|
| 524 |
+
q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
| 525 |
+
|
| 526 |
+
# For Pack-GQA with non-constant q_idx, we need per-element head indices
|
| 527 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 528 |
+
head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
| 529 |
+
|
| 530 |
+
for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
|
| 531 |
+
for j in cutlass.range(vec_size, unroll_full=True):
|
| 532 |
+
grad_vec[j] = grad_tensor[i + j]
|
| 533 |
+
# Scale score so joint graph sees same value as forward score_mod
|
| 534 |
+
score_vec[j] = score_tensor[i + j] * softmax_scale
|
| 535 |
+
|
| 536 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 537 |
+
q_idx_packed = index_tensor[i + j][q_idx_pos]
|
| 538 |
+
q_idx_logical = q_idx_packed // qhead_per_kvhead
|
| 539 |
+
head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
|
| 540 |
+
head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
|
| 541 |
+
|
| 542 |
+
if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
|
| 543 |
+
if cutlass.const_expr(constant_q_idx is None):
|
| 544 |
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
| 545 |
+
q_idx_floored = floor_if_packed(
|
| 546 |
+
index_tensor[i + j][q_idx_pos], qhead_per_kvhead
|
| 547 |
+
)
|
| 548 |
+
_, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
|
| 549 |
+
q_idx_vec[j] = q_idx_wrapped
|
| 550 |
+
else:
|
| 551 |
+
_, seqlen_k_divmod = fastdiv_mods
|
| 552 |
+
|
| 553 |
+
_, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
|
| 554 |
+
kv_idx_vec[j] = kv_idx_wrapped
|
| 555 |
+
else:
|
| 556 |
+
# No bounds checking - direct indexing
|
| 557 |
+
if constant_q_idx is None:
|
| 558 |
+
q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
|
| 559 |
+
kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
|
| 560 |
+
|
| 561 |
+
grad_ssa = grad_vec.load()
|
| 562 |
+
score_ssa = score_vec.load()
|
| 563 |
+
kv_idx_ssa = kv_idx_vec.load()
|
| 564 |
+
|
| 565 |
+
if cutlass.const_expr(constant_q_idx is None):
|
| 566 |
+
q_idx_ssa = q_idx_vec.load()
|
| 567 |
+
else:
|
| 568 |
+
q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,))
|
| 569 |
+
|
| 570 |
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
| 571 |
+
head_idx_ssa = head_idx_vec.load()
|
| 572 |
+
else:
|
| 573 |
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
|
| 574 |
+
|
| 575 |
+
aux_args = []
|
| 576 |
+
if cutlass.const_expr(aux_tensors is not None):
|
| 577 |
+
aux_args = aux_tensors
|
| 578 |
+
|
| 579 |
+
grad_out_ssa = score_mod_bwd(
|
| 580 |
+
grad_ssa,
|
| 581 |
+
score_ssa,
|
| 582 |
+
batch_idx_ssa,
|
| 583 |
+
head_idx_ssa,
|
| 584 |
+
q_idx=q_idx_ssa,
|
| 585 |
+
kv_idx=kv_idx_ssa,
|
| 586 |
+
seqlen_info=seqlen_info,
|
| 587 |
+
aux_tensors=aux_args,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
grad_vec.store(grad_out_ssa)
|
| 591 |
+
for j in cutlass.range(vec_size, unroll_full=True):
|
| 592 |
+
grad_tensor[i + j] = grad_vec[j]
|
build/torch-cuda/testing.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
from functools import wraps
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from torch._guards import active_fake_mode
|
| 10 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IndexFirstAxis(torch.autograd.Function):
|
| 14 |
+
@staticmethod
|
| 15 |
+
def forward(ctx, input, indices):
|
| 16 |
+
ctx.save_for_backward(indices)
|
| 17 |
+
assert input.ndim >= 2
|
| 18 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| 19 |
+
second_dim = other_shape.numel()
|
| 20 |
+
return torch.gather(
|
| 21 |
+
rearrange(input, "b ... -> b (...)"),
|
| 22 |
+
0,
|
| 23 |
+
repeat(indices, "z -> z d", d=second_dim),
|
| 24 |
+
).reshape(-1, *other_shape)
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def backward(ctx, grad_output):
|
| 28 |
+
(indices,) = ctx.saved_tensors
|
| 29 |
+
assert grad_output.ndim >= 2
|
| 30 |
+
other_shape = grad_output.shape[1:]
|
| 31 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
| 32 |
+
grad_input = torch.zeros(
|
| 33 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
| 34 |
+
device=grad_output.device,
|
| 35 |
+
dtype=grad_output.dtype,
|
| 36 |
+
)
|
| 37 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
| 38 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
index_first_axis = IndexFirstAxis.apply
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def forward(ctx, values, indices, first_axis_dim):
|
| 47 |
+
ctx.save_for_backward(indices)
|
| 48 |
+
assert indices.ndim == 1
|
| 49 |
+
assert values.ndim >= 2
|
| 50 |
+
output = torch.zeros(
|
| 51 |
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
| 52 |
+
)
|
| 53 |
+
output[indices] = values
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def backward(ctx, grad_output):
|
| 58 |
+
(indices,) = ctx.saved_tensors
|
| 59 |
+
grad_values = grad_output[indices]
|
| 60 |
+
return grad_values, None, None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
| 67 |
+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
| 68 |
+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
| 69 |
+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 70 |
+
in_fake_mode = active_fake_mode() is not None
|
| 71 |
+
if not in_fake_mode:
|
| 72 |
+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
| 73 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 74 |
+
else:
|
| 75 |
+
# torch.nonzero and .item() are not supported in FakeTensorMode
|
| 76 |
+
batch_size, seqlen = attention_mask.shape
|
| 77 |
+
indices = torch.arange(batch_size * seqlen, device=hidden_states.device)
|
| 78 |
+
max_seqlen_in_batch = seqlen
|
| 79 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 80 |
+
return (
|
| 81 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
| 82 |
+
indices,
|
| 83 |
+
cu_seqlens,
|
| 84 |
+
max_seqlen_in_batch,
|
| 85 |
+
used_seqlens_in_batch,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
| 90 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 91 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
| 95 |
+
assert mode in ["full", "random", "third"]
|
| 96 |
+
if mode == "full":
|
| 97 |
+
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
| 98 |
+
elif mode == "random":
|
| 99 |
+
lengths = torch.randint(
|
| 100 |
+
max(0 if zero_lengths else 1, max_seqlen - 20),
|
| 101 |
+
max_seqlen + 1,
|
| 102 |
+
(batch_size, 1),
|
| 103 |
+
device=device,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
lengths = torch.randint(
|
| 107 |
+
max(0 if zero_lengths else 1, max_seqlen // 3),
|
| 108 |
+
max_seqlen + 1,
|
| 109 |
+
(batch_size, 1),
|
| 110 |
+
device=device,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if zero_lengths:
|
| 114 |
+
for i in range(batch_size):
|
| 115 |
+
if i % 5 == 0:
|
| 116 |
+
lengths[i] = 0
|
| 117 |
+
lengths[-1] = 0
|
| 118 |
+
padding_mask = (
|
| 119 |
+
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
| 120 |
+
)
|
| 121 |
+
return padding_mask
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_qkv(
|
| 125 |
+
q,
|
| 126 |
+
k,
|
| 127 |
+
v,
|
| 128 |
+
query_padding_mask=None,
|
| 129 |
+
key_padding_mask=None,
|
| 130 |
+
qv=None,
|
| 131 |
+
kvpacked=False,
|
| 132 |
+
qkvpacked=False,
|
| 133 |
+
query_unused_mask=None,
|
| 134 |
+
key_unused_mask=None,
|
| 135 |
+
):
|
| 136 |
+
assert not (kvpacked and qkvpacked)
|
| 137 |
+
batch_size, seqlen_q, nheads, d = q.shape
|
| 138 |
+
d_v = v.shape[-1]
|
| 139 |
+
_, seqlen_k, nheads_k, _ = k.shape
|
| 140 |
+
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
| 141 |
+
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
|
| 142 |
+
if query_unused_mask is not None or key_unused_mask is not None:
|
| 143 |
+
assert not kvpacked
|
| 144 |
+
assert not qkvpacked
|
| 145 |
+
|
| 146 |
+
if query_padding_mask is not None:
|
| 147 |
+
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
| 148 |
+
q, query_padding_mask, query_unused_mask
|
| 149 |
+
)
|
| 150 |
+
output_pad_fn = lambda output_unpad: pad_input(
|
| 151 |
+
output_unpad, indices_q, batch_size, seqlen_q
|
| 152 |
+
)
|
| 153 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
| 154 |
+
else:
|
| 155 |
+
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
| 156 |
+
cu_seqlens_q = torch.arange(
|
| 157 |
+
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
| 158 |
+
)
|
| 159 |
+
seqused_q = None
|
| 160 |
+
max_seqlen_q = seqlen_q
|
| 161 |
+
output_pad_fn = lambda output_unpad: rearrange(
|
| 162 |
+
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
| 163 |
+
)
|
| 164 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
| 165 |
+
|
| 166 |
+
if key_padding_mask is not None:
|
| 167 |
+
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
| 168 |
+
k, key_padding_mask, key_unused_mask
|
| 169 |
+
)
|
| 170 |
+
v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask)
|
| 171 |
+
else:
|
| 172 |
+
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
| 173 |
+
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
| 174 |
+
cu_seqlens_k = torch.arange(
|
| 175 |
+
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
| 176 |
+
)
|
| 177 |
+
seqused_k = None
|
| 178 |
+
max_seqlen_k = seqlen_k
|
| 179 |
+
|
| 180 |
+
if qkvpacked:
|
| 181 |
+
assert (query_padding_mask == key_padding_mask).all()
|
| 182 |
+
assert nheads == nheads_k
|
| 183 |
+
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
| 184 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 185 |
+
if query_padding_mask is not None:
|
| 186 |
+
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
|
| 187 |
+
else:
|
| 188 |
+
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
| 189 |
+
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
| 190 |
+
)
|
| 191 |
+
return (
|
| 192 |
+
qkv_unpad.detach().requires_grad_(),
|
| 193 |
+
cu_seqlens_q,
|
| 194 |
+
max_seqlen_q,
|
| 195 |
+
qkv.detach().requires_grad_(),
|
| 196 |
+
output_pad_fn,
|
| 197 |
+
dqkv_pad_fn,
|
| 198 |
+
)
|
| 199 |
+
elif kvpacked:
|
| 200 |
+
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
| 201 |
+
kv = torch.stack([k, v], dim=2)
|
| 202 |
+
dq_pad_fn = output_pad_fn
|
| 203 |
+
if key_padding_mask is not None:
|
| 204 |
+
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
|
| 205 |
+
else:
|
| 206 |
+
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
| 207 |
+
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
| 208 |
+
)
|
| 209 |
+
return (
|
| 210 |
+
q_unpad.detach().requires_grad_(),
|
| 211 |
+
kv_unpad.detach().requires_grad_(),
|
| 212 |
+
cu_seqlens_q,
|
| 213 |
+
cu_seqlens_k,
|
| 214 |
+
max_seqlen_q,
|
| 215 |
+
max_seqlen_k,
|
| 216 |
+
q.detach().requires_grad_(),
|
| 217 |
+
kv.detach().requires_grad_(),
|
| 218 |
+
output_pad_fn,
|
| 219 |
+
dq_pad_fn,
|
| 220 |
+
dkv_pad_fn,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
dq_pad_fn = output_pad_fn
|
| 224 |
+
if key_padding_mask is not None:
|
| 225 |
+
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
|
| 226 |
+
else:
|
| 227 |
+
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
|
| 228 |
+
return (
|
| 229 |
+
q_unpad.detach().requires_grad_(),
|
| 230 |
+
k_unpad.detach().requires_grad_(),
|
| 231 |
+
v_unpad.detach().requires_grad_(),
|
| 232 |
+
qv_unpad.detach() if qv is not None else None,
|
| 233 |
+
cu_seqlens_q,
|
| 234 |
+
cu_seqlens_k,
|
| 235 |
+
seqused_q,
|
| 236 |
+
seqused_k,
|
| 237 |
+
max_seqlen_q,
|
| 238 |
+
max_seqlen_k,
|
| 239 |
+
q.detach().requires_grad_(),
|
| 240 |
+
k.detach().requires_grad_(),
|
| 241 |
+
v.detach().requires_grad_(),
|
| 242 |
+
qv.detach() if qv is not None else None,
|
| 243 |
+
output_pad_fn,
|
| 244 |
+
dq_pad_fn,
|
| 245 |
+
dk_pad_fn,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def construct_local_mask(
|
| 250 |
+
seqlen_q,
|
| 251 |
+
seqlen_k,
|
| 252 |
+
window_size=(None, None),
|
| 253 |
+
sink_token_length=0,
|
| 254 |
+
query_padding_mask=None,
|
| 255 |
+
key_padding_mask=None,
|
| 256 |
+
key_leftpad=None,
|
| 257 |
+
device=None,
|
| 258 |
+
):
|
| 259 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
| 260 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
| 261 |
+
if key_leftpad is not None:
|
| 262 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
| 263 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
| 264 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
| 265 |
+
sk = (
|
| 266 |
+
seqlen_k
|
| 267 |
+
if key_padding_mask is None
|
| 268 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 269 |
+
)
|
| 270 |
+
sq = (
|
| 271 |
+
seqlen_q
|
| 272 |
+
if query_padding_mask is None
|
| 273 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 274 |
+
)
|
| 275 |
+
if window_size[0] is None:
|
| 276 |
+
return col_idx > row_idx + sk - sq + window_size[1]
|
| 277 |
+
else:
|
| 278 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
| 279 |
+
if window_size[1] is None:
|
| 280 |
+
local_mask_left = col_idx > sk
|
| 281 |
+
else:
|
| 282 |
+
local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)
|
| 283 |
+
return torch.logical_or(
|
| 284 |
+
local_mask_left,
|
| 285 |
+
torch.logical_and(
|
| 286 |
+
col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
|
| 287 |
+
),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def construct_chunk_mask(
|
| 292 |
+
seqlen_q,
|
| 293 |
+
seqlen_k,
|
| 294 |
+
attention_chunk,
|
| 295 |
+
query_padding_mask=None,
|
| 296 |
+
key_padding_mask=None,
|
| 297 |
+
key_leftpad=None,
|
| 298 |
+
device=None,
|
| 299 |
+
):
|
| 300 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
| 301 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
| 302 |
+
if key_leftpad is not None:
|
| 303 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
| 304 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
| 305 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
| 306 |
+
sk = (
|
| 307 |
+
seqlen_k
|
| 308 |
+
if key_padding_mask is None
|
| 309 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 310 |
+
)
|
| 311 |
+
sq = (
|
| 312 |
+
seqlen_q
|
| 313 |
+
if query_padding_mask is None
|
| 314 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 315 |
+
)
|
| 316 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
| 317 |
+
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
|
| 318 |
+
return torch.logical_or(
|
| 319 |
+
col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def attention_ref(
|
| 324 |
+
q,
|
| 325 |
+
k,
|
| 326 |
+
v,
|
| 327 |
+
query_padding_mask=None,
|
| 328 |
+
key_padding_mask=None,
|
| 329 |
+
key_leftpad=None,
|
| 330 |
+
attn_bias=None,
|
| 331 |
+
dropout_p=0.0,
|
| 332 |
+
dropout_mask=None,
|
| 333 |
+
causal=False,
|
| 334 |
+
qv=None,
|
| 335 |
+
q_descale=None,
|
| 336 |
+
k_descale=None,
|
| 337 |
+
v_descale=None,
|
| 338 |
+
window_size=(None, None),
|
| 339 |
+
attention_chunk=0,
|
| 340 |
+
sink_token_length=0,
|
| 341 |
+
learnable_sink: Optional[torch.Tensor] = None,
|
| 342 |
+
softcap=0.0,
|
| 343 |
+
upcast=True,
|
| 344 |
+
reorder_ops=False,
|
| 345 |
+
intermediate_dtype=None,
|
| 346 |
+
):
|
| 347 |
+
if causal:
|
| 348 |
+
window_size = (window_size[0], 0)
|
| 349 |
+
dtype_og = q.dtype
|
| 350 |
+
if upcast:
|
| 351 |
+
q, k, v = q.float(), k.float(), v.float()
|
| 352 |
+
qv = qv.float() if qv is not None else None
|
| 353 |
+
if q_descale is not None:
|
| 354 |
+
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
|
| 355 |
+
q = (q.float() * q_descale).to(q.dtype)
|
| 356 |
+
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
|
| 357 |
+
if k_descale is not None:
|
| 358 |
+
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
|
| 359 |
+
if v_descale is not None:
|
| 360 |
+
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
|
| 361 |
+
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
| 362 |
+
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
| 363 |
+
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
| 364 |
+
d = q.shape[-1]
|
| 365 |
+
dv = v.shape[-1]
|
| 366 |
+
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
|
| 367 |
+
if not reorder_ops:
|
| 368 |
+
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
| 369 |
+
else:
|
| 370 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 371 |
+
if qv is not None:
|
| 372 |
+
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
|
| 373 |
+
if softcap > 0:
|
| 374 |
+
scores = torch.tanh(scores / softcap) * softcap
|
| 375 |
+
if key_padding_mask is not None:
|
| 376 |
+
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
| 377 |
+
local_mask = None
|
| 378 |
+
if window_size[0] is not None or window_size[1] is not None:
|
| 379 |
+
local_mask = construct_local_mask(
|
| 380 |
+
seqlen_q,
|
| 381 |
+
seqlen_k,
|
| 382 |
+
window_size,
|
| 383 |
+
sink_token_length,
|
| 384 |
+
query_padding_mask,
|
| 385 |
+
key_padding_mask,
|
| 386 |
+
key_leftpad=key_leftpad,
|
| 387 |
+
device=q.device,
|
| 388 |
+
)
|
| 389 |
+
if attention_chunk > 0:
|
| 390 |
+
chunk_mask = construct_chunk_mask(
|
| 391 |
+
seqlen_q,
|
| 392 |
+
seqlen_k,
|
| 393 |
+
attention_chunk,
|
| 394 |
+
query_padding_mask,
|
| 395 |
+
key_padding_mask,
|
| 396 |
+
key_leftpad=key_leftpad,
|
| 397 |
+
device=q.device,
|
| 398 |
+
)
|
| 399 |
+
local_mask = (
|
| 400 |
+
torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
| 401 |
+
)
|
| 402 |
+
if local_mask is not None:
|
| 403 |
+
scores.masked_fill_(local_mask, float("-inf"))
|
| 404 |
+
if attn_bias is not None:
|
| 405 |
+
scores = scores + attn_bias
|
| 406 |
+
if learnable_sink is None:
|
| 407 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 408 |
+
else:
|
| 409 |
+
scores_fp32 = scores.to(torch.float32)
|
| 410 |
+
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
| 411 |
+
learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
|
| 412 |
+
logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
|
| 413 |
+
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
| 414 |
+
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
| 415 |
+
learnable_sink - logits_or_sinks_max
|
| 416 |
+
)
|
| 417 |
+
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
| 418 |
+
if query_padding_mask is not None:
|
| 419 |
+
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
| 420 |
+
if key_padding_mask is not None:
|
| 421 |
+
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
| 422 |
+
if local_mask is not None:
|
| 423 |
+
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
|
| 424 |
+
dropout_scaling = 1.0 / (1 - dropout_p)
|
| 425 |
+
if dropout_mask is not None:
|
| 426 |
+
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
| 427 |
+
else:
|
| 428 |
+
attention_drop = attention
|
| 429 |
+
if intermediate_dtype is not None:
|
| 430 |
+
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
|
| 431 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
| 432 |
+
if query_padding_mask is not None:
|
| 433 |
+
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
| 434 |
+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def maybe_fake_tensor_mode(fake: bool = True):
|
| 438 |
+
"""
|
| 439 |
+
One way to populate/pre-compile cache is to use torch fake tensor mode,
|
| 440 |
+
which does not allocate actual GPU tensors but retains tensor shape/dtype
|
| 441 |
+
metadata for cute.compile.
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
def decorator(fn):
|
| 445 |
+
@wraps(fn)
|
| 446 |
+
def wrapper(*args, **kwargs):
|
| 447 |
+
with FakeTensorMode() if fake else nullcontext():
|
| 448 |
+
return fn(*args, **kwargs)
|
| 449 |
+
|
| 450 |
+
return wrapper
|
| 451 |
+
|
| 452 |
+
return decorator
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def is_fake_mode() -> bool:
|
| 456 |
+
return active_fake_mode() is not None
|
build/torch-cuda/tile_scheduler.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from typing import override
|
| 8 |
+
except ImportError: # Python < 3.12
|
| 9 |
+
from typing_extensions import override
|
| 10 |
+
|
| 11 |
+
import cutlass
|
| 12 |
+
from cutlass._mlir import ir
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
from cutlass import Int32, const_expr
|
| 15 |
+
from cutlass.cute import FastDivmodDivisor
|
| 16 |
+
|
| 17 |
+
from .quack.cute_dsl_utils import ParamsBase
|
| 18 |
+
|
| 19 |
+
from . import utils
|
| 20 |
+
from .fast_math import clz
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
| 24 |
+
"""Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
|
| 25 |
+
|
| 26 |
+
@override
|
| 27 |
+
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
|
| 28 |
+
assert len(values) == 5
|
| 29 |
+
new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
|
| 30 |
+
new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
|
| 31 |
+
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class TileSchedulerArguments(ParamsBase):
|
| 36 |
+
num_block: Int32
|
| 37 |
+
num_head: Int32
|
| 38 |
+
num_batch: Int32
|
| 39 |
+
num_splits: Int32
|
| 40 |
+
seqlen_k: Int32
|
| 41 |
+
headdim: Int32
|
| 42 |
+
headdim_v: Int32
|
| 43 |
+
total_q: Int32
|
| 44 |
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
| 45 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 46 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
| 47 |
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
| 48 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 49 |
+
element_size: cutlass.Constexpr[int] = 2
|
| 50 |
+
is_persistent: cutlass.Constexpr[bool] = False
|
| 51 |
+
lpt: cutlass.Constexpr[bool] = False
|
| 52 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 53 |
+
head_swizzle: cutlass.Constexpr[bool] = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SingleTileScheduler:
|
| 57 |
+
@dataclass
|
| 58 |
+
class Params(ParamsBase):
|
| 59 |
+
num_block: Int32
|
| 60 |
+
num_head: Int32
|
| 61 |
+
num_batch: Int32
|
| 62 |
+
num_splits: Int32
|
| 63 |
+
num_splits_divmod: FastDivmodDivisor
|
| 64 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 65 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def create(
|
| 69 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 70 |
+
) -> "SingleTileScheduler.Params":
|
| 71 |
+
return SingleTileScheduler.Params(
|
| 72 |
+
args.num_block,
|
| 73 |
+
args.num_head,
|
| 74 |
+
args.num_batch,
|
| 75 |
+
args.num_splits,
|
| 76 |
+
FastDivmodDivisor(args.num_splits),
|
| 77 |
+
args.is_split_kv,
|
| 78 |
+
args.cluster_shape_mn,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
|
| 82 |
+
self.params = params
|
| 83 |
+
self._blk_coord = blk_coord
|
| 84 |
+
self._is_first_block = True
|
| 85 |
+
self._loc = loc
|
| 86 |
+
self._ip = ip
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
| 90 |
+
return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
|
| 94 |
+
# if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
| 95 |
+
# blk_coord = cute.arch.block_idx()
|
| 96 |
+
# else:
|
| 97 |
+
# # All CTAs in a cluster must get the same block coordinate
|
| 98 |
+
# blk_coord = cute.arch.cluster_idx()
|
| 99 |
+
# Temporary set to block_idx until we sort out the best way to handle cluster
|
| 100 |
+
blk_coord = cute.arch.block_idx()
|
| 101 |
+
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
|
| 102 |
+
|
| 103 |
+
# called by host
|
| 104 |
+
@staticmethod
|
| 105 |
+
def get_grid_shape(
|
| 106 |
+
params: Params,
|
| 107 |
+
*,
|
| 108 |
+
loc=None,
|
| 109 |
+
ip=None,
|
| 110 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 111 |
+
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
|
| 112 |
+
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 113 |
+
return (
|
| 114 |
+
cute.round_up(params.num_block, params.cluster_shape_mn[0]),
|
| 115 |
+
params.num_head * params.num_splits,
|
| 116 |
+
params.num_batch,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 120 |
+
block_idx, head_idx, batch_idx = self._blk_coord
|
| 121 |
+
if const_expr(self.params.is_split_kv):
|
| 122 |
+
head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
|
| 123 |
+
else:
|
| 124 |
+
split_idx = Int32(0)
|
| 125 |
+
return WorkTileInfo(
|
| 126 |
+
(block_idx, head_idx, batch_idx, split_idx),
|
| 127 |
+
self._is_first_block,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 131 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 132 |
+
|
| 133 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 137 |
+
self._is_first_block = False
|
| 138 |
+
|
| 139 |
+
def __extract_mlir_values__(self):
|
| 140 |
+
values, self._values_pos = [], []
|
| 141 |
+
for obj in [self.params, self._blk_coord]:
|
| 142 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 143 |
+
values += obj_values
|
| 144 |
+
self._values_pos.append(len(obj_values))
|
| 145 |
+
return values
|
| 146 |
+
|
| 147 |
+
def __new_from_mlir_values__(self, values):
|
| 148 |
+
obj_list = []
|
| 149 |
+
for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
|
| 150 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 151 |
+
values = values[n_items:]
|
| 152 |
+
return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class StaticPersistentTileScheduler:
|
| 156 |
+
@dataclass
|
| 157 |
+
class Params(ParamsBase):
|
| 158 |
+
num_block_cluster_divmod: FastDivmodDivisor
|
| 159 |
+
num_head_divmod: FastDivmodDivisor
|
| 160 |
+
total_blocks_cluster: Int32
|
| 161 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def create(
|
| 165 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 166 |
+
) -> "StaticPersistentTileScheduler.Params":
|
| 167 |
+
num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn))
|
| 168 |
+
total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch
|
| 169 |
+
return StaticPersistentTileScheduler.Params(
|
| 170 |
+
FastDivmodDivisor(num_block_cluster),
|
| 171 |
+
FastDivmodDivisor(args.num_head),
|
| 172 |
+
total_blocks_cluster,
|
| 173 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
|
| 177 |
+
self.params = params
|
| 178 |
+
self._tile_idx = tile_idx
|
| 179 |
+
self._loc = loc
|
| 180 |
+
self._ip = ip
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
| 184 |
+
return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 185 |
+
|
| 186 |
+
@staticmethod
|
| 187 |
+
def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler":
|
| 188 |
+
if const_expr(cute.size(params.cluster_shape_m) == 1):
|
| 189 |
+
tile_idx = cute.arch.block_idx()[0]
|
| 190 |
+
else:
|
| 191 |
+
tile_idx = cute.arch.cluster_idx()[0]
|
| 192 |
+
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
|
| 193 |
+
|
| 194 |
+
# called by host
|
| 195 |
+
@staticmethod
|
| 196 |
+
def get_grid_shape(
|
| 197 |
+
params: Params,
|
| 198 |
+
*,
|
| 199 |
+
loc=None,
|
| 200 |
+
ip=None,
|
| 201 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 202 |
+
hardware_info = cutlass.utils.HardwareInfo()
|
| 203 |
+
sm_count = hardware_info.get_device_multiprocessor_count()
|
| 204 |
+
# Grid must be a multiple of cluster_shape_m for CUDA cluster launch.
|
| 205 |
+
max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
|
| 206 |
+
grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
|
| 207 |
+
return (grid_x, Int32(1), Int32(1))
|
| 208 |
+
|
| 209 |
+
# @cute.jit
|
| 210 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 211 |
+
hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
|
| 212 |
+
batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
|
| 213 |
+
is_valid = self._tile_idx < self.params.total_blocks_cluster
|
| 214 |
+
# if cute.arch.thread_idx()[0] == 0:
|
| 215 |
+
# cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
|
| 216 |
+
return WorkTileInfo(
|
| 217 |
+
(Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 221 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 222 |
+
|
| 223 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 224 |
+
pass
|
| 225 |
+
|
| 226 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 227 |
+
if const_expr(self.params.cluster_shape_m == 1):
|
| 228 |
+
self._tile_idx += cute.arch.grid_dim()[0]
|
| 229 |
+
else:
|
| 230 |
+
self._tile_idx += cute.arch.cluster_dim()[0]
|
| 231 |
+
|
| 232 |
+
def __extract_mlir_values__(self):
|
| 233 |
+
values, self._values_pos = [], []
|
| 234 |
+
for obj in [self.params, self._tile_idx]:
|
| 235 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 236 |
+
values += obj_values
|
| 237 |
+
self._values_pos.append(len(obj_values))
|
| 238 |
+
return values
|
| 239 |
+
|
| 240 |
+
def __new_from_mlir_values__(self, values):
|
| 241 |
+
obj_list = []
|
| 242 |
+
for obj, n_items in zip(
|
| 243 |
+
[self.params, self._tile_idx],
|
| 244 |
+
self._values_pos,
|
| 245 |
+
):
|
| 246 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 247 |
+
values = values[n_items:]
|
| 248 |
+
return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class SingleTileLPTScheduler:
|
| 252 |
+
@dataclass
|
| 253 |
+
class Params(ParamsBase):
|
| 254 |
+
total_blocks: Int32
|
| 255 |
+
num_splits: Int32
|
| 256 |
+
num_block: Int32
|
| 257 |
+
l2_minor: Int32
|
| 258 |
+
num_block_divmod: FastDivmodDivisor
|
| 259 |
+
num_head_divmod: FastDivmodDivisor
|
| 260 |
+
l2_minor_divmod: FastDivmodDivisor
|
| 261 |
+
l2_major_divmod: FastDivmodDivisor
|
| 262 |
+
l2_minor_residual_divmod: FastDivmodDivisor
|
| 263 |
+
num_hb_quotient: Int32
|
| 264 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
@cute.jit
|
| 268 |
+
def create(
|
| 269 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 270 |
+
) -> "SingleTileLPTScheduler.Params":
|
| 271 |
+
# cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)
|
| 272 |
+
size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 273 |
+
size_one_head = size_one_kv_head
|
| 274 |
+
size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
|
| 275 |
+
# Swizzle is the size of each "section". Round swizzle to a power of 2
|
| 276 |
+
# Need to be careful about the case where only one head will fit
|
| 277 |
+
# swizzle is how many heads can fit in L2
|
| 278 |
+
# swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
|
| 279 |
+
# Seems faster if swizzle if a power of 2
|
| 280 |
+
log2_floor = lambda n: 31 - clz(n)
|
| 281 |
+
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
| 282 |
+
# swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
|
| 283 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 284 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 285 |
+
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
| 286 |
+
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
|
| 287 |
+
return SingleTileLPTScheduler.Params(
|
| 288 |
+
total_blocks=args.num_block * args.num_head * args.num_batch,
|
| 289 |
+
num_block=args.num_block,
|
| 290 |
+
l2_minor=Int32(swizzle),
|
| 291 |
+
num_block_divmod=FastDivmodDivisor(args.num_block),
|
| 292 |
+
num_head_divmod=FastDivmodDivisor(args.num_head),
|
| 293 |
+
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
| 294 |
+
l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
|
| 295 |
+
l2_minor_residual_divmod=FastDivmodDivisor(
|
| 296 |
+
max(num_hb_remainder, 1)
|
| 297 |
+
), # don't divide by 0
|
| 298 |
+
num_hb_quotient=Int32(num_hb_quotient),
|
| 299 |
+
num_splits=args.num_splits,
|
| 300 |
+
is_split_kv=args.is_split_kv,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
|
| 304 |
+
self.params = params
|
| 305 |
+
self._tile_idx = tile_idx
|
| 306 |
+
self._split_idx = split_idx
|
| 307 |
+
self._loc = loc
|
| 308 |
+
self._ip = ip
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
| 312 |
+
return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)
|
| 313 |
+
|
| 314 |
+
@staticmethod
|
| 315 |
+
@cute.jit
|
| 316 |
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler":
|
| 317 |
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 318 |
+
return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 319 |
+
|
| 320 |
+
# called by host
|
| 321 |
+
@staticmethod
|
| 322 |
+
def get_grid_shape(
|
| 323 |
+
params: Params,
|
| 324 |
+
*,
|
| 325 |
+
loc=None,
|
| 326 |
+
ip=None,
|
| 327 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 328 |
+
return (params.total_blocks, params.num_splits, Int32(1))
|
| 329 |
+
|
| 330 |
+
@cute.jit
|
| 331 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 332 |
+
params = self.params
|
| 333 |
+
# Implement LPT scheduling coordinate calculation
|
| 334 |
+
bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
|
| 335 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 336 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 337 |
+
block, bidhb_residual = 0, 0
|
| 338 |
+
if bidhb < params.num_hb_quotient:
|
| 339 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
|
| 340 |
+
else:
|
| 341 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
|
| 342 |
+
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
| 343 |
+
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
| 344 |
+
# Longest-processing-time-first
|
| 345 |
+
block = params.num_block - 1 - block
|
| 346 |
+
is_valid = self._tile_idx < params.total_blocks
|
| 347 |
+
return WorkTileInfo(
|
| 348 |
+
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 352 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 353 |
+
|
| 354 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 355 |
+
pass
|
| 356 |
+
|
| 357 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 358 |
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 359 |
+
self._tile_idx = self.params.total_blocks
|
| 360 |
+
|
| 361 |
+
def __extract_mlir_values__(self):
|
| 362 |
+
values, self._values_pos = [], []
|
| 363 |
+
for obj in [self.params, self._tile_idx, self._split_idx]:
|
| 364 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 365 |
+
values += obj_values
|
| 366 |
+
self._values_pos.append(len(obj_values))
|
| 367 |
+
return values
|
| 368 |
+
|
| 369 |
+
def __new_from_mlir_values__(self, values):
|
| 370 |
+
obj_list = []
|
| 371 |
+
for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):
|
| 372 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 373 |
+
values = values[n_items:]
|
| 374 |
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class SingleTileLPTBwdScheduler:
|
| 378 |
+
@dataclass
|
| 379 |
+
class Params(ParamsBase):
|
| 380 |
+
total_blocks: Int32
|
| 381 |
+
num_block: Int32
|
| 382 |
+
l2_minor: Int32
|
| 383 |
+
num_head_divmod: FastDivmodDivisor
|
| 384 |
+
l2_minor_divmod: FastDivmodDivisor
|
| 385 |
+
l2_major_divmod: FastDivmodDivisor
|
| 386 |
+
l2_minor_residual_divmod: FastDivmodDivisor
|
| 387 |
+
num_hb_quotient: Int32
|
| 388 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 389 |
+
spt: cutlass.Constexpr[bool] = True
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
@cute.jit
|
| 393 |
+
def create(
|
| 394 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 395 |
+
) -> "SingleTileLPTBwdScheduler.Params":
|
| 396 |
+
size_l2 = 50 * 1024 * 1024
|
| 397 |
+
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 398 |
+
# size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
|
| 399 |
+
size_one_dqaccum_head = 0
|
| 400 |
+
size_one_head = size_one_qdo_head + size_one_dqaccum_head
|
| 401 |
+
log2_floor = lambda n: 31 - clz(n)
|
| 402 |
+
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
| 403 |
+
# swizzle = 8
|
| 404 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 405 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 406 |
+
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
| 407 |
+
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
|
| 408 |
+
num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])
|
| 409 |
+
return SingleTileLPTBwdScheduler.Params(
|
| 410 |
+
total_blocks=(num_block * args.cluster_shape_mn[0])
|
| 411 |
+
* args.num_head
|
| 412 |
+
* args.num_batch,
|
| 413 |
+
num_block=num_block,
|
| 414 |
+
l2_minor=Int32(swizzle),
|
| 415 |
+
num_head_divmod=FastDivmodDivisor(args.num_head),
|
| 416 |
+
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
| 417 |
+
l2_major_divmod=FastDivmodDivisor(swizzle * num_block),
|
| 418 |
+
l2_minor_residual_divmod=FastDivmodDivisor(
|
| 419 |
+
max(num_hb_remainder, 1)
|
| 420 |
+
), # don't divide by 0
|
| 421 |
+
num_hb_quotient=Int32(num_hb_quotient),
|
| 422 |
+
cluster_shape_mn=args.cluster_shape_mn,
|
| 423 |
+
spt=args.lpt,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
|
| 427 |
+
self.params = params
|
| 428 |
+
self._tile_idx = tile_idx
|
| 429 |
+
self._loc = loc
|
| 430 |
+
self._ip = ip
|
| 431 |
+
|
| 432 |
+
@staticmethod
|
| 433 |
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
| 434 |
+
return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
|
| 435 |
+
|
| 436 |
+
@staticmethod
|
| 437 |
+
@cute.jit
|
| 438 |
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler":
|
| 439 |
+
tile_idx = cute.arch.block_idx()[0]
|
| 440 |
+
return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip)
|
| 441 |
+
|
| 442 |
+
# called by host
|
| 443 |
+
@staticmethod
|
| 444 |
+
def get_grid_shape(
|
| 445 |
+
params: Params,
|
| 446 |
+
*,
|
| 447 |
+
loc=None,
|
| 448 |
+
ip=None,
|
| 449 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 450 |
+
return (params.total_blocks, Int32(1), Int32(1))
|
| 451 |
+
|
| 452 |
+
@cute.jit
|
| 453 |
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
| 454 |
+
cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0]
|
| 455 |
+
params = self.params
|
| 456 |
+
# Implement LPT scheduling coordinate calculation
|
| 457 |
+
bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod)
|
| 458 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 459 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 460 |
+
block, bidhb_residual = 0, 0
|
| 461 |
+
if bidhb < params.num_hb_quotient:
|
| 462 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
|
| 463 |
+
else:
|
| 464 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
|
| 465 |
+
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
| 466 |
+
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
| 467 |
+
if cutlass.const_expr(params.spt):
|
| 468 |
+
block = params.num_block - 1 - block
|
| 469 |
+
if cutlass.const_expr(params.cluster_shape_mn[0] > 1):
|
| 470 |
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
| 471 |
+
block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
| 472 |
+
is_valid = self._tile_idx < params.total_blocks
|
| 473 |
+
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)
|
| 474 |
+
|
| 475 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 476 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 477 |
+
|
| 478 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 479 |
+
pass
|
| 480 |
+
|
| 481 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 482 |
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 483 |
+
self._tile_idx = self.params.total_blocks
|
| 484 |
+
|
| 485 |
+
def __extract_mlir_values__(self):
|
| 486 |
+
values, self._values_pos = [], []
|
| 487 |
+
for obj in [self.params, self._tile_idx]:
|
| 488 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 489 |
+
values += obj_values
|
| 490 |
+
self._values_pos.append(len(obj_values))
|
| 491 |
+
return values
|
| 492 |
+
|
| 493 |
+
def __new_from_mlir_values__(self, values):
|
| 494 |
+
obj_list = []
|
| 495 |
+
for obj, n_items in zip([self.params, self._tile_idx], self._values_pos):
|
| 496 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 497 |
+
values = values[n_items:]
|
| 498 |
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class SingleTileVarlenScheduler:
|
| 502 |
+
@dataclass
|
| 503 |
+
class Params(ParamsBase):
|
| 504 |
+
num_head: Int32
|
| 505 |
+
num_batch: Int32
|
| 506 |
+
total_q: Int32
|
| 507 |
+
num_splits: Int32
|
| 508 |
+
max_kvblock_in_l2: Int32
|
| 509 |
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
| 510 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
| 511 |
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
| 512 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 513 |
+
lpt: cutlass.Constexpr[bool] = False
|
| 514 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 515 |
+
head_swizzle: cutlass.Constexpr[bool] = False
|
| 516 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 517 |
+
|
| 518 |
+
@staticmethod
|
| 519 |
+
@cute.jit
|
| 520 |
+
def create(
|
| 521 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 522 |
+
) -> "SingleTileVarlenScheduler.Params":
|
| 523 |
+
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
|
| 524 |
+
max_kvblock_in_l2 = size_l2 // (
|
| 525 |
+
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
|
| 526 |
+
)
|
| 527 |
+
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
|
| 528 |
+
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
|
| 529 |
+
)
|
| 530 |
+
assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 531 |
+
return SingleTileVarlenScheduler.Params(
|
| 532 |
+
num_head=args.num_head,
|
| 533 |
+
num_batch=args.num_batch,
|
| 534 |
+
total_q=args.total_q,
|
| 535 |
+
num_splits=args.num_splits,
|
| 536 |
+
max_kvblock_in_l2=max_kvblock_in_l2,
|
| 537 |
+
tile_shape_mn=args.tile_shape_mn,
|
| 538 |
+
mCuSeqlensQ=args.mCuSeqlensQ,
|
| 539 |
+
mSeqUsedQ=args.mSeqUsedQ,
|
| 540 |
+
qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
|
| 541 |
+
lpt=args.lpt,
|
| 542 |
+
is_split_kv=args.is_split_kv,
|
| 543 |
+
head_swizzle=args.head_swizzle,
|
| 544 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
|
| 548 |
+
self.params = params
|
| 549 |
+
self._tile_idx = tile_idx
|
| 550 |
+
self._split_idx = split_idx
|
| 551 |
+
self._is_first_block = True
|
| 552 |
+
self._loc = loc
|
| 553 |
+
self._ip = ip
|
| 554 |
+
|
| 555 |
+
@staticmethod
|
| 556 |
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
| 557 |
+
return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)
|
| 558 |
+
|
| 559 |
+
@staticmethod
|
| 560 |
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler":
|
| 561 |
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 562 |
+
return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 563 |
+
|
| 564 |
+
# called by host
|
| 565 |
+
@staticmethod
|
| 566 |
+
def get_grid_shape(
|
| 567 |
+
params: Params,
|
| 568 |
+
*,
|
| 569 |
+
loc=None,
|
| 570 |
+
ip=None,
|
| 571 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 572 |
+
total_blocks_max = (
|
| 573 |
+
params.total_q
|
| 574 |
+
+ params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
|
| 575 |
+
) // params.tile_shape_mn[0]
|
| 576 |
+
# round down to nearest multiple of cluster since odd excess is always padding
|
| 577 |
+
total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
|
| 578 |
+
return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
|
| 579 |
+
|
| 580 |
+
@cute.jit
|
| 581 |
+
def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
|
| 582 |
+
params = self.params
|
| 583 |
+
batch_idx = lane + bidb_start
|
| 584 |
+
if cutlass.const_expr(params.mSeqUsedQ is not None):
|
| 585 |
+
seqlen = Int32(0)
|
| 586 |
+
if batch_idx < params.num_batch:
|
| 587 |
+
seqlen = params.mSeqUsedQ[batch_idx]
|
| 588 |
+
else:
|
| 589 |
+
assert params.mCuSeqlensQ is not None
|
| 590 |
+
cur_cu_seqlen = Int32(0)
|
| 591 |
+
if batch_idx <= params.num_batch:
|
| 592 |
+
cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
|
| 593 |
+
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
| 594 |
+
seqlen = next_cu_seqlen - cur_cu_seqlen
|
| 595 |
+
if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
|
| 596 |
+
seqlen *= params.qhead_per_kvhead_packgqa
|
| 597 |
+
return (
|
| 598 |
+
cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m)
|
| 599 |
+
if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
|
| 600 |
+
else Int32(0)
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
@cute.jit
|
| 604 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 605 |
+
params = self.params
|
| 606 |
+
lane_idx = cute.arch.lane_idx()
|
| 607 |
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
|
| 608 |
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
| 609 |
+
# Total number of blocks for the next 31 batches
|
| 610 |
+
m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
|
| 611 |
+
# Same for all lanes
|
| 612 |
+
group_end_tile = m_blocks_in_group * params.num_head
|
| 613 |
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
|
| 614 |
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
|
| 615 |
+
next_tile_idx = self._tile_idx // params.cluster_shape_m
|
| 616 |
+
while group_end_tile <= next_tile_idx:
|
| 617 |
+
batch_idx += cute.arch.WARP_SIZE - 1
|
| 618 |
+
if batch_idx >= params.num_batch:
|
| 619 |
+
batch_idx = Int32(params.num_batch)
|
| 620 |
+
group_end_tile = next_tile_idx + 1
|
| 621 |
+
else:
|
| 622 |
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
|
| 623 |
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
| 624 |
+
m_blocks_in_group = cute.arch.shuffle_sync(
|
| 625 |
+
num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
|
| 626 |
+
)
|
| 627 |
+
group_end_tile += m_blocks_in_group * params.num_head
|
| 628 |
+
is_valid = False
|
| 629 |
+
if batch_idx >= params.num_batch:
|
| 630 |
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
|
| 631 |
+
else:
|
| 632 |
+
group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
|
| 633 |
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
|
| 634 |
+
# The next problem to process is the first one that does not have ending tile position
|
| 635 |
+
# that is greater than or equal to tile index.
|
| 636 |
+
batch_idx_in_group = cute.arch.popc(
|
| 637 |
+
cute.arch.vote_ballot_sync(
|
| 638 |
+
group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
|
| 639 |
+
)
|
| 640 |
+
)
|
| 641 |
+
batch_idx += batch_idx_in_group
|
| 642 |
+
num_m_blocks_prev_lane = (
|
| 643 |
+
0
|
| 644 |
+
if batch_idx_in_group == 0
|
| 645 |
+
else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
|
| 646 |
+
)
|
| 647 |
+
num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
|
| 648 |
+
mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
|
| 649 |
+
if cutlass.const_expr(params.lpt or params.head_swizzle):
|
| 650 |
+
# This is a version of the SingleTileLPTScheduler, complicated by the fact that
|
| 651 |
+
# the seqlen can vary per batch.
|
| 652 |
+
# TODO: is there any case where num_m_blocks is 0?
|
| 653 |
+
# TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
|
| 654 |
+
num_n_blocks = (
|
| 655 |
+
num_m_blocks
|
| 656 |
+
* params.tile_shape_mn[0]
|
| 657 |
+
// params.qhead_per_kvhead_packgqa
|
| 658 |
+
// params.tile_shape_mn[1]
|
| 659 |
+
)
|
| 660 |
+
# nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
|
| 661 |
+
# Seems faster to have this be a power of 2
|
| 662 |
+
nheads_in_l2 = (
|
| 663 |
+
16
|
| 664 |
+
if num_n_blocks * 16 <= params.max_kvblock_in_l2
|
| 665 |
+
else (
|
| 666 |
+
8
|
| 667 |
+
if num_n_blocks * 8 <= params.max_kvblock_in_l2
|
| 668 |
+
else (
|
| 669 |
+
4
|
| 670 |
+
if num_n_blocks * 4 <= params.max_kvblock_in_l2
|
| 671 |
+
else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
|
| 672 |
+
)
|
| 673 |
+
)
|
| 674 |
+
)
|
| 675 |
+
nheads_in_l2 = min(nheads_in_l2, params.num_head)
|
| 676 |
+
mh_in_l2 = nheads_in_l2 * num_m_blocks
|
| 677 |
+
section_idx = mh_block // mh_in_l2
|
| 678 |
+
l2_mod = mh_block - section_idx * mh_in_l2
|
| 679 |
+
# Deal with tail section
|
| 680 |
+
nheads_in_this_section = (
|
| 681 |
+
nheads_in_l2
|
| 682 |
+
if nheads_in_l2 * (section_idx + 1) <= params.num_head
|
| 683 |
+
else params.num_head - section_idx * nheads_in_l2
|
| 684 |
+
)
|
| 685 |
+
block = l2_mod // nheads_in_this_section
|
| 686 |
+
head_idx_residual = l2_mod - block * nheads_in_this_section
|
| 687 |
+
head_idx = section_idx * nheads_in_l2 + head_idx_residual
|
| 688 |
+
if cutlass.const_expr(params.lpt):
|
| 689 |
+
block = num_m_blocks - 1 - block
|
| 690 |
+
else:
|
| 691 |
+
head_idx = mh_block // num_m_blocks
|
| 692 |
+
block = mh_block - head_idx * num_m_blocks
|
| 693 |
+
is_valid = self._is_first_block and batch_idx < params.num_batch
|
| 694 |
+
if cutlass.const_expr(params.cluster_shape_m > 1):
|
| 695 |
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
| 696 |
+
block = block * params.cluster_shape_m + bidx_in_cluster[0]
|
| 697 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
|
| 698 |
+
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
|
| 699 |
+
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
|
| 700 |
+
|
| 701 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 702 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 703 |
+
|
| 704 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 705 |
+
pass
|
| 706 |
+
|
| 707 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 708 |
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 709 |
+
self._is_first_block = False
|
| 710 |
+
|
| 711 |
+
def __extract_mlir_values__(self):
|
| 712 |
+
values, self._values_pos = [], []
|
| 713 |
+
for obj in [self.params, self._tile_idx, self._split_idx]:
|
| 714 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 715 |
+
values += obj_values
|
| 716 |
+
self._values_pos.append(len(obj_values))
|
| 717 |
+
return values
|
| 718 |
+
|
| 719 |
+
def __new_from_mlir_values__(self, values):
|
| 720 |
+
obj_list = []
|
| 721 |
+
for obj, n_items in zip(
|
| 722 |
+
[self.params, self._tile_idx, self._split_idx],
|
| 723 |
+
self._values_pos,
|
| 724 |
+
):
|
| 725 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 726 |
+
values = values[n_items:]
|
| 727 |
+
return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)
|
build/torch-cuda/utils.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import hashlib
|
| 5 |
+
import inspect
|
| 6 |
+
from typing import Type, Callable, Optional, Tuple, overload
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.cute as cute
|
| 10 |
+
|
| 11 |
+
from cutlass import Float32, const_expr
|
| 12 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 13 |
+
from cutlass._mlir.dialects import nvvm, llvm
|
| 14 |
+
from cutlass.cute.runtime import from_dlpack
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from .quack import activation
|
| 18 |
+
|
| 19 |
+
_MIXER_ATTRS = ("__vec_size__",)
|
| 20 |
+
|
| 21 |
+
# Obtained from sollya:
|
| 22 |
+
# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
|
| 23 |
+
POLY_EX2 = {
|
| 24 |
+
0: (1.0),
|
| 25 |
+
1: (
|
| 26 |
+
1.0,
|
| 27 |
+
0.922497093677520751953125,
|
| 28 |
+
),
|
| 29 |
+
2: (
|
| 30 |
+
1.0,
|
| 31 |
+
0.6657850742340087890625,
|
| 32 |
+
0.330107033252716064453125,
|
| 33 |
+
),
|
| 34 |
+
3: (
|
| 35 |
+
1.0,
|
| 36 |
+
0.695146143436431884765625,
|
| 37 |
+
0.227564394474029541015625,
|
| 38 |
+
0.077119089663028717041015625,
|
| 39 |
+
),
|
| 40 |
+
4: (
|
| 41 |
+
1.0,
|
| 42 |
+
0.693042695522308349609375,
|
| 43 |
+
0.2412912547588348388671875,
|
| 44 |
+
5.2225358784198760986328125e-2,
|
| 45 |
+
1.3434938155114650726318359375e-2,
|
| 46 |
+
),
|
| 47 |
+
5: (
|
| 48 |
+
1.0,
|
| 49 |
+
0.693151414394378662109375,
|
| 50 |
+
0.24016360938549041748046875,
|
| 51 |
+
5.5802188813686370849609375e-2,
|
| 52 |
+
9.01452265679836273193359375e-3,
|
| 53 |
+
1.86810153536498546600341796875e-3,
|
| 54 |
+
),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _compute_base_hash(func: Callable) -> str:
|
| 59 |
+
"""Compute hash from source code or bytecode and closure values."""
|
| 60 |
+
try:
|
| 61 |
+
data = inspect.getsource(func).encode()
|
| 62 |
+
except (OSError, TypeError):
|
| 63 |
+
if hasattr(func, "__code__") and func.__code__ is not None:
|
| 64 |
+
data = func.__code__.co_code
|
| 65 |
+
else:
|
| 66 |
+
data = repr(func).encode()
|
| 67 |
+
|
| 68 |
+
hasher = hashlib.sha256(data)
|
| 69 |
+
|
| 70 |
+
if hasattr(func, "__closure__") and func.__closure__ is not None:
|
| 71 |
+
for cell in func.__closure__:
|
| 72 |
+
hasher.update(repr(cell.cell_contents).encode())
|
| 73 |
+
|
| 74 |
+
return hasher.hexdigest()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def hash_callable(
|
| 78 |
+
func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
|
| 79 |
+
) -> str:
|
| 80 |
+
"""Hash a callable based on the source code or bytecode and closure values.
|
| 81 |
+
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
|
| 82 |
+
attribute, that value is returned immediately as the base hash, then
|
| 83 |
+
metadata dunders are mixed in to produce the final dict-key hash.
|
| 84 |
+
set_cute_hash: whether or not to set func.__cute_hash__
|
| 85 |
+
"""
|
| 86 |
+
# Resolve base hash
|
| 87 |
+
if hasattr(func, "__cute_hash__"):
|
| 88 |
+
base_hash = func.__cute_hash__
|
| 89 |
+
else:
|
| 90 |
+
# Unwrap decorated functions (e.g., cute.jit wrappers).
|
| 91 |
+
base_func = getattr(func, "__wrapped__", func)
|
| 92 |
+
|
| 93 |
+
if hasattr(base_func, "__cute_hash__"):
|
| 94 |
+
base_hash = base_func.__cute_hash__
|
| 95 |
+
else:
|
| 96 |
+
base_hash = _compute_base_hash(base_func)
|
| 97 |
+
|
| 98 |
+
if set_cute_hash:
|
| 99 |
+
base_func.__cute_hash__ = base_hash
|
| 100 |
+
|
| 101 |
+
# Mix in mutable metadata dunders
|
| 102 |
+
mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
|
| 103 |
+
|
| 104 |
+
if all(v is None for v in mixer_values):
|
| 105 |
+
return base_hash
|
| 106 |
+
|
| 107 |
+
hasher = hashlib.sha256(base_hash.encode())
|
| 108 |
+
|
| 109 |
+
for attr, val in zip(_MIXER_ATTRS, mixer_values):
|
| 110 |
+
hasher.update(f"{attr}={val!r}".encode())
|
| 111 |
+
|
| 112 |
+
return hasher.hexdigest()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def create_softcap_scoremod(softcap_val):
|
| 116 |
+
inv_softcap = 1.0 / softcap_val
|
| 117 |
+
|
| 118 |
+
@cute.jit
|
| 119 |
+
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
|
| 120 |
+
scores = acc_S_SSA * inv_softcap
|
| 121 |
+
return scores * cute.math.tanh(scores, fastmath=True)
|
| 122 |
+
|
| 123 |
+
return scoremod_premask_fn
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
| 127 |
+
return (
|
| 128 |
+
from_dlpack(x, assumed_align=alignment)
|
| 129 |
+
.mark_layout_dynamic(leading_dim=leading_dim)
|
| 130 |
+
.mark_compact_shape_dynamic(
|
| 131 |
+
mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def convert_from_dlpack_leading_static(
|
| 137 |
+
x, leading_dim, alignment=16, static_modes=None, stride_order=None
|
| 138 |
+
) -> cute.Tensor:
|
| 139 |
+
if stride_order is None:
|
| 140 |
+
stride_order = x.dim_order()
|
| 141 |
+
x_ = from_dlpack(x, assumed_align=alignment)
|
| 142 |
+
for i in range(x.ndim):
|
| 143 |
+
if i != leading_dim and (static_modes is None or i not in static_modes):
|
| 144 |
+
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
|
| 145 |
+
return x_
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def make_tiled_copy_A(
|
| 149 |
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
| 150 |
+
) -> cute.TiledCopy:
|
| 151 |
+
if const_expr(swapAB):
|
| 152 |
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
| 153 |
+
else:
|
| 154 |
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def make_tiled_copy_B(
|
| 158 |
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
| 159 |
+
) -> cute.TiledCopy:
|
| 160 |
+
if const_expr(swapAB):
|
| 161 |
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 162 |
+
else:
|
| 163 |
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def mma_make_fragment_A(
|
| 167 |
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
| 168 |
+
) -> cute.Tensor:
|
| 169 |
+
if const_expr(swapAB):
|
| 170 |
+
return mma_make_fragment_B(smem, thr_mma)
|
| 171 |
+
else:
|
| 172 |
+
return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def mma_make_fragment_B(
|
| 176 |
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
| 177 |
+
) -> cute.Tensor:
|
| 178 |
+
if const_expr(swapAB):
|
| 179 |
+
return mma_make_fragment_A(smem, thr_mma)
|
| 180 |
+
else:
|
| 181 |
+
return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_smem_store_atom(
|
| 185 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 186 |
+
) -> cute.CopyAtom:
|
| 187 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 188 |
+
return cute.make_copy_atom(
|
| 189 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 190 |
+
element_type,
|
| 191 |
+
num_bits_per_copy=2 * element_type.width,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
return cute.make_copy_atom(
|
| 195 |
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 196 |
+
element_type,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@cute.jit
|
| 201 |
+
def warp_reduce(
|
| 202 |
+
val: cute.TensorSSA | cute.Numeric,
|
| 203 |
+
op: Callable,
|
| 204 |
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
| 205 |
+
) -> cute.TensorSSA | cute.Numeric:
|
| 206 |
+
if const_expr(isinstance(val, cute.TensorSSA)):
|
| 207 |
+
res = cute.make_fragment(val.shape, val.dtype)
|
| 208 |
+
res.store(val)
|
| 209 |
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
| 210 |
+
res[i] = warp_reduce(res[i], op, width)
|
| 211 |
+
return res.load()
|
| 212 |
+
else:
|
| 213 |
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
| 214 |
+
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
| 215 |
+
return val
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@dsl_user_op
|
| 219 |
+
def fmax(
|
| 220 |
+
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
| 221 |
+
) -> Float32:
|
| 222 |
+
from cutlass import CUDA_VERSION
|
| 223 |
+
|
| 224 |
+
# * NVVM call based on nvvm version
|
| 225 |
+
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
|
| 226 |
+
# Old API: requires explicit result type as first positional argument
|
| 227 |
+
return Float32(
|
| 228 |
+
nvvm.fmax(
|
| 229 |
+
T.f32(),
|
| 230 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 231 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 232 |
+
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
|
| 233 |
+
loc=loc,
|
| 234 |
+
ip=ip,
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
# New API: infers result type automatically
|
| 239 |
+
return Float32(
|
| 240 |
+
nvvm.fmax(
|
| 241 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 242 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 243 |
+
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
|
| 244 |
+
loc=loc,
|
| 245 |
+
ip=ip,
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@cute.jit
|
| 251 |
+
def fmax_reduce(
|
| 252 |
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
| 253 |
+
) -> Float32:
|
| 254 |
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
| 255 |
+
# if const_expr(init_val is None):
|
| 256 |
+
# init_val = -cutlass.Float32.if
|
| 257 |
+
# return x.reduce(cute.ReductionOp.MAX, init_val, 0)
|
| 258 |
+
res = cute.make_fragment(x.shape, Float32)
|
| 259 |
+
res.store(x)
|
| 260 |
+
# local_max = [res[0], res[1]]
|
| 261 |
+
# for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
|
| 262 |
+
# local_max[0] = fmax(local_max[0], res[i + 0])
|
| 263 |
+
# local_max[1] = fmax(local_max[1], res[i + 1])
|
| 264 |
+
# local_max[0] = fmax(local_max[0], local_max[1])
|
| 265 |
+
# return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
|
| 266 |
+
local_max = [res[0], res[1], res[2], res[3]]
|
| 267 |
+
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
|
| 268 |
+
local_max[0] = fmax(local_max[0], res[i + 0])
|
| 269 |
+
local_max[1] = fmax(local_max[1], res[i + 1])
|
| 270 |
+
local_max[2] = fmax(local_max[2], res[i + 2])
|
| 271 |
+
local_max[3] = fmax(local_max[3], res[i + 3])
|
| 272 |
+
local_max[0] = fmax(local_max[0], local_max[1])
|
| 273 |
+
local_max[2] = fmax(local_max[2], local_max[3])
|
| 274 |
+
local_max[0] = fmax(local_max[0], local_max[2])
|
| 275 |
+
return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
|
| 276 |
+
else:
|
| 277 |
+
# [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
|
| 278 |
+
# We instead force the 3-input max.
|
| 279 |
+
res = cute.make_fragment(x.shape, Float32)
|
| 280 |
+
res.store(x)
|
| 281 |
+
local_max_0 = (
|
| 282 |
+
fmax(init_val, res[0], res[1])
|
| 283 |
+
if const_expr(init_val is not None)
|
| 284 |
+
else fmax(res[0], res[1])
|
| 285 |
+
)
|
| 286 |
+
local_max = [
|
| 287 |
+
local_max_0,
|
| 288 |
+
fmax(res[2], res[3]),
|
| 289 |
+
fmax(res[4], res[5]),
|
| 290 |
+
fmax(res[6], res[7]),
|
| 291 |
+
]
|
| 292 |
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
| 293 |
+
local_max[0] = fmax(local_max[0], res[i], res[i + 1])
|
| 294 |
+
local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
|
| 295 |
+
local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
|
| 296 |
+
local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
|
| 297 |
+
local_max[0] = fmax(local_max[0], local_max[1])
|
| 298 |
+
return fmax(local_max[0], local_max[2], local_max[3])
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@cute.jit
|
| 302 |
+
def fadd_reduce(
|
| 303 |
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
| 304 |
+
) -> Float32:
|
| 305 |
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
| 306 |
+
if const_expr(init_val is None):
|
| 307 |
+
init_val = Float32.zero
|
| 308 |
+
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
|
| 309 |
+
# res = cute.make_fragment(x.shape, Float32)
|
| 310 |
+
# res.store(x)
|
| 311 |
+
# local_sum = [res[0], res[1], res[2], res[3]]
|
| 312 |
+
# for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
|
| 313 |
+
# local_sum[0] += res[i + 0]
|
| 314 |
+
# local_sum[1] += res[i + 1]
|
| 315 |
+
# local_sum[2] += res[i + 2]
|
| 316 |
+
# local_sum[3] += res[i + 3]
|
| 317 |
+
# local_sum[0] += local_sum[1]
|
| 318 |
+
# local_sum[2] += local_sum[3]
|
| 319 |
+
# local_sum[0] += local_sum[2]
|
| 320 |
+
# return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
|
| 321 |
+
else:
|
| 322 |
+
res = cute.make_fragment(x.shape, Float32)
|
| 323 |
+
res.store(x)
|
| 324 |
+
local_sum_0 = (
|
| 325 |
+
cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
|
| 326 |
+
# cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
|
| 327 |
+
if const_expr(init_val is not None)
|
| 328 |
+
else (res[0], res[1])
|
| 329 |
+
)
|
| 330 |
+
local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
|
| 331 |
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
| 332 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
|
| 333 |
+
local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
|
| 334 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
|
| 335 |
+
local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
|
| 336 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
|
| 337 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
|
| 338 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
|
| 339 |
+
return local_sum[0][0] + local_sum[0][1]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@dsl_user_op
|
| 343 |
+
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
|
| 344 |
+
# gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 345 |
+
# # cache_hint = cutlass.Int64(0x12F0000000000000)
|
| 346 |
+
# llvm.inline_asm(
|
| 347 |
+
# None,
|
| 348 |
+
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
|
| 349 |
+
# # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
|
| 350 |
+
# "red.global.add.f32 [$0], $1;",
|
| 351 |
+
# # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
|
| 352 |
+
# # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
|
| 353 |
+
# "l,f",
|
| 354 |
+
# # "l,f,l",
|
| 355 |
+
# has_side_effects=True,
|
| 356 |
+
# is_align_stack=False,
|
| 357 |
+
# asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 358 |
+
# )
|
| 359 |
+
nvvm.atomicrmw(
|
| 360 |
+
res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@dsl_user_op
|
| 365 |
+
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
| 366 |
+
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@cute.jit
|
| 370 |
+
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
| 371 |
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 372 |
+
tApA = cute.make_fragment(
|
| 373 |
+
cute.make_layout(
|
| 374 |
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 375 |
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
| 376 |
+
),
|
| 377 |
+
cutlass.Boolean,
|
| 378 |
+
)
|
| 379 |
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
| 380 |
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
| 381 |
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
| 382 |
+
return tApA
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
|
| 386 |
+
warp_group_idx = cute.arch.thread_idx()[0] // 128
|
| 387 |
+
if const_expr(sync):
|
| 388 |
+
warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
|
| 389 |
+
return warp_group_idx
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# @dsl_user_op
|
| 393 |
+
# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
|
| 394 |
+
# mask = cutlass.Int32(-1)
|
| 395 |
+
# return cutlass.Boolean(
|
| 396 |
+
# llvm.inline_asm(
|
| 397 |
+
# T.i32(),
|
| 398 |
+
# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
|
| 399 |
+
# ".pred p1, p2;\n"
|
| 400 |
+
# "setp.lt.f32 p1, $1, $2;\n"
|
| 401 |
+
# "vote.sync.any.pred p2, p1, $3;\n"
|
| 402 |
+
# "selp.u32 $0, 1, 0, p2;",
|
| 403 |
+
# # "selp.u32 $0, 1, 0, p1;",
|
| 404 |
+
# "=r,f,f,r",
|
| 405 |
+
# has_side_effects=False,
|
| 406 |
+
# is_align_stack=False,
|
| 407 |
+
# asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 408 |
+
# )
|
| 409 |
+
# )
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@cute.jit
|
| 413 |
+
def shuffle_sync(
|
| 414 |
+
value: cute.Numeric,
|
| 415 |
+
offset: cute.typing.Int,
|
| 416 |
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
| 417 |
+
) -> cute.Numeric:
|
| 418 |
+
assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
|
| 419 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 420 |
+
mask = cute.arch.WARP_SIZE - width
|
| 421 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 422 |
+
mask_and_clamp = mask << 8 | clamp
|
| 423 |
+
# important: need stride 1 and not 0 for recast_tensor to work
|
| 424 |
+
val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
|
| 425 |
+
val[0] = value
|
| 426 |
+
val_i32 = cute.recast_tensor(val, cutlass.Int32)
|
| 427 |
+
for i in cutlass.range_constexpr(cute.size(val_i32)):
|
| 428 |
+
val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
|
| 429 |
+
return val[0]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@dsl_user_op
|
| 433 |
+
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
| 434 |
+
return cutlass.Uint32(
|
| 435 |
+
llvm.inline_asm(
|
| 436 |
+
T.i32(),
|
| 437 |
+
[
|
| 438 |
+
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 439 |
+
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 440 |
+
],
|
| 441 |
+
"shr.s32 $0, $1, $2;",
|
| 442 |
+
"=r,r,r",
|
| 443 |
+
has_side_effects=False,
|
| 444 |
+
is_align_stack=False,
|
| 445 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@cute.jit
|
| 451 |
+
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
| 452 |
+
if const_expr(lane is None):
|
| 453 |
+
lane = cute.arch.lane_idx()
|
| 454 |
+
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
|
| 455 |
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
| 456 |
+
offset = 1 << i
|
| 457 |
+
# Very important that we set mask_and_clamp to 0
|
| 458 |
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
| 459 |
+
if lane >= offset:
|
| 460 |
+
val += partial_sum
|
| 461 |
+
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
|
| 462 |
+
return val
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
@dsl_user_op
|
| 466 |
+
def cvt_f16x2_f32(
|
| 467 |
+
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
|
| 468 |
+
) -> cutlass.Int32:
|
| 469 |
+
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
|
| 470 |
+
return cutlass.Int32(
|
| 471 |
+
llvm.inline_asm(
|
| 472 |
+
T.i32(),
|
| 473 |
+
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
|
| 474 |
+
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
|
| 475 |
+
"=r,f,f",
|
| 476 |
+
has_side_effects=False,
|
| 477 |
+
is_align_stack=False,
|
| 478 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@overload
|
| 484 |
+
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@overload
|
| 488 |
+
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@cute.jit
|
| 492 |
+
def cvt_f16(src: cute.Tensor, dst_or_dtype):
|
| 493 |
+
"""Convert Float32 tensor to Float16/BFloat16.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
src: Source tensor with Float32 element type
|
| 497 |
+
dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
None if dst is a tensor, or a new tensor if dtype is provided
|
| 501 |
+
"""
|
| 502 |
+
if const_expr(isinstance(dst_or_dtype, type)):
|
| 503 |
+
# dtype variant: create new tensor and call the tensor variant
|
| 504 |
+
dtype = dst_or_dtype
|
| 505 |
+
dst = cute.make_fragment(src.shape, dtype)
|
| 506 |
+
cvt_f16(src, dst)
|
| 507 |
+
return dst
|
| 508 |
+
else:
|
| 509 |
+
# tensor variant: write to dst
|
| 510 |
+
dst = dst_or_dtype
|
| 511 |
+
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
|
| 512 |
+
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
|
| 513 |
+
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
|
| 514 |
+
"dst must be BFloat16 or Float16"
|
| 515 |
+
)
|
| 516 |
+
assert src.element_type is Float32, "src must be Float32"
|
| 517 |
+
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
|
| 518 |
+
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
|
| 519 |
+
for i in cutlass.range_constexpr(cute.size(dst_i32)):
|
| 520 |
+
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@dsl_user_op
|
| 524 |
+
@cute.jit
|
| 525 |
+
def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
|
| 526 |
+
deg = len(poly) - 1
|
| 527 |
+
out = poly[deg]
|
| 528 |
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
| 529 |
+
out = out * x + poly[i]
|
| 530 |
+
return out
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@dsl_user_op
|
| 534 |
+
@cute.jit
|
| 535 |
+
def evaluate_polynomial_2(
|
| 536 |
+
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
|
| 537 |
+
) -> Tuple[Float32, Float32]:
|
| 538 |
+
deg = len(poly) - 1
|
| 539 |
+
out = (poly[deg], poly[deg])
|
| 540 |
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
| 541 |
+
out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
|
| 542 |
+
return out
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
@dsl_user_op
|
| 546 |
+
def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 547 |
+
# There's probably a way to call llvm or nvvm to do this instead of ptx
|
| 548 |
+
return cutlass.Float32(
|
| 549 |
+
llvm.inline_asm(
|
| 550 |
+
T.f32(),
|
| 551 |
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
|
| 552 |
+
"add.rm.ftz.f32 $0, $1, $2;",
|
| 553 |
+
"=f,f,f",
|
| 554 |
+
has_side_effects=False,
|
| 555 |
+
is_align_stack=False,
|
| 556 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 557 |
+
)
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@dsl_user_op
|
| 562 |
+
def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
|
| 563 |
+
return cutlass.Float32(
|
| 564 |
+
llvm.inline_asm(
|
| 565 |
+
T.f32(),
|
| 566 |
+
[
|
| 567 |
+
Float32(x_rounded).ir_value(loc=loc, ip=ip),
|
| 568 |
+
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
|
| 569 |
+
],
|
| 570 |
+
"{\n\t"
|
| 571 |
+
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
|
| 572 |
+
"mov.b32 x_rounded_i, $1;\n\t"
|
| 573 |
+
"mov.b32 frac_ex_i, $2;\n\t"
|
| 574 |
+
"shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
|
| 575 |
+
# add.u32 generates IMAD instruction and add.s32 generates LEA instruction
|
| 576 |
+
# IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
|
| 577 |
+
"add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
|
| 578 |
+
"mov.b32 $0, out_i;\n\t"
|
| 579 |
+
"}\n",
|
| 580 |
+
"=f,f,f",
|
| 581 |
+
has_side_effects=False,
|
| 582 |
+
is_align_stack=False,
|
| 583 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 584 |
+
)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@dsl_user_op
|
| 589 |
+
def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
|
| 590 |
+
assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
|
| 591 |
+
# We assume x <= 127.0
|
| 592 |
+
fp32_round_int = float(2**23 + 2**22)
|
| 593 |
+
x_clamped = cute.arch.fmax(x, -127.0)
|
| 594 |
+
# We want to round down here, so that the fractional part is in [0, 1)
|
| 595 |
+
x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
|
| 596 |
+
# The integer floor of x is now in the last 8 bits of x_rounded
|
| 597 |
+
# We assume the next 2 ops round to nearest even. The rounding mode is important.
|
| 598 |
+
x_rounded_back = x_rounded - fp32_round_int
|
| 599 |
+
x_frac = x_clamped - x_rounded_back
|
| 600 |
+
x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
|
| 601 |
+
return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
|
| 605 |
+
@dsl_user_op
|
| 606 |
+
def ex2_emulation_2(
|
| 607 |
+
x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
|
| 608 |
+
) -> Tuple[Float32, Float32]:
|
| 609 |
+
# We assume x <= 127.0 and y <= 127.0
|
| 610 |
+
fp32_round_int = float(2**23 + 2**22)
|
| 611 |
+
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
|
| 612 |
+
# We want to round down here, so that the fractional part is in [0, 1)
|
| 613 |
+
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
|
| 614 |
+
# The integer floor of x & y are now in the last 8 bits of xy_rounded
|
| 615 |
+
# We want the next 2 ops to round to nearest even. The rounding mode is important.
|
| 616 |
+
xy_rounded_back = activation.sub_packed_f32x2(
|
| 617 |
+
xy_rounded, (fp32_round_int, fp32_round_int)
|
| 618 |
+
)
|
| 619 |
+
xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
|
| 620 |
+
xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
|
| 621 |
+
x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
|
| 622 |
+
y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
|
| 623 |
+
return x_out, y_out
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
@dsl_user_op
|
| 627 |
+
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
| 628 |
+
out_f32x2 = llvm.inline_asm(
|
| 629 |
+
llvm.StructType.get_literal([T.f32(), T.f32()]),
|
| 630 |
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
|
| 631 |
+
"{\n\t"
|
| 632 |
+
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
|
| 633 |
+
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
|
| 634 |
+
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
|
| 635 |
+
"max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
|
| 636 |
+
"max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
|
| 637 |
+
"mov.b64 l1, {f1, f2};\n\t"
|
| 638 |
+
"mov.f32 f3, 0f4B400000;\n\t"
|
| 639 |
+
"mov.b64 l2, {f3, f3};\n\t"
|
| 640 |
+
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
|
| 641 |
+
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
|
| 642 |
+
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
|
| 643 |
+
"mov.f32 f7, 0f3D9DF09D;\n\t"
|
| 644 |
+
"mov.b64 l6, {f7, f7};\n\t"
|
| 645 |
+
"mov.f32 f6, 0f3E6906A4;\n\t"
|
| 646 |
+
"mov.b64 l5, {f6, f6};\n\t"
|
| 647 |
+
"mov.f32 f5, 0f3F31F519;\n\t"
|
| 648 |
+
"mov.b64 l4, {f5, f5};\n\t"
|
| 649 |
+
"mov.f32 f4, 0f3F800000;\n\t"
|
| 650 |
+
"mov.b64 l3, {f4, f4};\n\t"
|
| 651 |
+
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
|
| 652 |
+
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
|
| 653 |
+
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
|
| 654 |
+
"mov.b64 {r1, r2}, l7;\n\t"
|
| 655 |
+
"mov.b64 {r3, r4}, l10;\n\t"
|
| 656 |
+
"shl.b32 r5, r1, 23;\n\t"
|
| 657 |
+
"add.s32 r7, r5, r3;\n\t"
|
| 658 |
+
"shl.b32 r6, r2, 23;\n\t"
|
| 659 |
+
"add.s32 r8, r6, r4;\n\t"
|
| 660 |
+
"mov.b32 $0, r7;\n\t"
|
| 661 |
+
"mov.b32 $1, r8;\n\t"
|
| 662 |
+
"}\n",
|
| 663 |
+
"=r,=r,f,f",
|
| 664 |
+
has_side_effects=False,
|
| 665 |
+
is_align_stack=False,
|
| 666 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 667 |
+
)
|
| 668 |
+
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
|
| 669 |
+
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
|
| 670 |
+
return out0, out1
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@dsl_user_op
|
| 674 |
+
def domain_offset_aligned(
|
| 675 |
+
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
|
| 676 |
+
) -> cute.Tensor:
|
| 677 |
+
assert isinstance(tensor.iterator, cute.Pointer)
|
| 678 |
+
# We assume that applying the offset does not change the pointer alignment
|
| 679 |
+
new_ptr = cute.make_ptr(
|
| 680 |
+
tensor.element_type,
|
| 681 |
+
elem_pointer(tensor, coord).toint(),
|
| 682 |
+
tensor.memspace,
|
| 683 |
+
assumed_align=tensor.iterator.alignment,
|
| 684 |
+
)
|
| 685 |
+
return cute.make_tensor(new_ptr, tensor.layout)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
@cute.jit
|
| 689 |
+
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
|
| 690 |
+
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
|
| 691 |
+
vec = cute.make_fragment(1, dtype)
|
| 692 |
+
vec[0] = a
|
| 693 |
+
return vec.load()
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def ssa_to_scalar(val):
|
| 697 |
+
"""Could inline but nice for reflecting the above api"""
|
| 698 |
+
return val[0]
|