Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 25/32).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/__init__.py +5 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_k_grouped_gemm.py +88 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_m_grouped_gemm.py +92 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/b_fused_k_grouped_gemm.py +86 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/m_grouped_gemm.py +84 -0
- build/torch210-cxx11-cu128-x86_64-linux/legacy/tune_options.py +28 -0
- build/torch210-cxx11-cu128-x86_64-linux/mega/__init__.py +130 -0
- build/torch210-cxx11-cu128-x86_64-linux/metadata.json +3 -1
- build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py +13 -4
- build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py +1 -0
- build/torch210-cxx11-cu128-x86_64-linux/utils/dist.py +74 -0
- build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py +19 -23
- build/torch210-cxx11-cu128-x86_64-linux/utils/math.py +51 -15
- build/torch210-cxx11-cu130-x86_64-linux/_C.py +194 -0
- build/torch210-cxx11-cu130-x86_64-linux/__init__.py +152 -19
- build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so +3 -0
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/comm/barrier.cuh +83 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/compile.cuh +18 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh +2 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/exception.cuh +43 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/math.cuh +153 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/tma_copy.cuh +92 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/types.cuh +43 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh +16 -149
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh +137 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh +144 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/transform.cuh +24 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +150 -195
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +31 -25
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh +457 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh +510 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh +514 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh +1380 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +9 -5
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +125 -126
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +205 -164
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +35 -30
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +47 -40
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +37 -28
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +79 -82
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +55 -46
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +64 -63
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +76 -155
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +36 -29
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +30 -23
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh +34 -21
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh +260 -0
- build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh +41 -0
.gitattributes
CHANGED
|
@@ -47,3 +47,4 @@ build/torch211-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=
|
|
| 47 |
build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 48 |
build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 49 |
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 47 |
build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 48 |
build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 49 |
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# All kernels may be deprecated in the future (or rewrite in TileLang)
|
| 2 |
+
from .m_grouped_gemm import *
|
| 3 |
+
from .a_fused_m_grouped_gemm import *
|
| 4 |
+
from .a_fused_k_grouped_gemm import *
|
| 5 |
+
from .b_fused_k_grouped_gemm import *
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_k_grouped_gemm.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from .tune_options import *
|
| 7 |
+
from .._C import get_mk_alignment_for_contiguous_layout
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr'])
|
| 11 |
+
@triton.jit
|
| 12 |
+
def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
|
| 13 |
+
k_indices_ptr, k_start_ptr, k_end_ptr,
|
| 14 |
+
M: tl.constexpr,
|
| 15 |
+
N: tl.constexpr,
|
| 16 |
+
K,
|
| 17 |
+
ACC: tl.constexpr,
|
| 18 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 19 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 20 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 21 |
+
GROUP_SIZE_M: tl.constexpr):
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 24 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 25 |
+
pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64)
|
| 26 |
+
pid = pid % (num_pid_m * num_pid_n)
|
| 27 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 28 |
+
group_id = pid // num_pid_in_group
|
| 29 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 30 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 31 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 32 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 33 |
+
|
| 34 |
+
m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 35 |
+
n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 36 |
+
m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
| 37 |
+
n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
| 38 |
+
m_mask = (m_range < M)[:, None]
|
| 39 |
+
n_mask = (n_range < N)[None, :]
|
| 40 |
+
|
| 41 |
+
k_start = tl.load(k_start_ptr + pid_b)
|
| 42 |
+
k_end = tl.load(k_end_ptr + pid_b)
|
| 43 |
+
if k_start >= k_end:
|
| 44 |
+
if not ACC:
|
| 45 |
+
d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 46 |
+
tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask)
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
# Compute
|
| 50 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 51 |
+
for k in range(k_start, k_end, BLOCK_SIZE_K):
|
| 52 |
+
k_range = k + tl.arange(0, BLOCK_SIZE_K)
|
| 53 |
+
rows = tl.load(k_indices_ptr + k_range).to(tl.int64)
|
| 54 |
+
a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M
|
| 55 |
+
|
| 56 |
+
b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 57 |
+
a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0)
|
| 58 |
+
b = tl.load(b_ptrs, mask=n_mask, other=0)
|
| 59 |
+
acc = tl.dot(a, b, acc)
|
| 60 |
+
|
| 61 |
+
# Write back
|
| 62 |
+
d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 63 |
+
if ACC:
|
| 64 |
+
acc += tl.load(d_ptrs, mask=m_mask & n_mask)
|
| 65 |
+
acc = acc.to(d_ptr.dtype.element_ty)
|
| 66 |
+
tl.store(d_ptrs, acc, mask=m_mask & n_mask)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 70 |
+
handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool):
|
| 71 |
+
k_indices, k_start, k_end = handle
|
| 72 |
+
|
| 73 |
+
assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous()
|
| 74 |
+
assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous()
|
| 75 |
+
assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16
|
| 76 |
+
assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32
|
| 77 |
+
assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3
|
| 78 |
+
assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0)
|
| 79 |
+
assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1)
|
| 80 |
+
assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0
|
| 81 |
+
|
| 82 |
+
K_, M = a.shape
|
| 83 |
+
K, N = b.shape
|
| 84 |
+
B = k_start.numel()
|
| 85 |
+
|
| 86 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,)
|
| 87 |
+
a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](
|
| 88 |
+
a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc)
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_m_grouped_gemm.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from .tune_options import *
|
| 7 |
+
from .._C import get_mk_alignment_for_contiguous_layout
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[])
|
| 11 |
+
@triton.jit
|
| 12 |
+
def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
|
| 13 |
+
m_indices_ptr, m_row_indices_ptr,
|
| 14 |
+
M,
|
| 15 |
+
N: tl.constexpr,
|
| 16 |
+
K: tl.constexpr,
|
| 17 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 18 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 19 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 20 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 21 |
+
IS_B_K_MAJOR: tl.constexpr):
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 24 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 25 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 26 |
+
group_id = pid // num_pid_in_group
|
| 27 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 28 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 29 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 30 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 31 |
+
|
| 32 |
+
m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 33 |
+
n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 34 |
+
m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
| 35 |
+
n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
| 36 |
+
n_mask = (n_range < N)[None, :]
|
| 37 |
+
|
| 38 |
+
batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64)
|
| 39 |
+
if batch_id < 0:
|
| 40 |
+
d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 41 |
+
tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask)
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
# b block
|
| 45 |
+
rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64)
|
| 46 |
+
|
| 47 |
+
# Compute
|
| 48 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 49 |
+
for k in range(0, K, BLOCK_SIZE_K):
|
| 50 |
+
k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
| 51 |
+
k_mask = k_range < K
|
| 52 |
+
a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :]
|
| 53 |
+
b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1)
|
| 54 |
+
a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0)
|
| 55 |
+
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0)
|
| 56 |
+
acc = tl.dot(a, b, acc)
|
| 57 |
+
d = acc.to(d_ptr.dtype.element_ty)
|
| 58 |
+
|
| 59 |
+
# Write back
|
| 60 |
+
d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 61 |
+
tl.store(d_ptrs, d, mask=n_mask)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 65 |
+
mappings: Tuple[torch.Tensor, torch.Tensor]):
|
| 66 |
+
m_indices, m_row_indices = mappings
|
| 67 |
+
r0, r1, r2 = b.shape
|
| 68 |
+
|
| 69 |
+
assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous()
|
| 70 |
+
assert m_indices.is_contiguous() and m_row_indices.is_contiguous()
|
| 71 |
+
assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16
|
| 72 |
+
assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32
|
| 73 |
+
assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2
|
| 74 |
+
assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1
|
| 75 |
+
assert m_indices.numel() == m_row_indices.numel()
|
| 76 |
+
assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0
|
| 77 |
+
|
| 78 |
+
if d.size(0) == 0:
|
| 79 |
+
return d
|
| 80 |
+
|
| 81 |
+
M_, K = a.shape
|
| 82 |
+
B, K, N = r0, r2, r1
|
| 83 |
+
M = m_indices.numel()
|
| 84 |
+
|
| 85 |
+
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), )
|
| 86 |
+
a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices,
|
| 87 |
+
M, N, K, IS_B_K_MAJOR=b.is_contiguous())
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 91 |
+
mappings: Tuple[torch.Tensor, torch.Tensor]):
|
| 92 |
+
a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings)
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/b_fused_k_grouped_gemm.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from .tune_options import *
|
| 7 |
+
from .._C import get_mk_alignment_for_contiguous_layout
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr'])
|
| 11 |
+
@triton.jit
|
| 12 |
+
def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
|
| 13 |
+
k_indices_ptr, k_start_ptr, k_end_ptr,
|
| 14 |
+
M: tl.constexpr,
|
| 15 |
+
N: tl.constexpr,
|
| 16 |
+
K,
|
| 17 |
+
ACC: tl.constexpr,
|
| 18 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 19 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 20 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 21 |
+
GROUP_SIZE_M: tl.constexpr):
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 24 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 25 |
+
pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64)
|
| 26 |
+
pid = pid % (num_pid_m * num_pid_n)
|
| 27 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 28 |
+
group_id = pid // num_pid_in_group
|
| 29 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 30 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 31 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 32 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 33 |
+
|
| 34 |
+
k_start = tl.load(k_start_ptr + pid_b)
|
| 35 |
+
k_end = tl.load(k_end_ptr + pid_b)
|
| 36 |
+
|
| 37 |
+
m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 38 |
+
n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 39 |
+
m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
| 40 |
+
n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
| 41 |
+
m_mask = (m_range < M)[:, None]
|
| 42 |
+
n_mask = (n_range < N)[None, :]
|
| 43 |
+
|
| 44 |
+
if k_start >= k_end:
|
| 45 |
+
if not ACC:
|
| 46 |
+
d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 47 |
+
tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask)
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
# Compute
|
| 51 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 52 |
+
for k in range(k_start, k_end, BLOCK_SIZE_K):
|
| 53 |
+
k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
| 54 |
+
rows = tl.load(k_indices_ptr + k_range).to(tl.int64)
|
| 55 |
+
a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M
|
| 56 |
+
b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :]
|
| 57 |
+
a = tl.load(a_ptrs, mask=m_mask, other=0.0)
|
| 58 |
+
b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0)
|
| 59 |
+
acc = tl.dot(a, b, acc)
|
| 60 |
+
|
| 61 |
+
d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 62 |
+
if ACC:
|
| 63 |
+
acc += tl.load(d_ptrs, mask=m_mask & n_mask)
|
| 64 |
+
acc = acc.to(d_ptr.dtype.element_ty)
|
| 65 |
+
tl.store(d_ptrs, acc, mask=m_mask & n_mask)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 69 |
+
handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool):
|
| 70 |
+
k_indices, k_start, k_end = handle
|
| 71 |
+
|
| 72 |
+
assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous()
|
| 73 |
+
assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous()
|
| 74 |
+
assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16
|
| 75 |
+
assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32
|
| 76 |
+
assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3
|
| 77 |
+
assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0)
|
| 78 |
+
assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1)
|
| 79 |
+
assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0
|
| 80 |
+
|
| 81 |
+
K, M = a.shape
|
| 82 |
+
K_, N = b.shape
|
| 83 |
+
B = k_start.numel()
|
| 84 |
+
|
| 85 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,)
|
| 86 |
+
b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc)
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/m_grouped_gemm.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from .tune_options import *
|
| 7 |
+
from .._C import get_mk_alignment_for_contiguous_layout
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[])
|
| 11 |
+
@triton.jit
|
| 12 |
+
def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
|
| 13 |
+
m_indices_ptr,
|
| 14 |
+
M,
|
| 15 |
+
N: tl.constexpr,
|
| 16 |
+
K: tl.constexpr,
|
| 17 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 18 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 19 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 20 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 21 |
+
IS_B_K_MAJOR: tl.constexpr):
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 24 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 25 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 26 |
+
group_id = pid // num_pid_in_group
|
| 27 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 28 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 29 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 30 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 31 |
+
m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 32 |
+
n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 33 |
+
n_mask = (n_range < N)[None, :]
|
| 34 |
+
|
| 35 |
+
# Empty tokens
|
| 36 |
+
batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64)
|
| 37 |
+
if batch_id < 0:
|
| 38 |
+
d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 39 |
+
tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask)
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# Compute
|
| 43 |
+
a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :]
|
| 44 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 45 |
+
b_ptrs = b_ptr + batch_id * K * N + \
|
| 46 |
+
tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \
|
| 47 |
+
n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1)
|
| 48 |
+
for k in range(0, K, BLOCK_SIZE_K):
|
| 49 |
+
k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K
|
| 50 |
+
a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0)
|
| 51 |
+
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0)
|
| 52 |
+
accumulator = tl.dot(a, b, accumulator)
|
| 53 |
+
a_ptrs += BLOCK_SIZE_K
|
| 54 |
+
b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N)
|
| 55 |
+
|
| 56 |
+
# Write back
|
| 57 |
+
d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :]
|
| 58 |
+
tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 62 |
+
m_indices: torch.Tensor):
|
| 63 |
+
r0, r1, r2 = b.shape
|
| 64 |
+
|
| 65 |
+
assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous())
|
| 66 |
+
assert m_indices.is_contiguous() and d.is_contiguous()
|
| 67 |
+
assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16
|
| 68 |
+
assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16
|
| 69 |
+
assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2
|
| 70 |
+
assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1)
|
| 71 |
+
assert m_indices.numel() == a.size(0)
|
| 72 |
+
assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0
|
| 73 |
+
M, K = a.shape
|
| 74 |
+
B, N, K_ = r0, r1, r2
|
| 75 |
+
|
| 76 |
+
# For Triton 2.0, persistent kernel will lead to errors
|
| 77 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
| 78 |
+
m_grouped_bf16_gemm_contiguous_tl_impl[grid](
|
| 79 |
+
a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor,
|
| 83 |
+
m_indices: torch.Tensor):
|
| 84 |
+
m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices)
|
build/torch210-cxx11-cu128-x86_64-linux/legacy/tune_options.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from triton import Config
|
| 2 |
+
from .._C import get_mk_alignment_for_contiguous_layout
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_config_smem_size(config: Config, elem_bytes: int = 2):
|
| 6 |
+
# NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels
|
| 7 |
+
return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_gemm_configs = [
|
| 11 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
| 12 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 13 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
| 14 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 15 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
| 16 |
+
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
|
| 17 |
+
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
| 18 |
+
Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
|
| 19 |
+
Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere
|
| 23 |
+
_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs))
|
| 24 |
+
_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs))
|
| 25 |
+
_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs))
|
| 26 |
+
|
| 27 |
+
get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs))
|
| 28 |
+
get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs))
|
build/torch210-cxx11-cu128-x86_64-linux/mega/__init__.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
from ..utils.math import align
|
| 6 |
+
|
| 7 |
+
# noinspection PyBroadException
|
| 8 |
+
try:
|
| 9 |
+
# noinspection PyProtectedMember
|
| 10 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
except Exception as exception:
|
| 13 |
+
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
|
| 14 |
+
|
| 15 |
+
from .. import _C
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SymmBuffer:
|
| 19 |
+
def __init__(self, group: dist.ProcessGroup,
|
| 20 |
+
# MoE arguments
|
| 21 |
+
num_experts: int,
|
| 22 |
+
num_max_tokens_per_rank: int, num_topk: int,
|
| 23 |
+
hidden: int, intermediate_hidden: int,
|
| 24 |
+
use_fp8_dispatch: bool = True,
|
| 25 |
+
activation: str = 'swiglu'):
|
| 26 |
+
self.group = group
|
| 27 |
+
self.num_experts = num_experts
|
| 28 |
+
self.num_max_tokens_per_rank = num_max_tokens_per_rank
|
| 29 |
+
self.num_topk = num_topk
|
| 30 |
+
self.hidden = hidden
|
| 31 |
+
self.intermediate_hidden = intermediate_hidden
|
| 32 |
+
|
| 33 |
+
# Allocate a symmetric buffer
|
| 34 |
+
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
|
| 35 |
+
group.size(), num_experts,
|
| 36 |
+
num_max_tokens_per_rank, num_topk,
|
| 37 |
+
hidden, intermediate_hidden,
|
| 38 |
+
use_fp8_dispatch, activation
|
| 39 |
+
)
|
| 40 |
+
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
| 41 |
+
self.handle = symm_mem.rendezvous(self.buffer, group=group)
|
| 42 |
+
self.buffer.zero_()
|
| 43 |
+
self.group.barrier()
|
| 44 |
+
torch.cuda.synchronize()
|
| 45 |
+
|
| 46 |
+
# Create input buffer views
|
| 47 |
+
(self.x, self.x_sf,
|
| 48 |
+
self.topk_idx, self.topk_weights,
|
| 49 |
+
self.l1_acts, self.l1_acts_sf,
|
| 50 |
+
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
|
| 51 |
+
|
| 52 |
+
def destroy(self):
|
| 53 |
+
self.handle = None
|
| 54 |
+
self.buffer = None
|
| 55 |
+
self.group = None
|
| 56 |
+
self.x = None
|
| 57 |
+
self.x_sf = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
|
| 61 |
+
num_experts: int,
|
| 62 |
+
num_max_tokens_per_rank: int, num_topk: int,
|
| 63 |
+
hidden: int, intermediate_hidden: int,
|
| 64 |
+
use_fp8_dispatch: bool = True,
|
| 65 |
+
activation: str = 'swiglu') -> SymmBuffer:
|
| 66 |
+
# Token count must be aligned to block sizes
|
| 67 |
+
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
|
| 68 |
+
|
| 69 |
+
return SymmBuffer(
|
| 70 |
+
group, num_experts,
|
| 71 |
+
num_max_tokens_per_rank, num_topk,
|
| 72 |
+
hidden, intermediate_hidden,
|
| 73 |
+
use_fp8_dispatch, activation
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 78 |
+
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
|
| 79 |
+
def interleave(t, gran: int = 8) -> torch.Tensor:
|
| 80 |
+
g, n, *rest = t.shape
|
| 81 |
+
half = n // 2
|
| 82 |
+
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
|
| 83 |
+
up = t[:, half:].reshape(g, half // gran, gran, *rest)
|
| 84 |
+
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
|
| 85 |
+
|
| 86 |
+
return interleave(l1_weights[0]), interleave(l1_weights[1])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
num_groups, mn, packed_sf_k = sf.shape
|
| 91 |
+
assert sf.dtype == torch.int and mn % 128 == 0
|
| 92 |
+
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
| 93 |
+
.transpose(2, 3)
|
| 94 |
+
.reshape(num_groups, mn, packed_sf_k))
|
| 95 |
+
return torch.empty_like(sf).copy_(result)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def transform_weights_for_mega_moe(
|
| 99 |
+
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
| 100 |
+
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
| 101 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
| 102 |
+
# L1: interleave gate/up, then transpose SF for UTCCP
|
| 103 |
+
l1_interleaved = _interleave_l1_weights(l1_weights)
|
| 104 |
+
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
|
| 105 |
+
# L2: only transpose SF for UTCCP
|
| 106 |
+
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
|
| 107 |
+
return l1_weights, l2_weights
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def fp8_fp4_mega_moe(y: torch.Tensor,
|
| 111 |
+
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
| 112 |
+
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
| 113 |
+
sym_buffer: SymmBuffer,
|
| 114 |
+
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
| 115 |
+
recipe: Tuple[int, int, int] = (1, 1, 32),
|
| 116 |
+
activation: str = 'swiglu',
|
| 117 |
+
activation_clamp: Optional[float] = None,
|
| 118 |
+
fast_math: bool = True):
|
| 119 |
+
_C.fp8_fp4_mega_moe(
|
| 120 |
+
y,
|
| 121 |
+
l1_weights, l2_weights,
|
| 122 |
+
cumulative_local_expert_recv_stats,
|
| 123 |
+
sym_buffer.buffer,
|
| 124 |
+
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
| 125 |
+
sym_buffer.num_max_tokens_per_rank,
|
| 126 |
+
sym_buffer.num_experts, sym_buffer.num_topk,
|
| 127 |
+
recipe,
|
| 128 |
+
activation, activation_clamp,
|
| 129 |
+
fast_math
|
| 130 |
+
)
|
build/torch210-cxx11-cu128-x86_64-linux/metadata.json
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
{
|
| 2 |
-
"
|
|
|
|
|
|
|
| 3 |
"license": "MIT",
|
| 4 |
"python-depends": [],
|
| 5 |
"backend": {
|
|
|
|
| 1 |
{
|
| 2 |
+
"name": "deep-gemm",
|
| 3 |
+
"id": "_deep_gemm_cuda_388adb9",
|
| 4 |
+
"version": 2,
|
| 5 |
"license": "MIT",
|
| 6 |
"python-depends": [],
|
| 7 |
"backend": {
|
build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import torch
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
|
@@ -78,7 +79,8 @@ class suppress_stdout_stderr:
|
|
| 78 |
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 79 |
suppress_kineto_output: bool = False,
|
| 80 |
trace_path: str = None, flush_l2: bool = True,
|
| 81 |
-
with_multiple_kernels: bool = False
|
|
|
|
| 82 |
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 83 |
is_tuple = isinstance(kernel_names, tuple)
|
| 84 |
|
|
@@ -96,14 +98,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
|
| 96 |
# Profile
|
| 97 |
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 98 |
with suppress():
|
| 99 |
-
schedule = torch.profiler.schedule(wait=
|
| 100 |
-
profiler = torch.profiler.profile(
|
|
|
|
| 101 |
with profiler:
|
| 102 |
for i in range(2):
|
| 103 |
for _ in range(num_tests):
|
| 104 |
if flush_l2:
|
| 105 |
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
fn()
|
|
|
|
| 107 |
profiler.step()
|
| 108 |
|
| 109 |
# Parse the profiling table
|
|
@@ -111,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
|
| 111 |
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 112 |
if not with_multiple_kernels:
|
| 113 |
for name in kernel_names:
|
| 114 |
-
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
| 115 |
|
| 116 |
# Save chrome traces
|
| 117 |
if trace_path is not None:
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
|
| 6 |
|
| 7 |
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
|
|
|
| 79 |
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 80 |
suppress_kineto_output: bool = False,
|
| 81 |
trace_path: str = None, flush_l2: bool = True,
|
| 82 |
+
with_multiple_kernels: bool = False,
|
| 83 |
+
barrier: Optional[Callable] = None):
|
| 84 |
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 85 |
is_tuple = isinstance(kernel_names, tuple)
|
| 86 |
|
|
|
|
| 98 |
# Profile
|
| 99 |
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 100 |
with suppress():
|
| 101 |
+
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
| 102 |
+
profiler = torch.profiler.profile(
|
| 103 |
+
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True)
|
| 104 |
with profiler:
|
| 105 |
for i in range(2):
|
| 106 |
for _ in range(num_tests):
|
| 107 |
if flush_l2:
|
| 108 |
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
| 109 |
+
if barrier is not None:
|
| 110 |
+
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
| 111 |
+
# noinspection PyProtectedMember
|
| 112 |
+
torch.cuda._sleep(int(2e7)) # ~10ms
|
| 113 |
+
barrier()
|
| 114 |
fn()
|
| 115 |
+
torch.cuda.synchronize()
|
| 116 |
profiler.step()
|
| 117 |
|
| 118 |
# Parse the profiling table
|
|
|
|
| 120 |
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 121 |
if not with_multiple_kernels:
|
| 122 |
for name in kernel_names:
|
| 123 |
+
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}'
|
| 124 |
|
| 125 |
# Save chrome traces
|
| 126 |
if trace_path is not None:
|
build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
from . import math, layout
|
| 2 |
from .layout import *
|
| 3 |
from .math import *
|
|
|
|
|
|
| 1 |
from . import math, layout
|
| 2 |
from .layout import *
|
| 3 |
from .math import *
|
| 4 |
+
from .dist import init_dist, uneven_all_gather
|
build/torch210-cxx11-cu128-x86_64-linux/utils/dist.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
_local_rank = None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]:
|
| 11 |
+
# NOTES: you may rewrite this function with your own cluster settings
|
| 12 |
+
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
| 13 |
+
port = int(os.getenv('MASTER_PORT', '8361'))
|
| 14 |
+
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
| 15 |
+
node_rank = int(os.getenv('RANK', 0))
|
| 16 |
+
|
| 17 |
+
# Set local rank
|
| 18 |
+
global _local_rank
|
| 19 |
+
_local_rank = local_rank
|
| 20 |
+
|
| 21 |
+
sig = inspect.signature(dist.init_process_group)
|
| 22 |
+
params = {
|
| 23 |
+
'backend': 'nccl',
|
| 24 |
+
'init_method': f'tcp://{ip}:{port}',
|
| 25 |
+
'world_size': num_nodes * num_local_ranks,
|
| 26 |
+
'rank': node_rank * num_local_ranks + local_rank,
|
| 27 |
+
}
|
| 28 |
+
if 'device_id' in sig.parameters:
|
| 29 |
+
# noinspection PyTypeChecker
|
| 30 |
+
params['device_id'] = torch.device(f'cuda:{local_rank}')
|
| 31 |
+
dist.init_process_group(**params)
|
| 32 |
+
torch.set_default_device('cuda')
|
| 33 |
+
torch.cuda.set_device(local_rank)
|
| 34 |
+
|
| 35 |
+
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor:
|
| 39 |
+
world_size = dist.get_world_size(group)
|
| 40 |
+
|
| 41 |
+
# Exchange sizes
|
| 42 |
+
local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long)
|
| 43 |
+
all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)]
|
| 44 |
+
dist.all_gather(all_dim_sizes, local_dim_size, group=group)
|
| 45 |
+
all_dim_sizes = [s.item() for s in all_dim_sizes]
|
| 46 |
+
max_dim_size = max(all_dim_sizes)
|
| 47 |
+
|
| 48 |
+
# Pad
|
| 49 |
+
if tensor.shape[dim] < max_dim_size:
|
| 50 |
+
pad_shape = list(tensor.shape)
|
| 51 |
+
pad_shape[dim] = max_dim_size - tensor.shape[dim]
|
| 52 |
+
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
|
| 53 |
+
tensor_padded = torch.cat([tensor, padding], dim=dim)
|
| 54 |
+
else:
|
| 55 |
+
tensor_padded = tensor.contiguous()
|
| 56 |
+
|
| 57 |
+
# All-gather
|
| 58 |
+
gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)]
|
| 59 |
+
dist.all_gather(gathered, tensor_padded, group=group)
|
| 60 |
+
|
| 61 |
+
# Remove padding
|
| 62 |
+
trimmed = [
|
| 63 |
+
torch.narrow(gathered[i], dim, 0, all_dim_sizes[i])
|
| 64 |
+
for i in range(world_size)
|
| 65 |
+
]
|
| 66 |
+
return torch.cat(trimmed, dim=dim)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def dist_print(s: str = '', once_in_node: bool = False) -> None:
|
| 70 |
+
global _local_rank
|
| 71 |
+
assert _local_rank is not None
|
| 72 |
+
if not once_in_node or _local_rank == 0:
|
| 73 |
+
print(s, flush=True)
|
| 74 |
+
dist.barrier()
|
build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py
CHANGED
|
@@ -1,25 +1,21 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 21 |
-
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 25 |
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .._C import (
|
| 3 |
+
get_tma_aligned_size,
|
| 4 |
+
get_mn_major_tma_aligned_tensor,
|
| 5 |
+
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
| 6 |
+
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
| 7 |
+
)
|
| 8 |
+
except ImportError:
|
| 9 |
+
# Expected behavior for CUDA runtime version before 12.1
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
# Valid for all CUDA versions
|
| 13 |
+
from .._C import (
|
| 14 |
+
set_mk_alignment_for_contiguous_layout,
|
| 15 |
+
get_mk_alignment_for_contiguous_layout,
|
| 16 |
+
get_theoretical_mk_alignment_for_contiguous_layout,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Some alias
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 21 |
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
build/torch210-cxx11-cu128-x86_64-linux/utils/math.py
CHANGED
|
@@ -11,21 +11,30 @@ def align(x: int, y: int) -> int:
|
|
| 11 |
|
| 12 |
|
| 13 |
def ceil_to_ue8m0(x: torch.Tensor):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
assert x.dim() == 2
|
| 20 |
m, n = x.shape
|
| 21 |
padded_n = align(n, gran_k)
|
| 22 |
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
| 23 |
x_padded[:, :n] = x
|
| 24 |
-
x_view = x_padded.view(m,
|
| 25 |
-
x_amax = x_view.abs().float().amax(dim=2).view(m,
|
| 26 |
sf = x_amax / 448.0
|
| 27 |
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -70,13 +79,14 @@ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
|
| 70 |
code = idx.to(torch.uint8)
|
| 71 |
sign = (x < 0) & (idx != 0)
|
| 72 |
code = code | (sign.to(torch.uint8) << 3)
|
| 73 |
-
return code
|
| 74 |
|
| 75 |
|
| 76 |
-
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128
|
| 77 |
-
|
| 78 |
m, n = x.shape
|
| 79 |
assert n % 2 == 0
|
|
|
|
| 80 |
padded_n = align(n, gran_k)
|
| 81 |
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
| 82 |
x_padded[:, :n] = x
|
|
@@ -85,23 +95,49 @@ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -
|
|
| 85 |
sf = x_amax / 6.0
|
| 86 |
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 87 |
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
| 88 |
-
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) #
|
| 89 |
codes2 = codes.view(m, padded_n // 2, 2)
|
| 90 |
-
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) #
|
| 91 |
-
return packed[:, :n // 2].contiguous(), sf
|
| 92 |
|
| 93 |
|
| 94 |
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
| 95 |
-
assert a.dtype == torch.
|
| 96 |
assert a.dim() == 2
|
| 97 |
m, n2 = a.shape
|
| 98 |
n = n2 * 2
|
| 99 |
assert (m % 2) == 0
|
| 100 |
lo = a & 0x0F
|
| 101 |
hi = (a >> 4) & 0x0F
|
| 102 |
-
codes = torch.empty((m, n), device=a.device, dtype=torch.
|
| 103 |
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
| 104 |
codes_t = codes.transpose(0, 1).contiguous()
|
| 105 |
codes2 = codes_t.view(n, m // 2, 2)
|
| 106 |
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
| 107 |
-
return out.contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def ceil_to_ue8m0(x: torch.Tensor):
|
| 14 |
+
bits = x.abs().float().view(torch.int)
|
| 15 |
+
exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int()
|
| 16 |
+
return (exp.clamp(1, 254) << 23).view(torch.float)
|
| 17 |
|
| 18 |
|
| 19 |
+
def pack_ue8m0_to_int(x: torch.Tensor):
|
| 20 |
+
assert x.dtype == torch.float and x.size(-1) % 4 == 0
|
| 21 |
+
assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all()
|
| 22 |
+
return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128,
|
| 26 |
+
use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 27 |
assert x.dim() == 2
|
| 28 |
m, n = x.shape
|
| 29 |
padded_n = align(n, gran_k)
|
| 30 |
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
| 31 |
x_padded[:, :n] = x
|
| 32 |
+
x_view = x_padded.view(m, padded_n // gran_k, gran_k)
|
| 33 |
+
x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4)
|
| 34 |
sf = x_amax / 448.0
|
| 35 |
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 36 |
+
x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous()
|
| 37 |
+
return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf
|
| 38 |
|
| 39 |
|
| 40 |
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 79 |
code = idx.to(torch.uint8)
|
| 80 |
sign = (x < 0) & (idx != 0)
|
| 81 |
code = code | (sign.to(torch.uint8) << 3)
|
| 82 |
+
return code.view(torch.int8)
|
| 83 |
|
| 84 |
|
| 85 |
+
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128,
|
| 86 |
+
use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 87 |
m, n = x.shape
|
| 88 |
assert n % 2 == 0
|
| 89 |
+
assert not use_packed_ue8m0 or use_ue8m0
|
| 90 |
padded_n = align(n, gran_k)
|
| 91 |
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
| 92 |
x_padded[:, :n] = x
|
|
|
|
| 95 |
sf = x_amax / 6.0
|
| 96 |
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 97 |
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
| 98 |
+
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n)
|
| 99 |
codes2 = codes.view(m, padded_n // 2, 2)
|
| 100 |
+
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8
|
| 101 |
+
return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf
|
| 102 |
|
| 103 |
|
| 104 |
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
assert a.dtype == torch.int8
|
| 106 |
assert a.dim() == 2
|
| 107 |
m, n2 = a.shape
|
| 108 |
n = n2 * 2
|
| 109 |
assert (m % 2) == 0
|
| 110 |
lo = a & 0x0F
|
| 111 |
hi = (a >> 4) & 0x0F
|
| 112 |
+
codes = torch.empty((m, n), device=a.device, dtype=torch.int8)
|
| 113 |
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
| 114 |
codes_t = codes.transpose(0, 1).contiguous()
|
| 115 |
codes2 = codes_t.view(n, m // 2, 2)
|
| 116 |
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
| 117 |
+
return out.contiguous()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float)
|
| 122 |
+
sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int)
|
| 123 |
+
value = fp4_values[value_idx]
|
| 124 |
+
return torch.where(sign & (value_idx != 0), -value, value)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128,
|
| 132 |
+
use_packed_ue8m0: bool = False) -> torch.Tensor:
|
| 133 |
+
m, n2 = packed.shape
|
| 134 |
+
n = n2 * 2
|
| 135 |
+
if use_packed_ue8m0:
|
| 136 |
+
sf = unpack_ue8m0_from_int(sf)
|
| 137 |
+
unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device)
|
| 138 |
+
unpacked[:, ::2] = packed & 0x0F
|
| 139 |
+
unpacked[:, 1::2] = (packed >> 4) & 0x0F
|
| 140 |
+
x_dequantized = _dequantize_from_fp4_e2m1(unpacked)
|
| 141 |
+
group_idx = torch.arange(n, device=packed.device) // gran_k
|
| 142 |
+
x_restored = x_dequantized * sf[:, group_idx]
|
| 143 |
+
return x_restored
|
build/torch210-cxx11-cu130-x86_64-linux/_C.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ._ops import ops
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def set_num_sms(num_sms: int):
|
| 7 |
+
ops.set_num_sms(num_sms)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_num_sms() -> int:
|
| 11 |
+
return ops.get_num_sms()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def set_tc_util(tc_util: int):
|
| 15 |
+
ops.set_tc_util(tc_util)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_tc_util() -> int:
|
| 19 |
+
return ops.get_tc_util()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def set_ignore_compile_dims(value: bool):
|
| 23 |
+
ops.set_ignore_compile_dims(value)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def set_block_size_multiple_of(value):
|
| 27 |
+
if isinstance(value, tuple):
|
| 28 |
+
block_m, block_n = value
|
| 29 |
+
else:
|
| 30 |
+
block_m = block_n = value
|
| 31 |
+
ops.set_block_size_multiple_of(block_m, block_n)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def set_pdl(enable_pdl: bool):
|
| 35 |
+
ops.set_pdl(enable_pdl)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pdl() -> bool:
|
| 39 |
+
return ops.get_pdl()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def set_mk_alignment_for_contiguous_layout(value: int):
|
| 43 |
+
ops.set_mk_alignment_for_contiguous_layout(value)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_mk_alignment_for_contiguous_layout() -> int:
|
| 47 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int:
|
| 51 |
+
return ops.get_theoretical_mk_alignment_for_contiguous_layout(
|
| 52 |
+
0 if expected_m is None else expected_m,
|
| 53 |
+
expected_m is not None,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_tma_aligned_size(mn: int, element_size: int) -> int:
|
| 58 |
+
return ops.get_tma_aligned_size(mn, element_size).item()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 62 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 66 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 70 |
+
sf, ks_tensor, ks, gran_k
|
| 71 |
+
):
|
| 72 |
+
ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
|
| 73 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 74 |
+
sf, ks_tensor, ks_int, gran_k
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def transform_sf_into_required_layout(
|
| 79 |
+
sf,
|
| 80 |
+
mn,
|
| 81 |
+
k,
|
| 82 |
+
recipe,
|
| 83 |
+
num_groups=None,
|
| 84 |
+
is_sfa=None,
|
| 85 |
+
disable_ue8m0_cast=False,
|
| 86 |
+
):
|
| 87 |
+
if len(recipe) == 3:
|
| 88 |
+
r0, r1, r2 = recipe
|
| 89 |
+
recipe_len = 3
|
| 90 |
+
elif len(recipe) == 2:
|
| 91 |
+
r0, r1 = recipe
|
| 92 |
+
r2 = 0
|
| 93 |
+
recipe_len = 2
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError("recipe must have length 2 or 3")
|
| 96 |
+
|
| 97 |
+
return ops.transform_sf_into_required_layout(
|
| 98 |
+
sf,
|
| 99 |
+
mn,
|
| 100 |
+
k,
|
| 101 |
+
r0,
|
| 102 |
+
r1,
|
| 103 |
+
r2,
|
| 104 |
+
recipe_len,
|
| 105 |
+
0 if num_groups is None else num_groups,
|
| 106 |
+
num_groups is not None,
|
| 107 |
+
False if is_sfa is None else is_sfa,
|
| 108 |
+
is_sfa is not None,
|
| 109 |
+
disable_ue8m0_cast,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_token_alignment_for_mega_moe() -> int:
|
| 114 |
+
return ops.get_token_alignment_for_mega_moe()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_symm_buffer_size_for_mega_moe(
|
| 118 |
+
num_ranks,
|
| 119 |
+
num_experts,
|
| 120 |
+
num_max_tokens_per_rank,
|
| 121 |
+
num_topk,
|
| 122 |
+
hidden,
|
| 123 |
+
intermediate_hidden,
|
| 124 |
+
use_fp8_dispatch=True,
|
| 125 |
+
activation="swiglu",
|
| 126 |
+
):
|
| 127 |
+
num_bytes = ops.get_symm_buffer_size_for_mega_moe(
|
| 128 |
+
num_ranks,
|
| 129 |
+
num_experts,
|
| 130 |
+
num_max_tokens_per_rank,
|
| 131 |
+
num_topk,
|
| 132 |
+
hidden,
|
| 133 |
+
intermediate_hidden,
|
| 134 |
+
use_fp8_dispatch,
|
| 135 |
+
activation,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def slice_input_buffers(buffer):
|
| 139 |
+
return tuple(
|
| 140 |
+
ops.get_symm_buffer_views_for_mega_moe(
|
| 141 |
+
buffer,
|
| 142 |
+
num_ranks,
|
| 143 |
+
num_experts,
|
| 144 |
+
num_max_tokens_per_rank,
|
| 145 |
+
num_topk,
|
| 146 |
+
hidden,
|
| 147 |
+
intermediate_hidden,
|
| 148 |
+
use_fp8_dispatch,
|
| 149 |
+
activation,
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return num_bytes, slice_input_buffers
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def fp8_fp4_mega_moe(
|
| 157 |
+
y,
|
| 158 |
+
l1_weights,
|
| 159 |
+
l2_weights,
|
| 160 |
+
cumulative_local_expert_recv_stats,
|
| 161 |
+
sym_buffer,
|
| 162 |
+
sym_buffer_ptrs,
|
| 163 |
+
rank_idx,
|
| 164 |
+
num_max_tokens_per_rank,
|
| 165 |
+
num_experts,
|
| 166 |
+
num_topk,
|
| 167 |
+
recipe,
|
| 168 |
+
activation,
|
| 169 |
+
activation_clamp,
|
| 170 |
+
fast_math,
|
| 171 |
+
):
|
| 172 |
+
l1_weights_data, l1_weights_sf = l1_weights
|
| 173 |
+
l2_weights_data, l2_weights_sf = l2_weights
|
| 174 |
+
r0, r1, r2 = recipe
|
| 175 |
+
ops.fp8_fp4_mega_moe(
|
| 176 |
+
y,
|
| 177 |
+
l1_weights_data,
|
| 178 |
+
l1_weights_sf,
|
| 179 |
+
l2_weights_data,
|
| 180 |
+
l2_weights_sf,
|
| 181 |
+
cumulative_local_expert_recv_stats,
|
| 182 |
+
sym_buffer,
|
| 183 |
+
sym_buffer_ptrs,
|
| 184 |
+
rank_idx,
|
| 185 |
+
num_max_tokens_per_rank,
|
| 186 |
+
num_experts,
|
| 187 |
+
num_topk,
|
| 188 |
+
r0,
|
| 189 |
+
r1,
|
| 190 |
+
r2,
|
| 191 |
+
activation,
|
| 192 |
+
activation_clamp,
|
| 193 |
+
fast_math,
|
| 194 |
+
)
|
build/torch210-cxx11-cu130-x86_64-linux/__init__.py
CHANGED
|
@@ -1,12 +1,18 @@
|
|
| 1 |
import os
|
| 2 |
import subprocess
|
|
|
|
| 3 |
import torch
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
# Import the compiled extension
|
| 6 |
-
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
from . import utils
|
| 8 |
|
| 9 |
-
__version__ = "2.
|
| 10 |
|
| 11 |
|
| 12 |
# ── Register fake tensor implementations for torch.compile ──────────────────
|
|
@@ -32,6 +38,7 @@ for _op in [
|
|
| 32 |
"m_grouped_bf16_gemm_nn_contiguous",
|
| 33 |
"m_grouped_bf16_gemm_nt_masked",
|
| 34 |
"fp8_gemm_nt_skip_head_mid",
|
|
|
|
| 35 |
]:
|
| 36 |
|
| 37 |
@torch.library.register_fake(add_op_namespace_prefix(_op))
|
|
@@ -58,10 +65,41 @@ def get_tc_util() -> int:
|
|
| 58 |
return ops.get_tc_util()
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def get_mk_alignment_for_contiguous_layout() -> int:
|
| 62 |
return ops.get_mk_alignment_for_contiguous_layout()
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# Layout utilities
|
| 66 |
|
| 67 |
|
|
@@ -77,10 +115,12 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
|
| 77 |
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 78 |
|
| 79 |
|
| 80 |
-
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
|
|
|
|
|
|
| 81 |
ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
|
| 82 |
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 83 |
-
sf, ks_tensor, ks_int
|
| 84 |
)
|
| 85 |
|
| 86 |
|
|
@@ -88,16 +128,20 @@ def transform_sf_into_required_layout(
|
|
| 88 |
sf,
|
| 89 |
mn,
|
| 90 |
k,
|
| 91 |
-
recipe
|
| 92 |
-
recipe_ab=None,
|
| 93 |
num_groups=None,
|
| 94 |
-
is_sfa=
|
| 95 |
disable_ue8m0_cast=False,
|
| 96 |
):
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
has_ng = num_groups is not None
|
| 102 |
ng = num_groups if has_ng else 0
|
| 103 |
return ops.transform_sf_into_required_layout(
|
|
@@ -107,13 +151,11 @@ def transform_sf_into_required_layout(
|
|
| 107 |
r0,
|
| 108 |
r1,
|
| 109 |
r2,
|
| 110 |
-
|
| 111 |
-
rab0,
|
| 112 |
-
rab1,
|
| 113 |
-
has_recipe_ab,
|
| 114 |
ng,
|
| 115 |
has_ng,
|
| 116 |
-
is_sfa,
|
|
|
|
| 117 |
disable_ue8m0_cast,
|
| 118 |
)
|
| 119 |
|
|
@@ -593,8 +635,37 @@ def fp8_mqa_logits(
|
|
| 593 |
)
|
| 594 |
|
| 595 |
|
| 596 |
-
def
|
| 597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
|
| 600 |
def fp8_paged_mqa_logits(
|
|
@@ -606,6 +677,7 @@ def fp8_paged_mqa_logits(
|
|
| 606 |
schedule_meta,
|
| 607 |
max_context_len,
|
| 608 |
clean_logits=False,
|
|
|
|
| 609 |
):
|
| 610 |
return ops.fp8_paged_mqa_logits(
|
| 611 |
q,
|
|
@@ -616,6 +688,38 @@ def fp8_paged_mqa_logits(
|
|
| 616 |
schedule_meta,
|
| 617 |
max_context_len,
|
| 618 |
clean_logits,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
)
|
| 620 |
|
| 621 |
|
|
@@ -642,6 +746,14 @@ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
|
| 642 |
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
|
| 643 |
|
| 644 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
# Initialize the C++ runtime
|
| 646 |
|
| 647 |
|
|
@@ -683,6 +795,14 @@ if "DG_CUTLASS_INCLUDE" not in os.environ:
|
|
| 683 |
_include, # legacy layout: include/cutlass
|
| 684 |
os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
|
| 685 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
for _cutlass_include in _cutlass_include_candidates:
|
| 687 |
if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
|
| 688 |
os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
|
|
@@ -703,8 +823,21 @@ def _ensure_initialized():
|
|
| 703 |
global _initialized
|
| 704 |
if _initialized:
|
| 705 |
return
|
|
|
|
| 706 |
_initialized = True
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
|
| 709 |
|
| 710 |
# Try to initialize eagerly, but don't fail if CUDA is not found
|
|
|
|
| 1 |
import os
|
| 2 |
import subprocess
|
| 3 |
+
import sysconfig
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
# Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton.
|
| 7 |
+
# In packaged/lazy-loaded use, that can outlive PyTorch's CUDA teardown and crash
|
| 8 |
+
# during interpreter shutdown.
|
| 9 |
+
os.environ.setdefault("DG_USE_TEMP_CUBLASLT_WORKSPACE", "1")
|
| 10 |
+
|
| 11 |
# Import the compiled extension
|
| 12 |
+
from ._ops import ops as _ops, add_op_namespace_prefix
|
| 13 |
from . import utils
|
| 14 |
|
| 15 |
+
__version__ = "2.5.0"
|
| 16 |
|
| 17 |
|
| 18 |
# ── Register fake tensor implementations for torch.compile ──────────────────
|
|
|
|
| 38 |
"m_grouped_bf16_gemm_nn_contiguous",
|
| 39 |
"m_grouped_bf16_gemm_nt_masked",
|
| 40 |
"fp8_gemm_nt_skip_head_mid",
|
| 41 |
+
"fp8_fp4_mega_moe",
|
| 42 |
]:
|
| 43 |
|
| 44 |
@torch.library.register_fake(add_op_namespace_prefix(_op))
|
|
|
|
| 65 |
return ops.get_tc_util()
|
| 66 |
|
| 67 |
|
| 68 |
+
def set_ignore_compile_dims(value: bool):
|
| 69 |
+
ops.set_ignore_compile_dims(value)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def set_block_size_multiple_of(value):
|
| 73 |
+
if isinstance(value, tuple):
|
| 74 |
+
block_m, block_n = value
|
| 75 |
+
else:
|
| 76 |
+
block_m = block_n = value
|
| 77 |
+
ops.set_block_size_multiple_of(block_m, block_n)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def set_pdl(enable_pdl: bool):
|
| 81 |
+
ops.set_pdl(enable_pdl)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_pdl() -> bool:
|
| 85 |
+
return ops.get_pdl()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def set_mk_alignment_for_contiguous_layout(alignment: int):
|
| 89 |
+
ops.set_mk_alignment_for_contiguous_layout(alignment)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
def get_mk_alignment_for_contiguous_layout() -> int:
|
| 93 |
return ops.get_mk_alignment_for_contiguous_layout()
|
| 94 |
|
| 95 |
|
| 96 |
+
def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int:
|
| 97 |
+
return ops.get_theoretical_mk_alignment_for_contiguous_layout(
|
| 98 |
+
0 if expected_m is None else expected_m,
|
| 99 |
+
expected_m is not None,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
# Layout utilities
|
| 104 |
|
| 105 |
|
|
|
|
| 115 |
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 116 |
|
| 117 |
|
| 118 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 119 |
+
sf, ks_tensor, ks, gran_k
|
| 120 |
+
):
|
| 121 |
ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
|
| 122 |
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 123 |
+
sf, ks_tensor, ks_int, gran_k
|
| 124 |
)
|
| 125 |
|
| 126 |
|
|
|
|
| 128 |
sf,
|
| 129 |
mn,
|
| 130 |
k,
|
| 131 |
+
recipe,
|
|
|
|
| 132 |
num_groups=None,
|
| 133 |
+
is_sfa=None,
|
| 134 |
disable_ue8m0_cast=False,
|
| 135 |
):
|
| 136 |
+
if len(recipe) == 3:
|
| 137 |
+
r0, r1, r2 = recipe
|
| 138 |
+
recipe_len = 3
|
| 139 |
+
elif len(recipe) == 2:
|
| 140 |
+
r0, r1 = recipe
|
| 141 |
+
r2 = 0
|
| 142 |
+
recipe_len = 2
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError("recipe must have length 2 or 3")
|
| 145 |
has_ng = num_groups is not None
|
| 146 |
ng = num_groups if has_ng else 0
|
| 147 |
return ops.transform_sf_into_required_layout(
|
|
|
|
| 151 |
r0,
|
| 152 |
r1,
|
| 153 |
r2,
|
| 154 |
+
recipe_len,
|
|
|
|
|
|
|
|
|
|
| 155 |
ng,
|
| 156 |
has_ng,
|
| 157 |
+
False if is_sfa is None else is_sfa,
|
| 158 |
+
is_sfa is not None,
|
| 159 |
disable_ue8m0_cast,
|
| 160 |
)
|
| 161 |
|
|
|
|
| 635 |
)
|
| 636 |
|
| 637 |
|
| 638 |
+
def fp8_fp4_mqa_logits(
|
| 639 |
+
q,
|
| 640 |
+
kv,
|
| 641 |
+
weights,
|
| 642 |
+
cu_seq_len_k_start,
|
| 643 |
+
cu_seq_len_k_end,
|
| 644 |
+
clean_logits=True,
|
| 645 |
+
max_seqlen_k=0,
|
| 646 |
+
logits_dtype=torch.float32,
|
| 647 |
+
):
|
| 648 |
+
if isinstance(q, tuple):
|
| 649 |
+
q_data, q_sf = q
|
| 650 |
+
else:
|
| 651 |
+
q_data, q_sf = q, None
|
| 652 |
+
kv_data, kv_sf = kv
|
| 653 |
+
return ops.fp8_fp4_mqa_logits(
|
| 654 |
+
q_data,
|
| 655 |
+
q_sf,
|
| 656 |
+
kv_data,
|
| 657 |
+
kv_sf,
|
| 658 |
+
weights,
|
| 659 |
+
cu_seq_len_k_start,
|
| 660 |
+
cu_seq_len_k_end,
|
| 661 |
+
clean_logits,
|
| 662 |
+
max_seqlen_k,
|
| 663 |
+
logits_dtype,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None):
|
| 668 |
+
return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices)
|
| 669 |
|
| 670 |
|
| 671 |
def fp8_paged_mqa_logits(
|
|
|
|
| 677 |
schedule_meta,
|
| 678 |
max_context_len,
|
| 679 |
clean_logits=False,
|
| 680 |
+
indices=None,
|
| 681 |
):
|
| 682 |
return ops.fp8_paged_mqa_logits(
|
| 683 |
q,
|
|
|
|
| 688 |
schedule_meta,
|
| 689 |
max_context_len,
|
| 690 |
clean_logits,
|
| 691 |
+
indices,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def fp8_fp4_paged_mqa_logits(
|
| 696 |
+
q,
|
| 697 |
+
kv_cache,
|
| 698 |
+
weights,
|
| 699 |
+
context_lens,
|
| 700 |
+
block_table,
|
| 701 |
+
schedule_meta,
|
| 702 |
+
max_context_len,
|
| 703 |
+
clean_logits=False,
|
| 704 |
+
logits_dtype=torch.float32,
|
| 705 |
+
indices=None,
|
| 706 |
+
):
|
| 707 |
+
if isinstance(q, tuple):
|
| 708 |
+
q_data, q_sf = q
|
| 709 |
+
else:
|
| 710 |
+
q_data, q_sf = q, None
|
| 711 |
+
return ops.fp8_fp4_paged_mqa_logits(
|
| 712 |
+
q_data,
|
| 713 |
+
q_sf,
|
| 714 |
+
kv_cache,
|
| 715 |
+
weights,
|
| 716 |
+
context_lens,
|
| 717 |
+
block_table,
|
| 718 |
+
schedule_meta,
|
| 719 |
+
max_context_len,
|
| 720 |
+
clean_logits,
|
| 721 |
+
logits_dtype,
|
| 722 |
+
indices,
|
| 723 |
)
|
| 724 |
|
| 725 |
|
|
|
|
| 746 |
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
|
| 747 |
|
| 748 |
|
| 749 |
+
from .mega import (
|
| 750 |
+
SymmBuffer,
|
| 751 |
+
get_symm_buffer_for_mega_moe,
|
| 752 |
+
transform_weights_for_mega_moe,
|
| 753 |
+
fp8_fp4_mega_moe,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
|
| 757 |
# Initialize the C++ runtime
|
| 758 |
|
| 759 |
|
|
|
|
| 795 |
_include, # legacy layout: include/cutlass
|
| 796 |
os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
|
| 797 |
]
|
| 798 |
+
for _site_packages in {
|
| 799 |
+
sysconfig.get_paths().get("purelib"),
|
| 800 |
+
sysconfig.get_paths().get("platlib"),
|
| 801 |
+
}:
|
| 802 |
+
if _site_packages:
|
| 803 |
+
_cutlass_include_candidates.append(
|
| 804 |
+
os.path.join(_site_packages, "cutlass_library", "source", "include")
|
| 805 |
+
)
|
| 806 |
for _cutlass_include in _cutlass_include_candidates:
|
| 807 |
if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
|
| 808 |
os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
|
|
|
|
| 823 |
global _initialized
|
| 824 |
if _initialized:
|
| 825 |
return
|
| 826 |
+
_ops.init(_lib_root, _find_cuda_home())
|
| 827 |
_initialized = True
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class _InitializedOps:
|
| 831 |
+
def __init__(self, raw_ops):
|
| 832 |
+
self._raw_ops = raw_ops
|
| 833 |
+
|
| 834 |
+
def __getattr__(self, name):
|
| 835 |
+
if name != "init":
|
| 836 |
+
_ensure_initialized()
|
| 837 |
+
return getattr(self._raw_ops, name)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
ops = _InitializedOps(_ops)
|
| 841 |
|
| 842 |
|
| 843 |
# Try to initialize eagerly, but don't fail if CUDA is not found
|
build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec220d340cd32423ffdca60d3ba9e1cf8196e1cee096bad64dd7726824d79898
|
| 3 |
+
size 3461568
|
build/torch210-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _deep_gemm_cuda_388adb9
|
| 3 |
+
ops = torch.ops._deep_gemm_cuda_388adb9
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_deep_gemm_cuda_388adb9::{op_name}"
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/comm/barrier.cuh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
|
| 5 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 6 |
+
#include <deep_gemm/layout/sym_buffer.cuh>
|
| 7 |
+
#include <deep_gemm/layout/mega_moe.cuh>
|
| 8 |
+
|
| 9 |
+
namespace deep_gemm::comm {
|
| 10 |
+
|
| 11 |
+
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
|
| 12 |
+
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
|
| 13 |
+
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
|
| 14 |
+
cute::cluster_arrive_relaxed();
|
| 15 |
+
cute::cluster_wait();
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
|
| 19 |
+
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
|
| 20 |
+
const uint32_t& sm_idx, const uint32_t& thread_idx,
|
| 21 |
+
const sync_scope_t& sync_scope) {
|
| 22 |
+
// NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()`
|
| 23 |
+
static constexpr uint32_t kFinishSumTag = 0x80000000u;
|
| 24 |
+
sync_scope();
|
| 25 |
+
if (thread_idx == 0) {
|
| 26 |
+
const auto count_ptr = workspace.get_grid_sync_count_ptr<kGridSyncIndex>();
|
| 27 |
+
const auto old_value = ptx::atomic_add_rel(
|
| 28 |
+
count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1);
|
| 29 |
+
uint32_t new_value;
|
| 30 |
+
do {
|
| 31 |
+
new_value = ptx::ld_acq(count_ptr);
|
| 32 |
+
} while (((new_value ^ old_value) & kFinishSumTag) == 0);
|
| 33 |
+
}
|
| 34 |
+
sync_scope();
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <uint32_t kNumRanks, uint32_t kNumSMs, uint32_t kNumThreads, uint32_t kGridSyncIndex, uint32_t kTag, typename sync_scope_t>
|
| 38 |
+
CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace,
|
| 39 |
+
const layout::SymBuffer<kNumRanks>& sym_buffer,
|
| 40 |
+
const uint32_t& sm_idx, const uint32_t& thread_idx,
|
| 41 |
+
const sync_scope_t& sync_scope,
|
| 42 |
+
const bool& sync_prologue = true,
|
| 43 |
+
const bool& sync_epilogue = true) {
|
| 44 |
+
DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads");
|
| 45 |
+
|
| 46 |
+
// Grid sync before NVLink signaling
|
| 47 |
+
if (sync_prologue)
|
| 48 |
+
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
|
| 49 |
+
|
| 50 |
+
// NVLink cross-rank barrier, only SM 0 participates
|
| 51 |
+
if (sm_idx == 0) {
|
| 52 |
+
auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr();
|
| 53 |
+
const auto status = (*counter_ptr) & 3;
|
| 54 |
+
const auto signal_phase = status & 1, signal_sign = status >> 1;
|
| 55 |
+
auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase);
|
| 56 |
+
|
| 57 |
+
// Send signals to remote ranks
|
| 58 |
+
if (thread_idx < kNumRanks)
|
| 59 |
+
ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1);
|
| 60 |
+
sync_scope();
|
| 61 |
+
|
| 62 |
+
// Update status and wait arrival (with 30s timeout, at 2 GHz)
|
| 63 |
+
constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll;
|
| 64 |
+
if (thread_idx == 0) {
|
| 65 |
+
ptx::red_add(counter_ptr, 1);
|
| 66 |
+
const int target = signal_sign ? 0 : static_cast<int>(kNumRanks);
|
| 67 |
+
const auto start_clock = clock64();
|
| 68 |
+
while (ptx::ld_acq_sys(signal_ptr) != target) {
|
| 69 |
+
if (clock64() - start_clock >= kNumTimeoutCycles) {
|
| 70 |
+
printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n",
|
| 71 |
+
sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag);
|
| 72 |
+
DG_DEVICE_ASSERT(false and "NVLink barrier timeout");
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// Grid sync after NVLink completion
|
| 79 |
+
if (sync_epilogue)
|
| 80 |
+
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
} // namespace deep_gemm::comm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/compile.cuh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/detail/helper_macros.hpp>
|
| 4 |
+
|
| 5 |
+
#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__)
|
| 6 |
+
#define DG_IN_CUDA_COMPILATION
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__))
|
| 10 |
+
#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__
|
| 11 |
+
#define CUTLASS_DEVICE_NOINLINE __device__
|
| 12 |
+
#elif defined(__CUDACC_RTC__)
|
| 13 |
+
#define CUTLASS_HOST_DEVICE_NOINLINE __device__
|
| 14 |
+
#define CUTLASS_DEVICE_NOINLINE __device__
|
| 15 |
+
#else
|
| 16 |
+
#define CUTLASS_HOST_DEVICE_NOINLINE
|
| 17 |
+
#define CUTLASS_DEVICE_NOINLINE
|
| 18 |
+
#endif
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
|
|
|
| 3 |
namespace cute {
|
| 4 |
|
| 5 |
struct ignore_t {
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include <cute/int_tuple.hpp>
|
| 4 |
+
|
| 5 |
namespace cute {
|
| 6 |
|
| 7 |
struct ignore_t {
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/exception.cuh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda/std/cstdint>
|
| 4 |
+
#include <deep_gemm/common/compile.cuh>
|
| 5 |
+
|
| 6 |
+
#ifdef __CLION_IDE__
|
| 7 |
+
|
| 8 |
+
CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) {
|
| 9 |
+
asm volatile("trap;");
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
#define printf host_device_printf
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#ifndef DG_DEVICE_ASSERT
|
| 16 |
+
#define DG_DEVICE_ASSERT(cond) \
|
| 17 |
+
do { \
|
| 18 |
+
if (not (cond)) { \
|
| 19 |
+
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
| 20 |
+
asm("trap;"); \
|
| 21 |
+
} \
|
| 22 |
+
} while (0)
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
|
| 26 |
+
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
|
| 27 |
+
do { \
|
| 28 |
+
if (not (cond)) \
|
| 29 |
+
asm("trap;"); \
|
| 30 |
+
} while (0)
|
| 31 |
+
#endif
|
| 32 |
+
|
| 33 |
+
#ifndef DG_STATIC_ASSERT
|
| 34 |
+
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
#ifndef DG_UNIFIED_ASSERT
|
| 38 |
+
#ifdef DG_IN_CUDA_COMPILATION
|
| 39 |
+
#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond)
|
| 40 |
+
#else
|
| 41 |
+
#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond)
|
| 42 |
+
#endif
|
| 43 |
+
#endif
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/math.cuh
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda/std/cstdint>
|
| 4 |
+
#include <deep_gemm/common/compile.cuh>
|
| 5 |
+
#include <deep_gemm/common/exception.cuh>
|
| 6 |
+
|
| 7 |
+
namespace deep_gemm::math {
|
| 8 |
+
|
| 9 |
+
/// Pointer operations
|
| 10 |
+
template <typename dtype_t = void>
|
| 11 |
+
CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) {
|
| 12 |
+
return reinterpret_cast<dtype_t*>(static_cast<uint8_t*>(ptr) + num_bytes);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
/// Math functions
|
| 16 |
+
template <typename T>
|
| 17 |
+
CUTLASS_HOST_DEVICE T ceil_div(T a, T b) {
|
| 18 |
+
return (a + b - 1) / b;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
template <typename T>
|
| 22 |
+
CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) {
|
| 23 |
+
return (a + b - 1) / b;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <typename T, bool kDoCeilAlignment = true>
|
| 27 |
+
CUTLASS_HOST_DEVICE T align(T a, T b) {
|
| 28 |
+
return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) {
|
| 33 |
+
return constexpr_ceil_div(a, b) * b;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
template <typename T>
|
| 37 |
+
CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) {
|
| 38 |
+
return b == 0 ? a : constexpr_gcd(b, a % b);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template <typename T>
|
| 42 |
+
CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) {
|
| 43 |
+
return a < b ? a : b;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
template <typename T>
|
| 47 |
+
CUTLASS_DEVICE void swap(T& a, T& b) {
|
| 48 |
+
T temp = a;
|
| 49 |
+
a = b;
|
| 50 |
+
b = temp;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
#ifdef DG_IN_CUDA_COMPILATION
|
| 54 |
+
CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) {
|
| 55 |
+
#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
|
| 56 |
+
return __ffma2_rn(a, b, c);
|
| 57 |
+
#else
|
| 58 |
+
return make_float2(
|
| 59 |
+
__fmaf_rn(a.x, b.x, c.x),
|
| 60 |
+
__fmaf_rn(a.y, b.y, c.y)
|
| 61 |
+
);
|
| 62 |
+
#endif
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CUTLASS_HOST_DEVICE float fast_rcp(const float& x) {
|
| 66 |
+
#if defined(__CUDA_ARCH__)
|
| 67 |
+
float ret;
|
| 68 |
+
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x));
|
| 69 |
+
return ret;
|
| 70 |
+
#else
|
| 71 |
+
return 1.0f / x;
|
| 72 |
+
#endif
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
/// Casting
|
| 76 |
+
template <typename old_t>
|
| 77 |
+
CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
| 78 |
+
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
|
| 79 |
+
return *reinterpret_cast<int*>(&bf16x2);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
CUTLASS_DEVICE float fast_pow2(const int& x) {
|
| 83 |
+
uint32_t bits_x = (x + 127) << 23;
|
| 84 |
+
return *reinterpret_cast<float*>(&bits_x);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
CUTLASS_DEVICE int fast_log2_ceil(float x) {
|
| 88 |
+
const auto bits = *reinterpret_cast<uint32_t*>(&x);
|
| 89 |
+
const auto exp = bits >> 23;
|
| 90 |
+
const auto man = bits & ((1 << 23) - 1);
|
| 91 |
+
return exp - 127 + (man != 0);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <bool kUseUE8M0 = true>
|
| 95 |
+
CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) {
|
| 96 |
+
DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0");
|
| 97 |
+
const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0};
|
| 98 |
+
const auto scaled = __fmul2_rn(amax, finfo_factor);
|
| 99 |
+
const auto exp_x = fast_log2_ceil(scaled.x);
|
| 100 |
+
const auto exp_y = fast_log2_ceil(scaled.y);
|
| 101 |
+
sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x);
|
| 102 |
+
sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
/// Reduction
|
| 106 |
+
CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) {
|
| 107 |
+
#pragma unroll
|
| 108 |
+
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
| 109 |
+
const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset);
|
| 110 |
+
if (lane_idx >= offset)
|
| 111 |
+
value += synced;
|
| 112 |
+
}
|
| 113 |
+
return value;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Operation functors
|
| 117 |
+
template <typename T> struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } };
|
| 118 |
+
template <typename T> struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } };
|
| 119 |
+
template <typename T> struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } };
|
| 120 |
+
template <typename T> struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } };
|
| 121 |
+
template <typename T> struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } };
|
| 122 |
+
|
| 123 |
+
// Unified reduction function
|
| 124 |
+
template <uint32_t kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
|
| 125 |
+
CUTLASS_DEVICE T warp_reduce(T value, Op op) {
|
| 126 |
+
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
|
| 127 |
+
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
|
| 128 |
+
"Invalid number of lanes");
|
| 129 |
+
constexpr uint32_t mask = 0xffffffff;
|
| 130 |
+
if constexpr (kIntergroupReduce) {
|
| 131 |
+
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
|
| 132 |
+
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
|
| 133 |
+
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
|
| 134 |
+
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
|
| 135 |
+
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
|
| 136 |
+
} else {
|
| 137 |
+
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
|
| 138 |
+
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
|
| 139 |
+
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
|
| 140 |
+
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
|
| 141 |
+
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
|
| 142 |
+
}
|
| 143 |
+
return value;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// Convenience aliases
|
| 147 |
+
template <uint32_t kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
|
| 148 |
+
CUTLASS_DEVICE T warp_reduce_sum(T value) {
|
| 149 |
+
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
|
| 150 |
+
}
|
| 151 |
+
#endif
|
| 152 |
+
|
| 153 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/tma_copy.cuh
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 4 |
+
#include <cute/arch/copy_sm100_tma.hpp>
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/exception.cuh>
|
| 8 |
+
|
| 9 |
+
namespace deep_gemm::tma {
|
| 10 |
+
|
| 11 |
+
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
|
| 12 |
+
constexpr uint32_t get_inner_block_atom_size() {
|
| 13 |
+
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
|
| 17 |
+
uint32_t kSwizzleMode,
|
| 18 |
+
typename dtype_t, bool kIs3DTMA = false>
|
| 19 |
+
CUTLASS_DEVICE void
|
| 20 |
+
copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
|
| 21 |
+
dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
|
| 22 |
+
const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
|
| 23 |
+
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
|
| 24 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
|
| 25 |
+
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
|
| 26 |
+
|
| 27 |
+
if constexpr (not kIs3DTMA) {
|
| 28 |
+
if (num_tma_multicast == 1) {
|
| 29 |
+
#pragma unroll
|
| 30 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 31 |
+
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 32 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 33 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 34 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 35 |
+
}
|
| 36 |
+
} else {
|
| 37 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
|
| 38 |
+
// 2-CTA function will send signals to the leader CTA only
|
| 39 |
+
#pragma unroll
|
| 40 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 41 |
+
cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 42 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 43 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 44 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 45 |
+
}
|
| 46 |
+
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
|
| 47 |
+
if (cute::block_rank_in_cluster() == 0) {
|
| 48 |
+
#pragma unroll
|
| 49 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 50 |
+
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 51 |
+
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
|
| 52 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 53 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
#endif
|
| 57 |
+
}
|
| 58 |
+
} else {
|
| 59 |
+
if (num_tma_multicast == 1) {
|
| 60 |
+
#pragma unroll
|
| 61 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 62 |
+
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 63 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 64 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 65 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 66 |
+
}
|
| 67 |
+
} else {
|
| 68 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
|
| 69 |
+
// 2-CTA function will send signals to the leader CTA only
|
| 70 |
+
#pragma unroll
|
| 71 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 72 |
+
cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 73 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 74 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 75 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 76 |
+
}
|
| 77 |
+
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
|
| 78 |
+
if (cute::block_rank_in_cluster() == 0) {
|
| 79 |
+
#pragma unroll
|
| 80 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 81 |
+
cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 82 |
+
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
|
| 83 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 84 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
#endif
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
} // namespace deep_gemm::tma
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/types.cuh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/mma_sm100_desc.hpp>
|
| 4 |
+
|
| 5 |
+
namespace deep_gemm {
|
| 6 |
+
|
| 7 |
+
enum class MmaKind {
|
| 8 |
+
BF16 = 0,
|
| 9 |
+
MXFP8FP4 = 1,
|
| 10 |
+
};
|
| 11 |
+
|
| 12 |
+
constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) {
|
| 13 |
+
switch (mma_kind) {
|
| 14 |
+
case MmaKind::BF16: return 2;
|
| 15 |
+
case MmaKind::MXFP8FP4: return 1;
|
| 16 |
+
default: return 0;
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
enum class GemmType {
|
| 21 |
+
Normal = 0,
|
| 22 |
+
MGroupedContiguous = 1,
|
| 23 |
+
MGroupedMasked = 2,
|
| 24 |
+
KGroupedContiguous = 3,
|
| 25 |
+
Batched = 4,
|
| 26 |
+
MGroupedContiguousWithPsumLayout = 5,
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) {
|
| 30 |
+
switch (gemm_type) {
|
| 31 |
+
case GemmType::MGroupedContiguous: return true;
|
| 32 |
+
case GemmType::MGroupedContiguousWithPsumLayout: return true;
|
| 33 |
+
default: return false;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
enum class KernelType {
|
| 38 |
+
Kernel1D1D = 0,
|
| 39 |
+
Kernel1D2D = 1,
|
| 40 |
+
KernelNoSF = 2
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh
CHANGED
|
@@ -1,167 +1,24 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
#include <cuda_bf16.h>
|
| 4 |
-
#include <cuda_fp8.h>
|
| 5 |
#include <cuda/std/cstdint>
|
| 6 |
-
#include <cuda/std/utility>
|
| 7 |
-
#include <cute/container/tuple.hpp>
|
| 8 |
|
| 9 |
-
#include
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
| 14 |
-
asm volatile("trap;");
|
| 15 |
-
}
|
| 16 |
-
|
| 17 |
-
#define printf host_device_printf
|
| 18 |
-
#endif
|
| 19 |
-
|
| 20 |
-
#ifndef DG_DEVICE_ASSERT
|
| 21 |
-
#define DG_DEVICE_ASSERT(cond) \
|
| 22 |
-
do { \
|
| 23 |
-
if (not (cond)) { \
|
| 24 |
-
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
| 25 |
-
asm("trap;"); \
|
| 26 |
-
} \
|
| 27 |
-
} while (0)
|
| 28 |
-
#endif
|
| 29 |
-
|
| 30 |
-
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
|
| 31 |
-
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
|
| 32 |
-
do { \
|
| 33 |
-
if (not (cond)) \
|
| 34 |
-
asm("trap;"); \
|
| 35 |
-
} while (0)
|
| 36 |
-
#endif
|
| 37 |
-
|
| 38 |
-
#ifndef DG_STATIC_ASSERT
|
| 39 |
-
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
| 40 |
-
#endif
|
| 41 |
-
|
| 42 |
-
namespace deep_gemm {
|
| 43 |
|
| 44 |
template <typename FuncT>
|
| 45 |
struct PatternVisitor {
|
| 46 |
FuncT func;
|
| 47 |
|
| 48 |
-
|
| 49 |
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
|
| 50 |
|
| 51 |
-
|
| 52 |
-
auto operator [](const uint32_t& i) {
|
| 53 |
return func(i);
|
| 54 |
}
|
| 55 |
};
|
| 56 |
|
| 57 |
-
template <typename T>
|
| 58 |
-
__device__ __host__ T ceil_div(T a, T b) {
|
| 59 |
-
return (a + b - 1) / b;
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
template <typename T>
|
| 63 |
-
__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
|
| 64 |
-
return (a + b - 1) / b;
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
template <typename T>
|
| 68 |
-
__device__ __host__ T align(T a, T b) {
|
| 69 |
-
return ceil_div(a, b) * b;
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
template <typename T>
|
| 73 |
-
__device__ __host__ constexpr T constexpr_align(T a, T b) {
|
| 74 |
-
return constexpr_ceil_div(a, b) * b;
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
template <typename T>
|
| 78 |
-
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
|
| 79 |
-
return b == 0 ? a : constexpr_gcd(b, a % b);
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
template<typename T>
|
| 83 |
-
__forceinline__ __device__ void swap(T& a, T& b) {
|
| 84 |
-
T temp = a;
|
| 85 |
-
a = b;
|
| 86 |
-
b = temp;
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
__forceinline__ __device__ uint32_t get_sm_idx() {
|
| 90 |
-
uint32_t sm_idx;
|
| 91 |
-
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
|
| 92 |
-
return sm_idx;
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
__forceinline__ __device__ uint32_t get_lane_idx() {
|
| 96 |
-
uint32_t lane_id;
|
| 97 |
-
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
| 98 |
-
return lane_id;
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
|
| 102 |
-
uint32_t ret;
|
| 103 |
-
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
|
| 104 |
-
return ret;
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
|
| 108 |
-
float2 ret;
|
| 109 |
-
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
|
| 110 |
-
return ret;
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
|
| 114 |
-
float4 ret;
|
| 115 |
-
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
|
| 116 |
-
return ret;
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
|
| 120 |
-
uint4 ret;
|
| 121 |
-
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
|
| 122 |
-
return ret;
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
__device__ __forceinline__ float ld_shared(const float* ptr) {
|
| 126 |
-
float ret;
|
| 127 |
-
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
|
| 128 |
-
return ret;
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
| 132 |
-
asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
| 136 |
-
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
| 140 |
-
asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
|
| 144 |
-
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
|
| 148 |
-
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
|
| 152 |
-
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
template <typename old_t>
|
| 156 |
-
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
| 157 |
-
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
|
| 158 |
-
return *reinterpret_cast<int*>(&bf16x2);
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
__device__ __forceinline__ void prefetch_l1(void *ptr) {
|
| 162 |
-
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
template <uint32_t kNumBytes>
|
| 166 |
struct Vectorized {
|
| 167 |
static auto zeros() {
|
|
@@ -180,4 +37,14 @@ struct Vectorized {
|
|
| 180 |
using vec_t = decltype(zeros());
|
| 181 |
};
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
|
|
|
| 3 |
#include <cuda/std/cstdint>
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
#include <deep_gemm/common/exception.cuh>
|
| 6 |
|
| 7 |
+
namespace deep_gemm::utils {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
template <typename FuncT>
|
| 10 |
struct PatternVisitor {
|
| 11 |
FuncT func;
|
| 12 |
|
| 13 |
+
CUTLASS_HOST_DEVICE
|
| 14 |
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
|
| 15 |
|
| 16 |
+
CUTLASS_HOST_DEVICE
|
| 17 |
+
auto operator [](const uint32_t& i) const {
|
| 18 |
return func(i);
|
| 19 |
}
|
| 20 |
};
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
template <uint32_t kNumBytes>
|
| 23 |
struct Vectorized {
|
| 24 |
static auto zeros() {
|
|
|
|
| 37 |
using vec_t = decltype(zeros());
|
| 38 |
};
|
| 39 |
|
| 40 |
+
template <uint32_t kNumCols>
|
| 41 |
+
CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() {
|
| 42 |
+
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
|
| 43 |
+
if constexpr (kNumCols <= 32) return 32;
|
| 44 |
+
if constexpr (kNumCols <= 64) return 64;
|
| 45 |
+
if constexpr (kNumCols <= 128) return 128;
|
| 46 |
+
if constexpr (kNumCols <= 256) return 256;
|
| 47 |
+
return 512;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
} // namespace deep_gemm::utils
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/atom/copy_traits_sm100.hpp>
|
| 4 |
+
|
| 5 |
+
#include <deep_gemm/common/math.cuh>
|
| 6 |
+
#include <deep_gemm/common/types.cuh>
|
| 7 |
+
#include <deep_gemm/common/utils.cuh>
|
| 8 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 9 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 10 |
+
|
| 11 |
+
namespace deep_gemm::epilogue {
|
| 12 |
+
|
| 13 |
+
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
|
| 14 |
+
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
|
| 15 |
+
uint32_t kSwizzleCDMode,
|
| 16 |
+
uint32_t kNumTMAStoreStages,
|
| 17 |
+
uint32_t kNumUMMAStoreThreads,
|
| 18 |
+
GemmType kGemmType, bool kWithAccumulation,
|
| 19 |
+
typename cd_dtype_t,
|
| 20 |
+
typename epilogue_type_t,
|
| 21 |
+
typename pattern_cd_t>
|
| 22 |
+
CUTLASS_DEVICE void
|
| 23 |
+
sm100_store_cd(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
|
| 24 |
+
const uint32_t& tmem_base_addr,
|
| 25 |
+
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
|
| 26 |
+
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
|
| 27 |
+
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
|
| 28 |
+
const cute::TmaDescriptor& tensor_map_cd) {
|
| 29 |
+
// TMA checks
|
| 30 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 31 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
| 32 |
+
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 33 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 34 |
+
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
|
| 35 |
+
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
| 36 |
+
|
| 37 |
+
// Share store pipeline between blocks
|
| 38 |
+
auto advance_store_pipeline = [&]() {
|
| 39 |
+
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
// Iterate over M waves
|
| 43 |
+
constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M;
|
| 44 |
+
#pragma unroll
|
| 45 |
+
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
| 46 |
+
// Issue every swizzled atom and pipeline STSM and TMA store
|
| 47 |
+
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
| 48 |
+
#pragma unroll
|
| 49 |
+
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
|
| 50 |
+
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]);
|
| 51 |
+
|
| 52 |
+
// Wait shared memory to be released
|
| 53 |
+
if (epilogue_warp_idx == 0)
|
| 54 |
+
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 55 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 56 |
+
|
| 57 |
+
// The pipeline stage
|
| 58 |
+
const auto m_idx = base_m_idx + w * STORE_BLOCK_M;
|
| 59 |
+
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(base_n_idx + s * STORE_BLOCK_N);
|
| 60 |
+
|
| 61 |
+
// Store into shared memory
|
| 62 |
+
#pragma unroll
|
| 63 |
+
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
| 64 |
+
// Calculate the index of the bank group to be written in the atom
|
| 65 |
+
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
| 66 |
+
|
| 67 |
+
// Reshape the atom in another view and swizzle
|
| 68 |
+
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
| 69 |
+
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
| 70 |
+
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
| 71 |
+
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
| 72 |
+
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
| 73 |
+
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
| 74 |
+
col ^= row % (kSwizzleCDMode / 16);
|
| 75 |
+
|
| 76 |
+
// Source and destination memory address
|
| 77 |
+
uint32_t tmem_addr = tmem_base_addr + // Accumulator offset
|
| 78 |
+
w * BLOCK_N + // Wave offset
|
| 79 |
+
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
| 80 |
+
auto smem_ptr = smem_base_ptr + // Base pointer
|
| 81 |
+
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
| 82 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 83 |
+
|
| 84 |
+
// Load from tensor memory, store into shared memory
|
| 85 |
+
uint32_t values[kNumElemsPerBankGroup];
|
| 86 |
+
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
| 87 |
+
// For FP32 output, read and store
|
| 88 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 89 |
+
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 90 |
+
values[0], values[1], values[2], values[3]);
|
| 91 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 92 |
+
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 93 |
+
} else {
|
| 94 |
+
// For BF16 output, read, cast and store
|
| 95 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
| 96 |
+
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
| 97 |
+
values[0], values[1], values[2], values[3],
|
| 98 |
+
values[4], values[5], values[6], values[7]);
|
| 99 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 100 |
+
ptx::st_shared(
|
| 101 |
+
smem_ptr,
|
| 102 |
+
math::cast_into_bf16_and_pack(values[0], values[1]),
|
| 103 |
+
math::cast_into_bf16_and_pack(values[2], values[3]),
|
| 104 |
+
math::cast_into_bf16_and_pack(values[4], values[5]),
|
| 105 |
+
math::cast_into_bf16_and_pack(values[6], values[7])
|
| 106 |
+
);
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
| 111 |
+
// NOTES: only the last stage needs to do this
|
| 112 |
+
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
| 113 |
+
ptx::tcgen05_before_thread_sync();
|
| 114 |
+
tmem_empty_barrier->arrive(0u);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Synchronize all threads and issue TMA
|
| 118 |
+
cute::tma_store_fence();
|
| 119 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 120 |
+
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 121 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 122 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 123 |
+
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
|
| 124 |
+
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx);
|
| 125 |
+
} else {
|
| 126 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 127 |
+
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 128 |
+
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx);
|
| 129 |
+
}
|
| 130 |
+
cute::tma_store_arrive();
|
| 131 |
+
}
|
| 132 |
+
__syncwarp();
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
} // namespace deep_gemm::epilogue
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/atom/copy_traits_sm100.hpp>
|
| 4 |
+
|
| 5 |
+
#include <deep_gemm/common/math.cuh>
|
| 6 |
+
#include <deep_gemm/common/types.cuh>
|
| 7 |
+
#include <deep_gemm/common/utils.cuh>
|
| 8 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 9 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 10 |
+
|
| 11 |
+
namespace deep_gemm::epilogue {
|
| 12 |
+
|
| 13 |
+
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
|
| 14 |
+
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
|
| 15 |
+
uint32_t kSwizzleCDMode,
|
| 16 |
+
uint32_t kNumTMAStoreStages,
|
| 17 |
+
uint32_t kNumUMMAStoreThreads,
|
| 18 |
+
GemmType kGemmType, bool kWithAccumulation,
|
| 19 |
+
typename cd_dtype_t,
|
| 20 |
+
typename epilogue_type_t,
|
| 21 |
+
typename pattern_cd_t>
|
| 22 |
+
CUTLASS_DEVICE void
|
| 23 |
+
sm100_store_cd_swap_ab(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
|
| 24 |
+
const uint32_t& tmem_base_addr,
|
| 25 |
+
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
|
| 26 |
+
const uint32_t& effective_m,
|
| 27 |
+
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
|
| 28 |
+
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
|
| 29 |
+
const cute::TmaDescriptor& tensor_map_cd) {
|
| 30 |
+
// NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows,
|
| 31 |
+
// implying STORE_BLOCK_N must be 128.
|
| 32 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows");
|
| 33 |
+
|
| 34 |
+
// TMA checks
|
| 35 |
+
constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 36 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 37 |
+
constexpr uint32_t kNumSwizzleAtomRows = 8;
|
| 38 |
+
DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled");
|
| 39 |
+
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
|
| 40 |
+
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
| 41 |
+
DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling");
|
| 42 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling");
|
| 43 |
+
|
| 44 |
+
// Share store pipeline between blocks
|
| 45 |
+
auto advance_store_pipeline = [&]() {
|
| 46 |
+
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
// Iterate over M blocks
|
| 50 |
+
const auto num_stores = effective_m / STORE_BLOCK_M;
|
| 51 |
+
for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) {
|
| 52 |
+
// Wait shared memory to be released
|
| 53 |
+
if (epilogue_warp_idx == 0)
|
| 54 |
+
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 55 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 56 |
+
|
| 57 |
+
// Store into shared memory
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) {
|
| 60 |
+
uint32_t tmem_addr = tmem_base_addr +
|
| 61 |
+
s * STORE_BLOCK_M + // Store stage offset
|
| 62 |
+
i * kNumSwizzleAtomRows; // In-block offset
|
| 63 |
+
uint32_t values[kNumSwizzleAtomRows];
|
| 64 |
+
|
| 65 |
+
// Warps cooperatively write an atomic block to shared memory
|
| 66 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes");
|
| 67 |
+
constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32;
|
| 68 |
+
uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode;
|
| 69 |
+
uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode;
|
| 70 |
+
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset;
|
| 71 |
+
|
| 72 |
+
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
| 73 |
+
// NOTES: Swizzling is not required in this case, but used here for consistency with other cases
|
| 74 |
+
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3],
|
| 75 |
+
values[4], values[5], values[6], values[7]);
|
| 76 |
+
uint32_t col = lane_idx / 4;
|
| 77 |
+
|
| 78 |
+
#pragma unroll
|
| 79 |
+
for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) {
|
| 80 |
+
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
|
| 81 |
+
+ (col ^ row) * kNumBankGroupBytes
|
| 82 |
+
+ (lane_idx % 4) * sizeof(float);
|
| 83 |
+
ptx::st_shared(reinterpret_cast<uint32_t*>(smem_ptr), values[row]);
|
| 84 |
+
}
|
| 85 |
+
} else {
|
| 86 |
+
// Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements
|
| 87 |
+
// Start from lane index 0
|
| 88 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
|
| 89 |
+
values[0], values[1], values[2], values[3]);
|
| 90 |
+
// Start from lane index 16
|
| 91 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
|
| 92 |
+
values[4], values[5], values[6], values[7]);
|
| 93 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 94 |
+
|
| 95 |
+
// Destination shared memory address
|
| 96 |
+
uint32_t row = lane_idx % 8;
|
| 97 |
+
uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
|
| 98 |
+
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
|
| 99 |
+
+ (col ^ row) * kNumBankGroupBytes;
|
| 100 |
+
|
| 101 |
+
// Store matrix with transposition
|
| 102 |
+
ptx::SM90_U32x4_STSM_T<int>::copy(math::cast_into_bf16_and_pack(values[0], values[1]),
|
| 103 |
+
math::cast_into_bf16_and_pack(values[2], values[3]),
|
| 104 |
+
math::cast_into_bf16_and_pack(values[4], values[5]),
|
| 105 |
+
math::cast_into_bf16_and_pack(values[6], values[7]),
|
| 106 |
+
smem_ptr);
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
| 111 |
+
// NOTES: only the last stage needs to do this
|
| 112 |
+
if (s == num_stores - 1) {
|
| 113 |
+
ptx::tcgen05_before_thread_sync();
|
| 114 |
+
tmem_empty_barrier->arrive(0u);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Synchronize all threads and issue TMA
|
| 118 |
+
cute::tma_store_fence();
|
| 119 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 120 |
+
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 121 |
+
#pragma unroll
|
| 122 |
+
for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) {
|
| 123 |
+
auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM;
|
| 124 |
+
uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M;
|
| 125 |
+
uint32_t n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N_ATOM>(base_n_idx + i * STORE_BLOCK_N_ATOM);
|
| 126 |
+
|
| 127 |
+
// Issue 2D or 3D TMA store
|
| 128 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 129 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 130 |
+
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
|
| 131 |
+
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx);
|
| 132 |
+
} else {
|
| 133 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 134 |
+
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 135 |
+
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
cute::tma_store_arrive();
|
| 139 |
+
}
|
| 140 |
+
__syncwarp();
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
} // namespace deep_gemm::epilogue
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/transform.cuh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <deep_gemm/common/exception.cuh>
|
| 4 |
+
|
| 5 |
+
namespace deep_gemm::epilogue::transform {
|
| 6 |
+
|
| 7 |
+
struct EpilogueIdentity {
|
| 8 |
+
template <uint32_t STORE_BLOCK_N>
|
| 9 |
+
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
|
| 10 |
+
return n_idx;
|
| 11 |
+
}
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
|
| 15 |
+
struct EpilogueHeadSplits: EpilogueIdentity {
|
| 16 |
+
template <uint32_t STORE_BLOCK_N>
|
| 17 |
+
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
|
| 18 |
+
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and
|
| 19 |
+
kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
|
| 20 |
+
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
|
| 21 |
+
}
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
} // namespace deep_gemm::epilogue::transform
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh
CHANGED
|
@@ -4,14 +4,18 @@
|
|
| 4 |
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
-
#include <deep_gemm/
|
| 8 |
-
#include <deep_gemm/common/
|
| 9 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
namespace deep_gemm {
|
| 12 |
|
| 13 |
-
using namespace deep_gemm::sm100;
|
| 14 |
-
|
| 15 |
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 16 |
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 17 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
|
|
@@ -21,9 +25,10 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
|
| 21 |
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
| 22 |
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 23 |
uint32_t kNumSMs,
|
|
|
|
| 24 |
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
| 25 |
uint64_t kTensorCoreUtilControl>
|
| 26 |
-
|
| 27 |
sm100_bf16_gemm_impl(int* grouped_layout,
|
| 28 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 29 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
@@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 48 |
if constexpr (kWithAccumulation)
|
| 49 |
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 50 |
|
| 51 |
-
// Configs
|
| 52 |
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 53 |
-
constexpr uint32_t
|
| 54 |
-
constexpr uint32_t
|
| 55 |
-
constexpr uint32_t
|
| 56 |
-
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
|
| 57 |
-
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
| 58 |
-
DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode");
|
| 59 |
-
|
| 60 |
-
// Overwrite shape constants if the compiler gives
|
| 61 |
-
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 62 |
-
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 63 |
-
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 64 |
-
|
| 65 |
-
// Utils
|
| 66 |
-
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 67 |
-
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 68 |
-
const auto lane_idx = get_lane_idx();
|
| 69 |
-
|
| 70 |
-
// Align to 1024 bytes for swizzle-128B
|
| 71 |
-
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 72 |
-
|
| 73 |
-
// 2-CTA MMA
|
| 74 |
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
| 75 |
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 76 |
-
|
| 77 |
-
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 78 |
-
constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
|
| 79 |
-
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
| 80 |
-
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
|
| 81 |
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
|
| 83 |
|
| 84 |
// Share memory sizes
|
| 85 |
-
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M *
|
| 86 |
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 87 |
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 88 |
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
|
@@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 91 |
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
| 92 |
|
| 93 |
// NOTES: Make sure we have enough shared memory for UMMA padding
|
| 94 |
-
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
|
| 95 |
-
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory
|
| 96 |
-
|
| 97 |
-
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
| 98 |
-
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
| 99 |
-
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
|
| 100 |
|
| 101 |
// Real tensor memory size and offsets
|
| 102 |
-
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages *
|
| 103 |
-
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
// Prefetch TMA descriptors at the very beginning
|
| 106 |
-
if (warp_idx == 0
|
| 107 |
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 108 |
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 109 |
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 110 |
}
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
// D/A/B shared memory
|
| 113 |
-
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
| 114 |
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
| 115 |
});
|
| 116 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 117 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 118 |
});
|
| 119 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 120 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 121 |
});
|
| 122 |
|
| 123 |
// Fill barriers
|
| 124 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 125 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 126 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 127 |
-
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 128 |
-
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
| 129 |
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
|
| 130 |
|
| 131 |
// Fill the tensor memory pointer
|
|
@@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 159 |
}
|
| 160 |
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
// Block scheduler
|
| 163 |
uint32_t m_block_idx, n_block_idx;
|
| 164 |
-
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
|
|
|
|
| 165 |
|
| 166 |
// Pipeline and TMA phases
|
| 167 |
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
|
|
@@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 178 |
// TMA load warp
|
| 179 |
// Persistently schedule over blocks
|
| 180 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 183 |
// Wait consumer release
|
| 184 |
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 185 |
|
| 186 |
// Compute offsets
|
| 187 |
// NOTES: the group is always concatenated with the outer dimension
|
| 188 |
-
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
|
| 189 |
shape_m, BLOCK_M, m_block_idx);
|
| 190 |
-
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
|
| 191 |
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 192 |
|
| 193 |
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
|
@@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 195 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
|
| 196 |
kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 197 |
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 198 |
-
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
|
| 199 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 200 |
-
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
|
| 201 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 202 |
|
| 203 |
// Add 2 CTA offsets
|
| 204 |
if constexpr (kNumMulticast > 1) {
|
| 205 |
-
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() *
|
| 206 |
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
| 207 |
}
|
| 208 |
|
|
@@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 210 |
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 211 |
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 212 |
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 213 |
-
|
| 214 |
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
|
| 215 |
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 216 |
-
|
| 217 |
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
|
| 218 |
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 219 |
-
|
| 220 |
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
|
| 221 |
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 222 |
-
|
| 223 |
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
|
| 224 |
|
| 225 |
// Arrive at full barriers
|
|
@@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 235 |
// MMA issue warp
|
| 236 |
// NOTES: only the leader CTA will do this
|
| 237 |
// Make instruction descriptor
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
| 243 |
|
| 244 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 245 |
// Merged stages only happens in NT normal GEMM cases
|
| 246 |
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 247 |
-
auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 248 |
-
auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 249 |
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 250 |
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 251 |
|
|
@@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 262 |
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 263 |
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 264 |
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
| 265 |
-
tcgen05_after_thread_sync();
|
| 266 |
|
| 267 |
// UMMA and empty barrier arrival alias
|
| 268 |
auto umma_arrive = [](const uint64_t* barrier) {
|
|
@@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 279 |
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 280 |
if (do_tmem_full_arrive)
|
| 281 |
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
|
|
|
| 282 |
};
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
// Launch MMAs
|
| 285 |
-
const auto
|
| 286 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 287 |
// Wait TMA arrival
|
| 288 |
full_barriers[stage_idx]->wait(phase);
|
| 289 |
-
tcgen05_after_thread_sync();
|
| 290 |
|
| 291 |
// Issue UMMA in the leader CTA
|
| 292 |
-
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
|
| 293 |
-
const auto
|
| 294 |
-
const auto
|
| 295 |
-
const auto
|
| 296 |
if (cute::elect_one_sync()) {
|
| 297 |
#pragma unroll
|
| 298 |
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 299 |
uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
mma_t::fma(a_desc,
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
| 309 |
}
|
| 310 |
}
|
| 311 |
}
|
|
|
|
| 312 |
|
| 313 |
// Commit to the mbarrier object
|
| 314 |
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
|
@@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 319 |
if constexpr (kTensorCoreUtilControl < 100) {
|
| 320 |
// For utilization control
|
| 321 |
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
|
|
|
|
| 322 |
|
| 323 |
// Wait for last UMMA to be done
|
| 324 |
tensor_core_full_barrier->wait(tensor_core_phase);
|
| 325 |
tensor_core_phase ^= 1;
|
| 326 |
|
| 327 |
// Sleep for certain cycles
|
| 328 |
-
constexpr static uint64_t kNumUMMACycles = (2ull *
|
| 329 |
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
| 330 |
-
const auto
|
| 331 |
if (cute::elect_one_sync())
|
| 332 |
while (clock64() - start_clock < kNumDummyCycles) {}
|
| 333 |
__syncwarp();
|
|
@@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 336 |
}
|
| 337 |
|
| 338 |
// To safely deconstruct barriers, we need another round of waits
|
| 339 |
-
const auto
|
| 340 |
if (kNumMulticast > 1 and iter_idx >= 0) {
|
| 341 |
-
const auto
|
| 342 |
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 343 |
}
|
| 344 |
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
|
|
@@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 348 |
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 349 |
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 350 |
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 351 |
-
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 352 |
-
|
| 353 |
-
// TMA checks
|
| 354 |
-
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 355 |
-
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
| 356 |
-
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 357 |
-
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 358 |
|
| 359 |
// Share store pipeline between blocks
|
| 360 |
uint32_t tma_stage_idx = 0;
|
| 361 |
-
auto advance_store_pipeline = [&]() {
|
| 362 |
-
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
| 363 |
-
};
|
| 364 |
|
| 365 |
// Persistently schedule over blocks
|
| 366 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
|
@@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
|
| 369 |
|
| 370 |
// Wait UMMA arrival
|
| 371 |
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
| 372 |
-
tcgen05_after_thread_sync();
|
| 373 |
|
| 374 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
| 403 |
-
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
| 404 |
-
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
| 405 |
-
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
| 406 |
-
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
| 407 |
-
col ^= row % (kSwizzleCDMode / 16);
|
| 408 |
-
|
| 409 |
-
// Source and destination memory address
|
| 410 |
-
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
| 411 |
-
w * BLOCK_N + // Wave offset
|
| 412 |
-
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
| 413 |
-
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
| 414 |
-
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
| 415 |
-
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 416 |
-
|
| 417 |
-
// Load from tensor memory, store into shared memory
|
| 418 |
-
uint32_t values[kNumElemsPerBankGroup];
|
| 419 |
-
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
| 420 |
-
// For FP32 output, read and store
|
| 421 |
-
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 422 |
-
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 423 |
-
values[0], values[1], values[2], values[3]);
|
| 424 |
-
cutlass::arch::fence_view_async_tmem_load();
|
| 425 |
-
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 426 |
-
} else {
|
| 427 |
-
// For BF16 output, read, cast and store
|
| 428 |
-
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
| 429 |
-
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
| 430 |
-
values[0], values[1], values[2], values[3],
|
| 431 |
-
values[4], values[5], values[6], values[7]);
|
| 432 |
-
cutlass::arch::fence_view_async_tmem_load();
|
| 433 |
-
st_shared(smem_ptr,
|
| 434 |
-
cast_into_bf16_and_pack(values[0], values[1]),
|
| 435 |
-
cast_into_bf16_and_pack(values[2], values[3]),
|
| 436 |
-
cast_into_bf16_and_pack(values[4], values[5]),
|
| 437 |
-
cast_into_bf16_and_pack(values[6], values[7]));
|
| 438 |
-
}
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
-
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
| 442 |
-
// NOTES: only the last stage needs to do this
|
| 443 |
-
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
| 444 |
-
tcgen05_before_thread_sync();
|
| 445 |
-
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 446 |
-
}
|
| 447 |
-
__syncwarp();
|
| 448 |
-
|
| 449 |
-
// Synchronize all threads and issue TMA
|
| 450 |
-
cute::tma_store_fence();
|
| 451 |
-
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 452 |
-
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 453 |
-
if constexpr (kGemmType == GemmType::Batched) {
|
| 454 |
-
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 455 |
-
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
|
| 456 |
-
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
|
| 457 |
-
n_idx, m_idx, scheduler.current_group_idx);
|
| 458 |
-
} else {
|
| 459 |
-
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 460 |
-
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 461 |
-
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
|
| 462 |
-
}
|
| 463 |
-
cute::tma_store_arrive();
|
| 464 |
-
}
|
| 465 |
-
}
|
| 466 |
}
|
| 467 |
}
|
| 468 |
-
|
| 469 |
-
// Deallocate tensor memory by the last UMMA store warp
|
| 470 |
-
// NOTES: warp 0 is waiting TMA store
|
| 471 |
-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
| 472 |
-
Allocator().free(0, kNumTmemCols);
|
| 473 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
#else
|
| 475 |
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 476 |
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
|
|
|
| 4 |
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 8 |
+
#include <deep_gemm/common/math.cuh>
|
| 9 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 10 |
+
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
|
| 11 |
+
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
|
| 12 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 13 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 16 |
|
| 17 |
namespace deep_gemm {
|
| 18 |
|
|
|
|
|
|
|
| 19 |
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 20 |
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 21 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
|
|
|
|
| 25 |
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
| 26 |
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 27 |
uint32_t kNumSMs,
|
| 28 |
+
bool kSwapAB,
|
| 29 |
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
| 30 |
uint64_t kTensorCoreUtilControl>
|
| 31 |
+
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
| 32 |
sm100_bf16_gemm_impl(int* grouped_layout,
|
| 33 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 34 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
|
|
| 53 |
if constexpr (kWithAccumulation)
|
| 54 |
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 55 |
|
| 56 |
+
// MMA Configs
|
| 57 |
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 58 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
|
| 59 |
+
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
|
| 60 |
+
constexpr uint32_t UMMA_K = 16;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
| 62 |
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 63 |
+
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
| 65 |
+
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
|
| 66 |
+
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
|
| 67 |
+
|
| 68 |
+
// Epilogue configs
|
| 69 |
+
// Always enable pipeline for better performance
|
| 70 |
+
constexpr uint32_t kNumEpilogueStages = 2;
|
| 71 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 72 |
+
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
|
| 73 |
+
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
|
| 74 |
+
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 75 |
+
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 76 |
+
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
|
| 77 |
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
|
| 78 |
|
| 79 |
// Share memory sizes
|
| 80 |
+
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
|
| 81 |
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 82 |
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 83 |
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
|
|
|
| 86 |
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
| 87 |
|
| 88 |
// NOTES: Make sure we have enough shared memory for UMMA padding
|
| 89 |
+
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
|
| 90 |
+
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA");
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
// Real tensor memory size and offsets
|
| 93 |
+
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N;
|
| 94 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
| 95 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 96 |
+
|
| 97 |
+
// Synchronize the cluster before 2-CTA TMEM allocation
|
| 98 |
+
kNumMulticast > 1 ? cute::cluster_sync() : void();
|
| 99 |
+
|
| 100 |
+
// Utils
|
| 101 |
+
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 102 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 103 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 104 |
|
| 105 |
// Prefetch TMA descriptors at the very beginning
|
| 106 |
+
if (warp_idx == 0) {
|
| 107 |
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 108 |
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 109 |
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 110 |
}
|
| 111 |
|
| 112 |
+
// Overwrite shape constants if the compiler gives
|
| 113 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 114 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 115 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 116 |
+
|
| 117 |
+
// Align to 1024 bytes for swizzle-128B
|
| 118 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 119 |
+
|
| 120 |
// D/A/B shared memory
|
| 121 |
+
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
|
| 122 |
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
| 123 |
});
|
| 124 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 125 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 126 |
});
|
| 127 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 128 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 129 |
});
|
| 130 |
|
| 131 |
// Fill barriers
|
| 132 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 133 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 134 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 135 |
+
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 136 |
+
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
| 137 |
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
|
| 138 |
|
| 139 |
// Fill the tensor memory pointer
|
|
|
|
| 167 |
}
|
| 168 |
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 169 |
|
| 170 |
+
// Wait for primary kernel completion
|
| 171 |
+
cudaGridDependencySynchronize();
|
| 172 |
+
|
| 173 |
// Block scheduler
|
| 174 |
uint32_t m_block_idx, n_block_idx;
|
| 175 |
+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
|
| 176 |
+
shape_m, shape_n, shape_k, grouped_layout);
|
| 177 |
|
| 178 |
// Pipeline and TMA phases
|
| 179 |
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
|
|
|
|
| 190 |
// TMA load warp
|
| 191 |
// Persistently schedule over blocks
|
| 192 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 193 |
+
// Use dynamic load block M, when swap-AB is enabled
|
| 194 |
+
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
|
| 195 |
+
|
| 196 |
+
// For k-grouped layout, the number of block K is variable
|
| 197 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 198 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 199 |
// Wait consumer release
|
| 200 |
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 201 |
|
| 202 |
// Compute offsets
|
| 203 |
// NOTES: the group is always concatenated with the outer dimension
|
| 204 |
+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
|
| 205 |
shape_m, BLOCK_M, m_block_idx);
|
| 206 |
+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
|
| 207 |
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 208 |
|
| 209 |
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
|
|
|
| 211 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
|
| 212 |
kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 213 |
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 214 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 215 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 216 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 217 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 218 |
|
| 219 |
// Add 2 CTA offsets
|
| 220 |
if constexpr (kNumMulticast > 1) {
|
| 221 |
+
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
|
| 222 |
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
| 223 |
}
|
| 224 |
|
|
|
|
| 226 |
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 227 |
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 228 |
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 229 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 230 |
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
|
| 231 |
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 232 |
+
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 233 |
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
|
| 234 |
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 235 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 236 |
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
|
| 237 |
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 238 |
+
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 239 |
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
|
| 240 |
|
| 241 |
// Arrive at full barriers
|
|
|
|
| 251 |
// MMA issue warp
|
| 252 |
// NOTES: only the leader CTA will do this
|
| 253 |
// Make instruction descriptor
|
| 254 |
+
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
|
| 255 |
+
UMMA_M, UMMA_N, kMajorB, kMajorA>()
|
| 256 |
+
: cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
|
| 257 |
+
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
|
|
|
| 258 |
|
| 259 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 260 |
// Merged stages only happens in NT normal GEMM cases
|
| 261 |
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 262 |
+
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 263 |
+
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 264 |
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 265 |
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 266 |
|
|
|
|
| 277 |
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 278 |
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 279 |
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
| 280 |
+
ptx::tcgen05_after_thread_sync();
|
| 281 |
|
| 282 |
// UMMA and empty barrier arrival alias
|
| 283 |
auto umma_arrive = [](const uint64_t* barrier) {
|
|
|
|
| 294 |
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 295 |
if (do_tmem_full_arrive)
|
| 296 |
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
| 297 |
+
__syncwarp();
|
| 298 |
};
|
| 299 |
|
| 300 |
+
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
|
| 301 |
+
if constexpr (kSwapAB) {
|
| 302 |
+
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
|
| 303 |
+
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
// Launch MMAs
|
| 307 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 308 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 309 |
// Wait TMA arrival
|
| 310 |
full_barriers[stage_idx]->wait(phase);
|
| 311 |
+
ptx::tcgen05_after_thread_sync();
|
| 312 |
|
| 313 |
// Issue UMMA in the leader CTA
|
| 314 |
+
using mma_t = cute::conditional_t<kNumMulticast == 1, ptx::SM100_MMA_F16BF16_SS, ptx::SM100_MMA_F16BF16_2x1SM_SS>;
|
| 315 |
+
const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 316 |
+
const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
| 317 |
+
const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
| 318 |
if (cute::elect_one_sync()) {
|
| 319 |
#pragma unroll
|
| 320 |
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 321 |
uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
|
| 322 |
+
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(
|
| 323 |
+
a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
|
| 324 |
+
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(
|
| 325 |
+
b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
|
| 326 |
+
if (kSwapAB) {
|
| 327 |
+
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
|
| 328 |
+
k_block_idx > 0 or k > 0, runtime_instr_desc);
|
| 329 |
+
} else {
|
| 330 |
+
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
|
| 331 |
+
k_block_idx > 0 or k > 0, runtime_instr_desc);
|
| 332 |
}
|
| 333 |
}
|
| 334 |
}
|
| 335 |
+
__syncwarp();
|
| 336 |
|
| 337 |
// Commit to the mbarrier object
|
| 338 |
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
|
|
|
| 343 |
if constexpr (kTensorCoreUtilControl < 100) {
|
| 344 |
// For utilization control
|
| 345 |
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
|
| 346 |
+
__syncwarp();
|
| 347 |
|
| 348 |
// Wait for last UMMA to be done
|
| 349 |
tensor_core_full_barrier->wait(tensor_core_phase);
|
| 350 |
tensor_core_phase ^= 1;
|
| 351 |
|
| 352 |
// Sleep for certain cycles
|
| 353 |
+
constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull;
|
| 354 |
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
| 355 |
+
const auto start_clock = clock64();
|
| 356 |
if (cute::elect_one_sync())
|
| 357 |
while (clock64() - start_clock < kNumDummyCycles) {}
|
| 358 |
__syncwarp();
|
|
|
|
| 361 |
}
|
| 362 |
|
| 363 |
// To safely deconstruct barriers, we need another round of waits
|
| 364 |
+
const auto iter_idx = scheduler.current_iter - 1;
|
| 365 |
if (kNumMulticast > 1 and iter_idx >= 0) {
|
| 366 |
+
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
| 367 |
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 368 |
}
|
| 369 |
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
|
|
|
|
| 373 |
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 374 |
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 375 |
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 376 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
// Share store pipeline between blocks
|
| 379 |
uint32_t tma_stage_idx = 0;
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
// Persistently schedule over blocks
|
| 382 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
|
|
|
| 385 |
|
| 386 |
// Wait UMMA arrival
|
| 387 |
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
| 388 |
+
ptx::tcgen05_after_thread_sync();
|
| 389 |
|
| 390 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 391 |
+
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
|
| 392 |
+
const auto base_m_idx = scheduler.template get_global_idx<
|
| 393 |
+
(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 394 |
+
const auto base_n_idx = n_block_idx * BLOCK_N;
|
| 395 |
+
|
| 396 |
+
if constexpr (kSwapAB) {
|
| 397 |
+
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
|
| 398 |
+
epilogue::sm100_store_cd_swap_ab<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
|
| 399 |
+
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
|
| 400 |
+
kGemmType, kWithAccumulation,
|
| 401 |
+
cd_dtype_t, epilogue::transform::EpilogueIdentity>
|
| 402 |
+
(smem_cd, tma_stage_idx, tmem_base_addr,
|
| 403 |
+
base_m_idx, base_n_idx, scheduler.current_group_idx,
|
| 404 |
+
effective_m,
|
| 405 |
+
epilogue_warp_idx, lane_idx,
|
| 406 |
+
tmem_empty_barriers[accum_stage_idx],
|
| 407 |
+
tensor_map_cd);
|
| 408 |
+
} else {
|
| 409 |
+
epilogue::sm100_store_cd<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
|
| 410 |
+
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
|
| 411 |
+
kGemmType, kWithAccumulation,
|
| 412 |
+
cd_dtype_t, epilogue::transform::EpilogueIdentity>
|
| 413 |
+
(smem_cd, tma_stage_idx, tmem_base_addr,
|
| 414 |
+
base_m_idx, base_n_idx, scheduler.current_group_idx,
|
| 415 |
+
epilogue_warp_idx, lane_idx,
|
| 416 |
+
tmem_empty_barriers[accum_stage_idx],
|
| 417 |
+
tensor_map_cd);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
}
|
| 419 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
}
|
| 421 |
+
|
| 422 |
+
// TODO: Remove redundant synchronization
|
| 423 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 424 |
+
|
| 425 |
+
// Deallocate tensor memory
|
| 426 |
+
if (warp_idx == 0)
|
| 427 |
+
Allocator().free(0, kNumTmemCols);
|
| 428 |
+
|
| 429 |
#else
|
| 430 |
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 431 |
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh
CHANGED
|
@@ -5,18 +5,19 @@
|
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
#include <deep_gemm/common/utils.cuh>
|
| 8 |
-
#include <deep_gemm/
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
namespace deep_gemm {
|
| 11 |
|
| 12 |
-
using namespace deep_gemm::sm100;
|
| 13 |
-
|
| 14 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 15 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 16 |
uint32_t kSplitFactor,
|
| 17 |
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
|
| 18 |
uint32_t kNumStages, uint32_t kNumThreads>
|
| 19 |
-
|
| 20 |
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
| 21 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 22 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
@@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 30 |
|
| 31 |
// Utils
|
| 32 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 33 |
-
const auto lane_idx = get_lane_idx();
|
| 34 |
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
|
| 35 |
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
|
| 36 |
|
|
@@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 51 |
}
|
| 52 |
|
| 53 |
// Real tensor memory size and offsets
|
| 54 |
-
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
|
| 55 |
|
| 56 |
// Fill D/A/B
|
| 57 |
-
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
| 58 |
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
|
| 59 |
});
|
| 60 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 61 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 62 |
});
|
| 63 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 64 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 65 |
});
|
| 66 |
|
| 67 |
// Fill barriers
|
| 68 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 69 |
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 70 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 71 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 72 |
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
|
| 73 |
|
| 74 |
// Fill the tensor memory pointer
|
|
@@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 93 |
__syncthreads();
|
| 94 |
|
| 95 |
// Block indices
|
| 96 |
-
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
| 97 |
-
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
| 98 |
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 99 |
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 100 |
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 101 |
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 102 |
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 103 |
|
|
|
|
|
|
|
|
|
|
| 104 |
if (warp_idx == 0) {
|
| 105 |
// TMA load warp
|
| 106 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
|
@@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 115 |
|
| 116 |
// Issue TMAs
|
| 117 |
if (cute::elect_one_sync()) {
|
| 118 |
-
|
| 119 |
-
|
| 120 |
}
|
| 121 |
|
| 122 |
// Arrive at full barriers
|
|
@@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 134 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 135 |
|
| 136 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 137 |
-
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
|
| 138 |
-
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
|
| 139 |
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 140 |
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 141 |
|
|
@@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 147 |
"Invalid MMA instruction shape");
|
| 148 |
|
| 149 |
// Wait tensor memory empty barrier arrival
|
| 150 |
-
tcgen05_after_thread_sync();
|
| 151 |
|
| 152 |
// Launch MMAs
|
| 153 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 154 |
// Wait TMA arrival
|
| 155 |
const auto& stage_idx = s % kNumStages;
|
| 156 |
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 157 |
-
tcgen05_after_thread_sync();
|
| 158 |
|
| 159 |
// Issue UMMA in the leader CTA
|
| 160 |
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
|
@@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 163 |
if (cute::elect_one_sync()) {
|
| 164 |
#pragma unroll
|
| 165 |
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 166 |
-
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
}
|
| 170 |
}
|
| 171 |
|
|
@@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 180 |
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
|
| 181 |
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 182 |
if (warp_idx == 2)
|
| 183 |
-
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 184 |
|
| 185 |
// TMA checks
|
| 186 |
constexpr uint32_t kNumBankGroupBytes = 16;
|
|
@@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 191 |
|
| 192 |
// Wait UMMA arrival
|
| 193 |
tmem_full_barrier->wait(0);
|
| 194 |
-
tcgen05_after_thread_sync();
|
| 195 |
|
| 196 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 197 |
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
|
@@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
|
| 239 |
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 240 |
values[0], values[1], values[2], values[3]);
|
| 241 |
cutlass::arch::fence_view_async_tmem_load();
|
| 242 |
-
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 243 |
}
|
| 244 |
|
| 245 |
// Synchronize all threads and issue TMA
|
|
|
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
#include <deep_gemm/common/utils.cuh>
|
| 8 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 9 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 10 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 11 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 12 |
|
| 13 |
namespace deep_gemm {
|
| 14 |
|
|
|
|
|
|
|
| 15 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 16 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 17 |
uint32_t kSplitFactor,
|
| 18 |
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
|
| 19 |
uint32_t kNumStages, uint32_t kNumThreads>
|
| 20 |
+
CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1)
|
| 21 |
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
| 22 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 23 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
|
|
| 31 |
|
| 32 |
// Utils
|
| 33 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 34 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 35 |
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
|
| 36 |
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
|
| 37 |
|
|
|
|
| 52 |
}
|
| 53 |
|
| 54 |
// Real tensor memory size and offsets
|
| 55 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_N>();
|
| 56 |
|
| 57 |
// Fill D/A/B
|
| 58 |
+
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
|
| 59 |
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
|
| 60 |
});
|
| 61 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 62 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 63 |
});
|
| 64 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 65 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 66 |
});
|
| 67 |
|
| 68 |
// Fill barriers
|
| 69 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 70 |
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 71 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 72 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 73 |
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
|
| 74 |
|
| 75 |
// Fill the tensor memory pointer
|
|
|
|
| 94 |
__syncthreads();
|
| 95 |
|
| 96 |
// Block indices
|
| 97 |
+
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
|
| 98 |
+
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
|
| 99 |
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 100 |
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 101 |
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 102 |
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 103 |
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 104 |
|
| 105 |
+
// Wait for primary kernel completion
|
| 106 |
+
cudaGridDependencySynchronize();
|
| 107 |
+
|
| 108 |
if (warp_idx == 0) {
|
| 109 |
// TMA load warp
|
| 110 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
|
|
|
| 119 |
|
| 120 |
// Issue TMAs
|
| 121 |
if (cute::elect_one_sync()) {
|
| 122 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
|
| 123 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
|
| 124 |
}
|
| 125 |
|
| 126 |
// Arrive at full barriers
|
|
|
|
| 138 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 139 |
|
| 140 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 141 |
+
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
|
| 142 |
+
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
|
| 143 |
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 144 |
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 145 |
|
|
|
|
| 151 |
"Invalid MMA instruction shape");
|
| 152 |
|
| 153 |
// Wait tensor memory empty barrier arrival
|
| 154 |
+
ptx::tcgen05_after_thread_sync();
|
| 155 |
|
| 156 |
// Launch MMAs
|
| 157 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 158 |
// Wait TMA arrival
|
| 159 |
const auto& stage_idx = s % kNumStages;
|
| 160 |
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 161 |
+
ptx::tcgen05_after_thread_sync();
|
| 162 |
|
| 163 |
// Issue UMMA in the leader CTA
|
| 164 |
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
|
|
|
| 167 |
if (cute::elect_one_sync()) {
|
| 168 |
#pragma unroll
|
| 169 |
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 170 |
+
a_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(
|
| 171 |
+
a_desc_base_lo, 0, k * UMMA_K);
|
| 172 |
+
b_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(
|
| 173 |
+
b_desc_base_lo, 0, k * UMMA_K);
|
| 174 |
+
ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
|
|
|
| 186 |
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
|
| 187 |
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 188 |
if (warp_idx == 2)
|
| 189 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 190 |
|
| 191 |
// TMA checks
|
| 192 |
constexpr uint32_t kNumBankGroupBytes = 16;
|
|
|
|
| 197 |
|
| 198 |
// Wait UMMA arrival
|
| 199 |
tmem_full_barrier->wait(0);
|
| 200 |
+
ptx::tcgen05_after_thread_sync();
|
| 201 |
|
| 202 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 203 |
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
|
|
|
| 245 |
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 246 |
values[0], values[1], values[2], values[3]);
|
| 247 |
cutlass::arch::fence_view_async_tmem_load();
|
| 248 |
+
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 249 |
}
|
| 250 |
|
| 251 |
// Synchronize all threads and issue TMA
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 10 |
+
#include <deep_gemm/common/utils.cuh>
|
| 11 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 12 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 13 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 15 |
+
|
| 16 |
+
namespace deep_gemm {
|
| 17 |
+
|
| 18 |
+
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 19 |
+
bool kIsCompressedLogits,
|
| 20 |
+
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 21 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 22 |
+
uint32_t kNumSMs,
|
| 23 |
+
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 24 |
+
typename logits_dtype_t,
|
| 25 |
+
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 26 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 27 |
+
void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 28 |
+
const uint32_t max_seqlen_k,
|
| 29 |
+
const uint32_t logits_stride,
|
| 30 |
+
const uint32_t* cu_seq_len_k_start,
|
| 31 |
+
const uint32_t* cu_seq_len_k_end,
|
| 32 |
+
logits_dtype_t* logits,
|
| 33 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 34 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
|
| 37 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 38 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 39 |
+
|
| 40 |
+
// Utils
|
| 41 |
+
const auto sm_idx = blockIdx.x;
|
| 42 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 43 |
+
const auto warpgroup_idx = warp_idx / 4;
|
| 44 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 45 |
+
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
|
| 46 |
+
|
| 47 |
+
// Prefetch TMA descriptors
|
| 48 |
+
if (warp_idx == kSpecWarpStart) {
|
| 49 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 50 |
+
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
|
| 51 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 52 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 53 |
+
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// UMMA configs
|
| 57 |
+
static constexpr uint32_t kNumTmemStages = 3;
|
| 58 |
+
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
| 59 |
+
static constexpr uint32_t UMMA_M = 128;
|
| 60 |
+
static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
|
| 61 |
+
static constexpr uint32_t UMMA_K = 64;
|
| 62 |
+
static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems);
|
| 63 |
+
static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems);
|
| 64 |
+
static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads;
|
| 65 |
+
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 66 |
+
DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`");
|
| 67 |
+
|
| 68 |
+
// Shared memory configs
|
| 69 |
+
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
|
| 70 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2);
|
| 71 |
+
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int);
|
| 72 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2);
|
| 73 |
+
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
|
| 74 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
| 75 |
+
|
| 76 |
+
// Align to swizzling alignment bytes
|
| 77 |
+
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
| 78 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 79 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 80 |
+
|
| 81 |
+
// Q and KV data on shared memory
|
| 82 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 83 |
+
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
|
| 84 |
+
});
|
| 85 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 86 |
+
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
|
| 87 |
+
});
|
| 88 |
+
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
|
| 89 |
+
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 90 |
+
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
|
| 91 |
+
});
|
| 92 |
+
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 93 |
+
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
|
| 94 |
+
});
|
| 95 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 96 |
+
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
|
| 97 |
+
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 98 |
+
});
|
| 99 |
+
|
| 100 |
+
// Barriers and TMEM pointer on shared memory
|
| 101 |
+
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 102 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 103 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
| 104 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
| 105 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
| 106 |
+
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
| 107 |
+
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
|
| 108 |
+
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
|
| 109 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
|
| 110 |
+
|
| 111 |
+
// Tensor memory configs
|
| 112 |
+
constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages;
|
| 113 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQ / 32 + kNumSFKV / 32>();
|
| 114 |
+
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
|
| 115 |
+
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32;
|
| 116 |
+
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 117 |
+
|
| 118 |
+
// Initialize barriers
|
| 119 |
+
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
|
| 120 |
+
#pragma unroll
|
| 121 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 122 |
+
full_q_barriers[i]->init(1);
|
| 123 |
+
empty_q_barriers[i]->init(kNumMathThreads + 32);
|
| 124 |
+
}
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 127 |
+
full_kv_barriers[i]->init(1);
|
| 128 |
+
empty_kv_barriers[i]->init(1);
|
| 129 |
+
}
|
| 130 |
+
#pragma unroll
|
| 131 |
+
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
|
| 132 |
+
full_tmem_barriers[i]->init(1);
|
| 133 |
+
empty_tmem_barriers[i]->init(128);
|
| 134 |
+
}
|
| 135 |
+
cutlass::arch::fence_barrier_init();
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Allocate tensor memory
|
| 139 |
+
if (warp_idx == kSpecWarpStart + 2)
|
| 140 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 141 |
+
__syncthreads();
|
| 142 |
+
|
| 143 |
+
// Scheduler
|
| 144 |
+
const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
|
| 145 |
+
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 146 |
+
auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 147 |
+
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 148 |
+
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 149 |
+
#pragma unroll
|
| 150 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 151 |
+
const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1);
|
| 152 |
+
seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv);
|
| 153 |
+
seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv);
|
| 154 |
+
start = cute::min(start, seq_k_start[i]);
|
| 155 |
+
end = cute::max(end, seq_k_end[i]);
|
| 156 |
+
}
|
| 157 |
+
// TMA alignment requirements for SF KV
|
| 158 |
+
start = start / 4 * 4;
|
| 159 |
+
return {start, math::ceil_div(end - start, BLOCK_KV)};
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
// Make Q, KV and TMEM pipeline
|
| 163 |
+
auto make_pipeline = [](const uint32_t& num_stages) {
|
| 164 |
+
// Return current stage and phase, and advance pipeline by steps
|
| 165 |
+
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
|
| 166 |
+
uint32_t current_idx = iter_idx;
|
| 167 |
+
iter_idx += step;
|
| 168 |
+
return {current_idx % num_stages, (current_idx / num_stages) & 1};
|
| 169 |
+
};
|
| 170 |
+
};
|
| 171 |
+
auto advance_q_pipeline = make_pipeline(kNumQStages);
|
| 172 |
+
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
|
| 173 |
+
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
|
| 174 |
+
|
| 175 |
+
// Register reconfigurations
|
| 176 |
+
constexpr uint32_t kNumSpecializedRegisters = 56;
|
| 177 |
+
constexpr uint32_t kNumMathRegisters = 224;
|
| 178 |
+
|
| 179 |
+
// Wait for primary kernel completion
|
| 180 |
+
cudaGridDependencySynchronize();
|
| 181 |
+
|
| 182 |
+
if (warp_idx == kSpecWarpStart) {
|
| 183 |
+
// TMA warp for loading Q
|
| 184 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 185 |
+
|
| 186 |
+
// Enumerate Q blocks
|
| 187 |
+
if (cute::elect_one_sync()) {
|
| 188 |
+
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
|
| 189 |
+
// Wait Q consumer release
|
| 190 |
+
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 191 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 192 |
+
|
| 193 |
+
// Issue TMA Q
|
| 194 |
+
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
|
| 195 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 196 |
+
smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads);
|
| 197 |
+
tma::copy<BLOCK_Q * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q);
|
| 198 |
+
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q);
|
| 199 |
+
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
__syncwarp();
|
| 203 |
+
} else if (warp_idx == kSpecWarpStart + 1) {
|
| 204 |
+
// TMA warp for loading KV cache
|
| 205 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 206 |
+
|
| 207 |
+
if (cute::elect_one_sync()) {
|
| 208 |
+
// Enumerate Q blocks
|
| 209 |
+
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
|
| 210 |
+
// Load KV block ranges
|
| 211 |
+
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
|
| 212 |
+
|
| 213 |
+
// Enumerate KV blocks
|
| 214 |
+
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
|
| 215 |
+
// Wait KV consumer release
|
| 216 |
+
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
| 217 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 218 |
+
|
| 219 |
+
// Issue TMA KV
|
| 220 |
+
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
| 221 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 222 |
+
smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV);
|
| 223 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
|
| 224 |
+
smem_sf_kv[kv_stage_idx],
|
| 225 |
+
kv_start + kv_idx * BLOCK_KV, 0);
|
| 226 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
} else if (warp_idx == kSpecWarpStart + 2) {
|
| 231 |
+
// UMMA warp
|
| 232 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 233 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 234 |
+
|
| 235 |
+
// UTCCP transposer
|
| 236 |
+
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
| 237 |
+
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
|
| 238 |
+
uint32_t values[4];
|
| 239 |
+
#pragma unroll
|
| 240 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 241 |
+
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
| 242 |
+
__syncwarp();
|
| 243 |
+
#pragma unroll
|
| 244 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 245 |
+
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
| 246 |
+
};
|
| 247 |
+
|
| 248 |
+
// Make UMMA desc
|
| 249 |
+
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
|
| 250 |
+
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 251 |
+
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
|
| 252 |
+
|
| 253 |
+
// Enumerate Q blocks
|
| 254 |
+
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
|
| 255 |
+
// Load KV block ranges
|
| 256 |
+
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
|
| 257 |
+
|
| 258 |
+
// Wait TMA Q arrivals
|
| 259 |
+
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 260 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 261 |
+
|
| 262 |
+
// Transpose and copy SF Q
|
| 263 |
+
#pragma unroll
|
| 264 |
+
for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) {
|
| 265 |
+
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 266 |
+
utccp_required_smem_warp_transpose(smem_ptr);
|
| 267 |
+
cutlass::arch::fence_view_async_shared();
|
| 268 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 269 |
+
if (cute::elect_one_sync())
|
| 270 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
|
| 271 |
+
__syncwarp();
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// Enumerate KV blocks
|
| 275 |
+
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
|
| 276 |
+
// Wait TMA KV arrivals
|
| 277 |
+
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
| 278 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 279 |
+
|
| 280 |
+
// Transpose
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
|
| 283 |
+
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 284 |
+
utccp_required_smem_warp_transpose(smem_ptr);
|
| 285 |
+
cutlass::arch::fence_view_async_shared();
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// UMMA with SF
|
| 289 |
+
if (cute::elect_one_sync()) {
|
| 290 |
+
// Copy SF KV
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
|
| 293 |
+
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 294 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 295 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
#pragma unroll
|
| 299 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 300 |
+
// Wait TMEM release
|
| 301 |
+
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
|
| 302 |
+
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
|
| 303 |
+
|
| 304 |
+
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
|
| 305 |
+
ptx::tcgen05_after_thread_sync();
|
| 306 |
+
|
| 307 |
+
// Issue UMMA with SF
|
| 308 |
+
#pragma unroll
|
| 309 |
+
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 310 |
+
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
|
| 311 |
+
// TODO: generalize umma desc
|
| 312 |
+
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
|
| 313 |
+
auto a_desc = mma::sm100::make_smem_desc(
|
| 314 |
+
cute::UMMA::LayoutType::SWIZZLE_64B,
|
| 315 |
+
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
|
| 316 |
+
8 * (kHeadDim / 2), 0);
|
| 317 |
+
auto b_desc = mma::sm100::make_smem_desc(
|
| 318 |
+
cute::UMMA::LayoutType::SWIZZLE_64B,
|
| 319 |
+
smem_q[q_stage_idx] + k * UMMA_K / 2,
|
| 320 |
+
8 * (kHeadDim / 2), 0);
|
| 321 |
+
ptx::SM100_MMA_MXF4_SS::fma(
|
| 322 |
+
a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
|
| 323 |
+
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
|
| 324 |
+
}
|
| 325 |
+
// TODO: move this into `deep_gemm/ptx/tcgen05.cuh`
|
| 326 |
+
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
|
| 327 |
+
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// UMMA warp must also arrive on empty_q to prevent running ahead
|
| 334 |
+
// of math warps in the Q pipeline. Without this, UMMA can consume
|
| 335 |
+
// kNumQStages Q blocks before math warps release any, causing a
|
| 336 |
+
// circular dependency: UMMA waits full_q -> TMA_Q waits empty_q
|
| 337 |
+
// -> Math waits full_tmem -> UMMA (already moved on).
|
| 338 |
+
empty_q_barriers[q_stage_idx]->arrive();
|
| 339 |
+
}
|
| 340 |
+
} else if (warp_idx == kSpecWarpStart + 3) {
|
| 341 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 342 |
+
} else if (warp_idx < kSpecWarpStart) {
|
| 343 |
+
// Math warpgroups for reduce
|
| 344 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 345 |
+
|
| 346 |
+
const auto math_warpgroup_idx = warpgroup_idx;
|
| 347 |
+
const auto math_thread_idx = threadIdx.x;
|
| 348 |
+
|
| 349 |
+
// Helper lambda for loading tensor memory
|
| 350 |
+
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
|
| 351 |
+
constexpr uint32_t N = decltype(num_elems_c)::value;
|
| 352 |
+
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
|
| 353 |
+
using Loader = cute::conditional_t<N == 32,
|
| 354 |
+
cute::SM100_TMEM_LOAD_32dp32b32x,
|
| 355 |
+
cute::SM100_TMEM_LOAD_32dp32b64x>;
|
| 356 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) {
|
| 357 |
+
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
|
| 358 |
+
}(cute::make_index_sequence<N>{});
|
| 359 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 360 |
+
};
|
| 361 |
+
|
| 362 |
+
// Math warpgroups process TMEM stages alternately
|
| 363 |
+
// Advance pipeline to align with the assigned stage
|
| 364 |
+
advance_tmem_pipeline(math_warpgroup_idx);
|
| 365 |
+
|
| 366 |
+
// Local register buffers
|
| 367 |
+
float accum[kNumHeads];
|
| 368 |
+
float weights[BLOCK_Q][kNumHeads];
|
| 369 |
+
|
| 370 |
+
// Enumerate Q blocks
|
| 371 |
+
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
|
| 372 |
+
// Load KV block ranges
|
| 373 |
+
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
|
| 374 |
+
|
| 375 |
+
// Wait TMA Q arrivals
|
| 376 |
+
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 377 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 378 |
+
|
| 379 |
+
// Read weights
|
| 380 |
+
// TODO: optimize bank conflicts
|
| 381 |
+
#pragma unroll
|
| 382 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 383 |
+
#pragma unroll
|
| 384 |
+
for (uint32_t j = 0; j < kNumHeads; ++ j)
|
| 385 |
+
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
// Enumerate KV blocks
|
| 389 |
+
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
|
| 390 |
+
// Calculate KV offset in advance
|
| 391 |
+
auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx;
|
| 392 |
+
|
| 393 |
+
// Advance pipeline by `kNumMathWarpGroups` steps
|
| 394 |
+
// Wait UMMA arrival
|
| 395 |
+
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
|
| 396 |
+
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
|
| 397 |
+
ptx::tcgen05_after_thread_sync();
|
| 398 |
+
|
| 399 |
+
// Reduce over the head dim and store
|
| 400 |
+
#pragma unroll
|
| 401 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 402 |
+
// Load accumulator from TMEM
|
| 403 |
+
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
| 404 |
+
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
| 405 |
+
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
| 406 |
+
|
| 407 |
+
// Release TMEM empty
|
| 408 |
+
if (i == BLOCK_Q - 1) {
|
| 409 |
+
ptx::tcgen05_before_thread_sync();
|
| 410 |
+
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
// Accumulate weighted ReLU in parallel
|
| 414 |
+
auto sum_0 = make_float2(0, 0);
|
| 415 |
+
auto sum_1 = make_float2(0, 0);
|
| 416 |
+
|
| 417 |
+
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
| 418 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 419 |
+
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 420 |
+
return __ffma2_rn(a, b, sum);
|
| 421 |
+
};
|
| 422 |
+
|
| 423 |
+
#pragma unroll
|
| 424 |
+
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
| 425 |
+
sum_0 = transform(j, sum_0);
|
| 426 |
+
sum_1 = transform(j + 2, sum_1);
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 430 |
+
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
| 431 |
+
|
| 432 |
+
// Store into the global memory
|
| 433 |
+
// NOTES: we have redundant writes here, consider more carefully
|
| 434 |
+
// TODO: optimize performance
|
| 435 |
+
const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast<uint64_t>(logits_stride);
|
| 436 |
+
if constexpr (kIsCompressedLogits) {
|
| 437 |
+
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
|
| 438 |
+
logits[q_offset + kv_offset - seq_k_start[i]] = result;
|
| 439 |
+
} else {
|
| 440 |
+
logits[q_offset + kv_offset] = result;
|
| 441 |
+
}
|
| 442 |
+
__syncwarp();
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// Release last Q empty
|
| 447 |
+
empty_q_barriers[q_stage_idx]->arrive();
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
// Free tensor memory
|
| 451 |
+
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
|
| 452 |
+
if (warp_idx == 0)
|
| 453 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 10 |
+
#include <deep_gemm/common/math.cuh>
|
| 11 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 12 |
+
#include <deep_gemm/common/utils.cuh>
|
| 13 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 17 |
+
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
|
| 18 |
+
|
| 19 |
+
namespace deep_gemm {
|
| 20 |
+
|
| 21 |
+
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 22 |
+
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 23 |
+
bool kIsContextLens2D, bool kIsVarlen,
|
| 24 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 25 |
+
uint32_t SPLIT_KV,
|
| 26 |
+
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 27 |
+
typename logits_dtype_t,
|
| 28 |
+
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 29 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 30 |
+
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
| 31 |
+
const uint32_t logits_stride, const uint32_t block_table_stride,
|
| 32 |
+
const uint32_t* context_lens, logits_dtype_t* logits,
|
| 33 |
+
const uint32_t* block_table, const uint32_t* indices,
|
| 34 |
+
const uint32_t* schedule_meta,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
|
| 37 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 38 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
|
| 39 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 40 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 41 |
+
|
| 42 |
+
// Utils
|
| 43 |
+
const auto sm_idx = blockIdx.x;
|
| 44 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 45 |
+
const auto warpgroup_idx = warp_idx / 4;
|
| 46 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 47 |
+
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
|
| 48 |
+
|
| 49 |
+
// Prefetch TMA descriptors
|
| 50 |
+
if (warp_idx == kSpecWarpStart) {
|
| 51 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 52 |
+
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
|
| 53 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 54 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 55 |
+
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
| 59 |
+
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
| 60 |
+
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
| 61 |
+
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
| 62 |
+
|
| 63 |
+
// UMMA configs
|
| 64 |
+
static constexpr uint32_t kNumTmemStages = 3;
|
| 65 |
+
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
| 66 |
+
static constexpr uint32_t UMMA_M = 128;
|
| 67 |
+
static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
|
| 68 |
+
static constexpr uint32_t UMMA_K = 64;
|
| 69 |
+
static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems);
|
| 70 |
+
static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems);
|
| 71 |
+
static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads;
|
| 72 |
+
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 73 |
+
DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`");
|
| 74 |
+
|
| 75 |
+
// Shared memory configs
|
| 76 |
+
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
|
| 77 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2);
|
| 78 |
+
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int);
|
| 79 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2);
|
| 80 |
+
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
|
| 81 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
|
| 82 |
+
|
| 83 |
+
// Align to swizzling alignment bytes
|
| 84 |
+
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
| 85 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 86 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 87 |
+
|
| 88 |
+
// Q and KV data on shared memory
|
| 89 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 90 |
+
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
|
| 91 |
+
});
|
| 92 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 93 |
+
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
|
| 94 |
+
});
|
| 95 |
+
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
|
| 96 |
+
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 97 |
+
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
|
| 98 |
+
});
|
| 99 |
+
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 100 |
+
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
|
| 101 |
+
});
|
| 102 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 103 |
+
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
|
| 104 |
+
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 105 |
+
});
|
| 106 |
+
|
| 107 |
+
// Barriers and TMEM pointer on shared memory
|
| 108 |
+
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 109 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 110 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
| 111 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
| 112 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
| 113 |
+
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
| 114 |
+
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
|
| 115 |
+
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
|
| 116 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
|
| 117 |
+
|
| 118 |
+
// Tensor memory configs
|
| 119 |
+
constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages;
|
| 120 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQAtom / 32 + kNumSFKV / 32>();
|
| 121 |
+
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
|
| 122 |
+
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32;
|
| 123 |
+
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 124 |
+
|
| 125 |
+
// Initialize barriers
|
| 126 |
+
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
|
| 127 |
+
#pragma unroll
|
| 128 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 129 |
+
full_q_barriers[i]->init(1);
|
| 130 |
+
empty_q_barriers[i]->init(kNumMathThreads + 32);
|
| 131 |
+
}
|
| 132 |
+
cutlass::arch::fence_barrier_init();
|
| 133 |
+
}
|
| 134 |
+
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
|
| 135 |
+
#pragma unroll
|
| 136 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 137 |
+
full_kv_barriers[i]->init(1);
|
| 138 |
+
empty_kv_barriers[i]->init(1);
|
| 139 |
+
}
|
| 140 |
+
cutlass::arch::fence_barrier_init();
|
| 141 |
+
}
|
| 142 |
+
if (warp_idx == kSpecWarpStart + 2) {
|
| 143 |
+
if (cute::elect_one_sync()) {
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
|
| 146 |
+
full_tmem_barriers[i]->init(1);
|
| 147 |
+
empty_tmem_barriers[i]->init(128);
|
| 148 |
+
}
|
| 149 |
+
cutlass::arch::fence_barrier_init();
|
| 150 |
+
}
|
| 151 |
+
// Allocate tensor memory
|
| 152 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 153 |
+
}
|
| 154 |
+
__syncthreads();
|
| 155 |
+
|
| 156 |
+
// Wait for primary kernel completion
|
| 157 |
+
cudaGridDependencySynchronize();
|
| 158 |
+
|
| 159 |
+
// Scheduler
|
| 160 |
+
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
| 161 |
+
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
| 162 |
+
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
| 163 |
+
|
| 164 |
+
// Make Q, KV and TMEM pipeline
|
| 165 |
+
auto make_pipeline = [](const uint32_t& num_stages) {
|
| 166 |
+
// Return current stage and phase, and advance pipeline by steps
|
| 167 |
+
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
|
| 168 |
+
uint32_t current_idx = iter_idx;
|
| 169 |
+
iter_idx += step;
|
| 170 |
+
return {current_idx % num_stages, (current_idx / num_stages) & 1};
|
| 171 |
+
};
|
| 172 |
+
};
|
| 173 |
+
auto advance_q_pipeline = make_pipeline(kNumQStages);
|
| 174 |
+
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
|
| 175 |
+
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
|
| 176 |
+
|
| 177 |
+
// Register reconfigurations
|
| 178 |
+
constexpr uint32_t kNumSpecializedRegisters = 56;
|
| 179 |
+
constexpr uint32_t kNumMathRegisters = 224;
|
| 180 |
+
|
| 181 |
+
if (warp_idx == kSpecWarpStart) {
|
| 182 |
+
// TMA warp for loading Q
|
| 183 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 184 |
+
|
| 185 |
+
if (cute::elect_one_sync()) {
|
| 186 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 187 |
+
|
| 188 |
+
// Persistently schedule over blocks
|
| 189 |
+
// Initialize outside valid range to indicate no previous task
|
| 190 |
+
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
| 191 |
+
uint32_t q_atom_idx, _, __;
|
| 192 |
+
while (scheduler.fetch_next_task(q_atom_idx, _, __)) {
|
| 193 |
+
// Issue TMA Q when (q_idx, atom_idx) changes
|
| 194 |
+
if (q_atom_idx != last_q_atom_idx) {
|
| 195 |
+
// Wait Q consumer release
|
| 196 |
+
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 197 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 198 |
+
|
| 199 |
+
// Issue TMA Q
|
| 200 |
+
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
|
| 201 |
+
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
|
| 202 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 203 |
+
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
|
| 204 |
+
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
|
| 205 |
+
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
|
| 206 |
+
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 207 |
+
}
|
| 208 |
+
last_q_atom_idx = q_atom_idx;
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
__syncwarp();
|
| 212 |
+
} else if (warp_idx == kSpecWarpStart + 1) {
|
| 213 |
+
// TMA warp for loading KV cache
|
| 214 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 215 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 216 |
+
|
| 217 |
+
// Persistently schedule over blocks
|
| 218 |
+
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
|
| 219 |
+
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
| 220 |
+
uint32_t q_atom_idx, kv_idx, num_kv;
|
| 221 |
+
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) {
|
| 222 |
+
// Reset block table cache on kv restart
|
| 223 |
+
if (q_atom_idx != last_q_atom_idx)
|
| 224 |
+
kv_block_idx_ptr = 32;
|
| 225 |
+
last_q_atom_idx = q_atom_idx;
|
| 226 |
+
|
| 227 |
+
// Coalesced load of block table
|
| 228 |
+
if (kv_block_idx_ptr == 32) {
|
| 229 |
+
kv_block_idx_ptr = 0;
|
| 230 |
+
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
| 231 |
+
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
| 232 |
+
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
| 233 |
+
}
|
| 234 |
+
__syncwarp();
|
| 235 |
+
|
| 236 |
+
// Broadcast KV block indices
|
| 237 |
+
int kv_block_idx[kNumBlocksPerSplit];
|
| 238 |
+
#pragma unroll
|
| 239 |
+
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
|
| 240 |
+
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
|
| 241 |
+
kv_block_idx_ptr += kNumBlocksPerSplit;
|
| 242 |
+
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`");
|
| 243 |
+
|
| 244 |
+
// Wait KV consumer release
|
| 245 |
+
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
| 246 |
+
|
| 247 |
+
// Issue TMA KV
|
| 248 |
+
if (cute::elect_one_sync()) {
|
| 249 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 250 |
+
#pragma unroll
|
| 251 |
+
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
|
| 252 |
+
cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
| 253 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 254 |
+
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i,
|
| 255 |
+
0, 0, kv_block_idx[i]);
|
| 256 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
|
| 257 |
+
smem_sf_kv[kv_stage_idx] + BLOCK_KV * i,
|
| 258 |
+
0, kv_block_idx[i]);
|
| 259 |
+
}
|
| 260 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
} else if (warp_idx == kSpecWarpStart + 2) {
|
| 264 |
+
// UMMA warp
|
| 265 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 266 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 267 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 268 |
+
|
| 269 |
+
// UTCCP transposer
|
| 270 |
+
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
| 271 |
+
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
|
| 272 |
+
uint32_t values[4];
|
| 273 |
+
#pragma unroll
|
| 274 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 275 |
+
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
| 276 |
+
__syncwarp();
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 279 |
+
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
| 280 |
+
};
|
| 281 |
+
|
| 282 |
+
// Make UMMA desc
|
| 283 |
+
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
|
| 284 |
+
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 285 |
+
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
|
| 286 |
+
|
| 287 |
+
// Persistently schedule over blocks
|
| 288 |
+
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
| 289 |
+
uint32_t q_atom_idx, kv_idx, _;
|
| 290 |
+
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
|
| 291 |
+
// Wait TMA Q arrivals
|
| 292 |
+
uint32_t q_stage_idx, q_phase;
|
| 293 |
+
if (q_atom_idx != last_q_atom_idx) {
|
| 294 |
+
CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 295 |
+
|
| 296 |
+
// Release previous Q empty (UMMA warp must participate to prevent
|
| 297 |
+
// running ahead of math warps in the Q pipeline)
|
| 298 |
+
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
|
| 299 |
+
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
|
| 300 |
+
|
| 301 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 302 |
+
|
| 303 |
+
// Transpose and copy SF Q
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) {
|
| 306 |
+
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 307 |
+
utccp_required_smem_warp_transpose(smem_ptr);
|
| 308 |
+
cutlass::arch::fence_view_async_shared();
|
| 309 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 310 |
+
if (cute::elect_one_sync())
|
| 311 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
|
| 312 |
+
__syncwarp();
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
last_q_atom_idx = q_atom_idx;
|
| 316 |
+
|
| 317 |
+
// Wait TMA KV arrivals
|
| 318 |
+
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
| 319 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 320 |
+
|
| 321 |
+
// Transpose
|
| 322 |
+
#pragma unroll
|
| 323 |
+
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
|
| 324 |
+
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 325 |
+
utccp_required_smem_warp_transpose(smem_ptr);
|
| 326 |
+
cutlass::arch::fence_view_async_shared();
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// UMMA with SF
|
| 330 |
+
if (cute::elect_one_sync()) {
|
| 331 |
+
// Copy SF KV
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
|
| 334 |
+
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
|
| 335 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 336 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
#pragma unroll
|
| 340 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 341 |
+
// Wait TMEM release
|
| 342 |
+
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
|
| 343 |
+
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
|
| 344 |
+
|
| 345 |
+
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
|
| 346 |
+
ptx::tcgen05_after_thread_sync();
|
| 347 |
+
|
| 348 |
+
// Issue UMMA with SF
|
| 349 |
+
#pragma unroll
|
| 350 |
+
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 351 |
+
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
|
| 352 |
+
// TODO: generalize UMMA desc
|
| 353 |
+
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
|
| 354 |
+
auto a_desc = mma::sm100::make_smem_desc(
|
| 355 |
+
cute::UMMA::LayoutType::SWIZZLE_64B,
|
| 356 |
+
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
|
| 357 |
+
8 * (kHeadDim / 2), 0);
|
| 358 |
+
auto b_desc = mma::sm100::make_smem_desc(
|
| 359 |
+
cute::UMMA::LayoutType::SWIZZLE_64B,
|
| 360 |
+
smem_q[q_stage_idx] + k * UMMA_K / 2,
|
| 361 |
+
8 * (kHeadDim / 2), 0);
|
| 362 |
+
ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
|
| 363 |
+
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
|
| 364 |
+
}
|
| 365 |
+
// TODO: move this PTX into headers
|
| 366 |
+
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
|
| 367 |
+
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
|
| 371 |
+
}
|
| 372 |
+
} else if (warp_idx == kSpecWarpStart + 3) {
|
| 373 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 374 |
+
} else if (warp_idx < kSpecWarpStart) {
|
| 375 |
+
// Math warpgroups for reduce
|
| 376 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 377 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 378 |
+
|
| 379 |
+
const auto math_warpgroup_idx = warpgroup_idx;
|
| 380 |
+
const auto math_thread_idx = warp_idx * 32 + lane_idx;
|
| 381 |
+
|
| 382 |
+
// Helper lambda for loading tensor memory
|
| 383 |
+
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
|
| 384 |
+
constexpr int N = decltype(num_elems_c)::value;
|
| 385 |
+
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
|
| 386 |
+
using Loader = cute::conditional_t<N == 32,
|
| 387 |
+
cute::SM100_TMEM_LOAD_32dp32b32x,
|
| 388 |
+
cute::SM100_TMEM_LOAD_32dp32b64x>;
|
| 389 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) {
|
| 390 |
+
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
|
| 391 |
+
}(cute::make_index_sequence<N>{});
|
| 392 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 393 |
+
};
|
| 394 |
+
|
| 395 |
+
// Math warpgroups process TMEM stages alternately
|
| 396 |
+
// Advance pipeline to align with the assigned stage
|
| 397 |
+
advance_tmem_pipeline(math_warpgroup_idx);
|
| 398 |
+
|
| 399 |
+
// Local register buffers
|
| 400 |
+
float accum[kNumHeads];
|
| 401 |
+
float weights[kNextNAtom][kNumHeads];
|
| 402 |
+
|
| 403 |
+
// Persistently schedule over blocks
|
| 404 |
+
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
| 405 |
+
uint32_t q_atom_idx, kv_idx, _;
|
| 406 |
+
bool is_paired_atom = false;
|
| 407 |
+
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
|
| 408 |
+
if (q_atom_idx != last_q_atom_idx) {
|
| 409 |
+
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
| 410 |
+
|
| 411 |
+
// Release last Q empty
|
| 412 |
+
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
|
| 413 |
+
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
|
| 414 |
+
|
| 415 |
+
// Wait TMA Q arrivals
|
| 416 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 417 |
+
|
| 418 |
+
// Read weights
|
| 419 |
+
#pragma unroll
|
| 420 |
+
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
| 421 |
+
#pragma unroll
|
| 422 |
+
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
| 423 |
+
float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j));
|
| 424 |
+
weights[i][j + 0] = raw.x;
|
| 425 |
+
weights[i][j + 1] = raw.y;
|
| 426 |
+
weights[i][j + 2] = raw.z;
|
| 427 |
+
weights[i][j + 3] = raw.w;
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
// Check if this atom pairs two tokens from the same sequence
|
| 432 |
+
if constexpr (kIsVarlen) {
|
| 433 |
+
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
last_q_atom_idx = q_atom_idx;
|
| 437 |
+
|
| 438 |
+
// Calculate KV offset in advance
|
| 439 |
+
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
|
| 440 |
+
|
| 441 |
+
// Advance pipeline by `kNumMathWarpGroups` steps
|
| 442 |
+
// Wait UMMA arrival
|
| 443 |
+
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
|
| 444 |
+
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
|
| 445 |
+
ptx::tcgen05_after_thread_sync();
|
| 446 |
+
|
| 447 |
+
// Reduce over the head dim and store
|
| 448 |
+
const auto reduce_and_store = [&](auto num_iters_c) {
|
| 449 |
+
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
| 450 |
+
|
| 451 |
+
// Only loop over valid iterations
|
| 452 |
+
#pragma unroll
|
| 453 |
+
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
| 454 |
+
// Load accumulator from TMEM
|
| 455 |
+
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
| 456 |
+
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
| 457 |
+
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
| 458 |
+
|
| 459 |
+
// Accumulate weighted ReLU in parallel
|
| 460 |
+
auto sum_0 = make_float2(0, 0);
|
| 461 |
+
auto sum_1 = make_float2(0, 0);
|
| 462 |
+
|
| 463 |
+
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
| 464 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 465 |
+
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 466 |
+
return __ffma2_rn(a, b, sum);
|
| 467 |
+
};
|
| 468 |
+
|
| 469 |
+
#pragma unroll
|
| 470 |
+
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
| 471 |
+
sum_0 = transform(j, sum_0);
|
| 472 |
+
sum_1 = transform(j + 2, sum_1);
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 476 |
+
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
| 477 |
+
|
| 478 |
+
// Store into the global memory
|
| 479 |
+
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
|
| 480 |
+
__syncwarp();
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// Release TMEM empty
|
| 484 |
+
ptx::tcgen05_before_thread_sync();
|
| 485 |
+
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
if constexpr (kIsVarlen) {
|
| 489 |
+
if (is_paired_atom)
|
| 490 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 491 |
+
else
|
| 492 |
+
reduce_and_store(cute::Int<1>{});
|
| 493 |
+
} else if constexpr (kPadOddN) {
|
| 494 |
+
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
| 495 |
+
reduce_and_store(cute::Int<1>{});
|
| 496 |
+
else
|
| 497 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 498 |
+
} else {
|
| 499 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// Free tensor memory
|
| 504 |
+
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
|
| 505 |
+
if (warp_idx == 0)
|
| 506 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma clang diagnostic push
|
| 3 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/math.cuh>
|
| 8 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 9 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 10 |
+
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
|
| 11 |
+
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
|
| 12 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 13 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 15 |
+
|
| 16 |
+
namespace deep_gemm {
|
| 17 |
+
|
| 18 |
+
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 19 |
+
uint32_t kGranKA, uint32_t kGranKB,
|
| 20 |
+
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 21 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 22 |
+
uint32_t kNumGroups,
|
| 23 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
| 24 |
+
uint32_t kNumStages,
|
| 25 |
+
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
| 26 |
+
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 27 |
+
uint32_t kNumSMs,
|
| 28 |
+
bool kSwapAB,
|
| 29 |
+
GemmType kGemmType, bool kWithAccumulation,
|
| 30 |
+
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
|
| 31 |
+
typename epilogue_type_t>
|
| 32 |
+
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
| 33 |
+
sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
|
| 34 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 37 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
|
| 38 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
|
| 39 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
|
| 40 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 41 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 42 |
+
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
|
| 43 |
+
|
| 44 |
+
// GEMM with accumulation must have FP32 output
|
| 45 |
+
if constexpr (kWithAccumulation)
|
| 46 |
+
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 47 |
+
|
| 48 |
+
// MMA Configs
|
| 49 |
+
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 50 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
|
| 51 |
+
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
|
| 52 |
+
constexpr uint32_t UMMA_K = 32;
|
| 53 |
+
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
| 54 |
+
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 55 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
| 56 |
+
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
| 57 |
+
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
|
| 58 |
+
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
|
| 59 |
+
|
| 60 |
+
// SF configs
|
| 61 |
+
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
| 62 |
+
constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
|
| 63 |
+
constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
|
| 64 |
+
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
|
| 65 |
+
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
|
| 66 |
+
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
|
| 67 |
+
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
|
| 68 |
+
DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB");
|
| 69 |
+
|
| 70 |
+
// Epilogue configs
|
| 71 |
+
// Always enable pipeline for better performance
|
| 72 |
+
constexpr uint32_t kNumEpilogueStages = 2;
|
| 73 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 74 |
+
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
|
| 75 |
+
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
|
| 76 |
+
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 77 |
+
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 78 |
+
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
|
| 79 |
+
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
|
| 80 |
+
|
| 81 |
+
// Share memory sizes
|
| 82 |
+
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
|
| 83 |
+
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 84 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
| 85 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
| 86 |
+
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
| 87 |
+
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
| 88 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
|
| 89 |
+
"Shared memory of A/B must be aligned to 1024 bytes");
|
| 90 |
+
// NOTES: Make sure we have enough shared memory for UMMA padding
|
| 91 |
+
constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
|
| 92 |
+
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
|
| 93 |
+
|
| 94 |
+
// Tensor memory size and offsets
|
| 95 |
+
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
|
| 96 |
+
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
| 97 |
+
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
| 98 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
| 99 |
+
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
| 100 |
+
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
| 101 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 102 |
+
|
| 103 |
+
// Synchronize the cluster before 2-CTA TMEM allocation
|
| 104 |
+
kNumMulticast > 1 ? cute::cluster_sync() : void();
|
| 105 |
+
|
| 106 |
+
// Utils
|
| 107 |
+
const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 108 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 109 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 110 |
+
|
| 111 |
+
// Prefetch TMA descriptors at the very beginning
|
| 112 |
+
if (warp_idx == 0) {
|
| 113 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 114 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 115 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
| 116 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfb);
|
| 117 |
+
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Overwrite shape constants if the compiler gives
|
| 121 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 122 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 123 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 124 |
+
const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4);
|
| 125 |
+
const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4);
|
| 126 |
+
|
| 127 |
+
// Align to 1024 bytes for swizzle-128B
|
| 128 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 129 |
+
|
| 130 |
+
// D/A/B shared memory
|
| 131 |
+
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
|
| 132 |
+
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
| 133 |
+
});
|
| 134 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 135 |
+
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 136 |
+
});
|
| 137 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 138 |
+
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 139 |
+
});
|
| 140 |
+
|
| 141 |
+
// SFA/SFB shared memory
|
| 142 |
+
auto sf_start_ptr = reinterpret_cast<uint8_t*>(smem_b[kNumStages]);
|
| 143 |
+
auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
|
| 144 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
| 145 |
+
});
|
| 146 |
+
auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
|
| 147 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
| 148 |
+
});
|
| 149 |
+
|
| 150 |
+
// Barriers and tensor memory pointer
|
| 151 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_sfb[kNumStages]);;
|
| 152 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 153 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 154 |
+
auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 155 |
+
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
| 156 |
+
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
|
| 157 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
| 158 |
+
|
| 159 |
+
// Initialize barriers
|
| 160 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 163 |
+
// Arrive at all CTAs
|
| 164 |
+
full_barriers[i]->init(1);
|
| 165 |
+
empty_barriers[i]->init(1);
|
| 166 |
+
// Arrive only at the leader CTA
|
| 167 |
+
with_sf_full_barriers[i]->init(kNumMulticast * 32);
|
| 168 |
+
}
|
| 169 |
+
#pragma unroll
|
| 170 |
+
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
| 171 |
+
// Arrive at all CTAs
|
| 172 |
+
tmem_full_barriers[i]->init(1);
|
| 173 |
+
// Arrive only at the leader CTA
|
| 174 |
+
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Make initialized barrier visible in async proxy
|
| 178 |
+
cutlass::arch::fence_barrier_init();
|
| 179 |
+
} else if (warp_idx == 2) {
|
| 180 |
+
// Allocate tensor memory
|
| 181 |
+
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 182 |
+
}
|
| 183 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 184 |
+
|
| 185 |
+
// Wait for primary kernel completion
|
| 186 |
+
cudaGridDependencySynchronize();
|
| 187 |
+
|
| 188 |
+
// Block scheduler
|
| 189 |
+
uint32_t m_block_idx, n_block_idx;
|
| 190 |
+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs, kGranKA * 4>(
|
| 191 |
+
shape_m, shape_n, shape_k, grouped_layout);
|
| 192 |
+
|
| 193 |
+
// Pipeline and TMA phases
|
| 194 |
+
uint32_t stage_idx = 0, phase = 0;
|
| 195 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 196 |
+
++ k_block_idx;
|
| 197 |
+
|
| 198 |
+
// Flip phases only if reach the next first stage
|
| 199 |
+
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
| 200 |
+
phase ^= stage_idx == 0;
|
| 201 |
+
};
|
| 202 |
+
|
| 203 |
+
// Dispatch warps into different roles
|
| 204 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 205 |
+
// TMA load warp
|
| 206 |
+
// Persistently schedule over blocks
|
| 207 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 208 |
+
// Use dynamic load block M, when swap-AB is enabled
|
| 209 |
+
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
|
| 210 |
+
|
| 211 |
+
// For k-grouped layout, the number of block K is variable
|
| 212 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 213 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 214 |
+
// Wait consumer release
|
| 215 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 216 |
+
|
| 217 |
+
// Compute offsets
|
| 218 |
+
// NOTES: the group is always concatenated with the outer dimension
|
| 219 |
+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
|
| 220 |
+
shape_m, BLOCK_M, m_block_idx);
|
| 221 |
+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
|
| 222 |
+
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 223 |
+
|
| 224 |
+
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
| 225 |
+
// And for all m-grouped GEMMs, A must be K-majored
|
| 226 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
|
| 227 |
+
kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 228 |
+
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 229 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 230 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 231 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 232 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 233 |
+
|
| 234 |
+
// Add 2 CTA offsets
|
| 235 |
+
if constexpr (kNumMulticast > 1) {
|
| 236 |
+
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
|
| 237 |
+
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
// Issue TMAs
|
| 241 |
+
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 242 |
+
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 243 |
+
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 244 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
| 245 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
|
| 246 |
+
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 247 |
+
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
| 248 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
|
| 249 |
+
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 250 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
| 251 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
|
| 252 |
+
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 253 |
+
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
| 254 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
|
| 255 |
+
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
|
| 256 |
+
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
|
| 257 |
+
|
| 258 |
+
// Issue SFA and SFB TMAs at certain stages
|
| 259 |
+
// No swizzling, so one TMA for one SF is enough
|
| 260 |
+
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
| 261 |
+
uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
|
| 262 |
+
uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>(
|
| 263 |
+
shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad));
|
| 264 |
+
tma::copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
|
| 265 |
+
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
|
| 266 |
+
}
|
| 267 |
+
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
| 268 |
+
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
|
| 269 |
+
uint32_t sfb_k_idx = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(
|
| 270 |
+
shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx);
|
| 271 |
+
tma::copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
|
| 272 |
+
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Arrive at full barriers
|
| 276 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
} else if (warp_idx == 1 and is_leader_cta) {
|
| 280 |
+
// MMA issue warp
|
| 281 |
+
// NOTES: only the leader CTA will do this
|
| 282 |
+
// Make instruction descriptor
|
| 283 |
+
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled<b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
|
| 284 |
+
UMMA_M, UMMA_N, kMajorB, kMajorA>()
|
| 285 |
+
: cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
|
| 286 |
+
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
| 287 |
+
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
|
| 288 |
+
|
| 289 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 290 |
+
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 291 |
+
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 292 |
+
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 293 |
+
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 294 |
+
|
| 295 |
+
// Checks for MMA instructions
|
| 296 |
+
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
| 297 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 298 |
+
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
| 299 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 300 |
+
"Invalid MMA instruction shape");
|
| 301 |
+
|
| 302 |
+
// Persistently schedule over blocks
|
| 303 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 304 |
+
// Wait tensor memory empty barrier arrival
|
| 305 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 306 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 307 |
+
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
| 308 |
+
ptx::tcgen05_after_thread_sync();
|
| 309 |
+
|
| 310 |
+
// Empty barrier arrival
|
| 311 |
+
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
| 312 |
+
auto umma_arrive = [](const uint64_t* barrier) {
|
| 313 |
+
if constexpr (kNumMulticast == 1) {
|
| 314 |
+
cutlass::arch::umma_arrive(barrier);
|
| 315 |
+
} else {
|
| 316 |
+
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
| 317 |
+
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
| 318 |
+
}
|
| 319 |
+
};
|
| 320 |
+
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 321 |
+
|
| 322 |
+
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 323 |
+
if (do_tmem_full_arrive)
|
| 324 |
+
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
| 325 |
+
__syncwarp();
|
| 326 |
+
};
|
| 327 |
+
|
| 328 |
+
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
|
| 329 |
+
if constexpr (kSwapAB) {
|
| 330 |
+
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
|
| 331 |
+
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// Launch MMAs
|
| 335 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 336 |
+
#pragma unroll 4
|
| 337 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 338 |
+
// Wait TMA and SF-transpose arrival
|
| 339 |
+
with_sf_full_barriers[stage_idx]->wait(phase);
|
| 340 |
+
ptx::tcgen05_after_thread_sync();
|
| 341 |
+
|
| 342 |
+
const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
|
| 343 |
+
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
|
| 344 |
+
if (cute::elect_one_sync()) {
|
| 345 |
+
// Do SF copy at certain stages
|
| 346 |
+
// TODO: process shared memory descriptor by addition
|
| 347 |
+
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
| 348 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
| 349 |
+
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
|
| 350 |
+
if (sfa_stage_in_group_idx == 0) {
|
| 351 |
+
#pragma unroll
|
| 352 |
+
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
| 353 |
+
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 354 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 355 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
|
| 359 |
+
if (sfb_stage_in_group_idx == 0) {
|
| 360 |
+
#pragma unroll
|
| 361 |
+
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
| 362 |
+
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 363 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 364 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// Issue UMMA
|
| 369 |
+
using mma_t = cute::conditional_t<
|
| 370 |
+
kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>;
|
| 371 |
+
#pragma unroll
|
| 372 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 373 |
+
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
|
| 374 |
+
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
|
| 375 |
+
const auto runtime_instr_desc = kSwapAB ?
|
| 376 |
+
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id):
|
| 377 |
+
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
|
| 378 |
+
|
| 379 |
+
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
|
| 380 |
+
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
| 381 |
+
if constexpr (kSwapAB) {
|
| 382 |
+
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
|
| 383 |
+
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
| 384 |
+
kTmemStartColOfSFB, kTmemStartColOfSFA);
|
| 385 |
+
} else {
|
| 386 |
+
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
|
| 387 |
+
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
| 388 |
+
kTmemStartColOfSFA, kTmemStartColOfSFB);
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
__syncwarp();
|
| 393 |
+
|
| 394 |
+
// Commit to the mbarrier object
|
| 395 |
+
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
| 396 |
+
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// To safely deconstruct barriers, we need another round of waits
|
| 401 |
+
const auto iter_idx = scheduler.current_iter - 1;
|
| 402 |
+
if (kNumMulticast > 1 and iter_idx >= 0) {
|
| 403 |
+
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
| 404 |
+
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 405 |
+
}
|
| 406 |
+
} else if (warp_idx == 2) {
|
| 407 |
+
// UTCCP transposer
|
| 408 |
+
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
| 409 |
+
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
|
| 410 |
+
uint32_t values[4];
|
| 411 |
+
#pragma unroll
|
| 412 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 413 |
+
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
| 414 |
+
__syncwarp();
|
| 415 |
+
#pragma unroll
|
| 416 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 417 |
+
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
| 418 |
+
};
|
| 419 |
+
|
| 420 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 421 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 422 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 423 |
+
// Wait TMA arrival
|
| 424 |
+
full_barriers[stage_idx]->wait(phase);
|
| 425 |
+
|
| 426 |
+
// Transpose for UTCCP at certain stages
|
| 427 |
+
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
| 428 |
+
#pragma unroll
|
| 429 |
+
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
| 430 |
+
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
|
| 431 |
+
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
| 432 |
+
cutlass::arch::fence_view_async_shared();
|
| 433 |
+
}
|
| 434 |
+
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
| 435 |
+
#pragma unroll
|
| 436 |
+
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
| 437 |
+
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
|
| 438 |
+
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
| 439 |
+
cutlass::arch::fence_view_async_shared();
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
// Arrive
|
| 443 |
+
with_sf_full_barriers[stage_idx]->arrive(0u);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
|
| 447 |
+
// Epilogue warp groups
|
| 448 |
+
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
| 449 |
+
|
| 450 |
+
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 451 |
+
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 452 |
+
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 453 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 454 |
+
|
| 455 |
+
// Share store pipeline between blocks
|
| 456 |
+
uint32_t tma_stage_idx = 0;
|
| 457 |
+
|
| 458 |
+
// Persistently schedule over blocks
|
| 459 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 460 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 461 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 462 |
+
|
| 463 |
+
// Wait UMMA arrival
|
| 464 |
+
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
| 465 |
+
ptx::tcgen05_after_thread_sync();
|
| 466 |
+
|
| 467 |
+
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
|
| 468 |
+
const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 469 |
+
const auto base_n_idx = n_block_idx * BLOCK_N;
|
| 470 |
+
|
| 471 |
+
if constexpr (kSwapAB) {
|
| 472 |
+
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
|
| 473 |
+
epilogue::sm100_store_cd_swap_ab<
|
| 474 |
+
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
|
| 475 |
+
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
|
| 476 |
+
kGemmType, kWithAccumulation,
|
| 477 |
+
cd_dtype_t, epilogue_type_t>
|
| 478 |
+
(smem_cd, tma_stage_idx, tmem_base_addr,
|
| 479 |
+
base_m_idx, base_n_idx, scheduler.current_group_idx,
|
| 480 |
+
effective_m,
|
| 481 |
+
epilogue_warp_idx, lane_idx,
|
| 482 |
+
tmem_empty_barriers[accum_stage_idx],
|
| 483 |
+
tensor_map_cd);
|
| 484 |
+
} else {
|
| 485 |
+
epilogue::sm100_store_cd<
|
| 486 |
+
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
|
| 487 |
+
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
|
| 488 |
+
kGemmType, kWithAccumulation,
|
| 489 |
+
cd_dtype_t, epilogue_type_t>
|
| 490 |
+
(smem_cd, tma_stage_idx, tmem_base_addr,
|
| 491 |
+
base_m_idx, base_n_idx, scheduler.current_group_idx,
|
| 492 |
+
epilogue_warp_idx, lane_idx,
|
| 493 |
+
tmem_empty_barriers[accum_stage_idx],
|
| 494 |
+
tensor_map_cd);
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
// TODO: Remove redundant synchronization
|
| 500 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 501 |
+
|
| 502 |
+
// Deallocate tensor memory
|
| 503 |
+
if (warp_idx == 0)
|
| 504 |
+
Allocator().free(0, kNumTmemCols);
|
| 505 |
+
|
| 506 |
+
#else
|
| 507 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 508 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 509 |
+
#endif
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
}; // namespace deep_gemm
|
| 513 |
+
|
| 514 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh
ADDED
|
@@ -0,0 +1,1380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <cutlass/arch/barrier.h>
|
| 5 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/math.cuh>
|
| 8 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/comm/barrier.cuh>
|
| 11 |
+
#include <deep_gemm/layout/sym_buffer.cuh>
|
| 12 |
+
#include <deep_gemm/layout/mega_moe.cuh>
|
| 13 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 14 |
+
#include <deep_gemm/scheduler/mega_moe.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/tma.cuh>
|
| 17 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 18 |
+
|
| 19 |
+
namespace deep_gemm {
|
| 20 |
+
|
| 21 |
+
template <
|
| 22 |
+
uint32_t kNumMaxTokensPerRank,
|
| 23 |
+
uint32_t kHidden, uint32_t kIntermediateHidden,
|
| 24 |
+
uint32_t kNumExperts, uint32_t kNumTopk,
|
| 25 |
+
uint32_t kNumExpertsPerWave,
|
| 26 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 27 |
+
uint32_t STORE_BLOCK_M,
|
| 28 |
+
uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N,
|
| 29 |
+
uint32_t kNumMaxPoolTokens,
|
| 30 |
+
uint32_t kNumPaddedSFPoolTokens,
|
| 31 |
+
uint32_t kNumStages,
|
| 32 |
+
uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads,
|
| 33 |
+
uint32_t kNumEpilogueThreads,
|
| 34 |
+
uint32_t kNumSMs, uint32_t kNumRanks,
|
| 35 |
+
float kActivationClamp,
|
| 36 |
+
bool kFastMath,
|
| 37 |
+
uint32_t L1_SHAPE_N = kIntermediateHidden * 2,
|
| 38 |
+
uint32_t L1_SHAPE_K = kHidden,
|
| 39 |
+
uint32_t L2_SHAPE_N = kHidden,
|
| 40 |
+
uint32_t L2_SHAPE_K = kIntermediateHidden,
|
| 41 |
+
uint32_t kNumDispatchWarps = kNumDispatchThreads / 32,
|
| 42 |
+
uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32,
|
| 43 |
+
uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32,
|
| 44 |
+
uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4,
|
| 45 |
+
uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads,
|
| 46 |
+
uint32_t kNumTokensPerWarp = 32 / kNumTopk,
|
| 47 |
+
uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks
|
| 48 |
+
>
|
| 49 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
|
| 50 |
+
sm100_fp8_fp4_mega_moe_impl(void* y,
|
| 51 |
+
int* cumulative_local_expert_recv_stats,
|
| 52 |
+
const uint32_t num_tokens,
|
| 53 |
+
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
|
| 54 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
|
| 55 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf,
|
| 56 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights,
|
| 57 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf,
|
| 58 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output,
|
| 59 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts,
|
| 60 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf,
|
| 61 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights,
|
| 62 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) {
|
| 63 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 64 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 65 |
+
using Allocator = cute::TMEM::Allocator2Sm;
|
| 66 |
+
|
| 67 |
+
// Template checks
|
| 68 |
+
DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads");
|
| 69 |
+
DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads");
|
| 70 |
+
DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads");
|
| 71 |
+
DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks");
|
| 72 |
+
|
| 73 |
+
// Thread indices
|
| 74 |
+
const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 75 |
+
const uint32_t sm_idx = blockIdx.x;
|
| 76 |
+
const uint32_t thread_idx = threadIdx.x;
|
| 77 |
+
const uint32_t warp_idx = cutlass::canonical_warp_idx_sync();
|
| 78 |
+
const uint32_t lane_idx = ptx::get_lane_idx();
|
| 79 |
+
|
| 80 |
+
// Prefetch TMA descriptors at the very beginning
|
| 81 |
+
if (warp_idx == 0) {
|
| 82 |
+
cute::prefetch_tma_descriptor(&tensor_map_l1_acts);
|
| 83 |
+
cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf);
|
| 84 |
+
cute::prefetch_tma_descriptor(&tensor_map_l1_weights);
|
| 85 |
+
cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf);
|
| 86 |
+
cute::prefetch_tma_descriptor(&tensor_map_l1_output);
|
| 87 |
+
cute::prefetch_tma_descriptor(&tensor_map_l2_acts);
|
| 88 |
+
cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf);
|
| 89 |
+
cute::prefetch_tma_descriptor(&tensor_map_l2_weights);
|
| 90 |
+
cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Workspaces
|
| 94 |
+
const auto workspace = layout::Workspace(
|
| 95 |
+
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
|
| 96 |
+
|
| 97 |
+
// Token and buffer layouts
|
| 98 |
+
constexpr auto fp8_token_layout = layout::Data(kHidden);
|
| 99 |
+
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
|
| 100 |
+
constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
|
| 101 |
+
constexpr auto fp8_sf_layout = layout::Data(kHidden / 32);
|
| 102 |
+
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32);
|
| 103 |
+
constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
|
| 104 |
+
constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
|
| 105 |
+
constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
| 106 |
+
|
| 107 |
+
// Registered inputs
|
| 108 |
+
const auto input_token_buffer = layout::Buffer(
|
| 109 |
+
fp8_token_layout, 1, kNumMaxTokensPerRank,
|
| 110 |
+
workspace.get_end_ptr());
|
| 111 |
+
const auto input_sf_buffer = layout::Buffer(
|
| 112 |
+
fp8_sf_layout, 1, kNumMaxTokensPerRank,
|
| 113 |
+
input_token_buffer.get_end_ptr());
|
| 114 |
+
const auto input_topk_idx_buffer = layout::Buffer(
|
| 115 |
+
input_topk_idx_layout, 1, kNumMaxTokensPerRank,
|
| 116 |
+
input_sf_buffer.get_end_ptr());
|
| 117 |
+
const auto input_topk_weights_buffer = layout::Buffer(
|
| 118 |
+
input_topk_weights_layout, 1, kNumMaxTokensPerRank,
|
| 119 |
+
input_topk_idx_buffer.get_end_ptr());
|
| 120 |
+
|
| 121 |
+
// SF and its buffer configs
|
| 122 |
+
constexpr uint32_t kGranK = 32;
|
| 123 |
+
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
| 124 |
+
DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M");
|
| 125 |
+
DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB");
|
| 126 |
+
|
| 127 |
+
// UTCCP 4x32 transpose index mapping within each 128-element group
|
| 128 |
+
const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) {
|
| 129 |
+
const uint32_t idx = token_idx_in_expert % BLOCK_M;
|
| 130 |
+
return token_idx_in_expert / BLOCK_M * SF_BLOCK_M +
|
| 131 |
+
(idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u);
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
// L1 inputs
|
| 135 |
+
const auto l1_token_buffer = layout::Buffer(
|
| 136 |
+
fp8_token_layout, 1, kNumMaxPoolTokens,
|
| 137 |
+
input_topk_weights_buffer.get_end_ptr());
|
| 138 |
+
const auto l1_sf_buffer = layout::Buffer(
|
| 139 |
+
fp8_sf_layout, 1, kNumPaddedSFPoolTokens,
|
| 140 |
+
l1_token_buffer.get_end_ptr());
|
| 141 |
+
const auto l1_topk_weights_buffer = layout::Buffer(
|
| 142 |
+
l1_topk_weights_layout, 1, kNumMaxPoolTokens,
|
| 143 |
+
l1_sf_buffer.get_end_ptr());
|
| 144 |
+
|
| 145 |
+
// L2 inputs
|
| 146 |
+
const auto l2_token_buffer = layout::Buffer(
|
| 147 |
+
fp8_intermediate_token_layout, 1, kNumMaxPoolTokens,
|
| 148 |
+
l1_topk_weights_buffer.get_end_ptr()
|
| 149 |
+
);
|
| 150 |
+
const auto l2_sf_buffer = layout::Buffer(
|
| 151 |
+
fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
|
| 152 |
+
l2_token_buffer.get_end_ptr()
|
| 153 |
+
);
|
| 154 |
+
|
| 155 |
+
// Combine inputs
|
| 156 |
+
const auto combine_token_buffer = layout::Buffer(
|
| 157 |
+
bf16_token_layout, kNumTopk, kNumMaxTokensPerRank,
|
| 158 |
+
l2_sf_buffer.get_end_ptr()
|
| 159 |
+
);
|
| 160 |
+
|
| 161 |
+
// Data types
|
| 162 |
+
// NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1)
|
| 163 |
+
using a_dtype_t = cutlass::float_e4m3_t;
|
| 164 |
+
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
| 165 |
+
|
| 166 |
+
// MMA configs
|
| 167 |
+
// NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major
|
| 168 |
+
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 169 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2;
|
| 170 |
+
constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB
|
| 171 |
+
constexpr uint32_t UMMA_K = 32;
|
| 172 |
+
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
|
| 173 |
+
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
|
| 174 |
+
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
|
| 175 |
+
DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
|
| 176 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
| 177 |
+
|
| 178 |
+
// Swizzle configs
|
| 179 |
+
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t);
|
| 180 |
+
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t);
|
| 181 |
+
constexpr uint32_t kSwizzleCDMode = 128;
|
| 182 |
+
DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N");
|
| 183 |
+
|
| 184 |
+
// Epilogue configs
|
| 185 |
+
constexpr uint32_t kNumEpilogueStages = 2;
|
| 186 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 187 |
+
|
| 188 |
+
// Shared memory
|
| 189 |
+
constexpr uint32_t kSharedMemoryAlignment = 1024;
|
| 190 |
+
extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[];
|
| 191 |
+
|
| 192 |
+
// Shared memory sizes
|
| 193 |
+
// NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage)
|
| 194 |
+
constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2;
|
| 195 |
+
constexpr uint32_t SMEM_EXPERT_COUNT_SIZE =
|
| 196 |
+
math::constexpr_align<uint32_t>(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment);
|
| 197 |
+
constexpr uint32_t SMEM_SEND_BUFFER_SIZE =
|
| 198 |
+
math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
|
| 199 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
| 200 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
| 201 |
+
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
| 202 |
+
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
| 203 |
+
constexpr uint32_t SMEM_CD_L1_SIZE =
|
| 204 |
+
kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages;
|
| 205 |
+
constexpr uint32_t SMEM_CD_L2_SIZE =
|
| 206 |
+
kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
|
| 207 |
+
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
|
| 208 |
+
constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages;
|
| 209 |
+
constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE =
|
| 210 |
+
SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 211 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and
|
| 212 |
+
SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and
|
| 213 |
+
SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0,
|
| 214 |
+
"Shared memory of CD/A/B must be aligned to 1024 bytes");
|
| 215 |
+
|
| 216 |
+
// Tensor memory size
|
| 217 |
+
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
|
| 218 |
+
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
| 219 |
+
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
| 220 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
| 221 |
+
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
| 222 |
+
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
| 223 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 224 |
+
|
| 225 |
+
// Assign shared memory for dispatch warps
|
| 226 |
+
const auto smem_expert_count = reinterpret_cast<uint32_t*>(smem_buffer);
|
| 227 |
+
const auto smem_send_buffers = layout::Buffer(
|
| 228 |
+
fp8_token_layout, kNumDispatchWarps, 1,
|
| 229 |
+
math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE));
|
| 230 |
+
|
| 231 |
+
// GEMM shared memory: C/D, A, B
|
| 232 |
+
// NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes
|
| 233 |
+
auto smem_gemm_base = math::advance_ptr(
|
| 234 |
+
smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE
|
| 235 |
+
);
|
| 236 |
+
|
| 237 |
+
// D/A/B shared memory
|
| 238 |
+
auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) {
|
| 239 |
+
return math::advance_ptr<uint8_t>(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE);
|
| 240 |
+
});
|
| 241 |
+
auto smem_cd_l2 = smem_cd[0];
|
| 242 |
+
auto smem_a = utils::PatternVisitor([=](const uint32_t& i) {
|
| 243 |
+
return math::advance_ptr<a_dtype_t>(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 244 |
+
});
|
| 245 |
+
auto smem_b = utils::PatternVisitor([=](const uint32_t& i) {
|
| 246 |
+
return math::advance_ptr<b_dtype_t>(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 247 |
+
});
|
| 248 |
+
|
| 249 |
+
// SF shared memory: SFA and SFB per pipeline stage
|
| 250 |
+
auto sf_start_ptr = math::advance_ptr<uint8_t>(smem_gemm_base,
|
| 251 |
+
SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 252 |
+
auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
|
| 253 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
| 254 |
+
});
|
| 255 |
+
auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
|
| 256 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
| 257 |
+
});
|
| 258 |
+
|
| 259 |
+
// Epilogue amax reduction shared memory
|
| 260 |
+
auto smem_amax_reduction = reinterpret_cast<float2*>(smem_sfb[kNumStages]);
|
| 261 |
+
|
| 262 |
+
// Barriers and tensor memory pointer
|
| 263 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2);
|
| 264 |
+
auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 265 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); });
|
| 266 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); });
|
| 267 |
+
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); });
|
| 268 |
+
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); });
|
| 269 |
+
auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); });
|
| 270 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
|
| 271 |
+
|
| 272 |
+
// A cluster sync is essential for 2CTA tensor memory allocation
|
| 273 |
+
comm::cluster_sync_with_relaxed_arrive();
|
| 274 |
+
|
| 275 |
+
// Initialization
|
| 276 |
+
if (warp_idx == 0) {
|
| 277 |
+
// Clean shared memory
|
| 278 |
+
if (cute::elect_one_sync())
|
| 279 |
+
ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t));
|
| 280 |
+
} else if (warp_idx == 1) {
|
| 281 |
+
// Init m-barriers for dispatch
|
| 282 |
+
#pragma unroll
|
| 283 |
+
for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32)
|
| 284 |
+
dispatch_barriers[i]->init(1);
|
| 285 |
+
cutlass::arch::fence_barrier_init();
|
| 286 |
+
} else if (warp_idx == 2) {
|
| 287 |
+
// Init GEMM barriers
|
| 288 |
+
if (cute::elect_one_sync()) {
|
| 289 |
+
#pragma unroll
|
| 290 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 291 |
+
// Arrive at all CTAs
|
| 292 |
+
full_barriers[i]->init(2 * 2);
|
| 293 |
+
empty_barriers[i]->init(1);
|
| 294 |
+
}
|
| 295 |
+
#pragma unroll
|
| 296 |
+
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
| 297 |
+
// Arrive at all CTAs
|
| 298 |
+
tmem_full_barriers[i]->init(1);
|
| 299 |
+
// Arrive only at the leader CTA
|
| 300 |
+
tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads);
|
| 301 |
+
}
|
| 302 |
+
#pragma unroll
|
| 303 |
+
for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i)
|
| 304 |
+
combine_barriers[i]->init(1);
|
| 305 |
+
}
|
| 306 |
+
cutlass::arch::fence_barrier_init();
|
| 307 |
+
} else if (warp_idx == 3) {
|
| 308 |
+
// Allocate tensor memory
|
| 309 |
+
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 310 |
+
}
|
| 311 |
+
// NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
|
| 312 |
+
// and `barrier.cluster.wait.aligned` is by default `.acquire`
|
| 313 |
+
comm::cluster_sync_with_relaxed_arrive();
|
| 314 |
+
|
| 315 |
+
// Task scheduler
|
| 316 |
+
auto scheduler = sched::MegaMoEScheduler<
|
| 317 |
+
BLOCK_M, BLOCK_N, BLOCK_K,
|
| 318 |
+
L1_SHAPE_N, L1_SHAPE_K,
|
| 319 |
+
L2_SHAPE_N, L2_SHAPE_K,
|
| 320 |
+
kNumExpertsPerRank,
|
| 321 |
+
kNumExpertsPerWave,
|
| 322 |
+
kNumSMs, kNumRanks>(workspace);
|
| 323 |
+
|
| 324 |
+
// MMA pipeline and TMA phases
|
| 325 |
+
uint32_t stage_idx = 0, phase = 0;
|
| 326 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 327 |
+
++ k_block_idx;
|
| 328 |
+
|
| 329 |
+
// Flip phases only if reach the next first stage
|
| 330 |
+
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
| 331 |
+
phase ^= stage_idx == 0;
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
// Intra-SM Barrier indices
|
| 335 |
+
constexpr uint32_t kDispatchBarrierIdx = 0;
|
| 336 |
+
constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1;
|
| 337 |
+
constexpr uint32_t kEpilogueFullBarrierIdx = 2;
|
| 338 |
+
constexpr uint32_t kEpilogueWGBarrierStartIdx = 3;
|
| 339 |
+
|
| 340 |
+
// NVLink barrier tags
|
| 341 |
+
constexpr uint32_t kBeforeDispatchPullBarrierTag = 1;
|
| 342 |
+
constexpr uint32_t kBeforeCombineReduceBarrierTag = 2;
|
| 343 |
+
constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3;
|
| 344 |
+
|
| 345 |
+
// Adjust registers
|
| 346 |
+
constexpr uint32_t kNumDispatchRegisters = 48;
|
| 347 |
+
constexpr uint32_t kNumNonEpilogueRegisters = 40;
|
| 348 |
+
constexpr uint32_t kNumEpilogueRegisters = 208;
|
| 349 |
+
DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads +
|
| 350 |
+
kNumNonEpilogueRegisters * kNumNonEpilogueThreads +
|
| 351 |
+
kNumEpilogueRegisters * kNumEpilogueThreads <= 64512,
|
| 352 |
+
"Too many registers");
|
| 353 |
+
|
| 354 |
+
// Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts)
|
| 355 |
+
constexpr uint32_t kDispatchGridSyncIndex = 0;
|
| 356 |
+
constexpr uint32_t kEpilogueGridSyncIndex = 1;
|
| 357 |
+
|
| 358 |
+
// Different warp roles
|
| 359 |
+
if (warp_idx < kNumDispatchWarps) {
|
| 360 |
+
// Adjust registers
|
| 361 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumDispatchRegisters>();
|
| 362 |
+
|
| 363 |
+
// Dispatch warps
|
| 364 |
+
DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk");
|
| 365 |
+
constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk;
|
| 366 |
+
const auto read_topk_idx = [&](const auto& process) {
|
| 367 |
+
// TODO: figure out better unrolling
|
| 368 |
+
// Now, `unroll` is better than `unroll 8`
|
| 369 |
+
#pragma unroll
|
| 370 |
+
for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp;
|
| 371 |
+
i < num_tokens;
|
| 372 |
+
i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) {
|
| 373 |
+
// Allocate slots for each token-topk
|
| 374 |
+
int expert_idx = -1;
|
| 375 |
+
if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) {
|
| 376 |
+
expert_idx = static_cast<int>(
|
| 377 |
+
__ldg(input_topk_idx_buffer.get_base_ptr<int64_t>() + i * kNumTopk + lane_idx));
|
| 378 |
+
if (expert_idx >= 0)
|
| 379 |
+
process(i * kNumTopk + lane_idx, expert_idx);
|
| 380 |
+
}
|
| 381 |
+
__syncwarp();
|
| 382 |
+
}
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
// Count experts' tokens
|
| 386 |
+
read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) {
|
| 387 |
+
atomicAdd_block(smem_expert_count + expert_idx, 1);
|
| 388 |
+
});
|
| 389 |
+
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
| 390 |
+
|
| 391 |
+
// Get SM offset (~6.5 us)
|
| 392 |
+
#pragma unroll
|
| 393 |
+
for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) {
|
| 394 |
+
const uint64_t send_value = (1ull << 32) | static_cast<uint64_t>(smem_expert_count[i]);
|
| 395 |
+
smem_expert_count[i] = static_cast<uint32_t>(
|
| 396 |
+
ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value));
|
| 397 |
+
}
|
| 398 |
+
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
| 399 |
+
|
| 400 |
+
// Write source indices (~2 us with 512 tokens)
|
| 401 |
+
read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) {
|
| 402 |
+
const auto dst_rank_idx = expert_idx / kNumExpertsPerRank;
|
| 403 |
+
const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1);
|
| 404 |
+
const auto dst_ptr = workspace.get_src_token_topk_idx_ptr(
|
| 405 |
+
expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx);
|
| 406 |
+
*sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx;
|
| 407 |
+
});
|
| 408 |
+
|
| 409 |
+
// Grid sync
|
| 410 |
+
comm::grid_sync<kNumSMs, kDispatchGridSyncIndex>(
|
| 411 |
+
workspace, sm_idx, thread_idx,
|
| 412 |
+
[=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }
|
| 413 |
+
);
|
| 414 |
+
|
| 415 |
+
// Write expert count
|
| 416 |
+
if (sm_idx == 0) {
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) {
|
| 419 |
+
const auto dst_rank_idx = i / kNumExpertsPerRank;
|
| 420 |
+
const auto dst_local_expert_idx = i % kNumExpertsPerRank;
|
| 421 |
+
const auto expert_status = *workspace.get_expert_send_count_ptr(i);
|
| 422 |
+
*sym_buffer.map(
|
| 423 |
+
workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx),
|
| 424 |
+
dst_rank_idx) = expert_status & 0xffffffff;
|
| 425 |
+
ptx::atomic_add_sys(
|
| 426 |
+
sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx),
|
| 427 |
+
expert_status);
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
| 431 |
+
|
| 432 |
+
// Barrier before pulling
|
| 433 |
+
comm::nvlink_barrier<kNumRanks, kNumSMs, kNumDispatchThreads,
|
| 434 |
+
kDispatchGridSyncIndex, kBeforeDispatchPullBarrierTag>(
|
| 435 |
+
workspace, sym_buffer, sm_idx, thread_idx,
|
| 436 |
+
[=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); },
|
| 437 |
+
/* After the grid sync above, there is no more writes by other SMs (except 0) */ false,
|
| 438 |
+
/* After the NVLink barrier, there is a grid sync */ true
|
| 439 |
+
);
|
| 440 |
+
|
| 441 |
+
// Ensure the epilogue barrier cannot run with the pull barrier
|
| 442 |
+
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
| 443 |
+
|
| 444 |
+
// Pull token data and SF from remote ranks into local L1 buffer
|
| 445 |
+
uint32_t pull_mbarrier_phase = 0;
|
| 446 |
+
const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0);
|
| 447 |
+
const auto pull_mbarrier = dispatch_barriers[warp_idx];
|
| 448 |
+
|
| 449 |
+
// Cache expert token counts in registers (same pattern as scheduler)
|
| 450 |
+
scheduler.fetch_expert_recv_count();
|
| 451 |
+
|
| 452 |
+
// Per-rank counts for current expert (re-loaded when expert changes)
|
| 453 |
+
constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u);
|
| 454 |
+
int current_expert_idx = -1;
|
| 455 |
+
uint32_t stored_rank_count[kNumRanksPerLane] = {};
|
| 456 |
+
uint32_t expert_start_idx = 0, expert_end_idx = 0;
|
| 457 |
+
uint32_t expert_pool_block_offset = 0;
|
| 458 |
+
|
| 459 |
+
constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps;
|
| 460 |
+
for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) {
|
| 461 |
+
// Advance expert until within the range
|
| 462 |
+
int old_expert_idx = current_expert_idx;
|
| 463 |
+
while (token_idx >= expert_end_idx) {
|
| 464 |
+
if (++ current_expert_idx >= kNumExpertsPerRank)
|
| 465 |
+
break;
|
| 466 |
+
|
| 467 |
+
// Update pool block offset for the new expert
|
| 468 |
+
expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M);
|
| 469 |
+
|
| 470 |
+
// Move start and end to the next expert
|
| 471 |
+
expert_start_idx = expert_end_idx;
|
| 472 |
+
expert_end_idx += scheduler.get_num_tokens(current_expert_idx);
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// Finish all tokens
|
| 476 |
+
if (current_expert_idx >= kNumExpertsPerRank)
|
| 477 |
+
break;
|
| 478 |
+
|
| 479 |
+
// Load per-rank counts when expert changes
|
| 480 |
+
if (old_expert_idx != current_expert_idx) {
|
| 481 |
+
old_expert_idx = current_expert_idx;
|
| 482 |
+
#pragma unroll
|
| 483 |
+
for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
|
| 484 |
+
const uint32_t j = i * 32 + lane_idx;
|
| 485 |
+
// TODO: this is not coalesced
|
| 486 |
+
stored_rank_count[i] = j < kNumRanks ?
|
| 487 |
+
static_cast<uint32_t>(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0;
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
// Round-robin rank selection via iterative min-peeling
|
| 492 |
+
uint32_t current_rank_in_expert_idx;
|
| 493 |
+
uint32_t remaining[kNumRanksPerLane];
|
| 494 |
+
#pragma unroll
|
| 495 |
+
for (uint32_t i = 0; i < kNumRanksPerLane; ++ i)
|
| 496 |
+
remaining[i] = stored_rank_count[i];
|
| 497 |
+
uint32_t offset = 0;
|
| 498 |
+
uint32_t token_idx_in_expert = token_idx - expert_start_idx;
|
| 499 |
+
uint32_t slot_idx = token_idx_in_expert;
|
| 500 |
+
uint32_t token_idx_in_rank;
|
| 501 |
+
while (true) {
|
| 502 |
+
// Compute active count and min across all ranks
|
| 503 |
+
// NOTES: reduce within each lane first, then warp-reduce once
|
| 504 |
+
uint32_t num_actives_in_lane = 0;
|
| 505 |
+
uint32_t min_in_lane = 0xffffffff;
|
| 506 |
+
#pragma unroll
|
| 507 |
+
for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
|
| 508 |
+
num_actives_in_lane += remaining[i] > 0;
|
| 509 |
+
if (remaining[i] > 0)
|
| 510 |
+
min_in_lane = cute::min(min_in_lane, remaining[i]);
|
| 511 |
+
}
|
| 512 |
+
const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane);
|
| 513 |
+
const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane);
|
| 514 |
+
|
| 515 |
+
// Hit in the current round
|
| 516 |
+
const uint32_t num_round_tokens = length * num_active_ranks;
|
| 517 |
+
if (slot_idx < num_round_tokens) {
|
| 518 |
+
const uint32_t slot_idx_in_round = slot_idx % num_active_ranks;
|
| 519 |
+
uint32_t num_seen_ranks = 0;
|
| 520 |
+
current_rank_in_expert_idx = 0;
|
| 521 |
+
#pragma unroll
|
| 522 |
+
for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
|
| 523 |
+
const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0);
|
| 524 |
+
const uint32_t num_active_lanes = __popc(mask);
|
| 525 |
+
if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes)
|
| 526 |
+
current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1);
|
| 527 |
+
num_seen_ranks += num_active_lanes;
|
| 528 |
+
}
|
| 529 |
+
token_idx_in_rank = offset + (slot_idx / num_active_ranks);
|
| 530 |
+
break;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// Move into the next round
|
| 534 |
+
slot_idx -= num_round_tokens;
|
| 535 |
+
offset += length;
|
| 536 |
+
#pragma unroll
|
| 537 |
+
for (uint32_t i = 0; i < kNumRanksPerLane; ++ i)
|
| 538 |
+
remaining[i] -= cute::min(remaining[i], length);
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
// Read source token-topk index (written by remote dispatch via NVLink)
|
| 542 |
+
const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr(
|
| 543 |
+
current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank);
|
| 544 |
+
const uint32_t src_token_idx = src_token_topk_idx / kNumTopk;
|
| 545 |
+
const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk;
|
| 546 |
+
|
| 547 |
+
// TMA load token from remote rank into shared memory
|
| 548 |
+
if (cute::elect_one_sync()) {
|
| 549 |
+
ptx::tma_load_1d(
|
| 550 |
+
pull_buffer.get_base_ptr(),
|
| 551 |
+
sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
|
| 552 |
+
current_rank_in_expert_idx),
|
| 553 |
+
pull_mbarrier, kHidden);
|
| 554 |
+
}
|
| 555 |
+
__syncwarp();
|
| 556 |
+
|
| 557 |
+
// Load and store SF (overlaps with TMA token load)
|
| 558 |
+
constexpr uint32_t kNumSFUint32 = kHidden / 128;
|
| 559 |
+
DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF");
|
| 560 |
+
const auto remote_sf_ptr = sym_buffer.map(
|
| 561 |
+
input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr<uint32_t>(),
|
| 562 |
+
current_rank_in_expert_idx);
|
| 563 |
+
const auto local_sf_ptr = l1_sf_buffer.get_base_ptr<uint32_t>();
|
| 564 |
+
const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M +
|
| 565 |
+
transform_sf_token_idx(token_idx_in_expert);
|
| 566 |
+
#pragma unroll
|
| 567 |
+
for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) {
|
| 568 |
+
const uint32_t j = i * 32 + lane_idx;
|
| 569 |
+
if (j < kNumSFUint32)
|
| 570 |
+
local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j];
|
| 571 |
+
}
|
| 572 |
+
__syncwarp();
|
| 573 |
+
|
| 574 |
+
// Store weights and token data
|
| 575 |
+
const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert;
|
| 576 |
+
if (cute::elect_one_sync()) {
|
| 577 |
+
// Load weights
|
| 578 |
+
const auto weight = *sym_buffer.map(
|
| 579 |
+
input_topk_weights_buffer.get_base_ptr<float>() + src_token_topk_idx,
|
| 580 |
+
current_rank_in_expert_idx);
|
| 581 |
+
*l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
|
| 582 |
+
|
| 583 |
+
// Wait for TMA token load to complete
|
| 584 |
+
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden);
|
| 585 |
+
ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
|
| 586 |
+
|
| 587 |
+
// Store token to local L1 buffer via TMA
|
| 588 |
+
ptx::tma_store_1d(
|
| 589 |
+
l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(),
|
| 590 |
+
pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes());
|
| 591 |
+
|
| 592 |
+
// Write source metadata for combine write-back
|
| 593 |
+
*workspace.get_token_src_metadata_ptr(pool_token_idx) =
|
| 594 |
+
{current_rank_in_expert_idx, src_token_idx, src_topk_idx};
|
| 595 |
+
|
| 596 |
+
// Wait for token TMA store to complete
|
| 597 |
+
cute::tma_store_arrive();
|
| 598 |
+
ptx::tma_store_wait<0>();
|
| 599 |
+
ptx::red_add_rel(
|
| 600 |
+
workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1);
|
| 601 |
+
}
|
| 602 |
+
__syncwarp();
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
// Clean workspace for the next usage, and also do cumulative stats
|
| 606 |
+
// NOTES: it is overlapped with combine reduction epilogue
|
| 607 |
+
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
| 608 |
+
|
| 609 |
+
DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count");
|
| 610 |
+
if (sm_idx == 0) {
|
| 611 |
+
// SM 0: clear expert send count
|
| 612 |
+
#pragma unroll
|
| 613 |
+
for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads)
|
| 614 |
+
*workspace.get_expert_send_count_ptr(i) = 0;
|
| 615 |
+
} else {
|
| 616 |
+
// Other SMs: clean blocks
|
| 617 |
+
for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) {
|
| 618 |
+
// Read expert token count before clearing
|
| 619 |
+
const auto num_recv_tokens = static_cast<uint32_t>(
|
| 620 |
+
*workspace.get_expert_recv_count_sum_ptr(i));
|
| 621 |
+
const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M);
|
| 622 |
+
|
| 623 |
+
// Compute expert pool block offset
|
| 624 |
+
expert_pool_block_offset = scheduler.get_pool_block_offset(i);
|
| 625 |
+
|
| 626 |
+
// Wait read count ready
|
| 627 |
+
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
| 628 |
+
|
| 629 |
+
// Clean expert token count, and add cumulative results
|
| 630 |
+
DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
|
| 631 |
+
if (warp_idx == 0) {
|
| 632 |
+
*workspace.get_expert_recv_count_sum_ptr(i) = 0;
|
| 633 |
+
} else if (warp_idx == 1) {
|
| 634 |
+
if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
|
| 635 |
+
ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
|
| 636 |
+
__syncwarp();
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
// Clean per-rank token count
|
| 640 |
+
for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
|
| 641 |
+
*workspace.get_expert_recv_count_ptr(j, i) = 0;
|
| 642 |
+
__syncwarp();
|
| 643 |
+
|
| 644 |
+
// Clean L1 and L2 arrival stuffs
|
| 645 |
+
for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
|
| 646 |
+
*workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
|
| 647 |
+
*workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
|
| 648 |
+
}
|
| 649 |
+
__syncwarp();
|
| 650 |
+
}
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
// Wait for all ranks to finish cleaning
|
| 654 |
+
comm::nvlink_barrier<kNumRanks, kNumSMs, kNumDispatchThreads,
|
| 655 |
+
kDispatchGridSyncIndex, kAfterWorkspaceCleanBarrierTag>(
|
| 656 |
+
workspace, sym_buffer, sm_idx, thread_idx,
|
| 657 |
+
[=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); },
|
| 658 |
+
/* Before the NVLink barrier, there is a grid sync */ true,
|
| 659 |
+
/* At the end of kernel does not need to sync */ false
|
| 660 |
+
);
|
| 661 |
+
} else if (warp_idx == kNumDispatchWarps) {
|
| 662 |
+
// Adjust registers
|
| 663 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
| 664 |
+
|
| 665 |
+
// GEMM TMA load warp for tokens with SFA
|
| 666 |
+
scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
|
| 667 |
+
const uint32_t& local_expert_idx,
|
| 668 |
+
const uint32_t& num_k_blocks,
|
| 669 |
+
const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
|
| 670 |
+
const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2
|
| 671 |
+
? &tensor_map_l2_acts : &tensor_map_l1_acts;
|
| 672 |
+
const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2
|
| 673 |
+
? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf;
|
| 674 |
+
|
| 675 |
+
const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K;
|
| 676 |
+
const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u);
|
| 677 |
+
|
| 678 |
+
// Compute pool block offset for this expert
|
| 679 |
+
const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx;
|
| 680 |
+
|
| 681 |
+
// Wait the entire token arrival for linear 1
|
| 682 |
+
if (block_phase == sched::BlockPhase::Linear1) {
|
| 683 |
+
const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
|
| 684 |
+
const auto expected = scheduler.template get_valid_m<false>();
|
| 685 |
+
while (ptx::ld_acq(ptr) != expected);
|
| 686 |
+
} else {
|
| 687 |
+
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
|
| 688 |
+
// NOTES: Originally we wait blocks on-demand to overlap L1 calculation
|
| 689 |
+
// with L2, but this optimization is negative when `num_experts_per_wave`
|
| 690 |
+
// guarantees L1's completion when L2 starts. So we remove it.
|
| 691 |
+
// In the future, if `num_experts_per_wave` is not large enough
|
| 692 |
+
// due to small `num_experts_per_rank`, we may need to add it back or add a switch
|
| 693 |
+
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
|
| 694 |
+
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
|
| 695 |
+
// NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
|
| 696 |
+
// to avoid undefined behavior when `num_k_blocks == 32`
|
| 697 |
+
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
|
| 698 |
+
while (ptx::ld_acq_gpu(ptr) != expected);
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
|
| 702 |
+
// Wait consumer release
|
| 703 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 704 |
+
|
| 705 |
+
// Compute token offset from pool block index
|
| 706 |
+
uint32_t m_idx = pool_block_idx * BLOCK_M;
|
| 707 |
+
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 708 |
+
uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M;
|
| 709 |
+
uint32_t sfa_k_idx = k_block_idx;
|
| 710 |
+
|
| 711 |
+
// Add 2 CTA offsets for non-leader CTA
|
| 712 |
+
if (not is_leader_cta)
|
| 713 |
+
m_idx += scheduler.template get_valid_m<true>() / 2;
|
| 714 |
+
|
| 715 |
+
// TMA copy tokens and SFA, then arrive at full barrier
|
| 716 |
+
if (cute::elect_one_sync()) {
|
| 717 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(
|
| 718 |
+
tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2);
|
| 719 |
+
tma::copy<SF_BLOCK_M, 1, 0>(
|
| 720 |
+
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
|
| 721 |
+
if (is_leader_cta) {
|
| 722 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2);
|
| 723 |
+
} else {
|
| 724 |
+
full_barriers[stage_idx]->arrive(0u);
|
| 725 |
+
}
|
| 726 |
+
}
|
| 727 |
+
__syncwarp();
|
| 728 |
+
}
|
| 729 |
+
});
|
| 730 |
+
} else if (warp_idx == kNumDispatchWarps + 1) {
|
| 731 |
+
// Adjust registers
|
| 732 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
| 733 |
+
|
| 734 |
+
// GEMM TMA load warp for weights with SF
|
| 735 |
+
scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
|
| 736 |
+
const uint32_t& local_expert_idx,
|
| 737 |
+
const uint32_t& num_k_blocks,
|
| 738 |
+
const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
|
| 739 |
+
const auto tensor_map_b_ptr =
|
| 740 |
+
block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights;
|
| 741 |
+
const auto tensor_map_sfb_ptr =
|
| 742 |
+
block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf;
|
| 743 |
+
|
| 744 |
+
const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K;
|
| 745 |
+
const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N;
|
| 746 |
+
const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u);
|
| 747 |
+
|
| 748 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
|
| 749 |
+
// Wait consumer release
|
| 750 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 751 |
+
|
| 752 |
+
// Compute weight offset
|
| 753 |
+
uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N;
|
| 754 |
+
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 755 |
+
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
|
| 756 |
+
uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx;
|
| 757 |
+
|
| 758 |
+
// TMA copy weights with SF
|
| 759 |
+
if (cute::elect_one_sync()) {
|
| 760 |
+
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(
|
| 761 |
+
tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2);
|
| 762 |
+
tma::copy<BLOCK_N, 1, 0>(
|
| 763 |
+
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
|
| 764 |
+
if (is_leader_cta) {
|
| 765 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2);
|
| 766 |
+
} else {
|
| 767 |
+
full_barriers[stage_idx]->arrive(0u);
|
| 768 |
+
}
|
| 769 |
+
}
|
| 770 |
+
__syncwarp();
|
| 771 |
+
}
|
| 772 |
+
});
|
| 773 |
+
} else if (warp_idx == kNumDispatchWarps + 2) {
|
| 774 |
+
// Adjust registers
|
| 775 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
| 776 |
+
|
| 777 |
+
// GEMM MMA issue warp (only the leader CTA will run)
|
| 778 |
+
if (is_leader_cta) {
|
| 779 |
+
// Make instruction descriptor with block scaling
|
| 780 |
+
// NOTES: always swap A/B
|
| 781 |
+
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
|
| 782 |
+
b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
|
| 783 |
+
UMMA_M, UMMA_N,
|
| 784 |
+
cute::UMMA::Major::K, cute::UMMA::Major::K
|
| 785 |
+
>();
|
| 786 |
+
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
|
| 787 |
+
|
| 788 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 789 |
+
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 790 |
+
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 791 |
+
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 792 |
+
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 793 |
+
|
| 794 |
+
// Checks for MMA instructions
|
| 795 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 796 |
+
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
| 797 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 798 |
+
"Invalid MMA instruction shape");
|
| 799 |
+
|
| 800 |
+
// Persistently schedule over blocks
|
| 801 |
+
uint32_t current_iter_idx = 0;
|
| 802 |
+
scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
|
| 803 |
+
const uint32_t& local_expert_idx,
|
| 804 |
+
const uint32_t& num_k_blocks,
|
| 805 |
+
const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
|
| 806 |
+
// Dynamic update of UMMA N based on effective M
|
| 807 |
+
mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m<true>());
|
| 808 |
+
|
| 809 |
+
// Wait tensor memory empty barrier arrival
|
| 810 |
+
const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
|
| 811 |
+
const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
|
| 812 |
+
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1);
|
| 813 |
+
ptx::tcgen05_after_thread_sync();
|
| 814 |
+
|
| 815 |
+
// Empty barrier arrival
|
| 816 |
+
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
| 817 |
+
auto umma_arrive = [](const uint64_t* barrier) {
|
| 818 |
+
constexpr uint16_t kCTAMask = (1 << 2) - 1;
|
| 819 |
+
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
| 820 |
+
};
|
| 821 |
+
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 822 |
+
|
| 823 |
+
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 824 |
+
if (do_tmem_full_arrive)
|
| 825 |
+
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
| 826 |
+
__syncwarp();
|
| 827 |
+
};
|
| 828 |
+
|
| 829 |
+
// Launch MMAs
|
| 830 |
+
#pragma unroll 2
|
| 831 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
|
| 832 |
+
// Wait TMA load completion
|
| 833 |
+
full_barriers[stage_idx]->wait(phase);
|
| 834 |
+
ptx::tcgen05_after_thread_sync();
|
| 835 |
+
|
| 836 |
+
const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
|
| 837 |
+
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
|
| 838 |
+
if (cute::elect_one_sync()) {
|
| 839 |
+
// UTCCP copy SFA and SFB to TMEM
|
| 840 |
+
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
|
| 841 |
+
#pragma unroll
|
| 842 |
+
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
| 843 |
+
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 844 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 845 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
| 846 |
+
}
|
| 847 |
+
#pragma unroll
|
| 848 |
+
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
| 849 |
+
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 850 |
+
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 851 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
// Issue UMMA
|
| 855 |
+
#pragma unroll
|
| 856 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 857 |
+
const auto runtime_instr_desc =
|
| 858 |
+
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
|
| 859 |
+
a_desc.lo = mma::sm100::advance_umma_desc_lo<
|
| 860 |
+
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
|
| 861 |
+
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
| 862 |
+
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
| 863 |
+
ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
|
| 864 |
+
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
| 865 |
+
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
| 866 |
+
kTmemStartColOfSFB, kTmemStartColOfSFA);
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
__syncwarp();
|
| 870 |
+
|
| 871 |
+
// Commit to the mbarrier object
|
| 872 |
+
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
| 873 |
+
empty_barrier_arrive(k_block_idx == num_k_blocks - 1);
|
| 874 |
+
}
|
| 875 |
+
});
|
| 876 |
+
|
| 877 |
+
// To safely deconstruct barriers, we need another round of waits
|
| 878 |
+
if (current_iter_idx > 0) {
|
| 879 |
+
const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1;
|
| 880 |
+
tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 881 |
+
}
|
| 882 |
+
}
|
| 883 |
+
} else if (warp_idx == kNumDispatchWarps + 3) {
|
| 884 |
+
// Adjust registers
|
| 885 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
| 886 |
+
|
| 887 |
+
} else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) {
|
| 888 |
+
// Adjust registers
|
| 889 |
+
cutlass::arch::warpgroup_reg_alloc<kNumEpilogueRegisters>();
|
| 890 |
+
|
| 891 |
+
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 892 |
+
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 893 |
+
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 894 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 895 |
+
|
| 896 |
+
// GEMM epilogue warps
|
| 897 |
+
const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps);
|
| 898 |
+
const auto epilogue_wg_idx = epilogue_warp_idx / 4;
|
| 899 |
+
const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx;
|
| 900 |
+
const auto warp_idx_in_wg = epilogue_warp_idx % 4;
|
| 901 |
+
DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and
|
| 902 |
+
kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps");
|
| 903 |
+
|
| 904 |
+
// TODO: support effective block M
|
| 905 |
+
// NOTES:
|
| 906 |
+
// - 2 warpgroups divide the whole BM into BM / 2
|
| 907 |
+
// - 4 warps divide the whole BN into BN / 4
|
| 908 |
+
// - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size
|
| 909 |
+
// - `STORE_BLOCK_M` in further divided into `ATOM_M`
|
| 910 |
+
constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups;
|
| 911 |
+
constexpr uint32_t ATOM_M = 8;
|
| 912 |
+
constexpr uint32_t kNumBankGroupBytes = 16u;
|
| 913 |
+
constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M;
|
| 914 |
+
DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M");
|
| 915 |
+
DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M");
|
| 916 |
+
DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M");
|
| 917 |
+
DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N");
|
| 918 |
+
|
| 919 |
+
// Ensure the epilogue barrier cannot run with the pull barrier
|
| 920 |
+
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
| 921 |
+
|
| 922 |
+
// Persistently schedule over blocks
|
| 923 |
+
uint32_t current_iter_idx = 0;
|
| 924 |
+
scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
|
| 925 |
+
const uint32_t& local_expert_idx,
|
| 926 |
+
const uint32_t& num_k_blocks,
|
| 927 |
+
const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
|
| 928 |
+
// Wait UMMA arrival
|
| 929 |
+
const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
|
| 930 |
+
const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
|
| 931 |
+
tmem_full_barriers[accum_stage_idx]->wait(accum_phase);
|
| 932 |
+
ptx::tcgen05_after_thread_sync();
|
| 933 |
+
|
| 934 |
+
// Compute offsets
|
| 935 |
+
// NOTES: use shuffle here to let NVCC know warp divergence won't happen
|
| 936 |
+
const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m<false>(), 0);
|
| 937 |
+
const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx;
|
| 938 |
+
uint32_t m_idx = pool_block_idx * BLOCK_M;
|
| 939 |
+
uint32_t n_idx = n_block_idx * BLOCK_N;
|
| 940 |
+
|
| 941 |
+
if (block_phase == sched::BlockPhase::Linear1) {
|
| 942 |
+
// Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights
|
| 943 |
+
// With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are:
|
| 944 |
+
// (values[0], values[2]), (values[1], values[3]),
|
| 945 |
+
// (values[4], values[6]), (values[5], values[7])
|
| 946 |
+
float stored_cached_weight = 0;
|
| 947 |
+
|
| 948 |
+
#pragma unroll
|
| 949 |
+
for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) {
|
| 950 |
+
// Early break if the entire store block is beyond the valid token range
|
| 951 |
+
if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) {
|
| 952 |
+
ptx::tcgen05_before_thread_sync();
|
| 953 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 954 |
+
break;
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
// Iterate all atoms in the store block
|
| 958 |
+
float2 swiglu_values[kNumAtomsPerStore * 2];
|
| 959 |
+
float2 amax_values[kNumAtomsPerStore];
|
| 960 |
+
#pragma unroll
|
| 961 |
+
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
|
| 962 |
+
const uint32_t j = s * kNumAtomsPerStore + i;
|
| 963 |
+
|
| 964 |
+
// Load weights from global into register cache per 32 tokens
|
| 965 |
+
DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
|
| 966 |
+
if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
|
| 967 |
+
stored_cached_weight = *l1_topk_weights_buffer
|
| 968 |
+
.get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
|
| 969 |
+
.get_base_ptr<float>();
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
// Load weights from register cache
|
| 973 |
+
const float2 weights = {
|
| 974 |
+
ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0),
|
| 975 |
+
ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1)
|
| 976 |
+
};
|
| 977 |
+
|
| 978 |
+
// Load from TMEM
|
| 979 |
+
uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M;
|
| 980 |
+
uint32_t values[ATOM_M];
|
| 981 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
|
| 982 |
+
values[0], values[1], values[2], values[3]);
|
| 983 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
|
| 984 |
+
values[4], values[5], values[6], values[7]);
|
| 985 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 986 |
+
|
| 987 |
+
// Signal tensor memory consumed on the last atom
|
| 988 |
+
if (j == WG_BLOCK_M / ATOM_M - 1) {
|
| 989 |
+
ptx::tcgen05_before_thread_sync();
|
| 990 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
// Apply SwiGLU: silu(gate) * up
|
| 994 |
+
// Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7)
|
| 995 |
+
auto fp32_values = reinterpret_cast<float*>(values);
|
| 996 |
+
#pragma unroll
|
| 997 |
+
for (uint32_t k = 0; k < 2; ++ k) {
|
| 998 |
+
auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1]));
|
| 999 |
+
auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3]));
|
| 1000 |
+
|
| 1001 |
+
// Clamp
|
| 1002 |
+
if constexpr (kActivationClamp != cute::numeric_limits<float>::infinity()) {
|
| 1003 |
+
bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp});
|
| 1004 |
+
bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp});
|
| 1005 |
+
bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp});
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
// SwiGLU
|
| 1009 |
+
auto gate = __bfloat1622float2(bf16_gate);
|
| 1010 |
+
auto neg_gate_exp = make_float2(
|
| 1011 |
+
kFastMath ? __expf(-gate.x) : expf(-gate.x),
|
| 1012 |
+
kFastMath ? __expf(-gate.y) : expf(-gate.y));
|
| 1013 |
+
const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp);
|
| 1014 |
+
if constexpr (kFastMath) {
|
| 1015 |
+
gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)});
|
| 1016 |
+
} else {
|
| 1017 |
+
gate = {gate.x / denom.x, gate.y / denom.y};
|
| 1018 |
+
}
|
| 1019 |
+
const auto up = __bfloat1622float2(bf16_up);
|
| 1020 |
+
swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights);
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
// Amax reduction
|
| 1024 |
+
amax_values[i].x = math::warp_reduce<4, true>(
|
| 1025 |
+
cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)),
|
| 1026 |
+
math::ReduceMax<float>());
|
| 1027 |
+
amax_values[i].y = math::warp_reduce<4, true>(
|
| 1028 |
+
cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)),
|
| 1029 |
+
math::ReduceMax<float>());
|
| 1030 |
+
if (lane_idx < 4)
|
| 1031 |
+
smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i];
|
| 1032 |
+
__syncwarp();
|
| 1033 |
+
}
|
| 1034 |
+
|
| 1035 |
+
// Wait shared memory release from previous TMA store
|
| 1036 |
+
// And fence `smem_amax_reduction`
|
| 1037 |
+
const uint32_t tma_stage_idx = s % kNumTMAStoreStages;
|
| 1038 |
+
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 1039 |
+
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
| 1040 |
+
|
| 1041 |
+
// Cast to FP8 E4M3 and store into shared memory
|
| 1042 |
+
#pragma unroll
|
| 1043 |
+
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
|
| 1044 |
+
// Reduce amax
|
| 1045 |
+
const float2 wp_amax =
|
| 1046 |
+
smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4];
|
| 1047 |
+
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
|
| 1048 |
+
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
|
| 1049 |
+
|
| 1050 |
+
// Calculate SF
|
| 1051 |
+
float2 sf, sf_inv;
|
| 1052 |
+
math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);
|
| 1053 |
+
|
| 1054 |
+
// Cast
|
| 1055 |
+
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
|
| 1056 |
+
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
|
| 1057 |
+
const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y));
|
| 1058 |
+
|
| 1059 |
+
// STSM
|
| 1060 |
+
uint32_t row = lane_idx;
|
| 1061 |
+
uint32_t col = warp_idx_in_wg;
|
| 1062 |
+
const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
|
| 1063 |
+
+ i * ATOM_M * L1_OUT_BLOCK_N
|
| 1064 |
+
+ row * L1_OUT_BLOCK_N
|
| 1065 |
+
+ (col ^ (row / 2)) * kNumBankGroupBytes;
|
| 1066 |
+
ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);
|
| 1067 |
+
|
| 1068 |
+
// Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout)
|
| 1069 |
+
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
|
| 1070 |
+
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
|
| 1071 |
+
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
| 1072 |
+
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
|
| 1073 |
+
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
| 1074 |
+
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
|
| 1075 |
+
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
|
| 1076 |
+
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
|
| 1077 |
+
// NOTES: originally there was:
|
| 1078 |
+
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
|
| 1079 |
+
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
|
| 1080 |
+
// We find out that
|
| 1081 |
+
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
|
| 1082 |
+
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
|
| 1083 |
+
// This reduce the number of computation instructions.
|
| 1084 |
+
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
| 1085 |
+
__builtin_assume(token_base_idx < BLOCK_M);
|
| 1086 |
+
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
| 1087 |
+
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
| 1088 |
+
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
| 1089 |
+
sf_base_ptr[sf_addr] =
|
| 1090 |
+
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
|
| 1091 |
+
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
|
| 1092 |
+
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
|
| 1093 |
+
}
|
| 1094 |
+
__syncwarp();
|
| 1095 |
+
}
|
| 1096 |
+
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
| 1097 |
+
|
| 1098 |
+
// Issue TMA store after all atoms in this store block
|
| 1099 |
+
if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
|
| 1100 |
+
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N;
|
| 1101 |
+
cute::tma_store_fence();
|
| 1102 |
+
cute::SM90_TMA_STORE_2D::copy(
|
| 1103 |
+
&tensor_map_l1_output,
|
| 1104 |
+
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N,
|
| 1105 |
+
out_n_idx,
|
| 1106 |
+
m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
|
| 1107 |
+
cute::tma_store_arrive();
|
| 1108 |
+
}
|
| 1109 |
+
__syncwarp();
|
| 1110 |
+
}
|
| 1111 |
+
|
| 1112 |
+
// Notify L2
|
| 1113 |
+
// TODO: less epilogue sync scope
|
| 1114 |
+
ptx::tma_store_wait<0>();
|
| 1115 |
+
ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx);
|
| 1116 |
+
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 1117 |
+
DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large");
|
| 1118 |
+
ptx::red_or_rel_gpu(
|
| 1119 |
+
workspace.get_l2_arrival_mask_ptr(pool_block_idx),
|
| 1120 |
+
1ull << n_block_idx
|
| 1121 |
+
);
|
| 1122 |
+
}
|
| 1123 |
+
__syncwarp();
|
| 1124 |
+
} else {
|
| 1125 |
+
DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M");
|
| 1126 |
+
constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8;
|
| 1127 |
+
|
| 1128 |
+
// L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink
|
| 1129 |
+
#pragma unroll
|
| 1130 |
+
for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) {
|
| 1131 |
+
// Early break if the entire store block is beyond the valid token range
|
| 1132 |
+
// TODO: check performance
|
| 1133 |
+
if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) {
|
| 1134 |
+
ptx::tcgen05_before_thread_sync();
|
| 1135 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 1136 |
+
break;
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
#pragma unroll
|
| 1140 |
+
for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) {
|
| 1141 |
+
// Load from TMEM using .16x256b shape to satisfy STSM layout requirements
|
| 1142 |
+
// Start from lane index 0 and 16
|
| 1143 |
+
uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
| 1144 |
+
uint32_t values[ATOM_M];
|
| 1145 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
|
| 1146 |
+
values[0], values[1], values[2], values[3]);
|
| 1147 |
+
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
|
| 1148 |
+
values[4], values[5], values[6], values[7]);
|
| 1149 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 1150 |
+
|
| 1151 |
+
// Wait shared memory release from previous NVLink store
|
| 1152 |
+
// NOTES: skip for the first store block since the prior full barrier already ensures completion
|
| 1153 |
+
if (i == 0 and s > 0)
|
| 1154 |
+
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
| 1155 |
+
|
| 1156 |
+
// Signal tensor memory consumed
|
| 1157 |
+
if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) {
|
| 1158 |
+
ptx::tcgen05_before_thread_sync();
|
| 1159 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
// Store into shared memory
|
| 1163 |
+
// NOTES: only use first 16 lanes for address
|
| 1164 |
+
// NOTES: 2 warps share a BF16 swizzle atom
|
| 1165 |
+
uint32_t row = lane_idx % 8;
|
| 1166 |
+
uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
|
| 1167 |
+
const auto smem_ptr = smem_cd_l2 +
|
| 1168 |
+
epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(nv_bfloat16)) +
|
| 1169 |
+
(warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode +
|
| 1170 |
+
i * ATOM_M * kSwizzleCDMode +
|
| 1171 |
+
row * (kNumBankGroupBytes * 8) +
|
| 1172 |
+
(col ^ row) * kNumBankGroupBytes;
|
| 1173 |
+
ptx::SM90_U32x4_STSM_T<uint32_t>::copy(
|
| 1174 |
+
math::cast_into_bf16_and_pack(values[0], values[1]),
|
| 1175 |
+
math::cast_into_bf16_and_pack(values[2], values[3]),
|
| 1176 |
+
math::cast_into_bf16_and_pack(values[4], values[5]),
|
| 1177 |
+
math::cast_into_bf16_and_pack(values[6], values[7]),
|
| 1178 |
+
smem_ptr
|
| 1179 |
+
);
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
// Wait shared memory ready
|
| 1183 |
+
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
|
| 1184 |
+
|
| 1185 |
+
// Write into remote buffers
|
| 1186 |
+
// One warp per row, now the layout is different from shared memory storing
|
| 1187 |
+
const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M;
|
| 1188 |
+
const uint32_t bank_group_idx = lane_idx % 8;
|
| 1189 |
+
|
| 1190 |
+
#pragma unroll
|
| 1191 |
+
for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) {
|
| 1192 |
+
const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16;
|
| 1193 |
+
const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store;
|
| 1194 |
+
|
| 1195 |
+
// Skip padding rows beyond the actual token count for this expert
|
| 1196 |
+
if (m_idx_in_block >= valid_m)
|
| 1197 |
+
break;
|
| 1198 |
+
|
| 1199 |
+
const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block);
|
| 1200 |
+
const uint32_t dst_rank_idx = src_metadata.rank_idx;
|
| 1201 |
+
const uint32_t dst_token_idx = src_metadata.token_idx;
|
| 1202 |
+
const uint32_t dst_topk_idx = src_metadata.topk_idx;
|
| 1203 |
+
|
| 1204 |
+
// Read from shared memory
|
| 1205 |
+
const auto smem_ptr = smem_cd_l2 +
|
| 1206 |
+
epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(nv_bfloat16)) +
|
| 1207 |
+
(lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode +
|
| 1208 |
+
row_in_store * kSwizzleCDMode +
|
| 1209 |
+
(bank_group_idx ^ row_in_atom) * kNumBankGroupBytes;
|
| 1210 |
+
const auto packed = ptx::ld_shared(reinterpret_cast<float4*>(smem_ptr));
|
| 1211 |
+
|
| 1212 |
+
// Write into remote
|
| 1213 |
+
const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx)
|
| 1214 |
+
.get_data_buffer(dst_token_idx);
|
| 1215 |
+
const auto dst_ptr = math::advance_ptr<float4>(
|
| 1216 |
+
dst_token.get_base_ptr(),
|
| 1217 |
+
n_idx * static_cast<uint32_t>(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast<uint32_t>(sizeof(float4)));
|
| 1218 |
+
*sym_buffer.map(dst_ptr, dst_rank_idx) = packed;
|
| 1219 |
+
}
|
| 1220 |
+
}
|
| 1221 |
+
|
| 1222 |
+
// Ensure the next epilogue safe to use shared memory
|
| 1223 |
+
ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx);
|
| 1224 |
+
}
|
| 1225 |
+
});
|
| 1226 |
+
|
| 1227 |
+
// Deallocate tensor memory
|
| 1228 |
+
// NOTES: must be called by the same logical warp ID on both CTAs
|
| 1229 |
+
if (epilogue_warp_idx == 0)
|
| 1230 |
+
Allocator().free(0, kNumTmemCols);
|
| 1231 |
+
|
| 1232 |
+
// NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us
|
| 1233 |
+
comm::nvlink_barrier<kNumRanks, kNumSMs, kNumEpilogueThreads,
|
| 1234 |
+
kEpilogueGridSyncIndex, kBeforeCombineReduceBarrierTag>(
|
| 1235 |
+
workspace, sym_buffer, sm_idx, epilogue_thread_idx,
|
| 1236 |
+
[&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); }
|
| 1237 |
+
);
|
| 1238 |
+
|
| 1239 |
+
// Barrier with dispatch warps, so that they can do clean workspace
|
| 1240 |
+
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
| 1241 |
+
|
| 1242 |
+
// Combine: reduce top-k results and write back
|
| 1243 |
+
// NOTES: reuse shared memory from start up to the barriers
|
| 1244 |
+
// 1 token, 1 topk latency: ~3 us
|
| 1245 |
+
constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16);
|
| 1246 |
+
constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162);
|
| 1247 |
+
|
| 1248 |
+
// 3 slots of chunk is needed: 2 load stages and 1 store
|
| 1249 |
+
constexpr uint32_t kNumChunkSlots = 3;
|
| 1250 |
+
constexpr uint32_t kNumMaxRegistersForBuffer = 128;
|
| 1251 |
+
|
| 1252 |
+
// NOTES: either 1 or 2 chunks for simplicity
|
| 1253 |
+
// NOTES: Restrict on both smem and register
|
| 1254 |
+
constexpr uint32_t kNumChunks =
|
| 1255 |
+
kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2;
|
| 1256 |
+
constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks;
|
| 1257 |
+
constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4);
|
| 1258 |
+
constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32;
|
| 1259 |
+
DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks");
|
| 1260 |
+
DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large");
|
| 1261 |
+
DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)");
|
| 1262 |
+
DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes");
|
| 1263 |
+
DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)");
|
| 1264 |
+
DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp");
|
| 1265 |
+
|
| 1266 |
+
// Verify combined shared memory budget at runtime
|
| 1267 |
+
DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast<uint32_t>(
|
| 1268 |
+
reinterpret_cast<uint8_t*>(barrier_start_ptr) - smem_buffer));
|
| 1269 |
+
|
| 1270 |
+
// Per-warp buffer: 2 stage load buffers + 1 store buffer
|
| 1271 |
+
const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) {
|
| 1272 |
+
return math::advance_ptr<uint4>(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes);
|
| 1273 |
+
});
|
| 1274 |
+
const auto combine_store_buffer = math::advance_ptr<uint4>(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes);
|
| 1275 |
+
|
| 1276 |
+
// Per-warp barriers
|
| 1277 |
+
auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) {
|
| 1278 |
+
return combine_barriers[i + epilogue_warp_idx * 2];
|
| 1279 |
+
});
|
| 1280 |
+
|
| 1281 |
+
// Iterate over all tokens
|
| 1282 |
+
uint32_t combine_phase = 0;
|
| 1283 |
+
uint32_t load_stage_idx = 0;
|
| 1284 |
+
for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx;
|
| 1285 |
+
token_idx < num_tokens;
|
| 1286 |
+
token_idx += kNumSMs * kNumEpilogueWarps) {
|
| 1287 |
+
// Read top-k slot indices: each lane reads one slot, then broadcast via exchange
|
| 1288 |
+
DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk");
|
| 1289 |
+
const int stored_topk_slot_idx = lane_idx < kNumTopk ?
|
| 1290 |
+
static_cast<int>(__ldg(input_topk_idx_buffer.get_base_ptr<int64_t>() + token_idx * kNumTopk + lane_idx)) : -1;
|
| 1291 |
+
const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0);
|
| 1292 |
+
|
| 1293 |
+
// Iterate all chunks
|
| 1294 |
+
for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) {
|
| 1295 |
+
const uint32_t chunk_byte_offset = chunk * kNumChunkBytes;
|
| 1296 |
+
|
| 1297 |
+
// Move mask and load
|
| 1298 |
+
uint32_t mask = total_mask;
|
| 1299 |
+
const auto move_mask_and_load = [&](const uint32_t& i) {
|
| 1300 |
+
if (mask) {
|
| 1301 |
+
// Move
|
| 1302 |
+
const uint32_t slot_idx = __ffs(mask) - 1;
|
| 1303 |
+
mask ^= 1 << slot_idx;
|
| 1304 |
+
|
| 1305 |
+
// Load
|
| 1306 |
+
if (cute::elect_one_sync()) {
|
| 1307 |
+
const auto src_ptr = math::advance_ptr<uint8_t>(
|
| 1308 |
+
combine_token_buffer.get_rank_buffer(slot_idx)
|
| 1309 |
+
.get_data_buffer(token_idx).get_base_ptr(),
|
| 1310 |
+
chunk_byte_offset);
|
| 1311 |
+
ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes);
|
| 1312 |
+
ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes);
|
| 1313 |
+
}
|
| 1314 |
+
__syncwarp();
|
| 1315 |
+
return true;
|
| 1316 |
+
}
|
| 1317 |
+
return false;
|
| 1318 |
+
};
|
| 1319 |
+
|
| 1320 |
+
// Load the first selection
|
| 1321 |
+
bool do_reduce = move_mask_and_load(load_stage_idx);
|
| 1322 |
+
|
| 1323 |
+
// Accumulate all top-k contributions for this chunk in float registers
|
| 1324 |
+
float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {};
|
| 1325 |
+
while (do_reduce) {
|
| 1326 |
+
// Prefetch next top-k into the buffer while current is being accumulated
|
| 1327 |
+
do_reduce = move_mask_and_load(load_stage_idx ^ 1);
|
| 1328 |
+
|
| 1329 |
+
// Accumulate
|
| 1330 |
+
combine_load_barriers[load_stage_idx]->wait(combine_phase);
|
| 1331 |
+
#pragma unroll
|
| 1332 |
+
for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) {
|
| 1333 |
+
const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx];
|
| 1334 |
+
const auto bf16_values = reinterpret_cast<const nv_bfloat162*>(&uint4_values);
|
| 1335 |
+
#pragma unroll
|
| 1336 |
+
for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l)
|
| 1337 |
+
ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]);
|
| 1338 |
+
}
|
| 1339 |
+
combine_phase ^= load_stage_idx;
|
| 1340 |
+
load_stage_idx ^= 1;
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
// Cast
|
| 1344 |
+
#pragma unroll
|
| 1345 |
+
for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) {
|
| 1346 |
+
uint4 casted;
|
| 1347 |
+
auto casted_bf16 = reinterpret_cast<nv_bfloat162*>(&casted);
|
| 1348 |
+
#pragma unroll
|
| 1349 |
+
for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l)
|
| 1350 |
+
casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]);
|
| 1351 |
+
|
| 1352 |
+
// Wait share memory release and write
|
| 1353 |
+
if (j == 0) {
|
| 1354 |
+
ptx::tma_store_wait<0>();
|
| 1355 |
+
__syncwarp();
|
| 1356 |
+
}
|
| 1357 |
+
ptx::st_shared(combine_store_buffer + j * 32 + lane_idx,
|
| 1358 |
+
casted.x, casted.y, casted.z, casted.w);
|
| 1359 |
+
}
|
| 1360 |
+
__syncwarp();
|
| 1361 |
+
|
| 1362 |
+
// TMA store the token chunk
|
| 1363 |
+
if (cute::elect_one_sync()) {
|
| 1364 |
+
cute::tma_store_fence();
|
| 1365 |
+
ptx::tma_store_1d(
|
| 1366 |
+
math::advance_ptr(y, static_cast<uint64_t>(token_idx) * kNumHiddenBytes + chunk_byte_offset),
|
| 1367 |
+
combine_store_buffer, kNumChunkBytes);
|
| 1368 |
+
cute::tma_store_arrive();
|
| 1369 |
+
}
|
| 1370 |
+
__syncwarp();
|
| 1371 |
+
}
|
| 1372 |
+
}
|
| 1373 |
+
}
|
| 1374 |
+
#else
|
| 1375 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 1376 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 1377 |
+
#endif
|
| 1378 |
+
}
|
| 1379 |
+
|
| 1380 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
CHANGED
|
@@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
|
| 155 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
| 156 |
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 157 |
|
|
|
|
|
|
|
|
|
|
| 158 |
// Initialize barriers
|
| 159 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 160 |
#pragma unroll
|
|
@@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
|
| 546 |
}
|
| 547 |
}
|
| 548 |
}
|
| 549 |
-
|
| 550 |
-
// Deallocate tensor memory by the last UMMA store warp
|
| 551 |
-
// NOTES: warp 0 is waiting TMA store
|
| 552 |
-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
| 553 |
-
Allocator().free(0, kNumTmemCols);
|
| 554 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
#else
|
| 556 |
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 557 |
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
|
|
|
| 155 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
| 156 |
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 157 |
|
| 158 |
+
if (kNumMulticast > 1)
|
| 159 |
+
cute::cluster_sync();
|
| 160 |
+
|
| 161 |
// Initialize barriers
|
| 162 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 163 |
#pragma unroll
|
|
|
|
| 549 |
}
|
| 550 |
}
|
| 551 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
}
|
| 553 |
+
|
| 554 |
+
// Deallocate tensor memory
|
| 555 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 556 |
+
if (warp_idx == 0)
|
| 557 |
+
Allocator().free(0, kNumTmemCols);
|
| 558 |
+
|
| 559 |
#else
|
| 560 |
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 561 |
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh
CHANGED
|
@@ -6,27 +6,31 @@
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
#include <deep_gemm/common/utils.cuh>
|
| 10 |
-
#include <deep_gemm/
|
| 11 |
-
#include <deep_gemm/
|
|
|
|
|
|
|
| 12 |
|
| 13 |
namespace deep_gemm {
|
| 14 |
|
| 15 |
-
using namespace deep_gemm::sm90;
|
| 16 |
-
using namespace deep_gemm::sm100;
|
| 17 |
-
|
| 18 |
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 19 |
bool kIsCompressedLogits,
|
| 20 |
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 21 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
|
|
|
| 22 |
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
|
|
|
| 23 |
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 24 |
-
|
| 25 |
void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 26 |
-
const uint32_t max_seqlen_k, const
|
| 27 |
uint32_t* cu_seq_len_k_start,
|
| 28 |
uint32_t* cu_seq_len_k_end,
|
| 29 |
-
|
| 30 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 31 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 32 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
|
@@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 35 |
// Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
|
| 36 |
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 37 |
// Q should be load only at once for a block
|
| 38 |
-
const auto
|
| 39 |
|
| 40 |
// Types
|
| 41 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 42 |
|
| 43 |
-
//
|
| 44 |
-
const auto
|
| 45 |
-
const auto
|
| 46 |
-
const auto
|
| 47 |
-
const auto
|
|
|
|
| 48 |
|
| 49 |
// Prefetch TMA descriptors
|
| 50 |
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 51 |
-
if (warp_idx ==
|
| 52 |
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 53 |
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 54 |
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 55 |
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 56 |
}
|
| 57 |
-
__syncwarp();
|
| 58 |
|
| 59 |
// Shared memory configs
|
| 60 |
// NOTES: weight may be unaligned
|
|
@@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 62 |
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
| 63 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 64 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 65 |
-
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
|
| 66 |
|
| 67 |
// Align to 512 bytes for swizzle-64B
|
| 68 |
extern __shared__ __align__(512) uint8_t smem_buffer[];
|
|
@@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 75 |
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 76 |
|
| 77 |
// Data on shared memory
|
| 78 |
-
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 79 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 80 |
SMEM_Q_SIZE_PER_STAGE * i);
|
| 81 |
});
|
| 82 |
-
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<float*>(smem_buffer +
|
| 84 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 85 |
});
|
| 86 |
-
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 87 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 88 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 89 |
});
|
| 90 |
-
auto smem_kv_scales =
|
| 91 |
return reinterpret_cast<float*>(smem_buffer +
|
| 92 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
|
| 93 |
SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
|
@@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 95 |
|
| 96 |
// TMA barriers
|
| 97 |
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 98 |
-
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 99 |
-
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 100 |
-
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 101 |
-
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 102 |
-
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
|
| 103 |
-
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
|
| 104 |
|
| 105 |
// Tensor memory allocation
|
| 106 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
|
| 107 |
|
| 108 |
// Initialize barriers
|
| 109 |
DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
|
| 110 |
-
|
| 111 |
-
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1));
|
| 112 |
-
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 113 |
#pragma unroll
|
| 114 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 115 |
full_q_barriers[i]->init(1);
|
| 116 |
-
empty_q_barriers[i]->init(kNumMathThreads);
|
| 117 |
}
|
| 118 |
#pragma unroll
|
| 119 |
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 120 |
full_kv_barriers[i]->init(1);
|
| 121 |
empty_kv_barriers[i]->init(kNumMathThreads);
|
| 122 |
}
|
| 123 |
-
#pragma unroll
|
| 124 |
-
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 125 |
-
full_umma_barriers[i]->init(1);
|
| 126 |
-
empty_umma_barriers[i]->init(128);
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
// Make initialized barrier visible in async proxy
|
| 130 |
cutlass::arch::fence_barrier_init();
|
| 131 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
// Allocate tensor memory
|
| 133 |
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 134 |
}
|
| 135 |
__syncthreads();
|
| 136 |
|
| 137 |
// Register reconfigurations
|
| 138 |
-
constexpr uint32_t kNumSpecializedRegisters =
|
| 139 |
-
constexpr uint32_t kNumMathRegisters =
|
| 140 |
|
| 141 |
// Block scheduler
|
| 142 |
-
uint32_t block_q_idx =
|
| 143 |
-
const auto
|
| 144 |
-
return {block_q_idx +
|
| 145 |
};
|
| 146 |
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 147 |
-
const auto
|
| 148 |
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 149 |
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 150 |
|
| 151 |
#pragma unroll
|
| 152 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 153 |
-
const auto
|
| 154 |
-
seq_k_start[i] =
|
| 155 |
-
seq_k_end[i] =
|
| 156 |
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 157 |
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 158 |
}
|
|
|
|
| 159 |
start = start / 4 * 4;
|
| 160 |
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 161 |
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 162 |
-
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 163 |
};
|
| 164 |
|
| 165 |
// KV pipeline
|
| 166 |
uint32_t num_total_kv_blocks = 0;
|
| 167 |
-
const auto
|
| 168 |
return {
|
| 169 |
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 170 |
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
|
@@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 177 |
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 178 |
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 182 |
|
| 183 |
// Prefetch
|
| 184 |
-
const auto
|
| 185 |
-
|
| 186 |
-
|
| 187 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 188 |
};
|
| 189 |
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
|
@@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 209 |
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 210 |
|
| 211 |
// Issue TMA KV
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 217 |
}
|
| 218 |
num_total_kv_blocks += num_kv_blocks;
|
|
@@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 221 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 222 |
}
|
| 223 |
}
|
| 224 |
-
} else if (
|
| 225 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 226 |
|
| 227 |
// Require full allocation
|
| 228 |
-
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 229 |
|
| 230 |
// Make UMMA desc
|
| 231 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
|
@@ -252,12 +260,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 252 |
#pragma unroll
|
| 253 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 254 |
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
|
| 255 |
-
tcgen05_after_thread_sync();
|
| 256 |
#pragma unroll
|
| 257 |
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 258 |
-
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 259 |
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 260 |
-
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 261 |
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 262 |
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 263 |
}
|
|
@@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 266 |
}
|
| 267 |
num_total_kv_blocks += num_kv_blocks;
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
// Jump to the next block
|
| 270 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 271 |
}
|
| 272 |
-
} else if (warp_idx
|
| 273 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 274 |
-
} else if (warp_idx <
|
| 275 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 276 |
|
| 277 |
// Offsets
|
| 278 |
-
const auto
|
| 279 |
-
const auto
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
//
|
| 283 |
-
|
| 284 |
-
float weights[BLOCK_Q][kNumWeightsInReg];
|
| 285 |
-
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
| 286 |
|
| 287 |
while (block_q_idx < num_q_blocks) {
|
| 288 |
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
|
@@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 293 |
// Read weights
|
| 294 |
#pragma unroll
|
| 295 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
}
|
| 300 |
|
| 301 |
// Compute over KV blocks
|
|
@@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 307 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 308 |
|
| 309 |
// Read per-KV scales
|
| 310 |
-
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] +
|
| 311 |
|
| 312 |
// Wait UMMA arrival
|
| 313 |
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
|
| 314 |
-
tcgen05_after_thread_sync();
|
| 315 |
|
| 316 |
// Release KV empty
|
| 317 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 318 |
|
| 319 |
// Reduce over the head dim and store
|
| 320 |
-
const auto
|
| 321 |
-
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
| 322 |
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 323 |
|
| 324 |
-
constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
|
| 325 |
-
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
|
| 326 |
-
uint32_t shifted_accum[kNumLDTMElems];
|
| 327 |
-
auto tmem_load = [&](auto... Is) {
|
| 328 |
-
if constexpr (kNumLDTMElems == 32) {
|
| 329 |
-
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
| 330 |
-
} else if constexpr (kNumLDTMElems == 64) {
|
| 331 |
-
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
| 332 |
-
} else if constexpr (kNumLDTMElems == 128) {
|
| 333 |
-
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
| 334 |
-
}
|
| 335 |
-
};
|
| 336 |
-
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
| 337 |
-
cutlass::arch::fence_view_async_tmem_load();
|
| 338 |
-
|
| 339 |
-
tcgen05_before_thread_sync();
|
| 340 |
-
empty_umma_barriers[warpgroup_idx]->arrive();
|
| 341 |
-
|
| 342 |
#pragma unroll
|
| 343 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
|
|
|
| 346 |
auto sum_0 = make_float2(0, 0);
|
| 347 |
auto sum_1 = make_float2(0, 0);
|
| 348 |
|
| 349 |
-
const auto
|
| 350 |
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 351 |
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 352 |
return __ffma2_rn(a, b, sum);
|
| 353 |
};
|
| 354 |
|
| 355 |
#pragma unroll
|
| 356 |
-
for (
|
| 357 |
-
sum_0 =
|
| 358 |
-
sum_1 =
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
|
| 362 |
-
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 363 |
-
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
|
| 364 |
-
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
|
| 365 |
-
return __ffma2_rn(a, b, sum);
|
| 366 |
-
};
|
| 367 |
-
|
| 368 |
-
#pragma unroll
|
| 369 |
-
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
|
| 370 |
-
sum_0 = transform_smem(j, sum_0);
|
| 371 |
-
sum_1 = transform_smem(j + 2, sum_1);
|
| 372 |
}
|
| 373 |
|
| 374 |
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 375 |
-
|
| 376 |
|
| 377 |
// Store into the global memory
|
| 378 |
-
|
| 379 |
-
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
| 380 |
if constexpr (kIsCompressedLogits) {
|
| 381 |
-
if (seq_k_start[i] <= kv_offset
|
| 382 |
-
logits[
|
| 383 |
} else {
|
| 384 |
-
logits[
|
| 385 |
}
|
|
|
|
| 386 |
}
|
| 387 |
}
|
| 388 |
num_total_kv_blocks += num_kv_blocks;
|
|
@@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 393 |
// Jump to the next block
|
| 394 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 395 |
}
|
| 396 |
-
}
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
| 402 |
}
|
| 403 |
|
| 404 |
} // namespace deep_gemm
|
|
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
| 9 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 10 |
+
#include <deep_gemm/common/math.cuh>
|
| 11 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 12 |
#include <deep_gemm/common/utils.cuh>
|
| 13 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 17 |
|
| 18 |
namespace deep_gemm {
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 21 |
bool kIsCompressedLogits,
|
| 22 |
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 23 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 24 |
+
uint32_t kNumSMs,
|
| 25 |
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 26 |
+
typename logits_dtype_t,
|
| 27 |
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 28 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 29 |
void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 30 |
+
const uint32_t max_seqlen_k, const uint32_t stride_logits,
|
| 31 |
uint32_t* cu_seq_len_k_start,
|
| 32 |
uint32_t* cu_seq_len_k_end,
|
| 33 |
+
logits_dtype_t* logits,
|
| 34 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 35 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 36 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
|
|
|
| 39 |
// Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
|
| 40 |
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 41 |
// Q should be load only at once for a block
|
| 42 |
+
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
|
| 43 |
|
| 44 |
// Types
|
| 45 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 46 |
|
| 47 |
+
// Utils
|
| 48 |
+
const auto sm_idx = blockIdx.x;
|
| 49 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 50 |
+
const auto warpgroup_idx = warp_idx / 4;
|
| 51 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 52 |
+
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
|
| 53 |
|
| 54 |
// Prefetch TMA descriptors
|
| 55 |
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 56 |
+
if (warp_idx == kSpecWarpStart) {
|
| 57 |
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 58 |
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 59 |
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 60 |
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 61 |
}
|
|
|
|
| 62 |
|
| 63 |
// Shared memory configs
|
| 64 |
// NOTES: weight may be unaligned
|
|
|
|
| 66 |
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
| 67 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 68 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 69 |
+
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
|
| 70 |
|
| 71 |
// Align to 512 bytes for swizzle-64B
|
| 72 |
extern __shared__ __align__(512) uint8_t smem_buffer[];
|
|
|
|
| 79 |
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 80 |
|
| 81 |
// Data on shared memory
|
| 82 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 84 |
SMEM_Q_SIZE_PER_STAGE * i);
|
| 85 |
});
|
| 86 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 87 |
return reinterpret_cast<float*>(smem_buffer +
|
| 88 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 89 |
});
|
| 90 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 91 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 92 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 93 |
});
|
| 94 |
+
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
|
| 95 |
return reinterpret_cast<float*>(smem_buffer +
|
| 96 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
|
| 97 |
SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
|
|
|
| 99 |
|
| 100 |
// TMA barriers
|
| 101 |
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 102 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 103 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 104 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 105 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 106 |
+
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
|
| 107 |
+
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
|
| 108 |
|
| 109 |
// Tensor memory allocation
|
| 110 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
|
| 111 |
|
| 112 |
// Initialize barriers
|
| 113 |
DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
|
| 114 |
+
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
|
|
|
|
|
|
|
| 115 |
#pragma unroll
|
| 116 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 117 |
full_q_barriers[i]->init(1);
|
| 118 |
+
empty_q_barriers[i]->init(kNumMathThreads + 32);
|
| 119 |
}
|
| 120 |
#pragma unroll
|
| 121 |
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 122 |
full_kv_barriers[i]->init(1);
|
| 123 |
empty_kv_barriers[i]->init(kNumMathThreads);
|
| 124 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
cutlass::arch::fence_barrier_init();
|
| 126 |
+
}
|
| 127 |
+
if (warp_idx == kSpecWarpStart + 1) {
|
| 128 |
+
if (cute::elect_one_sync()) {
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 131 |
+
full_umma_barriers[i]->init(1);
|
| 132 |
+
empty_umma_barriers[i]->init(128);
|
| 133 |
+
}
|
| 134 |
+
cutlass::arch::fence_barrier_init();
|
| 135 |
+
}
|
| 136 |
// Allocate tensor memory
|
| 137 |
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 138 |
}
|
| 139 |
__syncthreads();
|
| 140 |
|
| 141 |
// Register reconfigurations
|
| 142 |
+
constexpr uint32_t kNumSpecializedRegisters = 40;
|
| 143 |
+
constexpr uint32_t kNumMathRegisters = 232;
|
| 144 |
|
| 145 |
// Block scheduler
|
| 146 |
+
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
|
| 147 |
+
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
| 148 |
+
return {block_q_idx + kNumSMs, q_iter_idx + 1};
|
| 149 |
};
|
| 150 |
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 151 |
+
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
| 152 |
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 153 |
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 154 |
|
| 155 |
#pragma unroll
|
| 156 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 157 |
+
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
| 158 |
+
seq_k_start[i] = cu_seq_len_k_start[q_idx];
|
| 159 |
+
seq_k_end[i] = cu_seq_len_k_end[q_idx];
|
| 160 |
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 161 |
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 162 |
}
|
| 163 |
+
// TMA alignment requirements for SF KV
|
| 164 |
start = start / 4 * 4;
|
| 165 |
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 166 |
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 167 |
+
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 168 |
};
|
| 169 |
|
| 170 |
// KV pipeline
|
| 171 |
uint32_t num_total_kv_blocks = 0;
|
| 172 |
+
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 173 |
return {
|
| 174 |
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 175 |
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
|
|
|
| 182 |
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 183 |
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
|
| 184 |
|
| 185 |
+
// Wait for primary kernel completion
|
| 186 |
+
cudaGridDependencySynchronize();
|
| 187 |
+
|
| 188 |
+
if (warp_idx == kSpecWarpStart) {
|
| 189 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 190 |
|
| 191 |
// Prefetch
|
| 192 |
+
const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
| 193 |
+
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
| 194 |
+
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
| 195 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 196 |
};
|
| 197 |
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
|
|
|
| 217 |
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 218 |
|
| 219 |
// Issue TMA KV
|
| 220 |
+
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 221 |
+
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
| 222 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 223 |
+
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
| 224 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 225 |
}
|
| 226 |
num_total_kv_blocks += num_kv_blocks;
|
|
|
|
| 229 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 230 |
}
|
| 231 |
}
|
| 232 |
+
} else if (warp_idx == kSpecWarpStart + 1) {
|
| 233 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 234 |
|
| 235 |
// Require full allocation
|
| 236 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 237 |
|
| 238 |
// Make UMMA desc
|
| 239 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
|
|
|
| 260 |
#pragma unroll
|
| 261 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 262 |
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
|
| 263 |
+
ptx::tcgen05_after_thread_sync();
|
| 264 |
#pragma unroll
|
| 265 |
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 266 |
+
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 267 |
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 268 |
+
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 269 |
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 270 |
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 271 |
}
|
|
|
|
| 274 |
}
|
| 275 |
num_total_kv_blocks += num_kv_blocks;
|
| 276 |
|
| 277 |
+
// UMMA warp must also arrive on empty_q to prevent running ahead
|
| 278 |
+
// of math warps in the Q pipeline
|
| 279 |
+
empty_q_barriers[q_stage_idx]->arrive();
|
| 280 |
+
|
| 281 |
// Jump to the next block
|
| 282 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 283 |
}
|
| 284 |
+
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
|
| 285 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 286 |
+
} else if (warp_idx < kSpecWarpStart) {
|
| 287 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 288 |
|
| 289 |
// Offsets
|
| 290 |
+
const auto tmem_start = warpgroup_idx * UMMA_N;
|
| 291 |
+
const auto math_thread_idx = warp_idx * 32 + lane_idx;
|
| 292 |
+
|
| 293 |
+
// Helper lambda for loading tensor memory
|
| 294 |
+
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
|
| 295 |
+
constexpr int N = decltype(num_elems_c)::value;
|
| 296 |
+
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
|
| 297 |
+
using Loader = cute::conditional_t<N == 32,
|
| 298 |
+
cute::SM100_TMEM_LOAD_32dp32b32x,
|
| 299 |
+
cute::SM100_TMEM_LOAD_32dp32b64x>;
|
| 300 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) {
|
| 301 |
+
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
|
| 302 |
+
}(cute::make_index_sequence<N>{});
|
| 303 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 304 |
+
};
|
| 305 |
|
| 306 |
+
// Local register buffers
|
| 307 |
+
float weights[BLOCK_Q][kNumHeads];
|
|
|
|
|
|
|
| 308 |
|
| 309 |
while (block_q_idx < num_q_blocks) {
|
| 310 |
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
|
|
|
| 315 |
// Read weights
|
| 316 |
#pragma unroll
|
| 317 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for (uint32_t j = 0; j < kNumHeads; ++ j)
|
| 320 |
+
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
| 321 |
}
|
| 322 |
|
| 323 |
// Compute over KV blocks
|
|
|
|
| 329 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 330 |
|
| 331 |
// Read per-KV scales
|
| 332 |
+
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
|
| 333 |
|
| 334 |
// Wait UMMA arrival
|
| 335 |
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
|
| 336 |
+
ptx::tcgen05_after_thread_sync();
|
| 337 |
|
| 338 |
// Release KV empty
|
| 339 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 340 |
|
| 341 |
// Reduce over the head dim and store
|
| 342 |
+
const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx;
|
|
|
|
| 343 |
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
#pragma unroll
|
| 346 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 347 |
+
// Load accumulator from TMEM
|
| 348 |
+
float accum[kNumHeads];
|
| 349 |
+
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
| 350 |
+
|
| 351 |
+
// Release TMEM empty
|
| 352 |
+
if (i == BLOCK_Q - 1) {
|
| 353 |
+
ptx::tcgen05_before_thread_sync();
|
| 354 |
+
empty_umma_barriers[warpgroup_idx]->arrive();
|
| 355 |
+
}
|
| 356 |
|
| 357 |
+
// Accumulate weighted ReLU in parallel
|
| 358 |
auto sum_0 = make_float2(0, 0);
|
| 359 |
auto sum_1 = make_float2(0, 0);
|
| 360 |
|
| 361 |
+
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
| 362 |
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 363 |
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 364 |
return __ffma2_rn(a, b, sum);
|
| 365 |
};
|
| 366 |
|
| 367 |
#pragma unroll
|
| 368 |
+
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
| 369 |
+
sum_0 = transform(j, sum_0);
|
| 370 |
+
sum_1 = transform(j + 2, sum_1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
}
|
| 372 |
|
| 373 |
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 374 |
+
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
| 375 |
|
| 376 |
// Store into the global memory
|
| 377 |
+
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
|
|
|
|
| 378 |
if constexpr (kIsCompressedLogits) {
|
| 379 |
+
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
|
| 380 |
+
logits[q_offset + kv_offset - seq_k_start[i]] = result;
|
| 381 |
} else {
|
| 382 |
+
logits[q_offset + kv_offset] = result;
|
| 383 |
}
|
| 384 |
+
__syncwarp();
|
| 385 |
}
|
| 386 |
}
|
| 387 |
num_total_kv_blocks += num_kv_blocks;
|
|
|
|
| 392 |
// Jump to the next block
|
| 393 |
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 394 |
}
|
|
|
|
| 395 |
|
| 396 |
+
// Free tensor memory
|
| 397 |
+
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
|
| 398 |
+
if (warp_idx == 0)
|
| 399 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 400 |
+
}
|
| 401 |
}
|
| 402 |
|
| 403 |
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh
CHANGED
|
@@ -6,56 +6,65 @@
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
#include <deep_gemm/common/utils.cuh>
|
| 10 |
-
#include <deep_gemm/
|
| 11 |
-
#include <deep_gemm/
|
| 12 |
-
|
| 13 |
-
#include <deep_gemm/
|
|
|
|
| 14 |
|
| 15 |
namespace deep_gemm {
|
| 16 |
|
| 17 |
-
using namespace deep_gemm::sm90;
|
| 18 |
-
using namespace deep_gemm::sm100;
|
| 19 |
-
|
| 20 |
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 21 |
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 22 |
-
bool kIsContextLens2D,
|
| 23 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 24 |
uint32_t SPLIT_KV,
|
| 25 |
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
|
|
|
| 26 |
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 27 |
-
|
| 28 |
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 29 |
-
const
|
| 30 |
-
const uint32_t* context_lens,
|
| 31 |
-
const uint32_t* block_table, const uint32_t*
|
|
|
|
| 32 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 33 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 34 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 35 |
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 36 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 37 |
|
| 38 |
-
//
|
| 39 |
-
const auto
|
| 40 |
-
const auto
|
| 41 |
-
const auto
|
|
|
|
|
|
|
| 42 |
|
| 43 |
// Prefetch TMA descriptors
|
| 44 |
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 45 |
-
if (warp_idx ==
|
| 46 |
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 47 |
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 48 |
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 49 |
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 50 |
}
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
// Shared memory configs
|
| 54 |
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 55 |
-
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE =
|
| 56 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 57 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
|
| 58 |
-
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE =
|
| 59 |
|
| 60 |
// Align to swizzling alignment bytes
|
| 61 |
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
|
@@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 63 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 64 |
|
| 65 |
// Q and KV data on shared memory
|
| 66 |
-
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 67 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 68 |
});
|
| 69 |
-
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 70 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
|
| 71 |
});
|
| 72 |
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
|
| 73 |
-
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 74 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 75 |
});
|
| 76 |
-
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 77 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 78 |
});
|
| 79 |
|
| 80 |
// Barriers and TMEM pointer on shared memory
|
| 81 |
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 82 |
-
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 83 |
-
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
| 84 |
-
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
| 85 |
-
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
| 86 |
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
| 87 |
-
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
|
| 88 |
-
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
|
| 89 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
|
| 90 |
|
| 91 |
-
constexpr uint32_t kNumTmemCols =
|
| 92 |
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 93 |
-
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
|
| 94 |
-
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
|
| 95 |
-
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
|
| 96 |
|
| 97 |
// Initialize barriers
|
| 98 |
-
if (
|
| 99 |
#pragma unroll
|
| 100 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 101 |
full_q_barriers[i]->init(1);
|
| 102 |
-
empty_q_barriers[i]->init(kNumMathThreads);
|
| 103 |
}
|
| 104 |
#pragma unroll
|
| 105 |
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
|
@@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 108 |
}
|
| 109 |
cutlass::arch::fence_barrier_init();
|
| 110 |
}
|
| 111 |
-
if (
|
| 112 |
if (cute::elect_one_sync()) {
|
| 113 |
#pragma unroll
|
| 114 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
|
|
@@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 123 |
__syncthreads();
|
| 124 |
|
| 125 |
// Register reconfigurations
|
| 126 |
-
constexpr uint32_t kNumSpecializedRegisters =
|
| 127 |
-
constexpr uint32_t kNumMathRegisters =
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
// Scheduler
|
| 130 |
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
| 131 |
-
|
| 132 |
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
| 133 |
|
| 134 |
// Q and KV pipeline
|
| 135 |
-
const auto
|
| 136 |
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 137 |
};
|
| 138 |
-
const auto
|
| 139 |
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 140 |
};
|
| 141 |
-
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 142 |
|
| 143 |
// UMMA settings
|
| 144 |
// Construct instruction with layout D
|
| 145 |
constexpr uint32_t UMMA_M = 128;
|
| 146 |
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 147 |
-
constexpr uint32_t UMMA_N =
|
| 148 |
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
| 149 |
|
| 150 |
-
if (
|
| 151 |
-
// TMA warp
|
| 152 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
const auto
|
| 155 |
if (cute::elect_one_sync()) {
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 159 |
}
|
| 160 |
};
|
| 161 |
|
| 162 |
-
// Initialize
|
| 163 |
-
uint32_t
|
| 164 |
-
uint32_t
|
| 165 |
bool fetched_next_task;
|
| 166 |
|
| 167 |
// Prefetch the first Q
|
| 168 |
-
if ((fetched_next_task = scheduler.fetch_next_task(
|
| 169 |
-
issue_tma_q(0,
|
| 170 |
|
| 171 |
-
|
| 172 |
uint32_t kv_block_idx_storage;
|
| 173 |
|
| 174 |
while (fetched_next_task) {
|
| 175 |
-
// Prefetch next Q when
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
kv_idx = next_kv_idx;
|
| 179 |
num_kv = next_num_kv;
|
| 180 |
|
| 181 |
// Read KV block index
|
| 182 |
-
// TODO:
|
| 183 |
-
if (
|
| 184 |
kv_block_idx_ptr = 0;
|
| 185 |
-
|
|
|
|
|
|
|
| 186 |
}
|
|
|
|
| 187 |
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
| 188 |
|
| 189 |
// Wait Q consumer release and issue TMA Q
|
| 190 |
if (prefetch_q) {
|
| 191 |
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 192 |
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 193 |
-
issue_tma_q(q_stage_idx,
|
| 194 |
}
|
| 195 |
|
| 196 |
-
|
| 197 |
#pragma unroll
|
| 198 |
-
for (
|
| 199 |
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
|
| 200 |
kv_block_idx_ptr += kNumBlocksPerSplit;
|
| 201 |
|
|
@@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 205 |
|
| 206 |
if (cute::elect_one_sync()) {
|
| 207 |
#pragma unroll
|
| 208 |
-
for (
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
}
|
| 216 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 217 |
}
|
| 218 |
|
| 219 |
// Fetch next task
|
| 220 |
-
fetched_next_task = scheduler.fetch_next_task(
|
| 221 |
}
|
| 222 |
-
} else if (
|
| 223 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
|
|
|
|
|
|
| 224 |
|
| 225 |
// Require full allocation
|
| 226 |
-
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 227 |
|
| 228 |
// Make UMMA desc
|
| 229 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
| 230 |
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 231 |
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 232 |
|
| 233 |
-
uint32_t
|
| 234 |
-
uint32_t
|
| 235 |
uint32_t q_stage_idx, q_phase;
|
| 236 |
uint32_t umma_phase = 1;
|
| 237 |
|
| 238 |
-
while (scheduler.fetch_next_task(
|
| 239 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 241 |
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 242 |
}
|
| 243 |
|
| 244 |
-
|
| 245 |
kv_idx = next_kv_idx;
|
| 246 |
|
|
|
|
| 247 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 248 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 249 |
|
|
@@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 251 |
#pragma unroll
|
| 252 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 253 |
empty_umma_barriers[i]->wait(umma_phase);
|
| 254 |
-
tcgen05_after_thread_sync();
|
| 255 |
#pragma unroll
|
| 256 |
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 257 |
-
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 258 |
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 259 |
-
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 260 |
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 261 |
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 262 |
}
|
|
@@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 264 |
}
|
| 265 |
umma_phase ^= 1;
|
| 266 |
}
|
| 267 |
-
} else if (
|
| 268 |
-
|
|
|
|
|
|
|
| 269 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
|
|
|
|
|
|
| 270 |
|
| 271 |
// Offsets
|
| 272 |
-
const auto
|
| 273 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
//
|
| 276 |
-
|
| 277 |
-
float weights[kNextN][kNumWeightsInReg];
|
| 278 |
-
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
| 279 |
|
| 280 |
-
// Initialize
|
| 281 |
-
uint32_t
|
| 282 |
-
uint32_t
|
| 283 |
uint32_t q_stage_idx, q_phase;
|
| 284 |
uint32_t umma_phase = 0;
|
|
|
|
| 285 |
|
| 286 |
-
while (scheduler.fetch_next_task(
|
| 287 |
-
//
|
| 288 |
-
if (
|
| 289 |
-
// Release
|
| 290 |
if (q_iter_idx > 0)
|
| 291 |
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
| 292 |
|
|
@@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 296 |
|
| 297 |
// Read weights
|
| 298 |
#pragma unroll
|
| 299 |
-
for (uint32_t i = 0; i <
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
}
|
| 303 |
}
|
| 304 |
|
| 305 |
-
// Get current
|
| 306 |
-
|
| 307 |
kv_idx = next_kv_idx;
|
| 308 |
|
| 309 |
// Calculate KV offset in advance
|
| 310 |
-
auto kv_offset =
|
| 311 |
|
| 312 |
-
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
|
| 313 |
// Wait TMA KV arrival
|
| 314 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 315 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 316 |
|
| 317 |
// Read per-KV scales
|
| 318 |
-
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] +
|
| 319 |
|
| 320 |
// Wait UMMA arrival
|
| 321 |
-
full_umma_barriers[
|
| 322 |
-
tcgen05_after_thread_sync();
|
| 323 |
umma_phase ^= 1;
|
| 324 |
|
| 325 |
// Release KV empty
|
|
@@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 327 |
|
| 328 |
// Reduce over the head dim and store
|
| 329 |
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 330 |
-
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
|
| 331 |
-
uint32_t shifted_accum[kNumLDTMElems];
|
| 332 |
-
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
|
| 333 |
-
auto tmem_load = [&](auto... Is) {
|
| 334 |
-
if constexpr (kNumLDTMElems == 32) {
|
| 335 |
-
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
| 336 |
-
} else if constexpr (kNumLDTMElems == 64) {
|
| 337 |
-
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
| 338 |
-
} else if constexpr (kNumLDTMElems == 128) {
|
| 339 |
-
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
| 340 |
-
}
|
| 341 |
-
};
|
| 342 |
-
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
| 343 |
-
cutlass::arch::fence_view_async_tmem_load();
|
| 344 |
-
|
| 345 |
-
tcgen05_before_thread_sync();
|
| 346 |
-
empty_umma_barriers[warpgroup_idx]->arrive();
|
| 347 |
-
|
| 348 |
-
#pragma unroll
|
| 349 |
-
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 350 |
-
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
| 351 |
-
|
| 352 |
-
auto sum_0 = make_float2(0, 0);
|
| 353 |
-
auto sum_1 = make_float2(0, 0);
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
return __ffma2_rn(a, b, sum);
|
| 359 |
-
};
|
| 360 |
|
| 361 |
#pragma unroll
|
| 362 |
-
for (
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
}
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
return __ffma2_rn(a, b, sum);
|
| 372 |
-
};
|
| 373 |
-
|
| 374 |
-
#pragma unroll
|
| 375 |
-
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
|
| 376 |
-
sum_0 = transform_smem(j, sum_0);
|
| 377 |
-
sum_1 = transform_smem(j + 2, sum_1);
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 381 |
-
float result = scale_kv * (sum.x + sum.y);
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
}
|
| 387 |
}
|
| 388 |
-
} else {
|
| 389 |
-
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 390 |
-
}
|
| 391 |
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
|
|
|
| 396 |
}
|
| 397 |
|
| 398 |
} // namespace deep_gemm
|
|
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
| 9 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 10 |
+
#include <deep_gemm/common/math.cuh>
|
| 11 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 12 |
#include <deep_gemm/common/utils.cuh>
|
| 13 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 17 |
+
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
|
| 18 |
|
| 19 |
namespace deep_gemm {
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 22 |
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 23 |
+
bool kIsContextLens2D, bool kIsVarlen,
|
| 24 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 25 |
uint32_t SPLIT_KV,
|
| 26 |
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 27 |
+
typename logits_dtype_t,
|
| 28 |
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 29 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 30 |
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 31 |
+
const uint32_t logits_stride, const uint32_t block_table_stride,
|
| 32 |
+
const uint32_t* context_lens, logits_dtype_t* logits,
|
| 33 |
+
const uint32_t* block_table, const uint32_t* indices,
|
| 34 |
+
const uint32_t* schedule_meta,
|
| 35 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 36 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 37 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 38 |
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 39 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 40 |
|
| 41 |
+
// Utils
|
| 42 |
+
const auto sm_idx = blockIdx.x;
|
| 43 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 44 |
+
const auto warpgroup_idx = warp_idx / 4;
|
| 45 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 46 |
+
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
|
| 47 |
|
| 48 |
// Prefetch TMA descriptors
|
| 49 |
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 50 |
+
if (warp_idx == kSpecWarpStart) {
|
| 51 |
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 52 |
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 53 |
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 54 |
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 55 |
}
|
| 56 |
+
|
| 57 |
+
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
| 58 |
+
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
| 59 |
+
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
| 60 |
+
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
| 61 |
|
| 62 |
// Shared memory configs
|
| 63 |
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 64 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 65 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 66 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
|
| 67 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
|
| 68 |
|
| 69 |
// Align to swizzling alignment bytes
|
| 70 |
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
|
|
|
| 72 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 73 |
|
| 74 |
// Q and KV data on shared memory
|
| 75 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 76 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 77 |
});
|
| 78 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 79 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
|
| 80 |
});
|
| 81 |
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
|
| 82 |
+
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 84 |
});
|
| 85 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 86 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 87 |
});
|
| 88 |
|
| 89 |
// Barriers and TMEM pointer on shared memory
|
| 90 |
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 91 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 92 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
| 93 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
| 94 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
| 95 |
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
| 96 |
+
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
|
| 97 |
+
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
|
| 98 |
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
|
| 99 |
|
| 100 |
+
constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups;
|
| 101 |
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
// Initialize barriers
|
| 104 |
+
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
|
| 105 |
#pragma unroll
|
| 106 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 107 |
full_q_barriers[i]->init(1);
|
| 108 |
+
empty_q_barriers[i]->init(kNumMathThreads + 32);
|
| 109 |
}
|
| 110 |
#pragma unroll
|
| 111 |
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
|
|
|
| 114 |
}
|
| 115 |
cutlass::arch::fence_barrier_init();
|
| 116 |
}
|
| 117 |
+
if (warp_idx == kSpecWarpStart + 1) {
|
| 118 |
if (cute::elect_one_sync()) {
|
| 119 |
#pragma unroll
|
| 120 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
|
|
|
|
| 129 |
__syncthreads();
|
| 130 |
|
| 131 |
// Register reconfigurations
|
| 132 |
+
constexpr uint32_t kNumSpecializedRegisters = 56;
|
| 133 |
+
constexpr uint32_t kNumMathRegisters = 224;
|
| 134 |
+
|
| 135 |
+
// Wait for primary kernel completion
|
| 136 |
+
cudaGridDependencySynchronize();
|
| 137 |
|
| 138 |
// Scheduler
|
| 139 |
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
| 140 |
+
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
| 141 |
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
| 142 |
|
| 143 |
// Q and KV pipeline
|
| 144 |
+
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 145 |
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 146 |
};
|
| 147 |
+
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 148 |
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 149 |
};
|
|
|
|
| 150 |
|
| 151 |
// UMMA settings
|
| 152 |
// Construct instruction with layout D
|
| 153 |
constexpr uint32_t UMMA_M = 128;
|
| 154 |
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 155 |
+
constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
|
| 156 |
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
| 157 |
|
| 158 |
+
if (warp_idx == kSpecWarpStart) {
|
| 159 |
+
// TMA warp for loading data
|
| 160 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 161 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 162 |
+
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 163 |
|
| 164 |
+
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
|
| 165 |
if (cute::elect_one_sync()) {
|
| 166 |
+
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
|
| 167 |
+
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
|
| 168 |
+
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
|
| 169 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 170 |
}
|
| 171 |
};
|
| 172 |
|
| 173 |
+
// Initialize outside valid range to indicate no previous task
|
| 174 |
+
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv;
|
| 175 |
+
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
|
| 176 |
bool fetched_next_task;
|
| 177 |
|
| 178 |
// Prefetch the first Q
|
| 179 |
+
if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)))
|
| 180 |
+
issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1;
|
| 181 |
|
| 182 |
+
uint32_t kv_block_idx_ptr = 32;
|
| 183 |
uint32_t kv_block_idx_storage;
|
| 184 |
|
| 185 |
while (fetched_next_task) {
|
| 186 |
+
// Prefetch next Q when (q, atom) changes
|
| 187 |
+
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
|
| 188 |
+
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
|
| 189 |
+
|
| 190 |
+
if (q_atom_idx != next_q_atom_idx)
|
| 191 |
+
kv_block_idx_ptr = 32;
|
| 192 |
+
|
| 193 |
+
q_atom_idx = next_q_atom_idx;
|
| 194 |
kv_idx = next_kv_idx;
|
| 195 |
num_kv = next_num_kv;
|
| 196 |
|
| 197 |
// Read KV block index
|
| 198 |
+
// TODO(xuzhean): consider -1
|
| 199 |
+
if (kv_block_idx_ptr == 32) {
|
| 200 |
kv_block_idx_ptr = 0;
|
| 201 |
+
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
| 202 |
+
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
| 203 |
+
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
| 204 |
}
|
| 205 |
+
__syncwarp();
|
| 206 |
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
| 207 |
|
| 208 |
// Wait Q consumer release and issue TMA Q
|
| 209 |
if (prefetch_q) {
|
| 210 |
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 211 |
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 212 |
+
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
|
| 213 |
}
|
| 214 |
|
| 215 |
+
uint32_t kv_block_idx[kNumBlocksPerSplit];
|
| 216 |
#pragma unroll
|
| 217 |
+
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i)
|
| 218 |
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
|
| 219 |
kv_block_idx_ptr += kNumBlocksPerSplit;
|
| 220 |
|
|
|
|
| 224 |
|
| 225 |
if (cute::elect_one_sync()) {
|
| 226 |
#pragma unroll
|
| 227 |
+
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) {
|
| 228 |
+
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 229 |
+
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
|
| 230 |
+
0, 0, 1, kv_block_idx[i]);
|
| 231 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 232 |
+
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
|
| 233 |
+
0, kv_block_idx[i]);
|
| 234 |
}
|
| 235 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 236 |
}
|
| 237 |
|
| 238 |
// Fetch next task
|
| 239 |
+
fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv);
|
| 240 |
}
|
| 241 |
+
} else if (warp_idx == kSpecWarpStart + 1) {
|
| 242 |
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 243 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 244 |
+
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 245 |
|
| 246 |
// Require full allocation
|
| 247 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
| 248 |
|
| 249 |
// Make UMMA desc
|
| 250 |
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
| 251 |
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 252 |
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 253 |
|
| 254 |
+
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
|
| 255 |
+
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
|
| 256 |
uint32_t q_stage_idx, q_phase;
|
| 257 |
uint32_t umma_phase = 1;
|
| 258 |
|
| 259 |
+
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
|
| 260 |
+
if (q_atom_idx != next_q_atom_idx) {
|
| 261 |
+
// Release previous Q empty (UMMA warp must participate to prevent
|
| 262 |
+
// running ahead of math warps in the Q pipeline)
|
| 263 |
+
if (q_iter_idx > 0)
|
| 264 |
+
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
| 265 |
+
|
| 266 |
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 267 |
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 268 |
}
|
| 269 |
|
| 270 |
+
q_atom_idx = next_q_atom_idx;
|
| 271 |
kv_idx = next_kv_idx;
|
| 272 |
|
| 273 |
+
// Wait KV arrival
|
| 274 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 275 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 276 |
|
|
|
|
| 278 |
#pragma unroll
|
| 279 |
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 280 |
empty_umma_barriers[i]->wait(umma_phase);
|
| 281 |
+
ptx::tcgen05_after_thread_sync();
|
| 282 |
#pragma unroll
|
| 283 |
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 284 |
+
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 285 |
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 286 |
+
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 287 |
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 288 |
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 289 |
}
|
|
|
|
| 291 |
}
|
| 292 |
umma_phase ^= 1;
|
| 293 |
}
|
| 294 |
+
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
|
| 295 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 296 |
+
} else if (warp_idx < kSpecWarpStart) {
|
| 297 |
+
// Math warpgroups for reduce
|
| 298 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 299 |
+
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
| 300 |
+
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 301 |
|
| 302 |
// Offsets
|
| 303 |
+
const auto math_warpgroup_idx = warpgroup_idx;
|
| 304 |
+
const auto tmem_start = math_warpgroup_idx * UMMA_N;
|
| 305 |
+
const auto math_thread_idx = warp_idx * 32 + lane_idx;
|
| 306 |
+
|
| 307 |
+
// Helper lambda for loading tensor memory
|
| 308 |
+
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
|
| 309 |
+
constexpr int N = decltype(num_elems_c)::value;
|
| 310 |
+
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
|
| 311 |
+
using Loader = cute::conditional_t<N == 32,
|
| 312 |
+
cute::SM100_TMEM_LOAD_32dp32b32x,
|
| 313 |
+
cute::SM100_TMEM_LOAD_32dp32b64x>;
|
| 314 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) {
|
| 315 |
+
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
|
| 316 |
+
}(cute::make_index_sequence<N>{});
|
| 317 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 318 |
+
};
|
| 319 |
|
| 320 |
+
// Local register buffers
|
| 321 |
+
float weights[kNextNAtom][kNumHeads];
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
// Initialize outside valid range to indicate no previous task
|
| 324 |
+
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
|
| 325 |
+
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
|
| 326 |
uint32_t q_stage_idx, q_phase;
|
| 327 |
uint32_t umma_phase = 0;
|
| 328 |
+
bool is_paired_atom = false;
|
| 329 |
|
| 330 |
+
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
|
| 331 |
+
// Q or atom changes
|
| 332 |
+
if (q_atom_idx != next_q_atom_idx) {
|
| 333 |
+
// Release last Q empty
|
| 334 |
if (q_iter_idx > 0)
|
| 335 |
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
| 336 |
|
|
|
|
| 340 |
|
| 341 |
// Read weights
|
| 342 |
#pragma unroll
|
| 343 |
+
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
| 344 |
+
#pragma unroll
|
| 345 |
+
for (uint32_t j = 0; j < kNumHeads; ++ j)
|
| 346 |
+
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
if constexpr (kIsVarlen) {
|
| 350 |
+
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
|
| 351 |
}
|
| 352 |
}
|
| 353 |
|
| 354 |
+
// Get current task indices
|
| 355 |
+
q_atom_idx = next_q_atom_idx;
|
| 356 |
kv_idx = next_kv_idx;
|
| 357 |
|
| 358 |
// Calculate KV offset in advance
|
| 359 |
+
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
|
| 360 |
|
|
|
|
| 361 |
// Wait TMA KV arrival
|
| 362 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 363 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 364 |
|
| 365 |
// Read per-KV scales
|
| 366 |
+
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
|
| 367 |
|
| 368 |
// Wait UMMA arrival
|
| 369 |
+
full_umma_barriers[math_warpgroup_idx]->wait(umma_phase);
|
| 370 |
+
ptx::tcgen05_after_thread_sync();
|
| 371 |
umma_phase ^= 1;
|
| 372 |
|
| 373 |
// Release KV empty
|
|
|
|
| 375 |
|
| 376 |
// Reduce over the head dim and store
|
| 377 |
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
+
const auto reduce_and_store = [&](auto num_iters_c) {
|
| 380 |
+
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
| 381 |
+
float accum[kNumHeads];
|
|
|
|
|
|
|
| 382 |
|
| 383 |
#pragma unroll
|
| 384 |
+
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
| 385 |
+
// Load accumulator from TMEM
|
| 386 |
+
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
| 387 |
+
|
| 388 |
+
// Accumulate weighted ReLU in parallel
|
| 389 |
+
auto sum_0 = make_float2(0, 0);
|
| 390 |
+
auto sum_1 = make_float2(0, 0);
|
| 391 |
+
|
| 392 |
+
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
| 393 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 394 |
+
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 395 |
+
return __ffma2_rn(a, b, sum);
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
#pragma unroll
|
| 399 |
+
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
| 400 |
+
sum_0 = transform(j, sum_0);
|
| 401 |
+
sum_1 = transform(j + 2, sum_1);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 405 |
+
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
| 406 |
+
|
| 407 |
+
// Store into the global memory
|
| 408 |
+
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
|
| 409 |
+
__syncwarp();
|
| 410 |
}
|
| 411 |
|
| 412 |
+
// Release TMEM empty
|
| 413 |
+
ptx::tcgen05_before_thread_sync();
|
| 414 |
+
empty_umma_barriers[math_warpgroup_idx]->arrive();
|
| 415 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
if constexpr (kIsVarlen) {
|
| 418 |
+
if (is_paired_atom)
|
| 419 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 420 |
+
else
|
| 421 |
+
reduce_and_store(cute::Int<1>{});
|
| 422 |
+
} else if constexpr (kPadOddN) {
|
| 423 |
+
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
| 424 |
+
reduce_and_store(cute::Int<1>{});
|
| 425 |
+
else
|
| 426 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 427 |
+
} else {
|
| 428 |
+
reduce_and_store(cute::Int<kNextNAtom>{});
|
| 429 |
}
|
| 430 |
}
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
// Free tensor memory
|
| 433 |
+
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
|
| 434 |
+
if (warp_idx == 0)
|
| 435 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 436 |
+
}
|
| 437 |
}
|
| 438 |
|
| 439 |
} // namespace deep_gemm
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh
CHANGED
|
@@ -4,20 +4,22 @@
|
|
| 4 |
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
| 8 |
#include <deep_gemm/common/utils.cuh>
|
| 9 |
-
#include <deep_gemm/
|
| 10 |
-
#include <deep_gemm/
|
|
|
|
|
|
|
| 11 |
|
| 12 |
namespace deep_gemm {
|
| 13 |
|
| 14 |
-
using namespace deep_gemm::sm100;
|
| 15 |
-
|
| 16 |
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 17 |
-
|
| 18 |
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 19 |
// Calculate the index of the bank group to be written in the atom
|
| 20 |
-
const auto
|
| 21 |
|
| 22 |
// Reshape the atom in another view and swizzle
|
| 23 |
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
|
|
@@ -37,7 +39,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|
| 37 |
uint32_t kSwizzleCDMode,
|
| 38 |
uint32_t kNumStages,
|
| 39 |
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
|
| 40 |
-
|
| 41 |
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 42 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 43 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
@@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 58 |
|
| 59 |
// Utils
|
| 60 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 61 |
-
const auto lane_idx = get_lane_idx();
|
| 62 |
|
| 63 |
// Align to 1024 bytes for swizzle-128B
|
| 64 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
|
@@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 70 |
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 71 |
|
| 72 |
// Real tensor memory size and offsets
|
| 73 |
-
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
|
| 74 |
|
| 75 |
// Prefetch TMA descriptors at the very beginning
|
| 76 |
if (warp_idx == 0 and cute::elect_one_sync()) {
|
|
@@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 82 |
// Data on shared memory (layout as ordered below)
|
| 83 |
// Fill D/A/B pointers
|
| 84 |
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 85 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 86 |
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 87 |
});
|
| 88 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 89 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 90 |
});
|
| 91 |
|
| 92 |
// Fill barriers
|
| 93 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 94 |
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 95 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 96 |
-
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 97 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 98 |
-
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
| 99 |
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
|
| 100 |
|
| 101 |
// Fill the tensor memory pointer
|
|
@@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 121 |
}
|
| 122 |
__syncthreads();
|
| 123 |
|
| 124 |
-
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 125 |
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 126 |
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 127 |
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
|
@@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 131 |
const uint32_t m_offset = shape_m * k_split_idx;
|
| 132 |
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
// Dispatch warps into different roles
|
| 135 |
if (warp_idx < kNumMMAThreads / 32) {
|
| 136 |
// TMA load warp
|
|
@@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 145 |
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 146 |
|
| 147 |
// Issue TMAs
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
// Arrive at full barriers
|
| 152 |
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
|
@@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 168 |
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 169 |
|
| 170 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 171 |
-
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 172 |
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 173 |
|
| 174 |
// Checks for MMA instructions
|
|
@@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 185 |
const auto& stage_idx = s % kNumStages;
|
| 186 |
const auto& cast_stage_idx = s % kNumCastStages;
|
| 187 |
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
|
| 188 |
-
tcgen05_after_thread_sync();
|
| 189 |
|
| 190 |
// Issue UMMA
|
| 191 |
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
|
@@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 194 |
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
|
| 195 |
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
|
| 196 |
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
|
| 197 |
-
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
|
| 198 |
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
|
| 199 |
}
|
| 200 |
|
|
@@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 218 |
|
| 219 |
// Wait UMMA arrival
|
| 220 |
tmem_full_barrier->wait(0);
|
| 221 |
-
tcgen05_after_thread_sync();
|
| 222 |
|
| 223 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 224 |
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
|
|
@@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 239 |
values[0], values[1], values[2], values[3]);
|
| 240 |
cutlass::arch::fence_view_async_tmem_load();
|
| 241 |
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
|
| 242 |
-
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 243 |
if constexpr (BLOCK_M == 64)
|
| 244 |
__syncwarp();
|
| 245 |
}
|
|
@@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 290 |
#pragma unroll
|
| 291 |
for (uint32_t i = 0; i < kNumLoads; i += 2) {
|
| 292 |
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
}
|
| 297 |
|
| 298 |
// Wait tensor memory empty
|
|
@@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 321 |
cutlass::arch::fence_view_async_tmem_store();
|
| 322 |
|
| 323 |
// Arrive for issuing MMAs
|
| 324 |
-
tcgen05_before_thread_sync();
|
| 325 |
full_cast_barriers[cast_stage_idx]->arrive();
|
| 326 |
}
|
| 327 |
|
| 328 |
// Intra-warp reduction and write back
|
| 329 |
#pragma unroll
|
| 330 |
for (uint32_t u = 0; u < 2; ++ u) {
|
| 331 |
-
const auto
|
| 332 |
-
const auto
|
| 333 |
if (lane_idx % 4 == 0 and m_idx < shape_m)
|
| 334 |
sqr_sum[m_offset + m_idx] = reduced_sum;
|
| 335 |
}
|
|
|
|
| 4 |
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
|
| 7 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 8 |
+
#include <deep_gemm/common/math.cuh>
|
| 9 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 10 |
#include <deep_gemm/common/utils.cuh>
|
| 11 |
+
#include <deep_gemm/mma/sm100.cuh>
|
| 12 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 13 |
+
#include <deep_gemm/ptx/tcgen05.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 15 |
|
| 16 |
namespace deep_gemm {
|
| 17 |
|
|
|
|
|
|
|
| 18 |
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 19 |
+
CUTLASS_DEVICE
|
| 20 |
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 21 |
// Calculate the index of the bank group to be written in the atom
|
| 22 |
+
const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
|
| 23 |
|
| 24 |
// Reshape the atom in another view and swizzle
|
| 25 |
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
|
|
|
|
| 39 |
uint32_t kSwizzleCDMode,
|
| 40 |
uint32_t kNumStages,
|
| 41 |
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
|
| 42 |
+
CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
|
| 43 |
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 44 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 45 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
|
|
| 60 |
|
| 61 |
// Utils
|
| 62 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 63 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 64 |
|
| 65 |
// Align to 1024 bytes for swizzle-128B
|
| 66 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
|
|
|
| 72 |
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 73 |
|
| 74 |
// Real tensor memory size and offsets
|
| 75 |
+
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
|
| 76 |
|
| 77 |
// Prefetch TMA descriptors at the very beginning
|
| 78 |
if (warp_idx == 0 and cute::elect_one_sync()) {
|
|
|
|
| 84 |
// Data on shared memory (layout as ordered below)
|
| 85 |
// Fill D/A/B pointers
|
| 86 |
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 87 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 88 |
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 89 |
});
|
| 90 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 91 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 92 |
});
|
| 93 |
|
| 94 |
// Fill barriers
|
| 95 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 96 |
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 97 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 98 |
+
auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 99 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 100 |
+
auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
| 101 |
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
|
| 102 |
|
| 103 |
// Fill the tensor memory pointer
|
|
|
|
| 123 |
}
|
| 124 |
__syncthreads();
|
| 125 |
|
| 126 |
+
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 127 |
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 128 |
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 129 |
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
|
|
|
| 133 |
const uint32_t m_offset = shape_m * k_split_idx;
|
| 134 |
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
| 135 |
|
| 136 |
+
// Wait for primary kernel completion
|
| 137 |
+
cudaGridDependencySynchronize();
|
| 138 |
+
|
| 139 |
// Dispatch warps into different roles
|
| 140 |
if (warp_idx < kNumMMAThreads / 32) {
|
| 141 |
// TMA load warp
|
|
|
|
| 150 |
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 151 |
|
| 152 |
// Issue TMAs
|
| 153 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
| 154 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
| 155 |
|
| 156 |
// Arrive at full barriers
|
| 157 |
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
|
|
|
| 173 |
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 174 |
|
| 175 |
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 176 |
+
auto b_desc = mma::sm100::make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 177 |
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 178 |
|
| 179 |
// Checks for MMA instructions
|
|
|
|
| 190 |
const auto& stage_idx = s % kNumStages;
|
| 191 |
const auto& cast_stage_idx = s % kNumCastStages;
|
| 192 |
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
|
| 193 |
+
ptx::tcgen05_after_thread_sync();
|
| 194 |
|
| 195 |
// Issue UMMA
|
| 196 |
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
|
|
|
| 199 |
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
|
| 200 |
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
|
| 201 |
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
|
| 202 |
+
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
|
| 203 |
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
|
| 204 |
}
|
| 205 |
|
|
|
|
| 223 |
|
| 224 |
// Wait UMMA arrival
|
| 225 |
tmem_full_barrier->wait(0);
|
| 226 |
+
ptx::tcgen05_after_thread_sync();
|
| 227 |
|
| 228 |
// Load from tensor memory into registers, and write shared memory with STSM
|
| 229 |
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
|
|
|
|
| 244 |
values[0], values[1], values[2], values[3]);
|
| 245 |
cutlass::arch::fence_view_async_tmem_load();
|
| 246 |
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
|
| 247 |
+
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 248 |
if constexpr (BLOCK_M == 64)
|
| 249 |
__syncwarp();
|
| 250 |
}
|
|
|
|
| 295 |
#pragma unroll
|
| 296 |
for (uint32_t i = 0; i < kNumLoads; i += 2) {
|
| 297 |
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
|
| 298 |
+
ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
|
| 299 |
+
uint32_values[0][i + 1], uint32_values[1][i + 1],
|
| 300 |
+
smem_ptr);
|
| 301 |
}
|
| 302 |
|
| 303 |
// Wait tensor memory empty
|
|
|
|
| 326 |
cutlass::arch::fence_view_async_tmem_store();
|
| 327 |
|
| 328 |
// Arrive for issuing MMAs
|
| 329 |
+
ptx::tcgen05_before_thread_sync();
|
| 330 |
full_cast_barriers[cast_stage_idx]->arrive();
|
| 331 |
}
|
| 332 |
|
| 333 |
// Intra-warp reduction and write back
|
| 334 |
#pragma unroll
|
| 335 |
for (uint32_t u = 0; u < 2; ++ u) {
|
| 336 |
+
const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y);
|
| 337 |
+
const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
|
| 338 |
if (lane_idx % 4 == 0 and m_idx < shape_m)
|
| 339 |
sqr_sum[m_offset + m_idx] = reduced_sum;
|
| 340 |
}
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh
CHANGED
|
@@ -11,14 +11,19 @@
|
|
| 11 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
#include <cute/arch/mma_sm100_desc.hpp>
|
| 13 |
|
|
|
|
| 14 |
#include <deep_gemm/common/utils.cuh>
|
| 15 |
-
#include <deep_gemm/common/
|
| 16 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
namespace deep_gemm {
|
| 19 |
|
| 20 |
-
using namespace deep_gemm::sm90;
|
| 21 |
-
|
| 22 |
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 23 |
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 24 |
uint32_t kNumGroups,
|
|
@@ -30,7 +35,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
|
| 30 |
uint32_t kNumSMs,
|
| 31 |
GemmType kGemmType, bool kWithAccumulation,
|
| 32 |
typename cd_dtype_t>
|
| 33 |
-
|
| 34 |
sm90_bf16_gemm_impl(int* grouped_layout,
|
| 35 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 36 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
@@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 51 |
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
|
| 52 |
|
| 53 |
// Types
|
| 54 |
-
using WGMMA = typename BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
|
| 55 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 56 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 57 |
|
|
@@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 61 |
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 62 |
|
| 63 |
// Shared memory
|
| 64 |
-
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
|
| 65 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
| 66 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
| 67 |
|
|
@@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 71 |
|
| 72 |
// Configs
|
| 73 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 74 |
-
const uint32_t lane_idx = get_lane_idx();
|
| 75 |
|
| 76 |
// Prefetch TMA descriptors at the very beginning
|
| 77 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
@@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 88 |
|
| 89 |
// D/A/B shared memory
|
| 90 |
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
|
| 91 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 92 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 93 |
});
|
| 94 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 95 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 96 |
});
|
| 97 |
|
| 98 |
// Fill barriers
|
| 99 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 100 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 101 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 102 |
|
| 103 |
// Initialize barriers
|
| 104 |
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
|
@@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 119 |
constexpr uint32_t kNumTMARegisters = 48;
|
| 120 |
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
|
| 121 |
|
|
|
|
|
|
|
|
|
|
| 122 |
// Block scheduler
|
| 123 |
uint32_t m_block_idx, n_block_idx;
|
| 124 |
-
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 125 |
|
| 126 |
// Pipeline and TMA phases
|
| 127 |
uint32_t stage_idx = 0, phase = 0;
|
|
@@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 151 |
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 152 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 153 |
|
| 154 |
-
const auto
|
| 155 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 156 |
// Wait consumer release
|
| 157 |
empty_barriers[stage_idx]->wait(phase ^ 1);
|
|
@@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 159 |
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 160 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 161 |
|
| 162 |
-
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 163 |
-
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 164 |
|
| 165 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 166 |
-
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
|
| 167 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 168 |
-
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
|
| 169 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 170 |
|
| 171 |
// Issue TMAs
|
| 172 |
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 173 |
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 174 |
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 175 |
-
|
| 176 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
|
| 177 |
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 178 |
-
|
| 179 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
|
| 180 |
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 181 |
-
|
| 182 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
|
| 183 |
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 184 |
-
|
| 185 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
|
| 186 |
-
|
| 187 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 188 |
}
|
| 189 |
}
|
|
@@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 203 |
|
| 204 |
// Merged stages only happens in NT normal GEMM cases
|
| 205 |
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 206 |
-
auto a_desc = make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
|
| 207 |
-
auto b_desc = make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 208 |
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 209 |
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 210 |
|
|
@@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 229 |
};
|
| 230 |
|
| 231 |
// TODO: remove some useless computation for unaligned Ms
|
| 232 |
-
const auto
|
| 233 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 234 |
-
const auto
|
| 235 |
-
const auto
|
| 236 |
|
| 237 |
// Wait TMA arrivals
|
| 238 |
full_barriers[stage_idx]->wait(phase);
|
|
@@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 240 |
// Commit WGMMA instructions
|
| 241 |
#pragma unroll
|
| 242 |
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 243 |
-
warpgroup_fence_operand(accum[i]);
|
| 244 |
-
warpgroup_arrive();
|
| 245 |
#pragma unroll
|
| 246 |
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 247 |
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
| 248 |
#pragma unroll
|
| 249 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 250 |
-
const uint32_t
|
| 251 |
-
a_desc.reg32_[0] = advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
|
| 252 |
a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
|
| 253 |
-
b_desc.reg32_[0] = advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
|
| 254 |
b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
|
| 255 |
WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
|
| 256 |
}
|
| 257 |
}
|
| 258 |
-
warpgroup_commit_batch();
|
| 259 |
#pragma unroll
|
| 260 |
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 261 |
-
warpgroup_fence_operand(accum[i]);
|
| 262 |
-
warpgroup_wait<0>();
|
| 263 |
|
| 264 |
// Notify barrier arrival
|
| 265 |
empty_barrier_arrive(stage_idx);
|
|
@@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 324 |
}
|
| 325 |
|
| 326 |
// NOTES: only 16 lanes' addresses are used
|
| 327 |
-
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 328 |
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 329 |
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 330 |
smem_ptr
|
|
@@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 341 |
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
|
| 342 |
#pragma unroll
|
| 343 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 344 |
-
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
|
| 345 |
-
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
|
| 346 |
}
|
| 347 |
}
|
| 348 |
}
|
|
@@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
|
| 350 |
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
|
| 351 |
|
| 352 |
// Use TMA store to write back to global memory
|
| 353 |
-
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 354 |
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
| 355 |
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
| 356 |
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
|
|
|
| 11 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
#include <cute/arch/mma_sm100_desc.hpp>
|
| 13 |
|
| 14 |
+
#include <deep_gemm/common/math.cuh>
|
| 15 |
#include <deep_gemm/common/utils.cuh>
|
| 16 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 17 |
+
#include <deep_gemm/common/types.cuh>
|
| 18 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 19 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 20 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 21 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 22 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 23 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 24 |
|
| 25 |
namespace deep_gemm {
|
| 26 |
|
|
|
|
|
|
|
| 27 |
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 28 |
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 29 |
uint32_t kNumGroups,
|
|
|
|
| 35 |
uint32_t kNumSMs,
|
| 36 |
GemmType kGemmType, bool kWithAccumulation,
|
| 37 |
typename cd_dtype_t>
|
| 38 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 39 |
sm90_bf16_gemm_impl(int* grouped_layout,
|
| 40 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 41 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
|
|
| 56 |
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
|
| 57 |
|
| 58 |
// Types
|
| 59 |
+
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
|
| 60 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 61 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 62 |
|
|
|
|
| 66 |
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 67 |
|
| 68 |
// Shared memory
|
| 69 |
+
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
|
| 70 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
| 71 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
| 72 |
|
|
|
|
| 76 |
|
| 77 |
// Configs
|
| 78 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 79 |
+
const uint32_t lane_idx = ptx::get_lane_idx();
|
| 80 |
|
| 81 |
// Prefetch TMA descriptors at the very beginning
|
| 82 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
|
|
| 93 |
|
| 94 |
// D/A/B shared memory
|
| 95 |
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
|
| 96 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 97 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 98 |
});
|
| 99 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 100 |
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 101 |
});
|
| 102 |
|
| 103 |
// Fill barriers
|
| 104 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 105 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 106 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 107 |
|
| 108 |
// Initialize barriers
|
| 109 |
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
|
|
|
| 124 |
constexpr uint32_t kNumTMARegisters = 48;
|
| 125 |
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
|
| 126 |
|
| 127 |
+
// Wait for primary kernel completion
|
| 128 |
+
cudaGridDependencySynchronize();
|
| 129 |
+
|
| 130 |
// Block scheduler
|
| 131 |
uint32_t m_block_idx, n_block_idx;
|
| 132 |
+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 133 |
|
| 134 |
// Pipeline and TMA phases
|
| 135 |
uint32_t stage_idx = 0, phase = 0;
|
|
|
|
| 159 |
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 160 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 161 |
|
| 162 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 163 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 164 |
// Wait consumer release
|
| 165 |
empty_barriers[stage_idx]->wait(phase ^ 1);
|
|
|
|
| 167 |
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 168 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 169 |
|
| 170 |
+
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 171 |
+
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 172 |
|
| 173 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 174 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 175 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 176 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
|
| 177 |
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 178 |
|
| 179 |
// Issue TMAs
|
| 180 |
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 181 |
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 182 |
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 183 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 184 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
|
| 185 |
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 186 |
+
tma::copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 187 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
|
| 188 |
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 189 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 190 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
|
| 191 |
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 192 |
+
tma::copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 193 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
|
|
|
|
| 194 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 195 |
}
|
| 196 |
}
|
|
|
|
| 210 |
|
| 211 |
// Merged stages only happens in NT normal GEMM cases
|
| 212 |
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 213 |
+
auto a_desc = mma::sm90::make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
|
| 214 |
+
auto b_desc = mma::sm90::make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 215 |
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 216 |
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 217 |
|
|
|
|
| 236 |
};
|
| 237 |
|
| 238 |
// TODO: remove some useless computation for unaligned Ms
|
| 239 |
+
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 240 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 241 |
+
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
|
| 242 |
+
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
|
| 243 |
|
| 244 |
// Wait TMA arrivals
|
| 245 |
full_barriers[stage_idx]->wait(phase);
|
|
|
|
| 247 |
// Commit WGMMA instructions
|
| 248 |
#pragma unroll
|
| 249 |
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 250 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 251 |
+
ptx::warpgroup_arrive();
|
| 252 |
#pragma unroll
|
| 253 |
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 254 |
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
| 255 |
#pragma unroll
|
| 256 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 257 |
+
const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
|
| 258 |
+
a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
|
| 259 |
a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
|
| 260 |
+
b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
|
| 261 |
b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
|
| 262 |
WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
|
| 263 |
}
|
| 264 |
}
|
| 265 |
+
ptx::warpgroup_commit_batch();
|
| 266 |
#pragma unroll
|
| 267 |
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 268 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 269 |
+
ptx::warpgroup_wait<0>();
|
| 270 |
|
| 271 |
// Notify barrier arrival
|
| 272 |
empty_barrier_arrive(stage_idx);
|
|
|
|
| 331 |
}
|
| 332 |
|
| 333 |
// NOTES: only 16 lanes' addresses are used
|
| 334 |
+
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 335 |
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 336 |
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 337 |
smem_ptr
|
|
|
|
| 348 |
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
|
| 349 |
#pragma unroll
|
| 350 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 351 |
+
ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
|
| 352 |
+
ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
|
| 353 |
}
|
| 354 |
}
|
| 355 |
}
|
|
|
|
| 357 |
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
|
| 358 |
|
| 359 |
// Use TMA store to write back to global memory
|
| 360 |
+
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 361 |
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
| 362 |
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
| 363 |
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh
CHANGED
|
@@ -4,26 +4,32 @@
|
|
| 4 |
#include <cutlass/arch/barrier.h>
|
| 5 |
#include <cutlass/arch/reg_reconfig.h>
|
| 6 |
|
|
|
|
| 7 |
#include <deep_gemm/common/utils.cuh>
|
| 8 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
namespace deep_gemm {
|
| 11 |
|
| 12 |
-
using namespace deep_gemm::sm90;
|
| 13 |
-
|
| 14 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 15 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 16 |
uint32_t kSplitFactor,
|
| 17 |
uint32_t kNumStages,
|
| 18 |
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
| 19 |
-
|
| 20 |
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
| 21 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 22 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 23 |
float *d) {
|
| 24 |
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 25 |
// Types
|
| 26 |
-
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
|
| 27 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 28 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 29 |
|
|
@@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
|
| 33 |
|
| 34 |
// Configs
|
| 35 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 36 |
-
const uint32_t lane_idx = get_lane_idx();
|
| 37 |
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
|
| 38 |
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
|
| 39 |
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
|
|
@@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
|
| 48 |
// Align to 1024 bytes for swizzle-128B
|
| 49 |
// Fill shared memory pointers
|
| 50 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 51 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 52 |
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
|
| 53 |
});
|
| 54 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 55 |
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 56 |
});
|
| 57 |
|
| 58 |
// Fill barriers
|
| 59 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 60 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 61 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 62 |
|
| 63 |
// Initialize barriers
|
| 64 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
|
@@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
|
| 80 |
constexpr uint32_t kNumMathRegisters = 232;
|
| 81 |
|
| 82 |
// Block indices
|
| 83 |
-
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
| 84 |
-
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
| 85 |
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 86 |
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 87 |
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 88 |
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 89 |
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 90 |
|
|
|
|
|
|
|
|
|
|
| 91 |
if (warp_idx >= kNumMathThreads / 32) {
|
| 92 |
// TMA warp-group for loading data
|
| 93 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
|
@@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
|
| 98 |
#pragma unroll
|
| 99 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 100 |
// Wait consumer release
|
| 101 |
-
const auto
|
| 102 |
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
|
| 103 |
|
| 104 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 105 |
-
const uint32_t
|
| 106 |
-
const uint32_t
|
| 107 |
-
const uint32_t
|
| 108 |
|
| 109 |
constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
|
| 110 |
-
|
| 111 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
|
| 112 |
-
|
| 113 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
|
| 114 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 115 |
}
|
|
@@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
|
| 125 |
// Launch MMAs
|
| 126 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 127 |
// Wait TMA arrivals
|
| 128 |
-
const auto
|
| 129 |
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 130 |
|
| 131 |
// Commit WGMMA instructions
|
| 132 |
#pragma unroll
|
| 133 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 134 |
-
warpgroup_fence_operand(accum[i]);
|
| 135 |
-
warpgroup_arrive();
|
| 136 |
#pragma unroll
|
| 137 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 138 |
-
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
|
| 139 |
-
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 140 |
WGMMA::wgmma(desc_a, desc_b, accum, 1);
|
| 141 |
}
|
| 142 |
-
warpgroup_commit_batch();
|
| 143 |
#pragma unroll
|
| 144 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 145 |
-
warpgroup_fence_operand(accum[i]);
|
| 146 |
-
warpgroup_wait<0>();
|
| 147 |
|
| 148 |
// Notify barrier arrival at the last warpgroup wave
|
| 149 |
empty_barriers[stage_idx]->arrive();
|
| 150 |
}
|
| 151 |
|
| 152 |
-
const auto
|
| 153 |
-
const auto
|
| 154 |
#pragma unroll
|
| 155 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 156 |
if (col + i * 8 >= SHAPE_N)
|
|
|
|
| 4 |
#include <cutlass/arch/barrier.h>
|
| 5 |
#include <cutlass/arch/reg_reconfig.h>
|
| 6 |
|
| 7 |
+
#include <deep_gemm/common/math.cuh>
|
| 8 |
#include <deep_gemm/common/utils.cuh>
|
| 9 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 10 |
+
#include <deep_gemm/common/types.cuh>
|
| 11 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 12 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 13 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 16 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 17 |
|
| 18 |
namespace deep_gemm {
|
| 19 |
|
|
|
|
|
|
|
| 20 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 21 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 22 |
uint32_t kSplitFactor,
|
| 23 |
uint32_t kNumStages,
|
| 24 |
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
| 25 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 26 |
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
| 27 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 28 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 29 |
float *d) {
|
| 30 |
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 31 |
// Types
|
| 32 |
+
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N>::type;
|
| 33 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 34 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 35 |
|
|
|
|
| 39 |
|
| 40 |
// Configs
|
| 41 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 42 |
+
const uint32_t lane_idx = ptx::get_lane_idx();
|
| 43 |
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
|
| 44 |
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
|
| 45 |
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
|
|
|
|
| 54 |
// Align to 1024 bytes for swizzle-128B
|
| 55 |
// Fill shared memory pointers
|
| 56 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 57 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 58 |
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
|
| 59 |
});
|
| 60 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 61 |
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 62 |
});
|
| 63 |
|
| 64 |
// Fill barriers
|
| 65 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 66 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 67 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 68 |
|
| 69 |
// Initialize barriers
|
| 70 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
|
|
|
| 86 |
constexpr uint32_t kNumMathRegisters = 232;
|
| 87 |
|
| 88 |
// Block indices
|
| 89 |
+
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
|
| 90 |
+
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
|
| 91 |
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 92 |
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 93 |
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 94 |
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 95 |
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 96 |
|
| 97 |
+
// Wait for primary kernel completion
|
| 98 |
+
cudaGridDependencySynchronize();
|
| 99 |
+
|
| 100 |
if (warp_idx >= kNumMathThreads / 32) {
|
| 101 |
// TMA warp-group for loading data
|
| 102 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
|
|
|
| 107 |
#pragma unroll
|
| 108 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 109 |
// Wait consumer release
|
| 110 |
+
const auto stage_idx = s % kNumStages;
|
| 111 |
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
|
| 112 |
|
| 113 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 114 |
+
const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
|
| 115 |
+
const uint32_t k_idx = sk_idx % SHAPE_K;
|
| 116 |
+
const uint32_t s_idx = sk_idx / SHAPE_K;
|
| 117 |
|
| 118 |
constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
|
| 119 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzle>(
|
| 120 |
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
|
| 121 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzle>(
|
| 122 |
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
|
| 123 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 124 |
}
|
|
|
|
| 134 |
// Launch MMAs
|
| 135 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 136 |
// Wait TMA arrivals
|
| 137 |
+
const auto stage_idx = s % kNumStages;
|
| 138 |
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 139 |
|
| 140 |
// Commit WGMMA instructions
|
| 141 |
#pragma unroll
|
| 142 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 143 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 144 |
+
ptx::warpgroup_arrive();
|
| 145 |
#pragma unroll
|
| 146 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 147 |
+
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
|
| 148 |
+
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 149 |
WGMMA::wgmma(desc_a, desc_b, accum, 1);
|
| 150 |
}
|
| 151 |
+
ptx::warpgroup_commit_batch();
|
| 152 |
#pragma unroll
|
| 153 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 154 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 155 |
+
ptx::warpgroup_wait<0>();
|
| 156 |
|
| 157 |
// Notify barrier arrival at the last warpgroup wave
|
| 158 |
empty_barriers[stage_idx]->arrive();
|
| 159 |
}
|
| 160 |
|
| 161 |
+
const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
|
| 162 |
+
const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
|
| 163 |
#pragma unroll
|
| 164 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 165 |
if (col + i * 8 >= SHAPE_N)
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh
CHANGED
|
@@ -6,18 +6,26 @@
|
|
| 6 |
#include <cutlass/arch/barrier.h>
|
| 7 |
#include <cutlass/arch/reg_reconfig.h>
|
| 8 |
|
|
|
|
| 9 |
#include <cute/arch/cluster_sm90.hpp>
|
| 10 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
|
|
|
|
|
|
|
| 13 |
#include <deep_gemm/common/utils.cuh>
|
| 14 |
-
#include <deep_gemm/common/
|
| 15 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
namespace deep_gemm {
|
| 18 |
|
| 19 |
-
using namespace deep_gemm::sm90;
|
| 20 |
-
|
| 21 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 22 |
uint32_t kNumGroups,
|
| 23 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
|
@@ -27,7 +35,7 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|
| 27 |
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 28 |
uint32_t kNumSMs,
|
| 29 |
GemmType kGemmType, typename cd_dtype_t>
|
| 30 |
-
|
| 31 |
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
| 32 |
int* grouped_layout,
|
| 33 |
cute::TmaDescriptor* tensor_map_buffer,
|
|
@@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 45 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
|
| 46 |
|
| 47 |
// Types
|
| 48 |
-
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
| 49 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 50 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 51 |
|
|
@@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 55 |
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 56 |
|
| 57 |
// Shared memory
|
| 58 |
-
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) *
|
| 59 |
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
| 60 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 61 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 62 |
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 63 |
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
| 64 |
-
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
|
| 65 |
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
|
| 66 |
|
| 67 |
// Configs
|
|
@@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 83 |
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 84 |
|
| 85 |
// Tensor maps on shared and global memory
|
| 86 |
-
auto smem_tensor_map_a =
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
auto
|
| 90 |
-
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
|
| 91 |
-
});
|
| 92 |
-
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
|
| 93 |
-
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
|
| 94 |
|
| 95 |
// Data on shared memory
|
| 96 |
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
|
| 97 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 98 |
-
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 99 |
});
|
| 100 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 101 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 102 |
});
|
| 103 |
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 104 |
-
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
| 105 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
|
| 106 |
});
|
| 107 |
-
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
|
| 108 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
|
| 109 |
});
|
| 110 |
|
| 111 |
// Barriers on shared memory
|
| 112 |
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
|
| 113 |
-
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
|
| 114 |
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
|
| 115 |
});
|
| 116 |
-
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
|
| 117 |
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
|
| 118 |
});
|
| 119 |
|
| 120 |
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
| 121 |
// Load tensormap A/B to shared memory
|
| 122 |
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 123 |
-
*smem_tensor_map_a
|
| 124 |
-
*
|
| 125 |
-
*smem_tensor_map_b[0] = tensor_map_b_base;
|
| 126 |
-
*smem_tensor_map_b[1] = tensor_map_b_base;
|
| 127 |
}
|
| 128 |
|
| 129 |
// Initialize barriers
|
|
@@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 149 |
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
|
| 150 |
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
|
| 151 |
|
|
|
|
|
|
|
|
|
|
| 152 |
// Block scheduler
|
| 153 |
uint32_t m_block_idx, n_block_idx;
|
| 154 |
-
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
|
| 155 |
|
| 156 |
// TMA and MMA pipeline
|
| 157 |
-
const auto
|
| 158 |
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
|
| 159 |
};
|
| 160 |
uint32_t iter_idx = 0;
|
|
@@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 165 |
|
| 166 |
// NOTES: only one thread (or warp) will be used
|
| 167 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 168 |
-
|
| 169 |
-
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
|
| 170 |
-
uint32_t last_group_idx = kNumGroups, sum_k = 0;
|
| 171 |
|
| 172 |
// Persistently schedule over blocks
|
| 173 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
|
@@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 177 |
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 178 |
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 179 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 180 |
-
|
| 181 |
-
const uint32_t
|
| 182 |
-
const uint32_t
|
| 183 |
-
const uint32_t
|
| 184 |
-
|
| 185 |
-
if (kGemmType == GemmType::KGroupedContiguous
|
| 186 |
-
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
|
| 187 |
-
const uint32_t& next_stage_idx = stage_idx ^ 1;
|
| 188 |
last_group_idx = scheduler.current_group_idx;
|
| 189 |
|
| 190 |
-
//
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
if (scheduler.current_num_valid_groups > 0) {
|
| 204 |
-
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
|
| 205 |
-
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
|
| 206 |
-
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
|
| 207 |
-
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
|
| 208 |
-
}
|
| 209 |
}
|
| 210 |
|
| 211 |
#pragma unroll kNumPipelineUnrolls
|
|
@@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 216 |
|
| 217 |
// Issue TMA
|
| 218 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 219 |
-
const uint32_t
|
| 220 |
-
const uint32_t
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
|
| 226 |
}
|
| 227 |
}
|
|
@@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 248 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 249 |
// Accumulation for WGMMA or CUDA promotion
|
| 250 |
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
|
| 251 |
-
const uint32_t
|
| 252 |
-
const uint32_t
|
| 253 |
-
const uint32_t
|
| 254 |
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
| 255 |
float2 scales_b[WGMMA::kNumAccum / 4];
|
| 256 |
|
|
@@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 272 |
|
| 273 |
// Read A scales
|
| 274 |
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 275 |
-
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
|
| 276 |
-
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
|
| 277 |
|
| 278 |
// Read B scales
|
| 279 |
#pragma unroll
|
| 280 |
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
| 281 |
-
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
|
| 282 |
|
| 283 |
// Commit WGMMA instructions
|
| 284 |
#pragma unroll
|
| 285 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 286 |
-
warpgroup_fence_operand(accum[i]);
|
| 287 |
-
warpgroup_arrive();
|
| 288 |
#pragma unroll
|
| 289 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 290 |
-
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
| 291 |
-
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 292 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 293 |
}
|
| 294 |
-
warpgroup_commit_batch();
|
| 295 |
#pragma unroll
|
| 296 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 297 |
-
warpgroup_fence_operand(accum[i]);
|
| 298 |
-
warpgroup_wait<0>();
|
| 299 |
|
| 300 |
// Notify barrier arrival
|
| 301 |
empty_barrier_arrive(stage_idx);
|
|
@@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
|
| 318 |
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
| 319 |
|
| 320 |
// Store to D shared memory
|
| 321 |
-
const auto
|
| 322 |
-
const auto
|
| 323 |
#pragma unroll
|
| 324 |
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 325 |
-
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
|
| 326 |
-
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
|
| 327 |
}
|
| 328 |
cute::tma_store_fence();
|
| 329 |
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
|
|
|
| 6 |
#include <cutlass/arch/barrier.h>
|
| 7 |
#include <cutlass/arch/reg_reconfig.h>
|
| 8 |
|
| 9 |
+
#include <cute/int_tuple.hpp>
|
| 10 |
#include <cute/arch/cluster_sm90.hpp>
|
| 11 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 12 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 13 |
|
| 14 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 15 |
+
#include <deep_gemm/common/math.cuh>
|
| 16 |
#include <deep_gemm/common/utils.cuh>
|
| 17 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 18 |
+
#include <deep_gemm/common/types.cuh>
|
| 19 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 20 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 21 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 22 |
+
#include <deep_gemm/ptx/tma.cuh>
|
| 23 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 24 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 25 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 26 |
|
| 27 |
namespace deep_gemm {
|
| 28 |
|
|
|
|
|
|
|
| 29 |
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 30 |
uint32_t kNumGroups,
|
| 31 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
|
|
|
| 35 |
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 36 |
uint32_t kNumSMs,
|
| 37 |
GemmType kGemmType, typename cd_dtype_t>
|
| 38 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 39 |
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
| 40 |
int* grouped_layout,
|
| 41 |
cute::TmaDescriptor* tensor_map_buffer,
|
|
|
|
| 53 |
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
|
| 54 |
|
| 55 |
// Types
|
| 56 |
+
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
|
| 57 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 58 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 59 |
|
|
|
|
| 63 |
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 64 |
|
| 65 |
// Shared memory
|
| 66 |
+
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0);
|
| 67 |
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
| 68 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 69 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 70 |
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 71 |
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
| 72 |
+
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
|
| 73 |
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
|
| 74 |
|
| 75 |
// Configs
|
|
|
|
| 91 |
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 92 |
|
| 93 |
// Tensor maps on shared and global memory
|
| 94 |
+
auto smem_tensor_map_a = reinterpret_cast<cute::TmaDescriptor*>(smem_buffer);
|
| 95 |
+
auto smem_tensor_map_b = smem_tensor_map_a + 1;
|
| 96 |
+
auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2;
|
| 97 |
+
auto gmem_tensor_map_b = gmem_tensor_map_a + 1;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
// Data on shared memory
|
| 100 |
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
|
| 101 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 102 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 103 |
});
|
| 104 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 105 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 106 |
});
|
| 107 |
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 108 |
+
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
|
| 109 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
|
| 110 |
});
|
| 111 |
+
auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) {
|
| 112 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
|
| 113 |
});
|
| 114 |
|
| 115 |
// Barriers on shared memory
|
| 116 |
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
|
| 117 |
+
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) {
|
| 118 |
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
|
| 119 |
});
|
| 120 |
+
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) {
|
| 121 |
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
|
| 122 |
});
|
| 123 |
|
| 124 |
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
| 125 |
// Load tensormap A/B to shared memory
|
| 126 |
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 127 |
+
*smem_tensor_map_a = tensor_map_a_base;
|
| 128 |
+
*smem_tensor_map_b = tensor_map_b_base;
|
|
|
|
|
|
|
| 129 |
}
|
| 130 |
|
| 131 |
// Initialize barriers
|
|
|
|
| 151 |
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
|
| 152 |
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
|
| 153 |
|
| 154 |
+
// Wait for primary kernel completion
|
| 155 |
+
cudaGridDependencySynchronize();
|
| 156 |
+
|
| 157 |
// Block scheduler
|
| 158 |
uint32_t m_block_idx, n_block_idx;
|
| 159 |
+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
|
| 160 |
|
| 161 |
// TMA and MMA pipeline
|
| 162 |
+
const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 163 |
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
|
| 164 |
};
|
| 165 |
uint32_t iter_idx = 0;
|
|
|
|
| 170 |
|
| 171 |
// NOTES: only one thread (or warp) will be used
|
| 172 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 173 |
+
uint32_t last_group_idx = kNumGroups;
|
|
|
|
|
|
|
| 174 |
|
| 175 |
// Persistently schedule over blocks
|
| 176 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
|
|
|
| 180 |
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 181 |
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 182 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 183 |
+
|
| 184 |
+
const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 185 |
+
const uint32_t m_idx = m_block_idx * BLOCK_M;
|
| 186 |
+
const uint32_t n_idx = n_block_idx * BLOCK_N;
|
| 187 |
+
|
| 188 |
+
if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) {
|
|
|
|
|
|
|
| 189 |
last_group_idx = scheduler.current_group_idx;
|
| 190 |
|
| 191 |
+
// Directly update current tensor map
|
| 192 |
+
const uint64_t current_k_offset = scheduler.current_k_cumsum;
|
| 193 |
+
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m);
|
| 194 |
+
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n);
|
| 195 |
+
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k);
|
| 196 |
+
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k);
|
| 197 |
+
*(gmem_tensor_map_a) = *(smem_tensor_map_a);
|
| 198 |
+
*(gmem_tensor_map_b) = *(smem_tensor_map_b);
|
| 199 |
+
ptx::tensor_map_release_gpu();
|
| 200 |
+
|
| 201 |
+
// Immediately acquire current tensor map
|
| 202 |
+
ptx::tensor_map_acquire_gpu(gmem_tensor_map_a);
|
| 203 |
+
ptx::tensor_map_acquire_gpu(gmem_tensor_map_b);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
}
|
| 205 |
|
| 206 |
#pragma unroll kNumPipelineUnrolls
|
|
|
|
| 211 |
|
| 212 |
// Issue TMA
|
| 213 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 214 |
+
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 215 |
+
const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
|
| 216 |
+
const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base);
|
| 217 |
+
const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base);
|
| 218 |
+
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
|
| 219 |
+
tma::copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
|
| 220 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
|
| 221 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
|
| 222 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
|
| 223 |
}
|
| 224 |
}
|
|
|
|
| 245 |
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 246 |
// Accumulation for WGMMA or CUDA promotion
|
| 247 |
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
|
| 248 |
+
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
| 249 |
+
const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
|
| 250 |
+
const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K);
|
| 251 |
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
| 252 |
float2 scales_b[WGMMA::kNumAccum / 4];
|
| 253 |
|
|
|
|
| 269 |
|
| 270 |
// Read A scales
|
| 271 |
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 272 |
+
auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0);
|
| 273 |
+
auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1);
|
| 274 |
|
| 275 |
// Read B scales
|
| 276 |
#pragma unroll
|
| 277 |
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
| 278 |
+
scales_b[i] = ptx::ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
|
| 279 |
|
| 280 |
// Commit WGMMA instructions
|
| 281 |
#pragma unroll
|
| 282 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 283 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 284 |
+
ptx::warpgroup_arrive();
|
| 285 |
#pragma unroll
|
| 286 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 287 |
+
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
| 288 |
+
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 289 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 290 |
}
|
| 291 |
+
ptx::warpgroup_commit_batch();
|
| 292 |
#pragma unroll
|
| 293 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 294 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 295 |
+
ptx::warpgroup_wait<0>();
|
| 296 |
|
| 297 |
// Notify barrier arrival
|
| 298 |
empty_barrier_arrive(stage_idx);
|
|
|
|
| 315 |
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
| 316 |
|
| 317 |
// Store to D shared memory
|
| 318 |
+
const auto smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
|
| 319 |
+
const auto smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
|
| 320 |
#pragma unroll
|
| 321 |
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 322 |
+
ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
|
| 323 |
+
ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
|
| 324 |
}
|
| 325 |
cute::tma_store_fence();
|
| 326 |
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh
CHANGED
|
@@ -10,17 +10,21 @@
|
|
| 10 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
|
| 13 |
-
#include <deep_gemm/common/
|
| 14 |
#include <deep_gemm/common/utils.cuh>
|
| 15 |
-
#include <deep_gemm/common/
|
| 16 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
namespace deep_gemm {
|
| 19 |
|
| 20 |
-
using namespace deep_gemm::sm90;
|
| 21 |
-
|
| 22 |
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
|
| 23 |
-
|
| 24 |
if (num_former_iters == kNumFormerIters) {
|
| 25 |
func(cute::Int<kNumFormerIters>{});
|
| 26 |
return;
|
|
@@ -35,12 +39,12 @@ template <cute::UMMA::Major kMajorSFB,
|
|
| 35 |
uint32_t kNumGroups,
|
| 36 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 37 |
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
|
| 38 |
-
uint32_t kNumStages,
|
| 39 |
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 40 |
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 41 |
uint32_t kNumSMs, GemmType kGemmType,
|
| 42 |
typename epilogue_type_t>
|
| 43 |
-
|
| 44 |
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
| 45 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 46 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
@@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 50 |
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 51 |
// Scaling checks
|
| 52 |
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
| 53 |
-
DG_STATIC_ASSERT(
|
|
|
|
|
|
|
| 54 |
|
| 55 |
// Types
|
| 56 |
-
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
| 57 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 58 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 59 |
|
|
@@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 64 |
|
| 65 |
// Shared memory
|
| 66 |
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
| 67 |
-
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
|
| 68 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 69 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 70 |
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 71 |
-
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
|
| 72 |
-
const uint32_t
|
| 73 |
-
const uint32_t
|
| 74 |
-
const uint32_t
|
| 75 |
|
| 76 |
// NOTES: Make sure we have enough shared memory for WGMMA padding
|
| 77 |
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 78 |
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
|
| 79 |
|
| 80 |
// Configs
|
| 81 |
-
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
|
| 82 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 83 |
-
const uint32_t lane_idx = get_lane_idx();
|
| 84 |
|
| 85 |
// Prefetch TMA descriptors at the very beginning
|
| 86 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
@@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 97 |
|
| 98 |
// Data on shared memory
|
| 99 |
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
| 100 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 101 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 102 |
});
|
| 103 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 104 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 105 |
});
|
| 106 |
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 107 |
-
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
| 108 |
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 109 |
});
|
| 110 |
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 111 |
|
| 112 |
// Fill barriers
|
| 113 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
|
| 114 |
-
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
|
| 115 |
-
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
|
| 116 |
|
| 117 |
// Initialize barriers
|
| 118 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
|
@@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 136 |
constexpr uint32_t kNumTMARegisters = 40;
|
| 137 |
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
|
| 138 |
|
|
|
|
|
|
|
|
|
|
| 139 |
// Block scheduler
|
| 140 |
uint32_t m_block_idx, n_block_idx;
|
| 141 |
-
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 142 |
|
| 143 |
// Pipeline and TMA phases
|
| 144 |
uint32_t stage_idx = 0, phase = 0;
|
|
@@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 177 |
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 178 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 179 |
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 180 |
-
|
| 181 |
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
| 182 |
num_tma_multicast_a, batch_idx);
|
| 183 |
-
|
| 184 |
-
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
|
| 185 |
num_tma_multicast_a);
|
| 186 |
|
| 187 |
// Issue TMA B
|
| 188 |
-
|
| 189 |
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
| 190 |
num_tma_multicast_b, batch_idx);
|
| 191 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
|
@@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 206 |
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 207 |
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
| 208 |
|
| 209 |
-
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
|
| 210 |
-
auto b_desc = make_smem_desc(smem_b[0], 1);
|
| 211 |
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 212 |
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 213 |
|
|
@@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 225 |
// Load B scales with math warp-groups
|
| 226 |
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
| 227 |
if (threadIdx.x >= 32) {
|
| 228 |
-
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
|
| 229 |
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
|
| 230 |
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
|
| 231 |
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
|
| 232 |
|
| 233 |
#pragma unroll
|
| 234 |
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
|
| 235 |
-
st_shared(smem_sfb + i,
|
| 236 |
}
|
| 237 |
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
| 238 |
|
|
@@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 259 |
// Skip useless computations
|
| 260 |
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
|
| 261 |
// The compiler must know the dynamic variable `num_former_iters`'s real value
|
| 262 |
-
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
| 263 |
-
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
| 264 |
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
| 265 |
|
| 266 |
// Dispatch `num_former_iters` and launch MMAs
|
| 267 |
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
|
| 268 |
#pragma unroll 8
|
| 269 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 270 |
-
const auto
|
| 271 |
-
const auto
|
| 272 |
|
| 273 |
// Read B scales
|
| 274 |
-
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
|
| 275 |
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
| 276 |
if constexpr (not kMustUseUniformedScaleB)
|
| 277 |
-
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
|
| 278 |
|
| 279 |
// Wait TMA arrivals
|
| 280 |
full_barriers[stage_idx]->wait(phase);
|
|
@@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 286 |
|
| 287 |
// Read A scales
|
| 288 |
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 289 |
-
auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
|
| 290 |
-
auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
|
| 291 |
|
| 292 |
// Commit WGMMA instructions
|
| 293 |
#pragma unroll
|
| 294 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 295 |
-
warpgroup_fence_operand(accum[i]);
|
| 296 |
-
warpgroup_arrive();
|
| 297 |
#pragma unroll
|
| 298 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 299 |
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
|
| 300 |
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
|
| 301 |
WGMMA::wgmma(a_desc, b_desc, accum, k);
|
| 302 |
}
|
| 303 |
-
warpgroup_commit_batch();
|
| 304 |
#pragma unroll
|
| 305 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 306 |
-
warpgroup_fence_operand(accum[i]);
|
| 307 |
-
warpgroup_wait<0>();
|
| 308 |
|
| 309 |
// Notify barrier arrival at the last warpgroup wave
|
| 310 |
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
|
@@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 325 |
#pragma unroll
|
| 326 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 327 |
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
| 328 |
-
const bool
|
| 329 |
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
| 330 |
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
| 331 |
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
|
@@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
|
| 399 |
}
|
| 400 |
|
| 401 |
// NOTES: only 16 lanes' addresses are used
|
| 402 |
-
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 403 |
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 404 |
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 405 |
smem_ptr
|
|
|
|
| 10 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
|
| 13 |
+
#include <deep_gemm/common/math.cuh>
|
| 14 |
#include <deep_gemm/common/utils.cuh>
|
| 15 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 16 |
+
#include <deep_gemm/common/types.cuh>
|
| 17 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 18 |
+
#include <deep_gemm/epilogue/transform.cuh>
|
| 19 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 20 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 21 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 22 |
+
#include <deep_gemm/scheduler/gemm.cuh>
|
| 23 |
|
| 24 |
namespace deep_gemm {
|
| 25 |
|
|
|
|
|
|
|
| 26 |
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
|
| 27 |
+
CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
|
| 28 |
if (num_former_iters == kNumFormerIters) {
|
| 29 |
func(cute::Int<kNumFormerIters>{});
|
| 30 |
return;
|
|
|
|
| 39 |
uint32_t kNumGroups,
|
| 40 |
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 41 |
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
|
| 42 |
+
uint32_t kNumStages,
|
| 43 |
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 44 |
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 45 |
uint32_t kNumSMs, GemmType kGemmType,
|
| 46 |
typename epilogue_type_t>
|
| 47 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 48 |
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
| 49 |
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 50 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
|
|
|
| 54 |
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 55 |
// Scaling checks
|
| 56 |
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
| 57 |
+
DG_STATIC_ASSERT(
|
| 58 |
+
math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or
|
| 59 |
+
(math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
| 60 |
|
| 61 |
// Types
|
| 62 |
+
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
|
| 63 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 64 |
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 65 |
|
|
|
|
| 70 |
|
| 71 |
// Shared memory
|
| 72 |
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
| 73 |
+
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
|
| 74 |
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 75 |
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 76 |
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 77 |
+
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
|
| 78 |
+
const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K);
|
| 79 |
+
const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K);
|
| 80 |
+
const uint32_t smem_sfb_size = math::align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
|
| 81 |
|
| 82 |
// NOTES: Make sure we have enough shared memory for WGMMA padding
|
| 83 |
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 84 |
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
|
| 85 |
|
| 86 |
// Configs
|
| 87 |
+
const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K);
|
| 88 |
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 89 |
+
const uint32_t lane_idx = ptx::get_lane_idx();
|
| 90 |
|
| 91 |
// Prefetch TMA descriptors at the very beginning
|
| 92 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
|
|
| 103 |
|
| 104 |
// Data on shared memory
|
| 105 |
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
| 106 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 107 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 108 |
});
|
| 109 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 110 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 111 |
});
|
| 112 |
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 113 |
+
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
|
| 114 |
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 115 |
});
|
| 116 |
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 117 |
|
| 118 |
// Fill barriers
|
| 119 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
|
| 120 |
+
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
|
| 121 |
+
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
|
| 122 |
|
| 123 |
// Initialize barriers
|
| 124 |
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
|
|
|
| 142 |
constexpr uint32_t kNumTMARegisters = 40;
|
| 143 |
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
|
| 144 |
|
| 145 |
+
// Wait for primary kernel completion
|
| 146 |
+
cudaGridDependencySynchronize();
|
| 147 |
+
|
| 148 |
// Block scheduler
|
| 149 |
uint32_t m_block_idx, n_block_idx;
|
| 150 |
+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 151 |
|
| 152 |
// Pipeline and TMA phases
|
| 153 |
uint32_t stage_idx = 0, phase = 0;
|
|
|
|
| 186 |
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 187 |
auto& full_barrier = *full_barriers[stage_idx];
|
| 188 |
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 189 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
|
| 190 |
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
| 191 |
num_tma_multicast_a, batch_idx);
|
| 192 |
+
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
|
| 193 |
+
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
|
| 194 |
num_tma_multicast_a);
|
| 195 |
|
| 196 |
// Issue TMA B
|
| 197 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
|
| 198 |
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
| 199 |
num_tma_multicast_b, batch_idx);
|
| 200 |
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
|
|
|
| 215 |
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 216 |
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
| 217 |
|
| 218 |
+
auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
|
| 219 |
+
auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1);
|
| 220 |
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 221 |
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 222 |
|
|
|
|
| 234 |
// Load B scales with math warp-groups
|
| 235 |
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
| 236 |
if (threadIdx.x >= 32) {
|
| 237 |
+
auto previous_group_offset = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
|
| 238 |
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
|
| 239 |
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
|
| 240 |
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
|
| 241 |
|
| 242 |
#pragma unroll
|
| 243 |
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
|
| 244 |
+
ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]);
|
| 245 |
}
|
| 246 |
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
| 247 |
|
|
|
|
| 268 |
// Skip useless computations
|
| 269 |
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
|
| 270 |
// The compiler must know the dynamic variable `num_former_iters`'s real value
|
| 271 |
+
constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
| 272 |
+
constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
| 273 |
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
| 274 |
|
| 275 |
// Dispatch `num_former_iters` and launch MMAs
|
| 276 |
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
|
| 277 |
#pragma unroll 8
|
| 278 |
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 279 |
+
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
|
| 280 |
+
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
|
| 281 |
|
| 282 |
// Read B scales
|
| 283 |
+
float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1;
|
| 284 |
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
| 285 |
if constexpr (not kMustUseUniformedScaleB)
|
| 286 |
+
scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales);
|
| 287 |
|
| 288 |
// Wait TMA arrivals
|
| 289 |
full_barriers[stage_idx]->wait(phase);
|
|
|
|
| 295 |
|
| 296 |
// Read A scales
|
| 297 |
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 298 |
+
auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
|
| 299 |
+
auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
|
| 300 |
|
| 301 |
// Commit WGMMA instructions
|
| 302 |
#pragma unroll
|
| 303 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 304 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 305 |
+
ptx::warpgroup_arrive();
|
| 306 |
#pragma unroll
|
| 307 |
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 308 |
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
|
| 309 |
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
|
| 310 |
WGMMA::wgmma(a_desc, b_desc, accum, k);
|
| 311 |
}
|
| 312 |
+
ptx::warpgroup_commit_batch();
|
| 313 |
#pragma unroll
|
| 314 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 315 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 316 |
+
ptx::warpgroup_wait<0>();
|
| 317 |
|
| 318 |
// Notify barrier arrival at the last warpgroup wave
|
| 319 |
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
|
|
|
| 334 |
#pragma unroll
|
| 335 |
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 336 |
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
| 337 |
+
const bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
| 338 |
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
| 339 |
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
| 340 |
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
|
|
|
| 408 |
}
|
| 409 |
|
| 410 |
// NOTES: only 16 lanes' addresses are used
|
| 411 |
+
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 412 |
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 413 |
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 414 |
smem_ptr
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
CHANGED
|
@@ -7,36 +7,31 @@
|
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
#include <cute/arch/mma_sm90_desc.hpp>
|
| 9 |
|
|
|
|
|
|
|
| 10 |
#include <deep_gemm/common/utils.cuh>
|
| 11 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
namespace deep_gemm {
|
| 14 |
|
| 15 |
-
using namespace deep_gemm::sm90;
|
| 16 |
-
|
| 17 |
-
// ReSharper disable once CppNotAllPathsReturnValue
|
| 18 |
-
template <uint32_t kHeadDim>
|
| 19 |
-
static constexpr int to_swizzle_cute_type() {
|
| 20 |
-
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
|
| 21 |
-
if constexpr (kHeadDim == 32)
|
| 22 |
-
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
|
| 23 |
-
if constexpr (kHeadDim == 64)
|
| 24 |
-
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
|
| 25 |
-
if constexpr (kHeadDim == 128)
|
| 26 |
-
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 30 |
bool kIsCompressedLogits,
|
| 31 |
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 32 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 33 |
-
uint32_t
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 36 |
-
const uint32_t max_seqlen_k, const
|
| 37 |
uint32_t* cu_seq_len_k_start,
|
| 38 |
uint32_t* cu_seq_len_k_end,
|
| 39 |
-
|
| 40 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 41 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 42 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
|
@@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 44 |
// TODO: consider TMA multicast
|
| 45 |
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 46 |
// Q should be load only at once for a block
|
| 47 |
-
const auto
|
| 48 |
|
| 49 |
// Types
|
| 50 |
-
using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
|
| 51 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 52 |
|
| 53 |
// Prefetch TMA descriptors
|
|
@@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 74 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 75 |
|
| 76 |
// Data on shared memory
|
| 77 |
-
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 78 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 79 |
SMEM_Q_SIZE_PER_STAGE * i);
|
| 80 |
});
|
| 81 |
-
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 82 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 83 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 84 |
});
|
| 85 |
-
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 86 |
return reinterpret_cast<float*>(smem_buffer +
|
| 87 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 88 |
});
|
| 89 |
-
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 90 |
return reinterpret_cast<float*>(smem_buffer +
|
| 91 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
|
| 92 |
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
|
@@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 94 |
|
| 95 |
// TMA barriers
|
| 96 |
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 97 |
-
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 98 |
-
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 99 |
-
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 100 |
-
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 101 |
|
| 102 |
// Initialize barriers
|
| 103 |
-
const bool
|
| 104 |
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 105 |
#pragma unroll
|
| 106 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
|
@@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 123 |
constexpr uint32_t kNumMathRegisters = 112;
|
| 124 |
|
| 125 |
// Block scheduler
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
};
|
| 130 |
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 131 |
-
const auto
|
| 132 |
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 133 |
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 134 |
|
| 135 |
#pragma unroll
|
| 136 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 137 |
-
const auto
|
| 138 |
-
seq_k_start[i] =
|
| 139 |
-
seq_k_end[i] =
|
| 140 |
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 141 |
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 142 |
}
|
|
|
|
| 143 |
start = start / 4 * 4;
|
| 144 |
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 145 |
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 146 |
-
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 147 |
};
|
| 148 |
|
| 149 |
// KV pipeline
|
| 150 |
uint32_t num_total_kv_blocks = 0;
|
| 151 |
-
const auto
|
| 152 |
return {
|
| 153 |
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 154 |
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
| 155 |
};
|
| 156 |
};
|
| 157 |
|
|
|
|
|
|
|
|
|
|
| 158 |
if (threadIdx.x >= kNumMathThreads) {
|
| 159 |
// TMA warp-group for loading data
|
| 160 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
|
@@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 165 |
|
| 166 |
// Prefetch
|
| 167 |
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
| 168 |
-
|
| 169 |
-
|
| 170 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 171 |
};
|
| 172 |
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
|
@@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 192 |
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 193 |
|
| 194 |
// Issue TMA KV
|
| 195 |
-
|
| 196 |
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
| 197 |
-
|
| 198 |
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
| 199 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 200 |
}
|
|
@@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 212 |
const auto& thread_idx = threadIdx.x % kNumMathThreads;
|
| 213 |
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
|
| 214 |
const auto& warpgroup_idx = warp_idx / 4;
|
| 215 |
-
const auto& lane_idx = get_lane_idx();
|
| 216 |
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
|
| 217 |
|
| 218 |
const auto& warp_offset = warp_idx * 16;
|
|
@@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 230 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 231 |
#pragma unroll
|
| 232 |
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 233 |
-
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 234 |
}
|
| 235 |
|
| 236 |
// Compute over KV blocks
|
|
@@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 242 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 243 |
|
| 244 |
// Read per-KV scales
|
| 245 |
-
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
|
| 246 |
-
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
|
| 247 |
|
| 248 |
// Issue WGMMA
|
| 249 |
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
| 250 |
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 251 |
#pragma unroll
|
| 252 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 253 |
-
warpgroup_fence_operand(accum[i]);
|
| 254 |
-
warpgroup_arrive();
|
| 255 |
#pragma unroll
|
| 256 |
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 257 |
-
auto desc_a =
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
| 261 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 262 |
}
|
| 263 |
-
warpgroup_commit_batch();
|
| 264 |
#pragma unroll
|
| 265 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 266 |
-
warpgroup_fence_operand(accum[i]);
|
| 267 |
-
warpgroup_wait<0>();
|
| 268 |
|
| 269 |
// Release KV empty
|
| 270 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
|
@@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 278 |
#pragma unroll
|
| 279 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 280 |
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 281 |
-
const auto
|
| 282 |
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 283 |
};
|
| 284 |
|
|
@@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
|
| 302 |
}
|
| 303 |
|
| 304 |
// Store into the global memory
|
| 305 |
-
|
| 306 |
-
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
| 307 |
if constexpr (kIsCompressedLogits) {
|
| 308 |
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
|
| 309 |
-
logits[
|
| 310 |
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
|
| 311 |
-
logits[
|
| 312 |
} else {
|
| 313 |
-
logits[
|
| 314 |
-
logits[
|
| 315 |
}
|
| 316 |
}
|
| 317 |
}
|
|
|
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
#include <cute/arch/mma_sm90_desc.hpp>
|
| 9 |
|
| 10 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 11 |
+
#include <deep_gemm/common/math.cuh>
|
| 12 |
#include <deep_gemm/common/utils.cuh>
|
| 13 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 14 |
+
#include <deep_gemm/common/types.cuh>
|
| 15 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 17 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 18 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 19 |
|
| 20 |
namespace deep_gemm {
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 23 |
bool kIsCompressedLogits,
|
| 24 |
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 25 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 26 |
+
uint32_t kNumSMs,
|
| 27 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 28 |
+
typename logits_dtype_t>
|
| 29 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
| 30 |
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 31 |
+
const uint32_t max_seqlen_k, const uint32_t stride_logits,
|
| 32 |
uint32_t* cu_seq_len_k_start,
|
| 33 |
uint32_t* cu_seq_len_k_end,
|
| 34 |
+
logits_dtype_t* logits,
|
| 35 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 36 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 37 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
|
|
|
| 39 |
// TODO: consider TMA multicast
|
| 40 |
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 41 |
// Q should be load only at once for a block
|
| 42 |
+
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
|
| 43 |
|
| 44 |
// Types
|
| 45 |
+
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_Q * kNumHeads>::type;
|
| 46 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 47 |
|
| 48 |
// Prefetch TMA descriptors
|
|
|
|
| 69 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 70 |
|
| 71 |
// Data on shared memory
|
| 72 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 73 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 74 |
SMEM_Q_SIZE_PER_STAGE * i);
|
| 75 |
});
|
| 76 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 77 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 78 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 79 |
});
|
| 80 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 81 |
return reinterpret_cast<float*>(smem_buffer +
|
| 82 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 83 |
});
|
| 84 |
+
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
|
| 85 |
return reinterpret_cast<float*>(smem_buffer +
|
| 86 |
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
|
| 87 |
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
|
|
|
| 89 |
|
| 90 |
// TMA barriers
|
| 91 |
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 92 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 93 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 94 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 95 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 96 |
|
| 97 |
// Initialize barriers
|
| 98 |
+
const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
|
| 99 |
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 100 |
#pragma unroll
|
| 101 |
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
|
|
|
| 118 |
constexpr uint32_t kNumMathRegisters = 112;
|
| 119 |
|
| 120 |
// Block scheduler
|
| 121 |
+
const auto sm_idx = blockIdx.x;
|
| 122 |
+
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
|
| 123 |
+
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
| 124 |
+
return {block_q_idx + kNumSMs, q_iter_idx + 1};
|
| 125 |
};
|
| 126 |
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 127 |
+
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
| 128 |
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 129 |
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 130 |
|
| 131 |
#pragma unroll
|
| 132 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 133 |
+
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
| 134 |
+
seq_k_start[i] = cu_seq_len_k_start[q_idx];
|
| 135 |
+
seq_k_end[i] = cu_seq_len_k_end[q_idx];
|
| 136 |
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 137 |
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 138 |
}
|
| 139 |
+
// TMA alignment requirements for SF KV
|
| 140 |
start = start / 4 * 4;
|
| 141 |
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 142 |
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 143 |
+
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 144 |
};
|
| 145 |
|
| 146 |
// KV pipeline
|
| 147 |
uint32_t num_total_kv_blocks = 0;
|
| 148 |
+
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 149 |
return {
|
| 150 |
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 151 |
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
| 152 |
};
|
| 153 |
};
|
| 154 |
|
| 155 |
+
// Wait for primary kernel completion
|
| 156 |
+
cudaGridDependencySynchronize();
|
| 157 |
+
|
| 158 |
if (threadIdx.x >= kNumMathThreads) {
|
| 159 |
// TMA warp-group for loading data
|
| 160 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
|
|
|
| 165 |
|
| 166 |
// Prefetch
|
| 167 |
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
| 168 |
+
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
| 169 |
+
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
| 170 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 171 |
};
|
| 172 |
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
|
|
|
| 192 |
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 193 |
|
| 194 |
// Issue TMA KV
|
| 195 |
+
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 196 |
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
| 197 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 198 |
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
| 199 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 200 |
}
|
|
|
|
| 212 |
const auto& thread_idx = threadIdx.x % kNumMathThreads;
|
| 213 |
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
|
| 214 |
const auto& warpgroup_idx = warp_idx / 4;
|
| 215 |
+
const auto& lane_idx = ptx::get_lane_idx();
|
| 216 |
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
|
| 217 |
|
| 218 |
const auto& warp_offset = warp_idx * 16;
|
|
|
|
| 230 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 231 |
#pragma unroll
|
| 232 |
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 233 |
+
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 234 |
}
|
| 235 |
|
| 236 |
// Compute over KV blocks
|
|
|
|
| 242 |
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 243 |
|
| 244 |
// Read per-KV scales
|
| 245 |
+
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
|
| 246 |
+
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
|
| 247 |
|
| 248 |
// Issue WGMMA
|
| 249 |
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
| 250 |
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 251 |
#pragma unroll
|
| 252 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 253 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 254 |
+
ptx::warpgroup_arrive();
|
| 255 |
#pragma unroll
|
| 256 |
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 257 |
+
auto desc_a = mma::sm90::make_smem_desc(
|
| 258 |
+
smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
|
| 259 |
+
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 260 |
+
auto desc_b = mma::sm90::make_smem_desc(
|
| 261 |
+
smem_q[q_stage_idx] + k * WGMMA::K,
|
| 262 |
+
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 263 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 264 |
}
|
| 265 |
+
ptx::warpgroup_commit_batch();
|
| 266 |
#pragma unroll
|
| 267 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 268 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 269 |
+
ptx::warpgroup_wait<0>();
|
| 270 |
|
| 271 |
// Release KV empty
|
| 272 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
|
|
|
| 280 |
#pragma unroll
|
| 281 |
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 282 |
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 283 |
+
const auto transform = [&](const uint32_t& j) {
|
| 284 |
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 285 |
};
|
| 286 |
|
|
|
|
| 304 |
}
|
| 305 |
|
| 306 |
// Store into the global memory
|
| 307 |
+
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
|
|
|
|
| 308 |
if constexpr (kIsCompressedLogits) {
|
| 309 |
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
|
| 310 |
+
logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_0);
|
| 311 |
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
|
| 312 |
+
logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_1);
|
| 313 |
} else {
|
| 314 |
+
logits[q_offset + kv_offset + v_0_offset] = static_cast<logits_dtype_t>(v_0);
|
| 315 |
+
logits[q_offset + kv_offset + v_1_offset] = static_cast<logits_dtype_t>(v_1);
|
| 316 |
}
|
| 317 |
}
|
| 318 |
}
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
CHANGED
|
@@ -6,133 +6,46 @@
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
|
|
|
|
|
|
| 9 |
#include <deep_gemm/common/utils.cuh>
|
| 10 |
-
#include <deep_gemm/common/
|
| 11 |
-
#include <deep_gemm/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
namespace deep_gemm {
|
| 14 |
|
| 15 |
-
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
| 16 |
-
__global__ __launch_bounds__(32, 1)
|
| 17 |
-
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
|
| 18 |
-
const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
| 19 |
-
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
| 20 |
-
const uint32_t lane_idx = get_lane_idx();
|
| 21 |
-
|
| 22 |
-
uint32_t num_segs[kAlignedBatchSize / 32];
|
| 23 |
-
#pragma unroll
|
| 24 |
-
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
| 25 |
-
const uint32_t q_idx = k * 32 + lane_idx;
|
| 26 |
-
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
| 27 |
-
const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0);
|
| 28 |
-
num_segs[k] = ceil_div(context_len, SPLIT_KV);
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
|
| 32 |
-
uint32_t sum = 0;
|
| 33 |
-
#pragma unroll
|
| 34 |
-
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
| 35 |
-
uint32_t x = num_segs[k];
|
| 36 |
-
#pragma unroll
|
| 37 |
-
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
| 38 |
-
const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
|
| 39 |
-
x += (lane_idx >= offset ? y : 0);
|
| 40 |
-
}
|
| 41 |
-
x += sum;
|
| 42 |
-
prefix_sum[k * 32 + lane_idx] = x;
|
| 43 |
-
sum = __shfl_sync(0xffffffff, x, 31);
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
|
| 47 |
-
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
| 48 |
-
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
| 49 |
-
uint32_t q_idx = 0;
|
| 50 |
-
while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
|
| 51 |
-
++ q_idx;
|
| 52 |
-
const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
|
| 53 |
-
__syncwarp();
|
| 54 |
-
|
| 55 |
-
schedule_metadata[sm_idx * 2] = q_idx;
|
| 56 |
-
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
| 57 |
-
}
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
template <uint32_t kNextN, bool kIsContextLens2D,
|
| 61 |
-
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
|
| 62 |
-
struct PagedMQALogitsScheduler {
|
| 63 |
-
uint32_t batch_size;
|
| 64 |
-
const uint32_t* context_lens;
|
| 65 |
-
|
| 66 |
-
uint32_t current_q_idx, current_kv_idx;
|
| 67 |
-
uint32_t end_q_idx, end_kv_idx;
|
| 68 |
-
uint32_t current_num_kv;
|
| 69 |
-
|
| 70 |
-
__device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) {
|
| 71 |
-
const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
| 72 |
-
return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0;
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
__device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
|
| 76 |
-
const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
| 77 |
-
this->batch_size = batch_size;
|
| 78 |
-
this->context_lens = context_lens;
|
| 79 |
-
|
| 80 |
-
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
|
| 81 |
-
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
|
| 82 |
-
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
|
| 83 |
-
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
|
| 84 |
-
|
| 85 |
-
current_num_kv = get_num_kv(current_q_idx);
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
__device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
| 89 |
-
q_idx = current_q_idx;
|
| 90 |
-
kv_idx = current_kv_idx;
|
| 91 |
-
num_kv = current_num_kv;
|
| 92 |
-
|
| 93 |
-
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
|
| 94 |
-
return false;
|
| 95 |
-
|
| 96 |
-
current_kv_idx += kNumBlocksPerSplit;
|
| 97 |
-
if (current_kv_idx >= current_num_kv) {
|
| 98 |
-
++ current_q_idx;
|
| 99 |
-
current_kv_idx = 0;
|
| 100 |
-
current_num_kv = get_num_kv(current_q_idx);
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
return true;
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
__device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
|
| 107 |
-
return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
|
| 108 |
-
}
|
| 109 |
-
};
|
| 110 |
-
|
| 111 |
-
using namespace deep_gemm::sm90;
|
| 112 |
-
|
| 113 |
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 114 |
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 115 |
-
bool kIsContextLens2D,
|
| 116 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 117 |
uint32_t SPLIT_KV,
|
| 118 |
-
uint32_t kNumTMAThreads, uint32_t kNumMathThreads
|
| 119 |
-
|
|
|
|
| 120 |
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 121 |
-
const
|
| 122 |
-
const uint32_t* context_lens,
|
| 123 |
-
const uint32_t* block_table, const uint32_t*
|
|
|
|
| 124 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 125 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 126 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 127 |
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
|
|
|
|
|
|
| 128 |
// Types
|
| 129 |
-
using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
|
| 130 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 131 |
|
| 132 |
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 133 |
-
const auto
|
| 134 |
-
const auto
|
| 135 |
-
const auto
|
| 136 |
|
| 137 |
// Prefetch TMA descriptors
|
| 138 |
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
|
@@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 150 |
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 151 |
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 152 |
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
| 153 |
-
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 154 |
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
| 155 |
-
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
| 156 |
|
| 157 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 158 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 159 |
-
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 160 |
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
| 161 |
-
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
| 162 |
|
| 163 |
// Align to swizzling alignment bytes
|
| 164 |
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
|
@@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 166 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 167 |
|
| 168 |
// Q data and barriers on shared memory
|
| 169 |
-
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 170 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 171 |
});
|
| 172 |
-
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 173 |
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 174 |
});
|
| 175 |
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 176 |
-
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
| 177 |
-
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
| 178 |
|
| 179 |
// Separate math warpgroups and tma load warps into KV groups
|
| 180 |
// Each math warpgroup corresponds to a tma load warp
|
| 181 |
-
const auto
|
| 182 |
|
| 183 |
// Per group KV data and barriers on shared memory
|
| 184 |
-
const auto
|
| 185 |
-
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 186 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
| 187 |
});
|
| 188 |
-
auto smem_kv_scales =
|
| 189 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 190 |
});
|
| 191 |
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 192 |
-
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
| 193 |
-
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
| 194 |
|
| 195 |
// Initialize barriers
|
| 196 |
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
@@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 218 |
constexpr uint32_t kNumTMARegisters = 64;
|
| 219 |
constexpr uint32_t kNumMathRegisters = 104;
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
// Scheduler
|
| 222 |
-
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups
|
|
|
|
| 223 |
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
| 224 |
|
| 225 |
// Q and KV pipeline
|
| 226 |
-
const auto
|
| 227 |
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 228 |
};
|
| 229 |
-
const auto
|
| 230 |
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 231 |
};
|
| 232 |
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
|
@@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 237 |
if (kv_group_idx >= kNumMathWarpGroups)
|
| 238 |
return;
|
| 239 |
|
| 240 |
-
const auto
|
| 241 |
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
| 242 |
-
|
| 243 |
-
|
| 244 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 245 |
}
|
| 246 |
};
|
|
@@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 259 |
|
| 260 |
while (fetched_next_task) {
|
| 261 |
// Prefetch next Q when current Q changes
|
| 262 |
-
bool prefetch_q = (q_idx != next_q_idx and scheduler.
|
| 263 |
q_idx = next_q_idx;
|
| 264 |
kv_idx = next_kv_idx;
|
| 265 |
num_kv = next_num_kv;
|
|
@@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 276 |
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
| 277 |
kv_block_idx_ptr = 0;
|
| 278 |
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
|
| 279 |
-
|
| 280 |
}
|
| 281 |
-
const auto
|
| 282 |
|
| 283 |
// Wait KV consumer release
|
| 284 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
|
@@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 286 |
|
| 287 |
// Issue TMA KV
|
| 288 |
if (cute::elect_one_sync()) {
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 294 |
}
|
| 295 |
|
|
@@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 301 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 302 |
|
| 303 |
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
|
| 304 |
-
const auto
|
| 305 |
-
const auto
|
| 306 |
-
const auto
|
| 307 |
|
| 308 |
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 309 |
uint32_t q_idx = batch_size, kv_idx;
|
|
@@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 326 |
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 327 |
#pragma unroll
|
| 328 |
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 329 |
-
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 330 |
}
|
| 331 |
}
|
| 332 |
|
|
@@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 335 |
kv_idx = next_kv_idx;
|
| 336 |
|
| 337 |
// Calculate KV offset in advance
|
| 338 |
-
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
| 339 |
|
| 340 |
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
| 341 |
// Wait TMA KV arrival
|
|
@@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 347 |
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 348 |
#pragma unroll
|
| 349 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 350 |
-
warpgroup_fence_operand(accum[i]);
|
| 351 |
-
warpgroup_arrive();
|
| 352 |
#pragma unroll
|
| 353 |
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 354 |
-
auto desc_a =
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 357 |
}
|
| 358 |
-
warpgroup_commit_batch();
|
| 359 |
#pragma unroll
|
| 360 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 361 |
-
warpgroup_fence_operand(accum[i]);
|
| 362 |
|
| 363 |
// Read per-KV scales
|
| 364 |
-
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
|
| 365 |
-
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
|
| 366 |
|
| 367 |
// Wait WGMMA
|
| 368 |
-
warpgroup_wait<0>();
|
| 369 |
|
| 370 |
// Release KV empty
|
| 371 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
|
@@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 378 |
#pragma unroll
|
| 379 |
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 380 |
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 381 |
-
const auto
|
| 382 |
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 383 |
};
|
| 384 |
|
|
@@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
|
| 396 |
// Inter-thread reduction
|
| 397 |
#pragma unroll
|
| 398 |
for (uint32_t j = 0; j < 2; ++ j) {
|
| 399 |
-
const auto
|
| 400 |
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
| 401 |
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
| 402 |
}
|
| 403 |
|
| 404 |
// Store into the global memory
|
| 405 |
// NOTES: we have redundant writes here, consider more carefully
|
| 406 |
-
logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
|
| 407 |
-
logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
|
| 408 |
}
|
| 409 |
}
|
| 410 |
}
|
|
|
|
| 6 |
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
|
| 9 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 10 |
+
#include <deep_gemm/common/math.cuh>
|
| 11 |
#include <deep_gemm/common/utils.cuh>
|
| 12 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 13 |
+
#include <deep_gemm/common/types.cuh>
|
| 14 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 16 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 17 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 18 |
+
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
|
| 19 |
|
| 20 |
namespace deep_gemm {
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 23 |
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 24 |
+
bool kIsContextLens2D, bool kIsVarlen,
|
| 25 |
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 26 |
uint32_t SPLIT_KV,
|
| 27 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 28 |
+
typename logits_dtype_t>
|
| 29 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
| 30 |
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 31 |
+
const uint32_t logits_stride, const uint32_t block_table_stride,
|
| 32 |
+
const uint32_t* context_lens, logits_dtype_t* logits,
|
| 33 |
+
const uint32_t* block_table, const uint32_t* indices,
|
| 34 |
+
const uint32_t* schedule_meta,
|
| 35 |
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 36 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 37 |
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 38 |
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 39 |
+
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
|
| 40 |
+
|
| 41 |
// Types
|
| 42 |
+
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
|
| 43 |
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 44 |
|
| 45 |
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 46 |
+
const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 47 |
+
const auto warpgroup_idx = warp_idx / 4;
|
| 48 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 49 |
|
| 50 |
// Prefetch TMA descriptors
|
| 51 |
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
|
|
|
| 63 |
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 64 |
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 65 |
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
| 66 |
+
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 67 |
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
| 68 |
+
math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
| 69 |
|
| 70 |
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 71 |
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 72 |
+
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 73 |
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
| 74 |
+
math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
| 75 |
|
| 76 |
// Align to swizzling alignment bytes
|
| 77 |
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
|
|
|
| 79 |
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 80 |
|
| 81 |
// Q data and barriers on shared memory
|
| 82 |
+
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 84 |
});
|
| 85 |
+
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
|
| 86 |
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 87 |
});
|
| 88 |
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 89 |
+
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
| 90 |
+
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
| 91 |
|
| 92 |
// Separate math warpgroups and tma load warps into KV groups
|
| 93 |
// Each math warpgroup corresponds to a tma load warp
|
| 94 |
+
const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
|
| 95 |
|
| 96 |
// Per group KV data and barriers on shared memory
|
| 97 |
+
const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
|
| 98 |
+
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
|
| 99 |
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
| 100 |
});
|
| 101 |
+
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
|
| 102 |
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 103 |
});
|
| 104 |
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 105 |
+
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
| 106 |
+
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
| 107 |
|
| 108 |
// Initialize barriers
|
| 109 |
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
|
|
|
|
| 131 |
constexpr uint32_t kNumTMARegisters = 64;
|
| 132 |
constexpr uint32_t kNumMathRegisters = 104;
|
| 133 |
|
| 134 |
+
// Wait for primary kernel completion
|
| 135 |
+
cudaGridDependencySynchronize();
|
| 136 |
+
|
| 137 |
// Scheduler
|
| 138 |
+
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
|
| 139 |
+
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
|
| 140 |
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
| 141 |
|
| 142 |
// Q and KV pipeline
|
| 143 |
+
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 144 |
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 145 |
};
|
| 146 |
+
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 147 |
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 148 |
};
|
| 149 |
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
|
|
|
| 154 |
if (kv_group_idx >= kNumMathWarpGroups)
|
| 155 |
return;
|
| 156 |
|
| 157 |
+
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
| 158 |
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
| 159 |
+
tma::copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
| 160 |
+
tma::copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN);
|
| 161 |
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 162 |
}
|
| 163 |
};
|
|
|
|
| 176 |
|
| 177 |
while (fetched_next_task) {
|
| 178 |
// Prefetch next Q when current Q changes
|
| 179 |
+
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1));
|
| 180 |
q_idx = next_q_idx;
|
| 181 |
kv_idx = next_kv_idx;
|
| 182 |
num_kv = next_num_kv;
|
|
|
|
| 193 |
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
| 194 |
kv_block_idx_ptr = 0;
|
| 195 |
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
|
| 196 |
+
block_table[q_idx * static_cast<uint64_t>(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0);
|
| 197 |
}
|
| 198 |
+
const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
|
| 199 |
|
| 200 |
// Wait KV consumer release
|
| 201 |
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
|
|
|
| 203 |
|
| 204 |
// Issue TMA KV
|
| 205 |
if (cute::elect_one_sync()) {
|
| 206 |
+
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 207 |
+
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
|
| 208 |
+
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 209 |
+
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
|
| 210 |
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 211 |
}
|
| 212 |
|
|
|
|
| 218 |
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 219 |
|
| 220 |
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
|
| 221 |
+
const auto sub_warp_offset = (warp_idx % 4) * 16;
|
| 222 |
+
const auto v_0_offset = lane_idx / 4 + 0;
|
| 223 |
+
const auto v_1_offset = lane_idx / 4 + 8;
|
| 224 |
|
| 225 |
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 226 |
uint32_t q_idx = batch_size, kv_idx;
|
|
|
|
| 243 |
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 244 |
#pragma unroll
|
| 245 |
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 246 |
+
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 247 |
}
|
| 248 |
}
|
| 249 |
|
|
|
|
| 252 |
kv_idx = next_kv_idx;
|
| 253 |
|
| 254 |
// Calculate KV offset in advance
|
| 255 |
+
auto kv_offset = q_idx * kNextN * static_cast<uint64_t>(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
| 256 |
|
| 257 |
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
| 258 |
// Wait TMA KV arrival
|
|
|
|
| 264 |
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 265 |
#pragma unroll
|
| 266 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 267 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 268 |
+
ptx::warpgroup_arrive();
|
| 269 |
#pragma unroll
|
| 270 |
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 271 |
+
auto desc_a = mma::sm90::make_smem_desc(
|
| 272 |
+
smem_kv[kv_stage_idx] + k * WGMMA::K,
|
| 273 |
+
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 274 |
+
auto desc_b = mma::sm90::make_smem_desc(
|
| 275 |
+
smem_q[q_stage_idx] + k * WGMMA::K,
|
| 276 |
+
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 277 |
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 278 |
}
|
| 279 |
+
ptx::warpgroup_commit_batch();
|
| 280 |
#pragma unroll
|
| 281 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 282 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 283 |
|
| 284 |
// Read per-KV scales
|
| 285 |
+
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
|
| 286 |
+
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
|
| 287 |
|
| 288 |
// Wait WGMMA
|
| 289 |
+
ptx::warpgroup_wait<0>();
|
| 290 |
|
| 291 |
// Release KV empty
|
| 292 |
empty_kv_barriers[kv_stage_idx]->arrive();
|
|
|
|
| 299 |
#pragma unroll
|
| 300 |
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 301 |
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 302 |
+
const auto transform = [&](const uint32_t& j) {
|
| 303 |
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 304 |
};
|
| 305 |
|
|
|
|
| 317 |
// Inter-thread reduction
|
| 318 |
#pragma unroll
|
| 319 |
for (uint32_t j = 0; j < 2; ++ j) {
|
| 320 |
+
const auto offset = static_cast<int>(1u << j);
|
| 321 |
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
| 322 |
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
| 323 |
}
|
| 324 |
|
| 325 |
// Store into the global memory
|
| 326 |
// NOTES: we have redundant writes here, consider more carefully
|
| 327 |
+
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_0_offset] = static_cast<logits_dtype_t>(v_0);
|
| 328 |
+
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_1_offset] = static_cast<logits_dtype_t>(v_1);
|
| 329 |
}
|
| 330 |
}
|
| 331 |
}
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh
CHANGED
|
@@ -5,20 +5,23 @@
|
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
#include <cutlass/arch/reg_reconfig.h>
|
| 7 |
|
| 8 |
-
#include <deep_gemm/common/
|
| 9 |
#include <deep_gemm/common/utils.cuh>
|
| 10 |
-
#include <deep_gemm/common/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
namespace deep_gemm {
|
| 13 |
|
| 14 |
-
using namespace deep_gemm::sm90;
|
| 15 |
-
|
| 16 |
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 17 |
-
|
| 18 |
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 19 |
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
|
| 20 |
|
| 21 |
-
const auto
|
| 22 |
|
| 23 |
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
| 24 |
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
|
|
@@ -35,7 +38,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|
| 35 |
uint32_t kSwizzleCDMode,
|
| 36 |
uint32_t kNumStages,
|
| 37 |
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
|
| 38 |
-
|
| 39 |
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 40 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 41 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
@@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 56 |
|
| 57 |
// Utils
|
| 58 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 59 |
-
const auto lane_idx = get_lane_idx();
|
| 60 |
|
| 61 |
// Align to 1024 bytes for swizzle-128B
|
| 62 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
|
@@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 76 |
// Data on shared memory (layout as ordered below)
|
| 77 |
// Fill D/A/B pointers
|
| 78 |
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 79 |
-
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 80 |
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 81 |
});
|
| 82 |
-
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 84 |
});
|
| 85 |
|
| 86 |
// Fill barriers
|
| 87 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 88 |
-
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 89 |
-
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 90 |
|
| 91 |
// Initialize barriers
|
| 92 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
|
@@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 101 |
}
|
| 102 |
__syncthreads();
|
| 103 |
|
| 104 |
-
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 105 |
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 106 |
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 107 |
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
|
@@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 113 |
constexpr uint32_t kNumTMARegisters = 40;
|
| 114 |
constexpr uint32_t kNumMathRegisters = 256;
|
| 115 |
|
|
|
|
|
|
|
|
|
|
| 116 |
// TMA load warp
|
| 117 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 118 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 119 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 120 |
// Wait consumer release
|
| 121 |
-
const auto
|
| 122 |
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 123 |
|
| 124 |
// Compute offsets
|
|
@@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 126 |
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 127 |
|
| 128 |
// Issue TMAs
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
// Arrive at full barriers
|
| 133 |
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
|
@@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 135 |
}
|
| 136 |
|
| 137 |
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
|
| 138 |
-
const auto
|
| 139 |
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 140 |
}
|
| 141 |
} else if (warp_idx < kNumMathThreads / 32) {
|
|
@@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 148 |
constexpr uint32_t WGMMA_N = BLOCK_N;
|
| 149 |
constexpr uint32_t WGMMA_K = 8;
|
| 150 |
|
| 151 |
-
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
|
| 152 |
float accum[WGMMA::kNumAccum] = {0};
|
| 153 |
|
| 154 |
constexpr uint32_t kNumBankGroupBytes = 16;
|
|
@@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 196 |
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
|
| 197 |
}
|
| 198 |
|
| 199 |
-
warpgroup_wait<0>();
|
| 200 |
if (s > 0)
|
| 201 |
empty_barriers[(s - 1) % kNumStages]->arrive();
|
| 202 |
|
| 203 |
#pragma unroll
|
| 204 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 205 |
-
warpgroup_fence_operand(accum[i]);
|
| 206 |
-
warpgroup_arrive();
|
| 207 |
|
| 208 |
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
|
| 209 |
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
|
|
@@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 213 |
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
|
| 214 |
#pragma unroll
|
| 215 |
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
|
| 216 |
-
auto b_desc =
|
|
|
|
| 217 |
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
|
| 218 |
}
|
| 219 |
}
|
| 220 |
-
warpgroup_commit_batch();
|
| 221 |
#pragma unroll
|
| 222 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 223 |
-
warpgroup_fence_operand(accum[i]);
|
| 224 |
}
|
| 225 |
|
| 226 |
-
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
|
| 227 |
-
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
|
| 228 |
|
| 229 |
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
|
| 230 |
if (lane_idx % 4 == 0) {
|
|
@@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 233 |
if (m_idx + 8 < shape_m)
|
| 234 |
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
|
| 235 |
}
|
| 236 |
-
warpgroup_wait<0>();
|
| 237 |
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
|
| 238 |
|
| 239 |
// Write accum to shared memory
|
|
@@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
|
| 260 |
|
| 261 |
// 0/1 write to the same row, 2/3 write to another row
|
| 262 |
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
|
| 263 |
-
st_shared(smem_ptr, values[0], values[1]);
|
| 264 |
-
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
|
| 265 |
}
|
| 266 |
cute::tma_store_fence();
|
| 267 |
cutlass::arch::NamedBarrier::sync(128, 1);
|
|
|
|
| 5 |
#include <cutlass/arch/barrier.h>
|
| 6 |
#include <cutlass/arch/reg_reconfig.h>
|
| 7 |
|
| 8 |
+
#include <deep_gemm/common/math.cuh>
|
| 9 |
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/tma_copy.cuh>
|
| 11 |
+
#include <deep_gemm/common/types.cuh>
|
| 12 |
+
#include <deep_gemm/mma/sm90.cuh>
|
| 13 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 14 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 15 |
+
#include <deep_gemm/ptx/wgmma.cuh>
|
| 16 |
|
| 17 |
namespace deep_gemm {
|
| 18 |
|
|
|
|
|
|
|
| 19 |
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 20 |
+
CUTLASS_DEVICE
|
| 21 |
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 22 |
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
|
| 23 |
|
| 24 |
+
const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
|
| 25 |
|
| 26 |
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
| 27 |
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
|
|
|
|
| 38 |
uint32_t kSwizzleCDMode,
|
| 39 |
uint32_t kNumStages,
|
| 40 |
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
|
| 41 |
+
CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
|
| 42 |
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 43 |
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 44 |
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
|
|
|
| 59 |
|
| 60 |
// Utils
|
| 61 |
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 62 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 63 |
|
| 64 |
// Align to 1024 bytes for swizzle-128B
|
| 65 |
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
|
|
|
| 79 |
// Data on shared memory (layout as ordered below)
|
| 80 |
// Fill D/A/B pointers
|
| 81 |
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 82 |
+
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
|
| 83 |
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 84 |
});
|
| 85 |
+
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
|
| 86 |
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 87 |
});
|
| 88 |
|
| 89 |
// Fill barriers
|
| 90 |
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 91 |
+
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 92 |
+
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 93 |
|
| 94 |
// Initialize barriers
|
| 95 |
if (warp_idx == 1 and cute::elect_one_sync()) {
|
|
|
|
| 104 |
}
|
| 105 |
__syncthreads();
|
| 106 |
|
| 107 |
+
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 108 |
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 109 |
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 110 |
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
|
|
|
| 116 |
constexpr uint32_t kNumTMARegisters = 40;
|
| 117 |
constexpr uint32_t kNumMathRegisters = 256;
|
| 118 |
|
| 119 |
+
// Wait for primary kernel completion
|
| 120 |
+
cudaGridDependencySynchronize();
|
| 121 |
+
|
| 122 |
// TMA load warp
|
| 123 |
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 124 |
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 125 |
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 126 |
// Wait consumer release
|
| 127 |
+
const auto stage_idx = s % kNumStages;
|
| 128 |
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 129 |
|
| 130 |
// Compute offsets
|
|
|
|
| 132 |
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 133 |
|
| 134 |
// Issue TMAs
|
| 135 |
+
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
| 136 |
+
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
| 137 |
|
| 138 |
// Arrive at full barriers
|
| 139 |
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
|
|
|
| 141 |
}
|
| 142 |
|
| 143 |
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
|
| 144 |
+
const auto stage_idx = s % kNumStages;
|
| 145 |
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 146 |
}
|
| 147 |
} else if (warp_idx < kNumMathThreads / 32) {
|
|
|
|
| 154 |
constexpr uint32_t WGMMA_N = BLOCK_N;
|
| 155 |
constexpr uint32_t WGMMA_K = 8;
|
| 156 |
|
| 157 |
+
using WGMMA = typename mma::sm90::TF32MMASelector<WGMMA_N, true>::type;
|
| 158 |
float accum[WGMMA::kNumAccum] = {0};
|
| 159 |
|
| 160 |
constexpr uint32_t kNumBankGroupBytes = 16;
|
|
|
|
| 202 |
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
|
| 203 |
}
|
| 204 |
|
| 205 |
+
ptx::warpgroup_wait<0>();
|
| 206 |
if (s > 0)
|
| 207 |
empty_barriers[(s - 1) % kNumStages]->arrive();
|
| 208 |
|
| 209 |
#pragma unroll
|
| 210 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 211 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 212 |
+
ptx::warpgroup_arrive();
|
| 213 |
|
| 214 |
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
|
| 215 |
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
|
|
|
|
| 219 |
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
|
| 220 |
#pragma unroll
|
| 221 |
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
|
| 222 |
+
auto b_desc = mma::sm90::make_smem_desc(
|
| 223 |
+
smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
|
| 224 |
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
|
| 225 |
}
|
| 226 |
}
|
| 227 |
+
ptx::warpgroup_commit_batch();
|
| 228 |
#pragma unroll
|
| 229 |
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 230 |
+
ptx::warpgroup_fence_operand(accum[i]);
|
| 231 |
}
|
| 232 |
|
| 233 |
+
const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0);
|
| 234 |
+
const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1);
|
| 235 |
|
| 236 |
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
|
| 237 |
if (lane_idx % 4 == 0) {
|
|
|
|
| 240 |
if (m_idx + 8 < shape_m)
|
| 241 |
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
|
| 242 |
}
|
| 243 |
+
ptx::warpgroup_wait<0>();
|
| 244 |
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
|
| 245 |
|
| 246 |
// Write accum to shared memory
|
|
|
|
| 267 |
|
| 268 |
// 0/1 write to the same row, 2/3 write to another row
|
| 269 |
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
|
| 270 |
+
ptx::st_shared(smem_ptr, values[0], values[1]);
|
| 271 |
+
ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
|
| 272 |
}
|
| 273 |
cute::tma_store_fence();
|
| 274 |
cutlass::arch::NamedBarrier::sync(128, 1);
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh
CHANGED
|
@@ -3,21 +3,24 @@
|
|
| 3 |
#include <cutlass/arch/barrier.h>
|
| 4 |
#include <cute/arch/cluster_sm90.hpp>
|
| 5 |
|
| 6 |
-
#include <deep_gemm/common/
|
|
|
|
| 7 |
|
| 8 |
namespace deep_gemm {
|
| 9 |
|
| 10 |
-
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
|
| 11 |
-
|
| 12 |
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
|
| 13 |
-
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end,
|
| 14 |
-
const uint32_t
|
| 15 |
-
const uint32_t
|
| 16 |
-
const uint32_t
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
// Allocate filled `-inf` shared memory
|
| 20 |
-
extern __shared__ __align__(1024)
|
| 21 |
#pragma unroll
|
| 22 |
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
|
| 23 |
smem_buffer[i] = neg_inf;
|
|
@@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const
|
|
| 25 |
__syncthreads();
|
| 26 |
|
| 27 |
// Assign sequence to each warp
|
| 28 |
-
const auto
|
| 29 |
-
|
| 30 |
-
const auto
|
| 31 |
-
return {start + idx * per + min(idx, rem), per + (idx < rem)};
|
| 32 |
};
|
| 33 |
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
|
| 34 |
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
if (cute::elect_one_sync()) {
|
| 37 |
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 38 |
-
const auto
|
| 39 |
-
const auto
|
| 40 |
-
const auto
|
| 41 |
|
| 42 |
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
|
| 43 |
-
const auto
|
| 44 |
if (right <= ks or ke <= left) {
|
| 45 |
-
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(
|
| 46 |
} else {
|
| 47 |
if (left < aligned_ks)
|
| 48 |
-
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(
|
| 49 |
if (aligned_ke < right)
|
| 50 |
-
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(
|
| 51 |
}
|
| 52 |
}
|
| 53 |
}
|
| 54 |
}
|
|
|
|
| 55 |
|
| 56 |
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 57 |
-
const auto
|
| 58 |
-
const auto
|
| 59 |
-
const auto
|
| 60 |
for (uint32_t j = aligned_ks; j < ks; ++ j)
|
| 61 |
logits[i * stride_logits + j] = neg_inf;
|
| 62 |
for (uint32_t j = ke; j < aligned_ke; ++ j)
|
|
|
|
| 3 |
#include <cutlass/arch/barrier.h>
|
| 4 |
#include <cute/arch/cluster_sm90.hpp>
|
| 5 |
|
| 6 |
+
#include <deep_gemm/common/cute_tie.cuh>
|
| 7 |
+
#include <deep_gemm/common/math.cuh>
|
| 8 |
|
| 9 |
namespace deep_gemm {
|
| 10 |
|
| 11 |
+
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps, typename logits_dtype_t>
|
| 12 |
+
CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1)
|
| 13 |
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
|
| 14 |
+
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) {
|
| 15 |
+
const uint32_t num_sms = gridDim.x;
|
| 16 |
+
const uint32_t sm_idx = blockIdx.x;
|
| 17 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 18 |
+
|
| 19 |
+
constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t);
|
| 20 |
+
const logits_dtype_t neg_inf = -cute::numeric_limits<logits_dtype_t>::infinity();
|
| 21 |
|
| 22 |
// Allocate filled `-inf` shared memory
|
| 23 |
+
extern __shared__ __align__(1024) logits_dtype_t smem_buffer[];
|
| 24 |
#pragma unroll
|
| 25 |
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
|
| 26 |
smem_buffer[i] = neg_inf;
|
|
|
|
| 28 |
__syncthreads();
|
| 29 |
|
| 30 |
// Assign sequence to each warp
|
| 31 |
+
const auto assign_task = [&](const uint32_t& num, const uint32_t& idx,
|
| 32 |
+
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
|
| 33 |
+
const auto per = total / num, rem = total % num;
|
| 34 |
+
return {start + idx * per + cute::min(idx, rem), per + (idx < rem)};
|
| 35 |
};
|
| 36 |
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
|
| 37 |
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
|
| 38 |
|
| 39 |
+
// Wait for primary kernel completion
|
| 40 |
+
cudaGridDependencySynchronize();
|
| 41 |
+
|
| 42 |
if (cute::elect_one_sync()) {
|
| 43 |
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 44 |
+
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
|
| 45 |
+
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
|
| 46 |
+
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
|
| 47 |
|
| 48 |
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
|
| 49 |
+
const auto right = cute::min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
|
| 50 |
if (right <= ks or ke <= left) {
|
| 51 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t));
|
| 52 |
} else {
|
| 53 |
if (left < aligned_ks)
|
| 54 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t));
|
| 55 |
if (aligned_ke < right)
|
| 56 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t));
|
| 57 |
}
|
| 58 |
}
|
| 59 |
}
|
| 60 |
}
|
| 61 |
+
__syncwarp();
|
| 62 |
|
| 63 |
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 64 |
+
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
|
| 65 |
+
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
|
| 66 |
+
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
|
| 67 |
for (uint32_t j = aligned_ks; j < ks; ++ j)
|
| 68 |
logits[i * stride_logits + j] = neg_inf;
|
| 69 |
for (uint32_t j = ke; j < aligned_ke; ++ j)
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
| 3 |
#include <deep_gemm/common/utils.cuh>
|
|
|
|
|
|
|
| 4 |
|
| 5 |
namespace deep_gemm {
|
| 6 |
|
| 7 |
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
|
| 8 |
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
|
| 9 |
-
|
| 10 |
-
typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
|
| 11 |
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
|
| 12 |
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
|
| 13 |
|
|
@@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
|
| 15 |
extern __shared__ float smem_buffer[];
|
| 16 |
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
|
| 17 |
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 18 |
-
const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
|
| 19 |
|
| 20 |
// Shift into the block
|
| 21 |
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 22 |
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
|
| 23 |
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
// Load
|
| 26 |
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
|
| 27 |
-
auto in_vec =
|
| 28 |
const auto& in_values = reinterpret_cast<float*>(&in_vec);
|
| 29 |
|
| 30 |
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
|
|
@@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
|
| 39 |
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
|
| 40 |
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
|
| 41 |
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
|
| 42 |
-
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
|
| 43 |
}
|
| 44 |
}
|
| 45 |
|
| 46 |
// NOTES: the two kernels below always pack the K dimension
|
| 47 |
|
| 48 |
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
| 49 |
-
|
| 50 |
extern __shared__ uint32_t smem_buffer[];
|
| 51 |
|
| 52 |
// Shapes and strides
|
| 53 |
-
constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
|
| 54 |
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
|
| 55 |
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 56 |
-
const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
|
| 57 |
|
| 58 |
// Shift into the group
|
| 59 |
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 60 |
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
// Load FP32 SFs
|
| 63 |
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
|
| 64 |
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
|
@@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
|
|
| 66 |
const auto num_uint4 = num_values / 4;
|
| 67 |
#pragma unroll
|
| 68 |
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
|
| 69 |
-
const auto& [x, y, z, w] =
|
| 70 |
-
st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
|
| 71 |
}
|
| 72 |
|
| 73 |
// Fill unaligned values as well
|
| 74 |
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
|
| 75 |
-
st_shared(smem_buffer + unaligned_idx,
|
| 76 |
__syncthreads();
|
| 77 |
|
| 78 |
// Pack into UE8M0 and store
|
|
@@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
|
|
| 85 |
#pragma unroll
|
| 86 |
for (uint32_t j = 0; j < 4; ++ j) {
|
| 87 |
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
|
| 88 |
-
values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
|
| 89 |
}
|
| 90 |
|
| 91 |
// Pack and store
|
|
@@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
|
|
| 101 |
|
| 102 |
template <uint32_t kNumGroups, uint32_t kNumThreads,
|
| 103 |
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
// Always packing the K dimension
|
| 107 |
// NOTES: should also assert `mn % 4 == 0` at launch
|
| 108 |
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
|
|
@@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
|
| 120 |
|
| 121 |
// Each warp is responsible for a packed row
|
| 122 |
const auto warp_idx = threadIdx.x / 32;
|
| 123 |
-
const auto lane_idx = get_lane_idx();
|
| 124 |
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
|
| 125 |
if (warp_idx >= in_block_packed_sf_k)
|
| 126 |
return;
|
| 127 |
|
|
|
|
|
|
|
|
|
|
| 128 |
// Make an offset on the input
|
| 129 |
uint32_t input_offset = 0;
|
| 130 |
if constexpr (kNumGroups > 1) {
|
|
@@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
|
| 134 |
#pragma unroll
|
| 135 |
for (uint32_t i = 0; i < 4; ++ i) {
|
| 136 |
const auto group_idx = lane_idx * 4 + i;
|
| 137 |
-
group_ks[i] = group_idx < kNumGroups ?
|
| 138 |
}
|
| 139 |
__syncwarp();
|
| 140 |
|
| 141 |
// Make the offset
|
| 142 |
sf_k = 0;
|
| 143 |
-
|
| 144 |
#pragma unroll
|
| 145 |
for (uint32_t i = 0; i < kNumGroups; ++ i) {
|
| 146 |
-
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] /
|
| 147 |
sf_k += sf_k_in_group;
|
| 148 |
-
sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
|
| 149 |
if (packed_sf_k_idx < sum_packed_sf_k)
|
| 150 |
break;
|
| 151 |
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
|
|
@@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
|
| 153 |
}
|
| 154 |
}
|
| 155 |
|
| 156 |
-
for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
|
| 157 |
// Load
|
| 158 |
uint4 values[4];
|
| 159 |
#pragma unroll
|
| 160 |
for (uint32_t j = 0; j < 4; ++ j) {
|
| 161 |
values[j] = make_uint4(0, 0, 0, 0);
|
| 162 |
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
|
| 163 |
-
values[j] =
|
| 164 |
}
|
| 165 |
|
| 166 |
// Pack and store
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include <deep_gemm/common/math.cuh>
|
| 4 |
#include <deep_gemm/common/utils.cuh>
|
| 5 |
+
#include <deep_gemm/ptx/ld_st.cuh>
|
| 6 |
+
#include <deep_gemm/ptx/utils.cuh>
|
| 7 |
|
| 8 |
namespace deep_gemm {
|
| 9 |
|
| 10 |
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
|
| 11 |
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
|
| 12 |
+
CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
| 13 |
+
typedef typename utils::Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
|
| 14 |
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
|
| 15 |
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
|
| 16 |
|
|
|
|
| 18 |
extern __shared__ float smem_buffer[];
|
| 19 |
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
|
| 20 |
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 21 |
+
const auto tma_aligned_mn = math::align<uint32_t>(mn, kNumTMAAlignedElems);
|
| 22 |
|
| 23 |
// Shift into the block
|
| 24 |
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 25 |
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
|
| 26 |
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
| 27 |
|
| 28 |
+
// Wait for primary kernel completion
|
| 29 |
+
cudaGridDependencySynchronize();
|
| 30 |
+
|
| 31 |
// Load
|
| 32 |
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
|
| 33 |
+
auto in_vec = local_sf[i];
|
| 34 |
const auto& in_values = reinterpret_cast<float*>(&in_vec);
|
| 35 |
|
| 36 |
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
|
|
|
|
| 45 |
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
|
| 46 |
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
|
| 47 |
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
|
| 48 |
+
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
|
| 49 |
}
|
| 50 |
}
|
| 51 |
|
| 52 |
// NOTES: the two kernels below always pack the K dimension
|
| 53 |
|
| 54 |
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
| 55 |
+
CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
|
| 56 |
extern __shared__ uint32_t smem_buffer[];
|
| 57 |
|
| 58 |
// Shapes and strides
|
| 59 |
+
constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u);
|
| 60 |
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
|
| 61 |
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 62 |
+
const auto tma_aligned_mn = math::align<uint64_t>(mn, kNumTMAAlignedElems);
|
| 63 |
|
| 64 |
// Shift into the group
|
| 65 |
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 66 |
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
|
| 67 |
|
| 68 |
+
// Wait for primary kernel completion
|
| 69 |
+
cudaGridDependencySynchronize();
|
| 70 |
+
|
| 71 |
// Load FP32 SFs
|
| 72 |
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
|
| 73 |
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
|
|
|
| 75 |
const auto num_uint4 = num_values / 4;
|
| 76 |
#pragma unroll
|
| 77 |
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
|
| 78 |
+
const auto& [x, y, z, w] = reinterpret_cast<const uint4*>(local_sf)[i];
|
| 79 |
+
ptx::st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
|
| 80 |
}
|
| 81 |
|
| 82 |
// Fill unaligned values as well
|
| 83 |
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
|
| 84 |
+
ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]);
|
| 85 |
__syncthreads();
|
| 86 |
|
| 87 |
// Pack into UE8M0 and store
|
|
|
|
| 94 |
#pragma unroll
|
| 95 |
for (uint32_t j = 0; j < 4; ++ j) {
|
| 96 |
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
|
| 97 |
+
values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
|
| 98 |
}
|
| 99 |
|
| 100 |
// Pack and store
|
|
|
|
| 110 |
|
| 111 |
template <uint32_t kNumGroups, uint32_t kNumThreads,
|
| 112 |
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
|
| 113 |
+
CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
| 114 |
+
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k,
|
| 115 |
+
const uint32_t gran_k) {
|
| 116 |
// Always packing the K dimension
|
| 117 |
// NOTES: should also assert `mn % 4 == 0` at launch
|
| 118 |
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
|
|
|
|
| 130 |
|
| 131 |
// Each warp is responsible for a packed row
|
| 132 |
const auto warp_idx = threadIdx.x / 32;
|
| 133 |
+
const auto lane_idx = ptx::get_lane_idx();
|
| 134 |
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
|
| 135 |
if (warp_idx >= in_block_packed_sf_k)
|
| 136 |
return;
|
| 137 |
|
| 138 |
+
// Wait for primary kernel completion
|
| 139 |
+
cudaGridDependencySynchronize();
|
| 140 |
+
|
| 141 |
// Make an offset on the input
|
| 142 |
uint32_t input_offset = 0;
|
| 143 |
if constexpr (kNumGroups > 1) {
|
|
|
|
| 147 |
#pragma unroll
|
| 148 |
for (uint32_t i = 0; i < 4; ++ i) {
|
| 149 |
const auto group_idx = lane_idx * 4 + i;
|
| 150 |
+
group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0;
|
| 151 |
}
|
| 152 |
__syncwarp();
|
| 153 |
|
| 154 |
// Make the offset
|
| 155 |
sf_k = 0;
|
| 156 |
+
uint32_t sum_packed_sf_k = 0;
|
| 157 |
#pragma unroll
|
| 158 |
for (uint32_t i = 0; i < kNumGroups; ++ i) {
|
| 159 |
+
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4);
|
| 160 |
sf_k += sf_k_in_group;
|
| 161 |
+
sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u);
|
| 162 |
if (packed_sf_k_idx < sum_packed_sf_k)
|
| 163 |
break;
|
| 164 |
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
|
|
|
|
| 166 |
}
|
| 167 |
}
|
| 168 |
|
| 169 |
+
for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
|
| 170 |
// Load
|
| 171 |
uint4 values[4];
|
| 172 |
#pragma unroll
|
| 173 |
for (uint32_t j = 0; j < 4; ++ j) {
|
| 174 |
values[j] = make_uint4(0, 0, 0, 0);
|
| 175 |
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
|
| 176 |
+
values[j] = reinterpret_cast<const uint4*>(sf + sf_k_idx * mn)[mn_idx];
|
| 177 |
}
|
| 178 |
|
| 179 |
// Pack and store
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/numeric/math.hpp>
|
| 4 |
+
|
| 5 |
+
#include <deep_gemm/common/math.cuh>
|
| 6 |
+
#include <deep_gemm/common/exception.cuh>
|
| 7 |
+
|
| 8 |
+
namespace deep_gemm::layout {
|
| 9 |
+
|
| 10 |
+
static constexpr int kNumCandidateBlockMs = 7;
|
| 11 |
+
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
|
| 12 |
+
static constexpr int kMaxCandidateBlockM = 192;
|
| 13 |
+
static constexpr int kMinCandidateBlockM = 8;
|
| 14 |
+
static constexpr int kLCMCandidateBlockM = 384;
|
| 15 |
+
|
| 16 |
+
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
|
| 17 |
+
template <typename T>
|
| 18 |
+
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
|
| 19 |
+
T num_experts_per_rank) {
|
| 20 |
+
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
|
| 21 |
+
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
|
| 22 |
+
return math::constexpr_align(
|
| 23 |
+
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
|
| 24 |
+
static_cast<T>(kLCMCandidateBlockM));
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
|
| 28 |
+
template <typename T>
|
| 29 |
+
CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
|
| 30 |
+
return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// Per-token source metadata for combine write-back
|
| 34 |
+
struct TokenSrcMetadata {
|
| 35 |
+
uint32_t rank_idx;
|
| 36 |
+
uint32_t token_idx;
|
| 37 |
+
uint32_t topk_idx;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
struct Workspace {
|
| 41 |
+
void* base;
|
| 42 |
+
uint32_t num_ranks, num_experts;
|
| 43 |
+
uint32_t num_experts_per_rank;
|
| 44 |
+
uint32_t num_max_tokens_per_rank;
|
| 45 |
+
uint32_t num_max_recv_tokens_per_expert;
|
| 46 |
+
|
| 47 |
+
// Pool capacity: all local experts share a contiguous token pool
|
| 48 |
+
uint32_t num_max_pool_tokens;
|
| 49 |
+
uint32_t num_max_pool_blocks;
|
| 50 |
+
|
| 51 |
+
// For both grid barrier and NVLink barrier
|
| 52 |
+
static constexpr uint64_t kNumBarrierSignalBytes = 32;
|
| 53 |
+
|
| 54 |
+
CUTLASS_HOST_DEVICE
|
| 55 |
+
Workspace(void* base,
|
| 56 |
+
const uint32_t& num_ranks,
|
| 57 |
+
const uint32_t& num_experts,
|
| 58 |
+
const uint32_t& num_max_tokens_per_rank,
|
| 59 |
+
const uint32_t& num_topk):
|
| 60 |
+
base(base),
|
| 61 |
+
num_ranks(num_ranks), num_experts(num_experts),
|
| 62 |
+
num_max_tokens_per_rank(num_max_tokens_per_rank) {
|
| 63 |
+
num_experts_per_rank = num_experts / num_ranks;
|
| 64 |
+
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
|
| 65 |
+
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
| 66 |
+
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
CUTLASS_HOST_DEVICE
|
| 70 |
+
uint64_t get_num_bytes() const {
|
| 71 |
+
uint64_t num_bytes = 0;
|
| 72 |
+
|
| 73 |
+
// Barrier
|
| 74 |
+
num_bytes += kNumBarrierSignalBytes;
|
| 75 |
+
|
| 76 |
+
// Expert send/recv count
|
| 77 |
+
num_bytes += num_experts * sizeof(uint64_t) * 2;
|
| 78 |
+
|
| 79 |
+
// Expert recv count sum
|
| 80 |
+
num_bytes += num_experts_per_rank * sizeof(uint64_t);
|
| 81 |
+
|
| 82 |
+
// L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
|
| 83 |
+
num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
|
| 84 |
+
|
| 85 |
+
// L2 block arrival mask
|
| 86 |
+
num_bytes += num_max_pool_blocks * sizeof(uint64_t);
|
| 87 |
+
|
| 88 |
+
// Dispatch pulling source token-topk
|
| 89 |
+
num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
|
| 90 |
+
|
| 91 |
+
// Combine push source indices
|
| 92 |
+
num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
|
| 93 |
+
|
| 94 |
+
// Align to TMA descriptor requirements
|
| 95 |
+
num_bytes = math::align<uint64_t>(num_bytes, 16);
|
| 96 |
+
return num_bytes;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
CUTLASS_HOST_DEVICE
|
| 100 |
+
void* get_end_ptr() const {
|
| 101 |
+
return math::advance_ptr(base, get_num_bytes());
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Grid sync counters: `kNumBarrierSignalBytes` layout
|
| 105 |
+
// [ 0..15]: 4 x `uint32_t` grid sync counters
|
| 106 |
+
// [16..20]: `uint32_t` NVLink barrier counter
|
| 107 |
+
// [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
|
| 108 |
+
static constexpr uint32_t kNumMaxGridSyncCounters = 4;
|
| 109 |
+
|
| 110 |
+
template <uint32_t kIndex = 0>
|
| 111 |
+
CUTLASS_DEVICE
|
| 112 |
+
uint32_t* get_grid_sync_count_ptr() const {
|
| 113 |
+
DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
|
| 114 |
+
return static_cast<uint32_t*>(base) + kIndex;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
CUTLASS_DEVICE
|
| 118 |
+
uint32_t* get_nvl_barrier_counter_ptr() const {
|
| 119 |
+
return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
CUTLASS_DEVICE
|
| 123 |
+
int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
|
| 124 |
+
// NOTES: the signal is signed, as we may minus
|
| 125 |
+
return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
CUTLASS_DEVICE
|
| 129 |
+
uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
|
| 130 |
+
return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
CUTLASS_DEVICE
|
| 134 |
+
uint64_t* get_expert_recv_count_ptr(
|
| 135 |
+
const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
|
| 136 |
+
return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
CUTLASS_DEVICE
|
| 140 |
+
uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
|
| 141 |
+
return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
CUTLASS_DEVICE
|
| 145 |
+
uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
|
| 146 |
+
const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
|
| 147 |
+
return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
CUTLASS_DEVICE
|
| 151 |
+
uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
|
| 152 |
+
// Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
|
| 153 |
+
const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
|
| 154 |
+
return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// For dispatch pulling
|
| 158 |
+
CUTLASS_DEVICE
|
| 159 |
+
uint32_t* get_src_token_topk_idx_ptr(
|
| 160 |
+
const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
|
| 161 |
+
const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
|
| 162 |
+
return reinterpret_cast<uint32_t*>(base) +
|
| 163 |
+
expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
|
| 164 |
+
rank_idx * num_max_recv_tokens_per_expert + token_idx;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
// For combine usages
|
| 168 |
+
CUTLASS_DEVICE
|
| 169 |
+
TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
|
| 170 |
+
const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
|
| 171 |
+
return base + pool_token_idx;
|
| 172 |
+
}
|
| 173 |
+
};
|
| 174 |
+
|
| 175 |
+
struct Data {
|
| 176 |
+
uint32_t num_bytes;
|
| 177 |
+
bool require_tma_alignment;
|
| 178 |
+
void* base;
|
| 179 |
+
|
| 180 |
+
CUTLASS_HOST_DEVICE
|
| 181 |
+
constexpr explicit Data(
|
| 182 |
+
const uint32_t& num_bytes,
|
| 183 |
+
const bool& require_tma_alignment = true,
|
| 184 |
+
void* base = nullptr) :
|
| 185 |
+
num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
|
| 186 |
+
DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
template <typename dtype_t = uint32_t>
|
| 190 |
+
CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
|
| 191 |
+
return static_cast<dtype_t>(num_bytes);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
template <typename dtype_t = void>
|
| 195 |
+
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
|
| 196 |
+
return static_cast<dtype_t*>(base);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
|
| 200 |
+
base = ptr;
|
| 201 |
+
}
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
struct Buffer {
|
| 205 |
+
Data data_layout;
|
| 206 |
+
uint32_t num_ranks;
|
| 207 |
+
uint32_t num_max_tokens_per_rank;
|
| 208 |
+
|
| 209 |
+
void* base;
|
| 210 |
+
|
| 211 |
+
CUTLASS_HOST_DEVICE
|
| 212 |
+
Buffer(const Data& data_layout,
|
| 213 |
+
const uint32_t& num_ranks,
|
| 214 |
+
const uint32_t& max_num_tokens_per_rank,
|
| 215 |
+
void* base = nullptr) :
|
| 216 |
+
data_layout(data_layout),
|
| 217 |
+
num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
|
| 218 |
+
base(base) {}
|
| 219 |
+
|
| 220 |
+
CUTLASS_HOST_DEVICE
|
| 221 |
+
uint64_t get_num_bytes_per_rank() const {
|
| 222 |
+
return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
CUTLASS_HOST_DEVICE
|
| 226 |
+
uint64_t get_num_bytes() const {
|
| 227 |
+
return get_num_bytes_per_rank() * num_ranks;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
template <typename dtype_t = void>
|
| 231 |
+
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
|
| 232 |
+
return static_cast<dtype_t*>(base);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
CUTLASS_HOST_DEVICE
|
| 236 |
+
void* get_end_ptr() const {
|
| 237 |
+
return math::advance_ptr(base, get_num_bytes());
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
Buffer get_rank_buffer(const uint32_t& rank_idx) const {
|
| 242 |
+
return {
|
| 243 |
+
data_layout,
|
| 244 |
+
1, num_max_tokens_per_rank,
|
| 245 |
+
math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
|
| 246 |
+
};
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
CUTLASS_HOST_DEVICE
|
| 250 |
+
Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
|
| 251 |
+
DG_DEVICE_ASSERT(num_ranks == 1 or global);
|
| 252 |
+
return Data(
|
| 253 |
+
data_layout.num_bytes,
|
| 254 |
+
data_layout.require_tma_alignment,
|
| 255 |
+
math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
|
| 256 |
+
);
|
| 257 |
+
}
|
| 258 |
+
};
|
| 259 |
+
|
| 260 |
+
} // namespace deep_gemm::layout
|
build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <deep_gemm/common/exception.cuh>
|
| 4 |
+
|
| 5 |
+
namespace deep_gemm::layout {
|
| 6 |
+
|
| 7 |
+
constexpr static uint32_t kNumMaxRanks = 72;
|
| 8 |
+
|
| 9 |
+
template <uint32_t kNumRanks = kNumMaxRanks>
|
| 10 |
+
struct SymBuffer {
|
| 11 |
+
int64_t base;
|
| 12 |
+
int64_t offsets[kNumMaxRanks];
|
| 13 |
+
uint32_t rank_idx;
|
| 14 |
+
|
| 15 |
+
DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks");
|
| 16 |
+
|
| 17 |
+
SymBuffer() = default;
|
| 18 |
+
|
| 19 |
+
template <typename Container>
|
| 20 |
+
explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) {
|
| 21 |
+
const auto size = static_cast<uint32_t>(c.size());
|
| 22 |
+
base = c[rank_idx];
|
| 23 |
+
for (uint32_t i = 0; i < kNumMaxRanks; ++ i)
|
| 24 |
+
offsets[i] = i < size ? (c[i] - base) : 0;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
|
| 28 |
+
template <typename ptr_t = void*>
|
| 29 |
+
CUTLASS_DEVICE ptr_t get_base_ptr() const {
|
| 30 |
+
return reinterpret_cast<ptr_t>(base);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <typename ptr_t>
|
| 34 |
+
CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const {
|
| 35 |
+
int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr);
|
| 36 |
+
return *reinterpret_cast<ptr_t*>(&mapped_ptr);
|
| 37 |
+
}
|
| 38 |
+
#endif
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
} // namespace deep_gemm::layout
|