diff --git a/.gitattributes b/.gitattributes index 209f74f0e890874a6414498146fb10b5a454e9b0..e3ee4f0efd360196276157329cb400210c947fd6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -47,3 +47,4 @@ build/torch211-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter= build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cce39ec7be8e80c6c99a4f9f10cba12c63f059ad --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_k_grouped_gemm.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..7b42f152ac183ecbdf72aae4e121295af6504e11 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_m_grouped_gemm.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..41b35d539796c30bb7589b9f5b5f98bb5a4d468e --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/b_fused_k_grouped_gemm.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..7df8741fa9b8d00498b5de61b609ef0980a3e873 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/m_grouped_gemm.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/m_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e685a9ab01b44ead9d16e4d1696d08716e12e47c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/legacy/tune_options.py b/build/torch210-cxx11-cu128-x86_64-linux/legacy/tune_options.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6a7f77c05ccea324a0e99e12e1506cdea0a086 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/mega/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/mega/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..670b409dada5ef46b62324fb458a10b585ca0a01 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/mega/__init__.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import torch +from typing import Tuple, Optional +from ..utils.math import align + +# noinspection PyBroadException +try: + # noinspection PyProtectedMember + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist +except Exception as exception: + print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') + +from .. import _C + + +class SymmBuffer: + def __init__(self, group: dist.ProcessGroup, + # MoE arguments + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + +def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] + def interleave(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return interleave(l1_weights[0]), interleave(l1_weights[1]) + + +def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + num_groups, mn, packed_sf_k = sf.shape + assert sf.dtype == torch.int and mn % 128 == 0 + result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(sf).copy_(result) + + +def transform_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + # L1: interleave gate/up, then transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) + # L2: only transpose SF for UTCCP + l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + +def fp8_fp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 32), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + _C.fp8_fp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json index 4899badb63d45293425e2164944268b6058af95d..843aad1b6073c0237b3a8e4e8a99029dadeceb52 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -1,5 +1,7 @@ { - "version": 1, + "name": "deep-gemm", + "id": "_deep_gemm_cuda_388adb9", + "version": 2, "license": "MIT", "python-depends": [], "backend": { diff --git a/build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py b/build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py index 2c752da2d3bb0aba7e03ef1921428432b396917a..552b9aa18a037a14d0869fac3527e34bac6d7760 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py @@ -1,6 +1,7 @@ import os import sys import torch +from typing import Callable, Optional def bench(fn, num_warmups: int = 5, num_tests: int = 10, @@ -78,7 +79,8 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): + with_multiple_kernels: bool = False, + barrier: Optional[Callable] = None): assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) @@ -96,14 +98,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True) with profiler: for i in range(2): for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + if barrier is not None: + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + # noinspection PyProtectedMember + torch.cuda._sleep(int(2e7)) # ~10ms + barrier() fn() + torch.cuda.synchronize() profiler.step() # Parse the profiling table @@ -111,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names if not with_multiple_kernels: for name in kernel_names: - assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' # Save chrome traces if trace_path is not None: diff --git a/build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py index e8f859a20726fcc0ea32c54ed8df37b19b3960a4..a0dc6f783bcd24c7be4d6afaec9fe5d12a6847d0 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py @@ -1,3 +1,4 @@ from . import math, layout from .layout import * from .math import * +from .dist import init_dist, uneven_all_gather diff --git a/build/torch210-cxx11-cu128-x86_64-linux/utils/dist.py b/build/torch210-cxx11-cu128-x86_64-linux/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..426c39676f2f4374fa6e6c646cbcc0ca8b5a7b88 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/utils/dist.py @@ -0,0 +1,74 @@ +import inspect +import os +import torch +import torch.distributed as dist +from typing import Tuple + +_local_rank = None + + +def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]: + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + + # Set local rank + global _local_rank + _local_rank = local_rank + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor: + world_size = dist.get_world_size(group) + + # Exchange sizes + local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long) + all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)] + dist.all_gather(all_dim_sizes, local_dim_size, group=group) + all_dim_sizes = [s.item() for s in all_dim_sizes] + max_dim_size = max(all_dim_sizes) + + # Pad + if tensor.shape[dim] < max_dim_size: + pad_shape = list(tensor.shape) + pad_shape[dim] = max_dim_size - tensor.shape[dim] + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor.contiguous() + + # All-gather + gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + dist.all_gather(gathered, tensor_padded, group=group) + + # Remove padding + trimmed = [ + torch.narrow(gathered[i], dim, 0, all_dim_sizes[i]) + for i in range(world_size) + ] + return torch.cat(trimmed, dim=dim) + + +def dist_print(s: str = '', once_in_node: bool = False) -> None: + global _local_rank + assert _local_rank is not None + if not once_in_node or _local_rank == 0: + print(s, flush=True) + dist.barrier() diff --git a/build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py b/build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py index a6bc29d9aaae296a83b8c3546b832a083ade6b28..6512c5ab7aee2bb07ca8324b7c6e49c420bd9df9 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py @@ -1,25 +1,21 @@ -from .._ops import ops - - -def get_mk_alignment_for_contiguous_layout(): - return ops.get_mk_alignment_for_contiguous_layout() - - -def get_tma_aligned_size(mn: int, element_size: int): - return ops.get_tma_aligned_size(mn, element_size).item() - - -def get_mn_major_tma_aligned_tensor(sf): - return ops.get_mn_major_tma_aligned_tensor(sf) - - -def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): - return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) - - -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): - return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) - - +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import ( + set_mk_alignment_for_contiguous_layout, + get_mk_alignment_for_contiguous_layout, + get_theoretical_mk_alignment_for_contiguous_layout, +) + +# Some alias get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch210-cxx11-cu128-x86_64-linux/utils/math.py b/build/torch210-cxx11-cu128-x86_64-linux/utils/math.py index c65026e54b87faf34b498d14d3f81a94759615f4..f1582ed560344e18980054bf502083fd641c1437 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/utils/math.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/utils/math.py @@ -11,21 +11,30 @@ def align(x: int, y: int) -> int: def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + bits = x.abs().float().view(torch.int) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float) -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def pack_ue8m0_to_int(x: torch.Tensor): + assert x.dtype == torch.float and x.size(-1) % 4 == 0 + assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape padded_n = align(n, gran_k) x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) x_padded[:, :n] = x - x_view = x_padded.view(m, -1, gran_k) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_view = x_padded.view(m, padded_n // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous() + return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: @@ -70,13 +79,14 @@ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: code = idx.to(torch.uint8) sign = (x < 0) & (idx != 0) code = code | (sign.to(torch.uint8) << 3) - return code # uint8, 0..15 + return code.view(torch.int8) -def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape assert n % 2 == 0 + assert not use_packed_ue8m0 or use_ue8m0 padded_n = align(n, gran_k) x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) x_padded[:, :n] = x @@ -85,23 +95,49 @@ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) - sf = x_amax / 6.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = x_view * (1.0 / sf.unsqueeze(2)) - codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n) codes2 = codes.view(m, padded_n // 2, 2) - packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 - return packed[:, :n // 2].contiguous(), sf + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8 + return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: - assert a.dtype == torch.uint8 + assert a.dtype == torch.int8 assert a.dim() == 2 m, n2 = a.shape n = n2 * 2 assert (m % 2) == 0 lo = a & 0x0F hi = (a >> 4) & 0x0F - codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes = torch.empty((m, n), device=a.device, dtype=torch.int8) codes[:, 0::2], codes[:, 1::2] = lo, hi codes_t = codes.transpose(0, 1).contiguous() codes2 = codes_t.view(n, m // 2, 2) out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) - return out.contiguous() \ No newline at end of file + return out.contiguous() + + +def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float) + sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int) + value = fp4_values[value_idx] + return torch.where(sign & (value_idx != 0), -value, value) + + +def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor: + return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float) + + +def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> torch.Tensor: + m, n2 = packed.shape + n = n2 * 2 + if use_packed_ue8m0: + sf = unpack_ue8m0_from_int(sf) + unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device) + unpacked[:, ::2] = packed & 0x0F + unpacked[:, 1::2] = (packed >> 4) & 0x0F + x_dequantized = _dequantize_from_fp4_e2m1(unpacked) + group_idx = torch.arange(n, device=packed.device) // gran_k + x_restored = x_dequantized * sf[:, group_idx] + return x_restored \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_C.py b/build/torch210-cxx11-cu130-x86_64-linux/_C.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2fd6df85149f4ecf481d67fcd12b52a929a7a4 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_C.py @@ -0,0 +1,194 @@ +import torch + +from ._ops import ops + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(value: int): + ops.set_mk_alignment_for_contiguous_layout(value) + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int, gran_k + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe, + num_groups=None, + is_sfa=None, + disable_ue8m0_cast=False, +): + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") + + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + recipe_len, + 0 if num_groups is None else num_groups, + num_groups is not None, + False if is_sfa is None else is_sfa, + is_sfa is not None, + disable_ue8m0_cast, + ) + + +def get_token_alignment_for_mega_moe() -> int: + return ops.get_token_alignment_for_mega_moe() + + +def get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch=True, + activation="swiglu", +): + num_bytes = ops.get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + + def slice_input_buffers(buffer): + return tuple( + ops.get_symm_buffer_views_for_mega_moe( + buffer, + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + ) + + return num_bytes, slice_input_buffers + + +def fp8_fp4_mega_moe( + y, + l1_weights, + l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + recipe, + activation, + activation_clamp, + fast_math, +): + l1_weights_data, l1_weights_sf = l1_weights + l2_weights_data, l2_weights_sf = l2_weights + r0, r1, r2 = recipe + ops.fp8_fp4_mega_moe( + y, + l1_weights_data, + l1_weights_sf, + l2_weights_data, + l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + r0, + r1, + r2, + activation, + activation_clamp, + fast_math, + ) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py index 8f0a7f80daf98c3979512b6fb75258a0f4cefdc5..8c4fe1c51ce5c419fc1b9db3b9f7e3ca03258c28 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -1,12 +1,18 @@ import os import subprocess +import sysconfig import torch +# Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton. +# In packaged/lazy-loaded use, that can outlive PyTorch's CUDA teardown and crash +# during interpreter shutdown. +os.environ.setdefault("DG_USE_TEMP_CUBLASLT_WORKSPACE", "1") + # Import the compiled extension -from ._ops import ops, add_op_namespace_prefix +from ._ops import ops as _ops, add_op_namespace_prefix from . import utils -__version__ = "2.3.0" +__version__ = "2.5.0" # ── Register fake tensor implementations for torch.compile ────────────────── @@ -32,6 +38,7 @@ for _op in [ "m_grouped_bf16_gemm_nn_contiguous", "m_grouped_bf16_gemm_nt_masked", "fp8_gemm_nt_skip_head_mid", + "fp8_fp4_mega_moe", ]: @torch.library.register_fake(add_op_namespace_prefix(_op)) @@ -58,10 +65,41 @@ def get_tc_util() -> int: return ops.get_tc_util() +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(alignment: int): + ops.set_mk_alignment_for_contiguous_layout(alignment) + + def get_mk_alignment_for_contiguous_layout() -> int: return ops.get_mk_alignment_for_contiguous_layout() +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + # Layout utilities @@ -77,10 +115,12 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( - sf, ks_tensor, ks_int + sf, ks_tensor, ks_int, gran_k ) @@ -88,16 +128,20 @@ def transform_sf_into_required_layout( sf, mn, k, - recipe=None, - recipe_ab=None, + recipe, num_groups=None, - is_sfa=False, + is_sfa=None, disable_ue8m0_cast=False, ): - has_recipe = recipe is not None - r0, r1, r2 = recipe if has_recipe else (0, 0, 0) - has_recipe_ab = recipe_ab is not None - rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") has_ng = num_groups is not None ng = num_groups if has_ng else 0 return ops.transform_sf_into_required_layout( @@ -107,13 +151,11 @@ def transform_sf_into_required_layout( r0, r1, r2, - has_recipe, - rab0, - rab1, - has_recipe_ab, + recipe_len, ng, has_ng, - is_sfa, + False if is_sfa is None else is_sfa, + is_sfa is not None, disable_ue8m0_cast, ) @@ -593,8 +635,37 @@ def fp8_mqa_logits( ) -def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): - return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) +def fp8_fp4_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, + logits_dtype=torch.float32, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + kv_data, kv_sf = kv + return ops.fp8_fp4_mqa_logits( + q_data, + q_sf, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + logits_dtype, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices) def fp8_paged_mqa_logits( @@ -606,6 +677,7 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits=False, + indices=None, ): return ops.fp8_paged_mqa_logits( q, @@ -616,6 +688,38 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits, + indices, + ) + + +def fp8_fp4_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, + logits_dtype=torch.float32, + indices=None, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + return ops.fp8_fp4_paged_mqa_logits( + q_data, + q_sf, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + logits_dtype, + indices, ) @@ -642,6 +746,14 @@ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) +from .mega import ( + SymmBuffer, + get_symm_buffer_for_mega_moe, + transform_weights_for_mega_moe, + fp8_fp4_mega_moe, +) + + # Initialize the C++ runtime @@ -683,6 +795,14 @@ if "DG_CUTLASS_INCLUDE" not in os.environ: _include, # legacy layout: include/cutlass os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout ] + for _site_packages in { + sysconfig.get_paths().get("purelib"), + sysconfig.get_paths().get("platlib"), + }: + if _site_packages: + _cutlass_include_candidates.append( + os.path.join(_site_packages, "cutlass_library", "source", "include") + ) for _cutlass_include in _cutlass_include_candidates: if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include @@ -703,8 +823,21 @@ def _ensure_initialized(): global _initialized if _initialized: return + _ops.init(_lib_root, _find_cuda_home()) _initialized = True - ops.init(_lib_root, _find_cuda_home()) + + +class _InitializedOps: + def __init__(self, raw_ops): + self._raw_ops = raw_ops + + def __getattr__(self, name): + if name != "init": + _ensure_initialized() + return getattr(self._raw_ops, name) + + +ops = _InitializedOps(_ops) # Try to initialize eagerly, but don't fail if CUDA is not found diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..3c8e7bea1b8e5eb3ad137f04b60a8f51f3370d82 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec220d340cd32423ffdca60d3ba9e1cf8196e1cee096bad64dd7726824d79898 +size 3461568 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py index 65e09b4e92d96545922fbce68acd103c33cd3845..d017d96b9d37776819ba7ab2e5d291158427f1a8 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _deep_gemm_cuda_8546a43 -ops = torch.ops._deep_gemm_cuda_8546a43 +from . import _deep_gemm_cuda_388adb9 +ops = torch.ops._deep_gemm_cuda_388adb9 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_deep_gemm_cuda_8546a43::{op_name}" + return f"_deep_gemm_cuda_388adb9::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/comm/barrier.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/comm/barrier.cuh new file mode 100644 index 0000000000000000000000000000000000000000..eb9858d8010db9088ae09ead48e6222a40f91075 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/comm/barrier.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::comm { + +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + +template +CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope) { + // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()` + static constexpr uint32_t kFinishSumTag = 0x80000000u; + sync_scope(); + if (thread_idx == 0) { + const auto count_ptr = workspace.get_grid_sync_count_ptr(); + const auto old_value = ptx::atomic_add_rel( + count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1); + uint32_t new_value; + do { + new_value = ptx::ld_acq(count_ptr); + } while (((new_value ^ old_value) & kFinishSumTag) == 0); + } + sync_scope(); +} + +template +CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, + const layout::SymBuffer& sym_buffer, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope, + const bool& sync_prologue = true, + const bool& sync_epilogue = true) { + DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads"); + + // Grid sync before NVLink signaling + if (sync_prologue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); + + // NVLink cross-rank barrier, only SM 0 participates + if (sm_idx == 0) { + auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr(); + const auto status = (*counter_ptr) & 3; + const auto signal_phase = status & 1, signal_sign = status >> 1; + auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase); + + // Send signals to remote ranks + if (thread_idx < kNumRanks) + ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1); + sync_scope(); + + // Update status and wait arrival (with 30s timeout, at 2 GHz) + constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll; + if (thread_idx == 0) { + ptx::red_add(counter_ptr, 1); + const int target = signal_sign ? 0 : static_cast(kNumRanks); + const auto start_clock = clock64(); + while (ptx::ld_acq_sys(signal_ptr) != target) { + if (clock64() - start_clock >= kNumTimeoutCycles) { + printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", + sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); + DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); + } + } + } + } + + // Grid sync after NVLink completion + if (sync_epilogue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); +} + +} // namespace deep_gemm::comm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/compile.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/compile.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e93c43fb77049ef91ca34490657db28bc132783b --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/compile.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include + +#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__) +#define DG_IN_CUDA_COMPILATION +#endif + +#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#else +#define CUTLASS_HOST_DEVICE_NOINLINE +#define CUTLASS_DEVICE_NOINLINE +#endif diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh index cd2aace7a8b8dd642f4c149bfc974c3d21e5f5b5..a3a8b62a2823835d14fbbfc26dd603680f2c5a02 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/cute_tie.cuh @@ -1,5 +1,7 @@ #pragma once +#include + namespace cute { struct ignore_t { diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/exception.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/exception.cuh new file mode 100644 index 0000000000000000000000000000000000000000..78acf74755f9f1293b50198fbd74d96873354bc3 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/exception.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_UNIFIED_ASSERT +#ifdef DG_IN_CUDA_COMPILATION +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/math.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..03bee8f91cf10cd39dadebe8dc6cc2334baed65d --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/math.cuh @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::math { + +/// Pointer operations +template +CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) { + return reinterpret_cast(static_cast(ptr) + num_bytes); +} + +/// Math functions +template +CUTLASS_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE T align(T a, T b) { + return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} + +template +CUTLASS_DEVICE void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +#ifdef DG_IN_CUDA_COMPILATION +CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __ffma2_rn(a, b, c); +#else + return make_float2( + __fmaf_rn(a.x, b.x, c.x), + __fmaf_rn(a.y, b.y, c.y) + ); +#endif +} + +CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { +#if defined(__CUDA_ARCH__) + float ret; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +#else + return 1.0f / x; +#endif +} + +/// Casting +template +CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +CUTLASS_DEVICE float fast_pow2(const int& x) { + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +CUTLASS_DEVICE int fast_log2_ceil(float x) { + const auto bits = *reinterpret_cast(&x); + const auto exp = bits >> 23; + const auto man = bits & ((1 << 23) - 1); + return exp - 127 + (man != 0); +} + +template +CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { + DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); + const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; + const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto exp_x = fast_log2_ceil(scaled.x); + const auto exp_y = fast_log2_ceil(scaled.y); + sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); + sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y); +} + +/// Reduction +CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) { + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset); + if (lane_idx >= offset) + value += synced; + } + return value; +} + +// Operation functors +template struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +CUTLASS_DEVICE T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +CUTLASS_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} +#endif + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/tma_copy.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/tma_copy.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2c5bf708d49737b8912c991d856fa9d4ceb5b5d0 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/tma_copy.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::tma { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE void +copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +} // namespace deep_gemm::tma diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/types.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/types.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e07df0af8a95a2ae0c6f32493adaa5ec00c09633 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/types.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh index 8fb6c2fc53b6d1eb067d13c113462a9f7de4133a..3a5f7ad668878aced913e859780b39ce2c06d3e8 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/common/utils.cuh @@ -1,167 +1,24 @@ #pragma once -#include -#include #include -#include -#include -#include "cute_tie.cuh" +#include -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_TRAP_ONLY_DEVICE_ASSERT -#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) \ - asm("trap;"); \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) -#endif - -namespace deep_gemm { +namespace deep_gemm::utils { template struct PatternVisitor { FuncT func; - __device__ __host__ + CUTLASS_HOST_DEVICE explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} - __device__ __host__ - auto operator [](const uint32_t& i) { + CUTLASS_HOST_DEVICE + auto operator [](const uint32_t& i) const { return func(i); } }; -template -__device__ __host__ T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ T align(T a, T b) { - return ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_align(T a, T b) { - return constexpr_ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} - -template -__forceinline__ __device__ void swap(T& a, T& b) { - T temp = a; - a = b; - b = temp; -} - -__forceinline__ __device__ uint32_t get_sm_idx() { - uint32_t sm_idx; - asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); - return sm_idx; -} - -__forceinline__ __device__ uint32_t get_lane_idx() { - uint32_t lane_id; - asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float4 ld_shared(const float4* ptr) { - float4 ret; - asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { - uint4 ret; - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); -} - -__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { - asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); -} - -template -__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { - auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); - return *reinterpret_cast(&bf16x2); -} - -__device__ __forceinline__ void prefetch_l1(void *ptr) { - asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); -} - template struct Vectorized { static auto zeros() { @@ -180,4 +37,14 @@ struct Vectorized { using vec_t = decltype(zeros()); }; -} // namespace `deep_gemm` +template +CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if constexpr (kNumCols <= 32) return 32; + if constexpr (kNumCols <= 64) return 64; + if constexpr (kNumCols <= 128) return 128; + if constexpr (kNumCols <= 256) return 256; + return 512; +} + +} // namespace deep_gemm::utils diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bf0e460c8f636117969d81e21c00e0d2a2586d78 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M waves + constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M; + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]); + + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = base_m_idx + w * STORE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(base_n_idx + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = tmem_base_addr + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = smem_base_ptr + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared( + smem_ptr, + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]) + ); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +} + +} // namespace deep_gemm::epilogue diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f3f5351e6ac6cb0526bb6d2ca8abd5a99ebe45df --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd_swap_ab(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& effective_m, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows, + // implying STORE_BLOCK_N must be 128. + DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows"); + + // TMA checks + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumSwizzleAtomRows = 8; + DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M blocks + const auto num_stores = effective_m / STORE_BLOCK_M; + for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // Store stage offset + i * kNumSwizzleAtomRows; // In-block offset + uint32_t values[kNumSwizzleAtomRows]; + + // Warps cooperatively write an atomic block to shared memory + DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes"); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + // NOTES: Swizzling is not required in this case, but used here for consistency with other cases + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + uint32_t col = lane_idx / 4; + + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + ptx::st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } else { + // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements + // Start from lane index 0 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + // Start from lane index 16 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Destination shared memory address + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + + // Store matrix with transposition + ptx::SM90_U32x4_STSM_T::copy(math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (s == num_stores - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) { + auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + uint32_t n_idx = epilogue_type_t::apply_index_n(base_n_idx + i * STORE_BLOCK_N_ATOM); + + // Issue 2D or 3D TMA store + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx); + } + } + cute::tma_store_arrive(); + } + __syncwarp(); + } +} + +} // namespace deep_gemm::epilogue diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/transform.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/transform.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0266f4d402ab25878a792fb351b32ce1a04924cb --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/epilogue/transform.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace deep_gemm::epilogue::transform { + +struct EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and + kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +} // namespace deep_gemm::epilogue::transform diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 0227b3e80061409c4dcf89f3f402ce408751246f..a60e2de8df85457a36145b77f06482d49eed0ed7 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -4,14 +4,18 @@ #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); - // Configs + // MMA Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; - constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); - DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - - // Utils - bool is_leader_cta = cute::block_rank_in_cluster() == 0; - const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - - // 2-CTA MMA + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 16; constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); - constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; - DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes - constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); @@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); // NOTES: Make sure we have enough shared memory for UMMA padding - static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); - DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); - - // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size - // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` - constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA"); // Real tensor memory size and offsets - constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { + if (warp_idx == 0) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_cd); } + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + // D/A/B shared memory - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; // Fill the tensor memory pointer @@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; @@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout, // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major @@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } @@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); // Arrive at full barriers @@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, // MMA issue warp // NOTES: only the leader CTA will do this // Make instruction descriptor - // TODO: refactor `UMMA_M` calculation - constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); - auto instr_desc = cute::UMMA::make_instr_desc(); + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc() + : cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // UMMA and empty barrier arrival alias auto umma_arrive = [](const uint64_t* barrier) { @@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting if (do_tmem_full_arrive) umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); }; + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + // Launch MMAs - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait TMA arrival full_barriers[stage_idx]->wait(phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA - using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + using mma_t = cute::conditional_t; + const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_block_idx > 0 or k > 0, - runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + if (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); } } } + __syncwarp(); // Commit to the mbarrier object // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` @@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kTensorCoreUtilControl < 100) { // For utilization control umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + __syncwarp(); // Wait for last UMMA to be done tensor_core_full_barrier->wait(tensor_core_phase); tensor_core_phase ^= 1; // Sleep for certain cycles - constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull; constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; - const auto& start_clock = clock64(); + const auto start_clock = clock64(); if (cute::elect_one_sync()) while (clock64() - start_clock < kNumDummyCycles) {} __syncwarp(); @@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // To safely deconstruct barriers, we need another round of waits - const auto& iter_idx = scheduler.current_iter - 1; + const auto iter_idx = scheduler.current_iter - 1; if (kNumMulticast > 1 and iter_idx >= 0) { - const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { @@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); - - // TMA checks - constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); - DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Share store pipeline between blocks uint32_t tma_stage_idx = 0; - auto advance_store_pipeline = [&]() { - tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; - }; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Wait UMMA arrival tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; - #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { - // Wait shared memory to be released - if (epilogue_warp_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - - // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; - - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], - n_idx, m_idx, scheduler.current_group_idx); - } else { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); - } - cute::tma_store_arrive(); - } - } + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx< + (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 86303347d9c7a3a93b65a16d6ad4a7b73eb2ad1a..13bb087232772ac1e9d65997f733164ed5827c49 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -5,18 +5,19 @@ #include #include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1) sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); @@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, } // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Fill D/A/B - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); // Fill the tensor memory pointer @@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, __syncthreads(); // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx == 0) { // TMA load warp for (uint32_t s = 0; s < num_total_stages; ++ s) { @@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Issue TMAs if (cute::elect_one_sync()) { - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); } // Arrive at full barriers @@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, auto instr_desc = cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, "Invalid MMA instruction shape"); // Wait tensor memory empty barrier arrival - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrival const auto& stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); @@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); } } @@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory if (warp_idx == 2) - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // TMA checks constexpr uint32_t kNumBankGroupBytes = 16; @@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } // Synchronize all threads and issue TMA diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b8a99fd04273d48a6b500b6e76f1e938be8858da --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, + const uint32_t logits_stride, + const uint32_t* cu_seq_len_k_start, + const uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + + // Allocate tensor memory + if (warp_idx == kSpecWarpStart + 2) + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + __syncthreads(); + + // Scheduler + const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv); + seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv); + start = cute::min(start, seq_k_start[i]); + end = cute::max(end, seq_k_end[i]); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {start, math::ceil_div(end - start, BLOCK_KV)}; + }; + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + // Enumerate Q blocks + if (cute::elect_one_sync()) { + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx], + kv_start + kv_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize umma desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this into `deep_gemm/ptx/tcgen05.cuh` + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline. Without this, UMMA can consume + // kNumQStages Q blocks before math warps release any, causing a + // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q + // -> Math waits full_tmem -> UMMA (already moved on). + empty_q_barriers[q_stage_idx]->arrive(); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = threadIdx.x; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr uint32_t N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[BLOCK_Q][kNumHeads]; + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + // TODO: optimize bank conflicts + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Calculate KV offset in advance + auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + // TODO: optimize performance + const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast(logits_stride); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + + // Release last Q empty + empty_q_barriers[q_stage_idx]->arrive(); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d9add53425517d936ed201f78d277db775d19507 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -0,0 +1,510 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + // Initialize outside valid range to indicate no previous task + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, _, __; + while (scheduler.fetch_next_task(q_atom_idx, _, __)) { + // Issue TMA Q when (q_idx, atom_idx) changes + if (q_atom_idx != last_q_atom_idx) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + last_q_atom_idx = q_atom_idx; + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, num_kv; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) { + // Reset block table cache on kv restart + if (q_atom_idx != last_q_atom_idx) + kv_block_idx_ptr = 32; + last_q_atom_idx = q_atom_idx; + + // Coalesced load of block table + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + + // Broadcast KV block indices + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`"); + + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + + // Issue TMA KV + if (cute::elect_one_sync()) { + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i, + 0, 0, kv_block_idx[i]); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + // Wait TMA Q arrivals + uint32_t q_stage_idx, q_phase; + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + } + last_q_atom_idx = q_atom_idx; + + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize UMMA desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this PTX into headers + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[kNextNAtom][kNumHeads]; + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release last Q empty + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrivals + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j)); + weights[i][j + 0] = raw.x; + weights[i][j + 1] = raw.y; + weights[i][j + 2] = raw.z; + weights[i][j + 3] = raw.w; + } + } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } + } + last_q_atom_idx = q_atom_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0bc6a3fe26e61057fbcfcc5f4c63d4faa6e475fe --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh @@ -0,0 +1,514 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // SF configs + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + // NOTES: Make sure we have enough shared memory for UMMA padding + constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4); + const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = reinterpret_cast(smem_b[kNumStages]); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]);; + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + uint32_t sfa_m_idx = m_block_idx * BLOCK_M; + uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>( + shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)); + tma::copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = scheduler.template get_global_idx( + shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx); + tma::copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled() + : cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll 4 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // Do SF copy at certain stages + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + + // Issue UMMA + using mma_t = cute::conditional_t< + kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto runtime_instr_desc = kSwapAB ? + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id): + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + a_desc.lo = mma::sm100::advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + if constexpr (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b2adc6c7ad40cc84aef802a418c3702287774b20 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -0,0 +1,1380 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // UTCCP 4x32 transpose index mapping within each 128-element group + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // Make instruction descriptor with block scaling + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603add3f494aed51dce7aec53b5545bdc23f4..7ce008e5ea30ff8ad5ce65f0f3051d5f663c50df 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 180a308b3279b38827741942917a31e103b15b52..e6744f59ac68a5b7a681ef4ff9ad985fdb5f5e51 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -6,27 +6,31 @@ #include #include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warp_in_group_idx = warp_idx % 4; - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); // Shared memory configs // NOTES: weight may be unaligned @@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); // Align to 512 bytes for swizzle-64B extern __shared__ __align__(512) uint8_t smem_buffer[]; @@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); // Tensor memory allocation auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); // Initialize barriers DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); - const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { full_kv_barriers[i]->init(1); empty_kv_barriers[i]->init(kNumMathThreads); } - #pragma unroll - for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - full_umma_barriers[i]->init(1); - empty_umma_barriers[i]->init(128); - } - - // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); - } else if (is_umma_warp) { + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } // Allocate tensor memory cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); } __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 24; - constexpr uint32_t kNumMathRegisters = 240; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase @@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; - if (is_tma_load_warp) { + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { cutlass::arch::warpgroup_reg_dealloc(); // Prefetch - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } num_total_kv_blocks += num_kv_blocks; @@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_descwait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } num_total_kv_blocks += num_kv_blocks; + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline + empty_q_barriers[q_stage_idx]->arrive(); + // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } else if (warp_idx >= kNumMathThreads / 32) { + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { cutlass::arch::warpgroup_reg_dealloc(); - } else if (warp_idx < kNumMathThreads / 32) { + } else if (warp_idx < kSpecWarpStart) { cutlass::arch::warpgroup_reg_alloc(); // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const auto& warp_offset = warp_idx * 32; - const auto& v_offset = lane_idx; + const auto tmem_start = warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Preload weights - constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); - float weights[BLOCK_Q][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[BLOCK_Q][kNumHeads]; while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Read weights #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); - } + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } // Compute over KV blocks @@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); - uint32_t shifted_accum[kNumLDTMElems]; - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + // Load accumulator from TMEM + float accum[kNumHeads]; + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + } + // Accumulate weighted ReLU in parallel auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + const auto transform = [&](const uint32_t& j, const float2& sum) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]); return __ffma2_rn(a, b, sum); }; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); - } - - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); } auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + auto result = static_cast(scale_kv * (sum.x + sum.y)); // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { - if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; } else { - logits[q_idx * stride_logits + kv_offset + v_offset] = result; + logits[q_offset + kv_offset] = result; } + __syncwarp(); } } num_total_kv_blocks += num_kv_blocks; @@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } - // Free tensor memory - __syncthreads(); - if (is_tma_load_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 7058c40f4f195de94184d3e7ebc6f9aa2eb3670f..9a5bddbf37ef0f0ce679ef7f553ee6084b92a44c 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -6,56 +6,65 @@ #include #include +#include +#include +#include #include -#include -#include - -#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; - static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); - static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q and KV data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); // Barriers and TMEM pointer on shared memory const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); - constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); - const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); - const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { @@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } cutlass::arch::fence_barrier_init(); } - if (is_umma_warp) { + if (warp_idx == kSpecWarpStart + 1) { if (cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { @@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; - uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings // Construct instruction with layout D constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - constexpr uint32_t UMMA_N = kNextN * kNumHeads; + constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); - if (is_tma_load_warp) { - // TMA warp-group for loading data + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading data cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { if (cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx, num_kv; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; bool fetched_next_task; // Prefetch the first Q - if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) - issue_tma_q(0, next_q_idx), q_iter_idx = 1; + if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1; - int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_ptr = 32; uint32_t kv_block_idx_storage; while (fetched_next_task) { - // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); - q_idx = next_q_idx; + // Prefetch next Q when (q, atom) changes + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); + + if (q_atom_idx != next_q_atom_idx) + kv_block_idx_ptr = 32; + + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { + // TODO(xuzhean): consider -1 + if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); - issue_tma_q(q_stage_idx, q_idx + 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); } - int kv_block_idx[kNumBlocksPerSplit]; + uint32_t kv_block_idx[kNumBlocksPerSplit]; #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); kv_block_idx_ptr += kNumBlocksPerSplit; @@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, if (cute::elect_one_sync()) { #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, - 0, 0, 1, kv_block_idx[i]); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, - 0, kv_block_idx[i]); + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } // Fetch next task - fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv); } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_desc(); auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 1; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) { + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + if (q_atom_idx != next_q_atom_idx) { + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); full_q_barriers[q_stage_idx]->wait(q_phase); } - q_idx = next_q_idx; + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; + // Wait KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); @@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(umma_phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } umma_phase ^= 1; } - } else if (is_math_warp) { - // Math warp-groups for WGMMA + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const uint32_t thread_idx = threadIdx.x; + const auto math_warpgroup_idx = warpgroup_idx; + const auto tmem_start = math_warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Weights - constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); - float weights[kNextN][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[kNextNAtom][kNumHeads]; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 0; + bool is_paired_atom = false; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - // Current Q changes - if (q_idx != next_q_idx) { - // Release Last Q empty + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + // Q or atom changes + if (q_atom_idx != next_q_atom_idx) { + // Release last Q empty if (q_iter_idx > 0) empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); @@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); } } - // Get current Q and KV index - q_idx = next_q_idx; + // Get current task indices + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase); - tcgen05_after_thread_sync(); + full_umma_barriers[math_warpgroup_idx]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty @@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Reduce over the head dim and store DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - - #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + float accum[kNumHeads]; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); } - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); - } - - auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; - // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + thread_idx] = result; + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } - } else { - cutlass::arch::warpgroup_reg_dealloc(); - } - // Free tensor memory - __syncthreads(); - if (is_umma_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh index 4e4ff21d0746cff7bc7ecaf23a49278a2f5810cc..aaf7fd9aea773fc66a696f5c9382b8b0e53e263d 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -4,20 +4,22 @@ #include -#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { // Calculate the index of the bank group to be written in the atom - const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); // Reshape the atom in another view and swizzle // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` @@ -37,7 +39,7 @@ template -__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Prefetch TMA descriptors at the very beginning if (warp_idx == 0 and cute::elect_one_sync()) { @@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; // Fill the tensor memory pointer @@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t m_offset = shape_m * k_split_idx; const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Dispatch warps into different roles if (warp_idx < kNumMMAThreads / 32) { // TMA load warp @@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; // Checks for MMA instructions @@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& stage_idx = s % kNumStages; const auto& cast_stage_idx = s % kNumCastStages; full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); @@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); } @@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); @@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); if constexpr (BLOCK_M == 64) __syncwarp(); } @@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, #pragma unroll for (uint32_t i = 0; i < kNumLoads; i += 2) { auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); - sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], - uint32_values[0][i + 1], uint32_values[1][i + 1], - smem_ptr); + ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); } // Wait tensor memory empty @@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, cutlass::arch::fence_view_async_tmem_store(); // Arrive for issuing MMAs - tcgen05_before_thread_sync(); + ptx::tcgen05_before_thread_sync(); full_cast_barriers[cast_stage_idx]->arrive(); } // Intra-warp reduction and write back #pragma unroll for (uint32_t u = 0; u < 2; ++ u) { - const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); - const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; if (lane_idx % 4 == 0 and m_idx < shape_m) sqr_sum[m_offset + m_idx] = reduced_sum; } diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 7a77e4e8fbbbffa56e8c8632ade7ae7938b30ee9..84a149eb9b6b35a907f03b4c04434ee9f8e558ee 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -11,14 +11,19 @@ #include #include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); @@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout, // D/A/B shared memory auto smem_d = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumTMARegisters = 48; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; - const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Issue TMAs constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } } @@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); - auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm90::make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = mma::sm90::make_gmma_desc(smem_b[0], 0, 0); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, }; // TODO: remove some useless computation for unaligned Ms - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; - a_desc.reg32_[0] = advance_gmma_desc_lo( + const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); - b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr @@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); - st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } } } @@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh index 191a4fe2c4ccf66b0743affedcbfd17950e2618f..7c344296519e7dd0852a8940d3e9d714b12a5646 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -4,26 +4,32 @@ #include #include +#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, float *d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); @@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Align to 1024 bytes for swizzle-128B // Fill shared memory pointers extern __shared__ __align__(1024) uint8_t smem_buffer[]; - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, constexpr uint32_t kNumMathRegisters = 232; // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, #pragma unroll for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; - const uint32_t& k_idx = sk_idx % SHAPE_K; - const uint32_t& s_idx = sk_idx / SHAPE_K; + const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t k_idx = sk_idx % SHAPE_K; + const uint32_t s_idx = sk_idx / SHAPE_K; constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } @@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrivals - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, 1); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave empty_barriers[stage_idx]->arrive(); } - const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; - const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { if (col + i * 8 >= SHAPE_N) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index cdd28fcb59d3b038c84c007ef1da1477d7ca263a..195d431f9067abcd94ce3c27e1ea7bf60ada7224 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -6,18 +6,26 @@ #include #include +#include #include #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, int* grouped_layout, cute::TmaDescriptor* tensor_map_buffer, @@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); // Configs @@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Tensor maps on shared and global memory - auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); - }); - auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); - }); - auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); - auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + auto smem_tensor_map_a = reinterpret_cast(smem_buffer); + auto smem_tensor_map_b = smem_tensor_map_a + 1; + auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2; + auto gmem_tensor_map_b = gmem_tensor_map_a + 1; // Data on shared memory auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); - auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); }); - auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); }); // Barriers on shared memory constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); }); if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // Load tensormap A/B to shared memory if constexpr (kGemmType == GemmType::KGroupedContiguous) { - *smem_tensor_map_a[0] = tensor_map_a_base; - *smem_tensor_map_a[1] = tensor_map_a_base; - *smem_tensor_map_b[0] = tensor_map_b_base; - *smem_tensor_map_b[1] = tensor_map_b_base; + *smem_tensor_map_a = tensor_map_a_base; + *smem_tensor_map_b = tensor_map_b_base; } // Initialize barriers @@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // TMA and MMA pipeline - const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase }; uint32_t iter_idx = 0; @@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // NOTES: only one thread (or warp) will be used if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { - const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; - const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; - uint32_t last_group_idx = kNumGroups, sum_k = 0; + uint32_t last_group_idx = kNumGroups; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); - const uint32_t& m_idx = m_block_idx * BLOCK_M; - const uint32_t& n_idx = n_block_idx * BLOCK_N; - - if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { - const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; - const uint32_t& next_stage_idx = stage_idx ^ 1; + + const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t m_idx = m_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) { last_group_idx = scheduler.current_group_idx; - // Prepare next tensor map - sum_k += scheduler.current_shape_k; - if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); - *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); - tensor_map_release_cta(); - } - - // Get current tensor map - if (scheduler.current_num_valid_groups > 0) { - tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); - tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); - current_tensor_map_a = gmem_tensor_map_a[stage_idx]; - current_tensor_map_b = gmem_tensor_map_b[stage_idx]; - } + // Directly update current tensor map + const uint64_t current_k_offset = scheduler.current_k_cumsum; + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m); + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); + *(gmem_tensor_map_b) = *(smem_tensor_map_b); + ptx::tensor_map_release_gpu(); + + // Immediately acquire current tensor map + ptx::tensor_map_acquire_gpu(gmem_tensor_map_a); + ptx::tensor_map_acquire_gpu(gmem_tensor_map_b); } #pragma unroll kNumPipelineUnrolls @@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Issue TMA auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& k_idx = k_block_idx * BLOCK_K; - const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base); + const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base); + tma::copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma::copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma::copy(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma::copy(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Accumulation for WGMMA or CUDA promotion DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); - const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); - const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; float2 scales_b[WGMMA::kNumAccum / 4]; @@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); - auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1); // Read B scales #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + scales_b[i] = ptx::ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cutlass::arch::NamedBarrier::sync(128, math_wg_idx); // Store to D shared memory - const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); - const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + const auto smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); - st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 9247304cdd17d8e2c3a5cdb31c78c191ae6b76ec..aa412484debb328df8f4f4d0d7cdfc1c61ec7b69 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,17 +10,21 @@ #include #include -#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { +CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { if (num_former_iters == kNumFormerIters) { func(cute::Int{}); return; @@ -35,12 +39,12 @@ template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT( + math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or + (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); - const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K); + const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K); + const uint32_t smem_sfb_size = math::align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // NOTES: Make sure we have enough shared memory for WGMMA padding static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); // Configs - const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Data on shared memory auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); }); auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); @@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, &full_barrier, + tma::copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), num_tma_multicast_a, batch_idx); - tma_copy(&tensor_map_sfa, &full_barrier, - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + tma::copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, &full_barrier, + tma::copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); @@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); - auto b_desc = make_smem_desc(smem_b[0], 1); + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]); } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); @@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Skip useless computations if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { // The compiler must know the dynamic variable `num_former_iters`'s real value - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; // Dispatch `num_former_iters` and launch MMAs dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { #pragma unroll 8 for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; - auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; WGMMA::wgmma(a_desc, b_desc, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) @@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + const bool predicate = kMustUseUniformedScaleB or i < num_former_iters; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index d58c716242a09922157aa13e16cb8afac477904c..225af4416810b2680317d3713c372807e548f464 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -7,36 +7,31 @@ #include #include +#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - -// ReSharper disable once CppNotAllPathsReturnValue -template -static constexpr int to_swizzle_cute_type() { - DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); - if constexpr (kHeadDim == 32) - return static_cast(cute::SM90::GMMA::LayoutType::B32); - if constexpr (kHeadDim == 64) - return static_cast(cute::SM90::GMMA::LayoutType::B64); - if constexpr (kHeadDim == 128) - return static_cast(cute::SM90::GMMA::LayoutType::B128); -} - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumSMs, + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TODO: consider TMA multicast // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // Prefetch TMA descriptors @@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); // Initialize barriers - const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; if (is_tma_load_warp and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { @@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t kNumMathRegisters = 112; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + const auto sm_idx = blockIdx.x; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase }; }; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Prefetch const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& thread_idx = threadIdx.x % kNumMathThreads; const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto& lane_idx = ptx::get_lane_idx(); float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; const auto& warp_offset = warp_idx * 16; @@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } // Compute over KV blocks @@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); // Issue WGMMA DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast(v_0); if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast(v_1); } else { - logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; - logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + logits[q_offset + kv_offset + v_0_offset] = static_cast(v_0); + logits[q_offset + kv_offset + v_1_offset] = static_cast(v_1); } } } diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 482a85a80fce29aa949b464070b0b20fb55ae030..cc2592bb402af88f3d7c7b841f26e1961093c8a3 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -6,133 +6,46 @@ #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(32, 1) -void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, - const uint32_t* context_lens, uint32_t* schedule_metadata) { - DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); - const uint32_t lane_idx = get_lane_idx(); - - uint32_t num_segs[kAlignedBatchSize / 32]; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - const uint32_t q_idx = k * 32 + lane_idx; - const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); - const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); - num_segs[k] = ceil_div(context_len, SPLIT_KV); - } - - __shared__ uint32_t prefix_sum[kAlignedBatchSize]; - uint32_t sum = 0; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - uint32_t x = num_segs[k]; - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset <<= 1) { - const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); - x += (lane_idx >= offset ? y : 0); - } - x += sum; - prefix_sum[k * 32 + lane_idx] = x; - sum = __shfl_sync(0xffffffff, x, 31); - } - - const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; - for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { - uint32_t seg_starts = sm_idx * q + min(sm_idx, r); - uint32_t q_idx = 0; - while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) - ++ q_idx; - const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); - __syncwarp(); - - schedule_metadata[sm_idx * 2] = q_idx; - schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; - } -} - -template -struct PagedMQALogitsScheduler { - uint32_t batch_size; - const uint32_t* context_lens; - - uint32_t current_q_idx, current_kv_idx; - uint32_t end_q_idx, end_kv_idx; - uint32_t current_num_kv; - - __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { - const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); - return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; - } - - __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, - const uint32_t* context_lens, const uint32_t* schedule_meta) { - this->batch_size = batch_size; - this->context_lens = context_lens; - - const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); - const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; - - current_num_kv = get_num_kv(current_q_idx); - } - - __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { - q_idx = current_q_idx; - kv_idx = current_kv_idx; - num_kv = current_num_kv; - - if (q_idx == end_q_idx and kv_idx == end_kv_idx) - return false; - - current_kv_idx += kNumBlocksPerSplit; - if (current_kv_idx >= current_num_kv) { - ++ current_q_idx; - current_kv_idx = 0; - current_num_kv = get_num_kv(current_q_idx); - } - - return true; - } - - __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { - return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; - } -}; - -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits"); + // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; @@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q data and barriers on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); }); auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); // Separate math warpgroups and tma load warps into KV groups // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); // Per group KV data and barriers on shared memory - const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); // Initialize barriers if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumTMARegisters = 64; constexpr uint32_t kNumMathRegisters = 104; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; uint32_t q_iter_idx = 0, kv_iter_idx = 0; @@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_group_idx >= kNumMathWarpGroups) return; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, while (fetched_next_task) { // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1)); q_idx = next_q_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; @@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_idx == 0 or kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + block_table[q_idx * static_cast(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0); } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); @@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Issue TMA KV if (cute::elect_one_sync()) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, cutlass::arch::warpgroup_reg_alloc(); float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const auto sub_warp_offset = (warp_idx % 4) * 16; + const auto v_0_offset = lane_idx / 4 + 0; + const auto v_1_offset = lane_idx / 4 + 8; // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, for (uint32_t i = 0; i < kNextN; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } } @@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * static_cast(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` // Wait TMA KV arrival @@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); // Wait WGMMA - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Inter-thread reduction #pragma unroll for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = static_cast(1u << j); + const auto offset = static_cast(1u << j); v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); } // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v_0; - logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + logits[kv_offset + i * static_cast(logits_stride) + v_0_offset] = static_cast(v_0); + logits[kv_offset + i * static_cast(logits_stride) + v_1_offset] = static_cast(v_1); } } } diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh index e3bf98478923a2bf560e69e6ecc802d218fb82c1..93b14100109c282fe4705af37bcf84547e20b3f5 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -5,20 +5,23 @@ #include #include -#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; - const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; @@ -35,7 +38,7 @@ template -__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = 256; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // TMA load warp if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cutlass::arch::warpgroup_reg_dealloc(); for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); // Compute offsets @@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); } } else if (warp_idx < kNumMathThreads / 32) { @@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t WGMMA_N = BLOCK_N; constexpr uint32_t WGMMA_K = 8; - using WGMMA = typename TF32MMASelector::type; + using WGMMA = typename mma::sm90::TF32MMASelector::type; float accum[WGMMA::kNumAccum] = {0}; constexpr uint32_t kNumBankGroupBytes = 16; @@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); if (s > 0) empty_barriers[(s - 1) % kNumStages]->arrive(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; @@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { #pragma unroll for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { - auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + auto b_desc = mma::sm90::make_smem_desc( + smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); } - const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); - const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1); const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); if (lane_idx % 4 == 0) { @@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, if (m_idx + 8 < shape_m) sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); // Write accum to shared memory @@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // 0/1 write to the same row, 2/3 write to another row auto values = reinterpret_cast(accum + i * 2); - st_shared(smem_ptr, values[0], values[1]); - st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1]); + ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, 1); diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh index cc9e5e6b0c7ce95acf0b7149221dc4d4f0f83a21..2f66b980c5f2c6d8c51e85b4feb47d9efefe1b64 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -3,21 +3,24 @@ #include #include -#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(kNumWarps * 32, 1) +template +CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1) void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, - const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { - const uint32_t& num_sms = gridDim.x; - const uint32_t& sm_idx = blockIdx.x; - const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - constexpr float neg_inf = -cute::numeric_limits::infinity(); + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) { + const uint32_t num_sms = gridDim.x; + const uint32_t sm_idx = blockIdx.x; + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t); + const logits_dtype_t neg_inf = -cute::numeric_limits::infinity(); // Allocate filled `-inf` shared memory - extern __shared__ __align__(1024) float smem_buffer[]; + extern __shared__ __align__(1024) logits_dtype_t smem_buffer[]; #pragma unroll for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) smem_buffer[i] = neg_inf; @@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const __syncthreads(); // Assign sequence to each warp - const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, - const uint32_t& start, const uint32_t& total) -> cute::tuple { - const auto& per = total / num, rem = total % num; - return {start + idx * per + min(idx, rem), per + (idx < rem)}; + const auto assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto per = total / num, rem = total % num; + return {start + idx * per + cute::min(idx, rem), per + (idx < rem)}; }; CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (cute::elect_one_sync()) { for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { - const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + const auto right = cute::min(left + BLOCK_KV, static_cast(stride_logits)); if (right <= ks or ke <= left) { - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t)); } else { if (left < aligned_ks) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t)); if (aligned_ke < right) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t)); } } } } + __syncwarp(); for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t j = aligned_ks; j < ks; ++ j) logits[i * stride_logits + j] = neg_inf; for (uint32_t j = ke; j < aligned_ke; ++ j) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh index bea7000276c3e382c1acfeff545d6181351849b6..a977c5547217363e545498f7ca25ee6108056afb 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh @@ -1,13 +1,16 @@ #pragma once +#include #include +#include +#include namespace deep_gemm { template -__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { - typedef typename Vectorized::vec_t in_vec_t; +CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename utils::Vectorized::vec_t in_vec_t; constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; @@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { extern __shared__ float smem_buffer[]; constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the block sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { - auto in_vec = __ldg(local_sf + i); + auto in_vec = local_sf[i]; const auto& in_values = reinterpret_cast(&in_vec); const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; @@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; - out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); } } // NOTES: the two kernels below always pack the K dimension template -__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { +CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { extern __shared__ uint32_t smem_buffer[]; // Shapes and strides - constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u); constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the group sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load FP32 SFs DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); @@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con const auto num_uint4 = num_values / 4; #pragma unroll for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { - const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); - st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + const auto& [x, y, z, w] = reinterpret_cast(local_sf)[i]; + ptx::st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); } // Fill unaligned values as well if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) - st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]); __syncthreads(); // Pack into UE8M0 and store @@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { const auto sf_k_idx = sf_k_pack_idx * 4 + j; - values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; } // Pack and store @@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con template -__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, - const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { +CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k, + const uint32_t gran_k) { // Always packing the K dimension // NOTES: should also assert `mn % 4 == 0` at launch DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); @@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, // Each warp is responsible for a packed row const auto warp_idx = threadIdx.x / 32; - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; if (warp_idx >= in_block_packed_sf_k) return; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Make an offset on the input uint32_t input_offset = 0; if constexpr (kNumGroups > 1) { @@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, #pragma unroll for (uint32_t i = 0; i < 4; ++ i) { const auto group_idx = lane_idx * 4 + i; - group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0; } __syncwarp(); // Make the offset sf_k = 0; - auto sum_packed_sf_k = 0; + uint32_t sum_packed_sf_k = 0; #pragma unroll for (uint32_t i = 0; i < kNumGroups; ++ i) { - const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4); sf_k += sf_k_in_group; - sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u); if (packed_sf_k_idx < sum_packed_sf_k) break; if (const auto remainder = sf_k_in_group % 4; remainder > 0) @@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, } } - for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { // Load uint4 values[4]; #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { values[j] = make_uint4(0, 0, 0, 0); if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) - values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + values[j] = reinterpret_cast(sf + sf_k_idx * mn)[mn_idx]; } // Pack and store diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..13520c60e29b37b6ab8ebed4bc0fd8ac26bbb63e --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include + +#include +#include + +namespace deep_gemm::layout { + +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, + T num_experts_per_rank) { + const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; + const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); + return math::constexpr_align( + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); +} + +// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) { + return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast(128)); +} + +// Per-token source metadata for combine write-back +struct TokenSrcMetadata { + uint32_t rank_idx; + uint32_t token_idx; + uint32_t topk_idx; +}; + +struct Workspace { + void* base; + uint32_t num_ranks, num_experts; + uint32_t num_experts_per_rank; + uint32_t num_max_tokens_per_rank; + uint32_t num_max_recv_tokens_per_expert; + + // Pool capacity: all local experts share a contiguous token pool + uint32_t num_max_pool_tokens; + uint32_t num_max_pool_blocks; + + // For both grid barrier and NVLink barrier + static constexpr uint64_t kNumBarrierSignalBytes = 32; + + CUTLASS_HOST_DEVICE + Workspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + base(base), + num_ranks(num_ranks), num_experts(num_experts), + num_max_tokens_per_rank(num_max_tokens_per_rank) { + num_experts_per_rank = num_experts / num_ranks; + num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // Grid sync counters: `kNumBarrierSignalBytes` layout + // [ 0..15]: 4 x `uint32_t` grid sync counters + // [16..20]: `uint32_t` NVLink barrier counter + // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1) + static constexpr uint32_t kNumMaxGridSyncCounters = 4; + + template + CUTLASS_DEVICE + uint32_t* get_grid_sync_count_ptr() const { + DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds"); + return static_cast(base) + kIndex; + } + + CUTLASS_DEVICE + uint32_t* get_nvl_barrier_counter_ptr() const { + return static_cast(base) + kNumMaxGridSyncCounters; + } + + CUTLASS_DEVICE + int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const { + // NOTES: the signal is signed, as we may minus + return math::advance_ptr(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int)); + } + + CUTLASS_DEVICE + uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const { + return math::advance_ptr(base, kNumBarrierSignalBytes) + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts * 2) + expert_idx; + } + + CUTLASS_DEVICE + uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const { + const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank); + return reinterpret_cast(base) + pool_block_idx; + } + + CUTLASS_DEVICE + uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const { + // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned + const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u)); + return reinterpret_cast(base) + pool_block_idx; + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + expert_idx * (num_ranks * num_max_recv_tokens_per_expert) + + rank_idx * num_max_recv_tokens_per_expert + token_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast(get_src_token_topk_idx_ptr(num_experts_per_rank)); + return base + pool_token_idx; + } +}; + +struct Data { + uint32_t num_bytes; + bool require_tma_alignment; + void* base; + + CUTLASS_HOST_DEVICE + constexpr explicit Data( + const uint32_t& num_bytes, + const bool& require_tma_alignment = true, + void* base = nullptr) : + num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) { + DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment); + } + + template + CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const { + return static_cast(num_bytes); + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) { + base = ptr; + } +}; + +struct Buffer { + Data data_layout; + uint32_t num_ranks; + uint32_t num_max_tokens_per_rank; + + void* base; + + CUTLASS_HOST_DEVICE + Buffer(const Data& data_layout, + const uint32_t& num_ranks, + const uint32_t& max_num_tokens_per_rank, + void* base = nullptr) : + data_layout(data_layout), + num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank), + base(base) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes_per_rank() const { + return num_max_tokens_per_rank * data_layout.get_num_bytes(); + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + return get_num_bytes_per_rank() * num_ranks; + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + CUTLASS_HOST_DEVICE + Buffer get_rank_buffer(const uint32_t& rank_idx) const { + return { + data_layout, + 1, num_max_tokens_per_rank, + math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx) + }; + } + + CUTLASS_HOST_DEVICE + Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const { + DG_DEVICE_ASSERT(num_ranks == 1 or global); + return Data( + data_layout.num_bytes, + data_layout.require_tma_alignment, + math::advance_ptr(base, data_layout.get_num_bytes() * token_idx) + ); + } +}; + +} // namespace deep_gemm::layout diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7f11aabc912b82d616779e9999ecfd00d19c9b93 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace deep_gemm::layout { + +constexpr static uint32_t kNumMaxRanks = 72; + +template +struct SymBuffer { + int64_t base; + int64_t offsets[kNumMaxRanks]; + uint32_t rank_idx; + + DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks"); + + SymBuffer() = default; + + template + explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) { + const auto size = static_cast(c.size()); + base = c[rank_idx]; + for (uint32_t i = 0; i < kNumMaxRanks; ++ i) + offsets[i] = i < size ? (c[i] - base) : 0; + } + +#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__) + template + CUTLASS_DEVICE ptr_t get_base_ptr() const { + return reinterpret_cast(base); + } + + template + CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { + int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast(ptr); + return *reinterpret_cast(&mapped_ptr); + } +#endif +}; + +} // namespace deep_gemm::layout diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm100.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm100.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0c554f4cd65c253582294152c8e72e79ccd92a42 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm100.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace deep_gemm::mma::sm100 { + +/// Shared memory descriptor +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +CUTLASS_DEVICE +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +CUTLASS_DEVICE +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +/// UMMA descriptors +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size(); +} + +template +CUTLASS_DEVICE +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto layout_type = to_umma_layout_type(); + const auto num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id( + cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +} // namespace deep_gemm::mma::sm100 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm90.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm90.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2c061940deef5a25849173c6d052eed4f0d24130 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/mma/sm90.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace deep_gemm::mma::sm90 { + +/// MMA +template +struct FP8MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + template + CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +/// Shared memory descriptor +template +CUTLASS_DEVICE cute::GmmaDescriptor +make_smem_desc(PointerType smem_ptr, const int& layout_type, + const uint32_t& leading_byte_offset = 0, + const uint32_t& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +CUTLASS_DEVICE +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +} // namespace deep_gemm::mma::sm90 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c3e03bec73d858c77b9f393ee091a52b5bdd01ac --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Compatibility: 256 bits LD/ST instructions +#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000 +using longlong4_t = longlong4_32a; +#define make_longlong4_t make_longlong4_32a +#else +struct alignas(32) longlong4_t { long long x, y, z, w; }; +CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t( + const long long& x, const long long& y, const long long& z, const long long& w) { + return {x, y, z, w}; +} +#endif + +/// LD/ST matrix +// TODO: remove `struct` +struct SM90_U32x2_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +template +struct SM90_U32x2_STSM_N { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM100_U8x4_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src = *reinterpret_cast(&src_0); + asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src)); + } +}; + +template +struct SM100_U8x8_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +/// Shared memory +CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) { + // `size` must be 64-bit before PTX ISA 9.0 + asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" :: + "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast(num_bytes))); +} + +/// Global memory +CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +/// Atomics +CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) { + uint32_t ret; + asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) { + asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE int ld_acq_sys(const int* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +/// Predicated loads +CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) { + longlong4_t ret = make_longlong4_t(0, 0, 0, 0); + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " setp.ge.s32 p, %5, 0;\n\t" + " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t" + "}" + : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w) + : "l"(ptr), "r"(pred) + : "memory"); + return ret; +} + +/// Prefetch +CUTLASS_DEVICE void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh new file mode 100644 index 0000000000000000000000000000000000000000..528b3dd10318a5d7493ec976c560774013fd4af8 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh @@ -0,0 +1,168 @@ +#pragma once + +namespace deep_gemm::ptx { + +/// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +/// Tensor memory operations +CUTLASS_DEVICE void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +CUTLASS_DEVICE void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tma.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tma.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1530a3edc57a81e7067a4929f9088929848a8960 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/tma.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Tensor-map instructions +CUTLASS_DEVICE void tensor_map_release_gpu() { + asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory"); +} + +CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +/// TMA instructions +CUTLASS_DEVICE void mbarrier_arrive( + cutlass::arch::ClusterTransactionBarrier* ptr) { + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: + "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_arrive_and_set_tx( + cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) { + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: + "r"(num_bytes), "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_wait_and_flip_phase( + cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) { + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: + "r"(static_cast(__cvta_generic_to_shared(ptr))), + "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +CUTLASS_DEVICE void tma_load_1d( + const void* dst_ptr, const void* src_ptr, + cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr, + const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) { + // NOTES: normally, the loaded part will be evicted soon + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" :: + "r"(static_cast(__cvta_generic_to_shared(dst_ptr))), + "l"(src_ptr), + "r"(num_bytes), + "r"(static_cast(__cvta_generic_to_shared(mbarrier_ptr))), + "l"(hint) + : "memory"); +} + +CUTLASS_DEVICE void tma_store_1d( + const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) { + // NOTES: normally, the stored part will be used soon + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" :: + "l"(dst_ptr), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(num_bytes), + "l"(hint) + : "memory"); +} + +template +__forceinline__ __device__ void tma_store_wait() { + // NOTES: this function does not have `.read` + asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory"); +} + +CUTLASS_DEVICE +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier, + void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) { + const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/utils.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5c27166b79ce710bd9eb99354e19fe1e6342dbaa --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/utils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +CUTLASS_DEVICE uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +template +CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) { + DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, ""); + const auto send_int_values = reinterpret_cast(&ptr); + dtype_t recv_dtype; + auto recv_int_values = reinterpret_cast(&recv_dtype); + #pragma unroll + for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast(src_lane_idx)); + return recv_dtype; +} + +CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100 + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast(&b.x))); + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast(&b.y))); +#else + const auto [x, y] = __bfloat1622float2(b); + a.x += x, a.y += y; +#endif +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8912a15766790db8a6fe8ba5a132df61a4958e39 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +CUTLASS_DEVICE void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5cd50c66f6da20a3c3be1d94cbe59757408c7f7b --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh @@ -0,0 +1,300 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sched { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto candidate: {8u, 16u}) { + const auto usage = kIsMulticastOnA ? + candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for contiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = grouped_layout[group_idx]; + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + const uint32_t& shape_k, int* grouped_layout = nullptr) { + num_m_blocks = math::ceil_div(shape_m, BLOCK_M); + num_n_blocks = math::ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = grouped_layout[0]; + num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + // For swap A/B and psum layout only + CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const { + constexpr uint32_t UMMA_STEP_N = 16; + DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment"); + if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) + return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N); + return BLOCK_M; + } + + CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = math::ceil_div(static_cast(grouped_layout[current_group_idx]), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = math::align(current_psum_m, BLOCK_M); + current_psum_m = grouped_layout[current_group_idx]; + current_m_block_cumsum += num_m_blocks; + num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with block M + m_block_idx += last_psum_m / BLOCK_M; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto group_idx = grouped_layout[m_block_idx * BLOCK_M]; + const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M]; + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx]; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return m_offset + m_block_idx * BLOCK_M < current_psum_m; + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm::sched diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cdbecccd560398cc81747e685bddd2d4b3d0ebf0 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::sched { + +// Computation phase for the current block +enum class BlockPhase { + None = 0, + Linear1 = 1, + Linear2 = 2 +}; + +template +struct MegaMoEScheduler { + DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config"); + + // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster + // always land on the same m_block_idx with n_block_idx differing by 1 + DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + // Arrival counts + const layout::Workspace& workspace; + + // Scheduler state + BlockPhase next_phase = BlockPhase::Linear1; + + // Current expert and block indices + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t block_idx = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + + // Pre-cached per-expert token counts (filled during `for_each_block` init) + // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const { + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + // Get pool block offset for a given expert index from a per-lane token count array + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffff, num_blocks); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE bool fetch_next_l1_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + m_block_idx = block_idx / kNumL1BlockNs; + if (m_block_idx < num_m_blocks) + return true; + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL1BlockNs; + advance_expert_idx(); + } + return false; + } + + CUTLASS_DEVICE bool fetch_next_l2_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + if (block_idx < num_m_blocks * kNumL2BlockNs) { + m_block_idx = block_idx / kNumL2BlockNs; + return true; + } + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL2BlockNs; + advance_expert_idx(); + } + return false; + } + + // Core state machine: assigns the next block + CUTLASS_DEVICE cute::tuple get_next_block() { + while (true) { + if (current_local_expert_idx >= kNumExpertsPerRank) + break; + + if (next_phase == BlockPhase::Linear1) { + if (fetch_next_l1_block()) { + // Found a new L1 block + n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // L1 for the current wave is complete, transition to L2 + next_phase = BlockPhase::Linear2; + set_expert_idx(math::align(current_local_expert_idx - 1, kNumExpertsPerWave)); + } + } else { + if (fetch_next_l2_block()) { + // Found a new L2 block + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // Move to L1 of the next wave + next_phase = BlockPhase::Linear1; + } + } + } + + // All waves and experts are fully processed + return {BlockPhase::None, 0, 0, 0}; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + // NOTES: each lane caches experts at indices (i * 32 + lane_idx) + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + // Wait for all expert counters to be finalized + fetch_expert_recv_count(); + + // Initialize current expert with 0 + set_expert_idx(0); + + // Iterate over all blocks + // TODO: add swizzle within expert waves for better L2 cache utilization + while (true) { + CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx); + if (block_phase == BlockPhase::None) + break; + + func(block_phase, current_local_expert_idx, + block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs, + m_block_idx, n_block_idx); + } + } +}; + +} // namespace deep_gemm::sched diff --git a/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..548bbbc6ba59d8abb2c56698908ab0713c1f39cd --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::sched { + +template +CUTLASS_GLOBAL __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } + num_segs[k] = math::ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } +} + +// Conditional storage for varlen indices pointer (EBO: zero cost when unused) +template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template +struct PagedMQALogitsScheduler : IndicesStorage { + const uint32_t* context_lens; + uint32_t batch_size; + + uint32_t current_q_atom_idx, current_kv_idx; + uint32_t end_q_atom_idx, end_kv_idx; + uint32_t current_num_kv; + + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } + } + + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { + this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } + + const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; + const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; + current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_atom_idx); + } + + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_atom_idx = current_q_atom_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + current_kv_idx = 0; + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { + current_num_kv = get_num_kv(current_q_atom_idx); + } + } + return true; + } + + CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const { + return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx); + } +}; + +} // namespace deep_gemm::sched